{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE RankNTypes                 #-}
module Network.SSH.Builder where

import           Control.Monad                  ( void )
import           Data.Bits
import qualified Data.ByteString               as BS
import qualified Data.ByteString.Short         as SBS
import qualified Data.ByteString.Short.Internal
                                               as SBS
import qualified Data.ByteArray                as BA
import           Foreign.Ptr
import           Foreign.Storable
import           Data.Memory.PtrMethods
import           Data.Word
import           Data.Semigroup
import           Data.List.NonEmpty             ( NonEmpty((:|)) )
import           Prelude                 hiding ( length )

class Monoid a => Builder a where
    word8      :: Word8 -> a
    word16BE   :: Word16 -> a
    word16BE x  = word8 (fromIntegral $ x `unsafeShiftR` 8)
               <> word8 (fromIntegral   x)
    word32BE   :: Word32 -> a
    word32BE x  = word8 (fromIntegral $ x `unsafeShiftR` 24)
               <> word8 (fromIntegral $ x `unsafeShiftR` 16)
               <> word8 (fromIntegral $ x `unsafeShiftR`  8)
               <> word8 (fromIntegral   x)
    word64BE   :: Word64 -> a
    word64BE x  = word8 (fromIntegral $ x `unsafeShiftR` 56)
               <> word8 (fromIntegral $ x `unsafeShiftR` 48)
               <> word8 (fromIntegral $ x `unsafeShiftR` 40)
               <> word8 (fromIntegral $ x `unsafeShiftR` 32)
               <> word8 (fromIntegral $ x `unsafeShiftR` 24)
               <> word8 (fromIntegral $ x `unsafeShiftR` 16)
               <> word8 (fromIntegral $ x `unsafeShiftR`  8)
               <> word8 (fromIntegral   x)
    byteArray :: forall ba. BA.ByteArrayAccess ba => ba -> a
    byteArray x =
        foldl (\acc i-> acc <> word8 (BA.index x i)) mempty [0.. BA.length x - 1]
    byteString :: BS.ByteString -> a
    byteString = byteArray
    shortByteString :: SBS.ShortByteString -> a
    shortByteString x =
        foldl (\acc i-> acc <> word8 (SBS.index x i)) mempty [0.. SBS.length x - 1]
    zeroes     :: Int -> a
    zeroes i    = mconcat $ fmap (const $ word8 0) [1..i]
    {-# MINIMAL word8 #-}

newtype Length = Length { length :: Int }
    deriving (Eq, Ord, Show, Num)

newtype PtrWriter = PtrWriter { runPtrWriter :: Ptr Word8 -> IO (Ptr Word8) }

instance Semigroup Length where
    Length i <> Length j = Length (i + j)
    sconcat (Length i :| is) = Length (f is i)
        where
            f [] acc = acc
            f (Length j:js) acc = f js $! acc + j

instance Monoid Length where
    mempty = 0
    mconcat is = Length (f is 0)
        where
            f [] acc = acc
            f (Length j:js) acc = f js $! acc + j

instance Builder Length where
    word8           = const 1
    word16BE        = const 2
    word32BE        = const 4
    word64BE        = const 8
    byteArray       = Length. BA.length
    byteString      = Length . BS.length
    shortByteString = Length . SBS.length
    zeroes          = Length
{-# SPECIALIZE word8           :: Word8           -> Length #-}
{-# SPECIALIZE word16BE        :: Word16          -> Length #-}
{-# SPECIALIZE word32BE        :: Word32          -> Length #-}
{-# SPECIALIZE word64BE        :: Word64          -> Length #-}
{-# SPECIALIZE byteArray       :: byteArray       -> Length #-}
{-# SPECIALIZE byteString      :: byteString      -> Length #-}
{-# SPECIALIZE shortByteString :: shortByteString -> Length #-}

instance Semigroup PtrWriter where
    PtrWriter f <> PtrWriter g = PtrWriter $ \ptr -> f ptr >>= g

instance Monoid PtrWriter where
    mempty = PtrWriter pure

instance Builder PtrWriter where
    word8 x = PtrWriter $ \ptr -> do
        poke ptr x
        pure (plusPtr ptr 1)
    word32BE x = PtrWriter $ \ptr -> do
        pokeByteOff ptr 0 (fromIntegral $ x `unsafeShiftR` 24 :: Word8)
        pokeByteOff ptr 1 (fromIntegral $ x `unsafeShiftR` 16 :: Word8)
        pokeByteOff ptr 2 (fromIntegral $ x `unsafeShiftR`  8 :: Word8)
        pokeByteOff ptr 3 (fromIntegral   x                   :: Word8)
        pure (plusPtr ptr 4)
    word64BE x = PtrWriter $ \ptr -> do
        pokeByteOff ptr 0 (fromIntegral $ x `unsafeShiftR` 56 :: Word8)
        pokeByteOff ptr 1 (fromIntegral $ x `unsafeShiftR` 48 :: Word8)
        pokeByteOff ptr 2 (fromIntegral $ x `unsafeShiftR` 40 :: Word8)
        pokeByteOff ptr 3 (fromIntegral $ x `unsafeShiftR` 32 :: Word8)
        pokeByteOff ptr 4 (fromIntegral $ x `unsafeShiftR` 24 :: Word8)
        pokeByteOff ptr 5 (fromIntegral $ x `unsafeShiftR` 16 :: Word8)
        pokeByteOff ptr 6 (fromIntegral $ x `unsafeShiftR`  8 :: Word8)
        pokeByteOff ptr 7 (fromIntegral   x                   :: Word8)
        pure (plusPtr ptr 8)
    byteArray x = PtrWriter $ \ptr -> do
        BA.copyByteArrayToPtr x ptr
        pure (plusPtr ptr $ BA.length x)
    byteString x = PtrWriter $ \ptr -> do
        BA.copyByteArrayToPtr x ptr
        pure (plusPtr ptr $ BA.length x)
    shortByteString x = PtrWriter $ \ptr -> do
        let l = SBS.length x
        SBS.copyToPtr x 0 ptr l
        pure (plusPtr ptr l)
    zeroes n = PtrWriter $ \ptr -> do
        memSet ptr 0 n
        pure (plusPtr ptr n)
{-# SPECIALIZE word8           :: Word8                 -> PtrWriter #-}
{-# SPECIALIZE word16BE        :: Word16                -> PtrWriter #-}
{-# SPECIALIZE word32BE        :: Word32                -> PtrWriter #-}
{-# SPECIALIZE word64BE        :: Word64                -> PtrWriter #-}
{-# SPECIALIZE byteArray       :: BA.ByteArray ba => ba -> PtrWriter #-}
{-# SPECIALIZE byteString      :: BS.ByteString         -> PtrWriter #-}
{-# SPECIALIZE shortByteString :: SBS.ShortByteString   -> PtrWriter #-}
{-# SPECIALIZE zeroes          :: Int                   -> PtrWriter #-}

data ByteArrayBuilder = ByteArrayBuilder Int PtrWriter

instance Semigroup ByteArrayBuilder where
    ByteArrayBuilder c0 w0 <> ByteArrayBuilder c1 w1 =
        c `seq` w `seq` ByteArrayBuilder c w
        where
            c = c0 + c1
            w = w0 <> w1

instance Monoid ByteArrayBuilder where
    mempty = ByteArrayBuilder 0 mempty

instance Builder ByteArrayBuilder where
    word8           x = ByteArrayBuilder                  1  (word8           x)
    word16BE        x = ByteArrayBuilder                  2  (word16BE        x)
    word32BE        x = ByteArrayBuilder                  4  (word32BE        x)
    word64BE        x = ByteArrayBuilder                  8  (word64BE        x)
    byteArray       x = ByteArrayBuilder (BA.length       x) (byteArray       x)
    byteString      x = ByteArrayBuilder (BS.length       x) (byteString      x)
    shortByteString x = ByteArrayBuilder (SBS.length      x) (shortByteString x)
    zeroes          n = ByteArrayBuilder                  n  (zeroes          n)

toByteArray :: BA.ByteArray ba => ByteArrayBuilder -> ba
toByteArray (ByteArrayBuilder n w) =
    BA.allocAndFreeze n $ void . runPtrWriter w

copyToPtr :: ByteArrayBuilder -> Ptr Word8 -> IO ()
copyToPtr (ByteArrayBuilder _ b) = void . runPtrWriter b

babLength :: ByteArrayBuilder -> Int
babLength (ByteArrayBuilder n _) = n