{-# LANGUAGE FlexibleInstances, ScopedTypeVariables, Rank2Types #-}
module General.Binary(
    BinaryOp(..), binaryOpMap,
    binarySplit, binarySplit2, binarySplit3, unsafeBinarySplit,
    Builder(..), runBuilder, sizeBuilder,
    BinaryEx(..),
    Storable, putExStorable, getExStorable, putExStorableList, getExStorableList,
    putExList, getExList, putExN, getExN
    ) where
import Development.Shake.Classes
import Control.Monad
import Data.Binary
import Data.List.Extra
import Data.Tuple.Extra
import Foreign.Storable
import Foreign.Ptr
import System.IO.Unsafe as U
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString.Unsafe as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.UTF8 as UTF8
import Data.Functor
import Data.Semigroup (Semigroup (..))
import Data.Monoid hiding ((<>))
import Prelude
data BinaryOp v = BinaryOp
    {putOp :: v -> Builder
    ,getOp :: BS.ByteString -> v
    }
binaryOpMap :: (Eq a, Hashable a, BinaryEx a) => (a -> BinaryOp b) -> BinaryOp (a, b)
binaryOpMap mp = BinaryOp
    {putOp = \(a, b) -> putExN (putEx a) <> putOp (mp a) b
    ,getOp = \bs -> let (bs1,bs2) = getExN bs; a = getEx bs1 in (a, getOp (mp a) bs2)
    }
binarySplit :: forall a . Storable a => BS.ByteString -> (a, BS.ByteString)
binarySplit bs | BS.length bs < sizeOf (undefined :: a) = error "Reading from ByteString, insufficient left"
               | otherwise = unsafeBinarySplit bs
binarySplit2 :: forall a b . (Storable a, Storable b) => BS.ByteString -> (a, b, BS.ByteString)
binarySplit2 bs | BS.length bs < sizeOf (undefined :: a) + sizeOf (undefined :: b) = error "Reading from ByteString, insufficient left"
                | (a,bs) <- unsafeBinarySplit bs, (b,bs) <- unsafeBinarySplit bs = (a,b,bs)
binarySplit3 :: forall a b c . (Storable a, Storable b, Storable c) => BS.ByteString -> (a, b, c, BS.ByteString)
binarySplit3 bs | BS.length bs < sizeOf (undefined :: a) + sizeOf (undefined :: b) + sizeOf (undefined :: c) = error "Reading from ByteString, insufficient left"
                | (a,bs) <- unsafeBinarySplit bs, (b,bs) <- unsafeBinarySplit bs, (c,bs) <- unsafeBinarySplit bs = (a,b,c,bs)
unsafeBinarySplit :: Storable a => BS.ByteString -> (a, BS.ByteString)
unsafeBinarySplit bs = (v, BS.unsafeDrop (sizeOf v) bs)
    where v = unsafePerformIO $ BS.unsafeUseAsCString bs $ \ptr -> peek (castPtr ptr)
for2M_ as bs f = zipWithM_ f as bs
data Builder = Builder {-# UNPACK #-} !Int (forall a . Ptr a -> Int -> IO ())
sizeBuilder :: Builder -> Int
sizeBuilder (Builder i _) = i
runBuilder :: Builder -> BS.ByteString
runBuilder (Builder i f) = unsafePerformIO $ BS.create i $ \ptr -> f ptr 0
instance Semigroup Builder where
    (Builder x1 x2) <> (Builder y1 y2) = Builder (x1+y1) $ \p i -> do x2 p i; y2 p $ i+x1
instance Monoid Builder where
    mempty = Builder 0 $ \_ _ -> return ()
    mappend = (<>)
class BinaryEx a where
    putEx :: a -> Builder
    getEx :: BS.ByteString -> a
instance BinaryEx BS.ByteString where
    putEx x = Builder n $ \ptr i -> BS.useAsCString x $ \bs -> BS.memcpy (ptr `plusPtr` i) (castPtr bs) (fromIntegral n)
        where n = BS.length x
    getEx = id
instance BinaryEx LBS.ByteString where
    putEx x = Builder (fromIntegral $ LBS.length x) $ \ptr i -> do
        let go _ [] = return ()
            go i (x:xs) = do
                let n = BS.length x
                BS.useAsCString x $ \bs -> BS.memcpy (ptr `plusPtr` i) (castPtr bs) (fromIntegral n)
                go (i+n) xs
        go i $ LBS.toChunks x
    getEx = LBS.fromChunks . return
instance BinaryEx [BS.ByteString] where
    
    
    
    
    putEx xs = Builder (4 + (n * 4) + sum ns) $ \p i -> do
        pokeByteOff p i (fromIntegral n :: Word32)
        for2M_ [4+i,8+i..] ns $ \i x -> pokeByteOff p i (fromIntegral x :: Word32)
        p <- return $ p `plusPtr` (i + 4 + (n * 4))
        for2M_ (scanl (+) 0 ns) xs $ \i x -> BS.useAsCStringLen x $ \(bs, n) ->
            BS.memcpy (p `plusPtr` i) (castPtr bs) (fromIntegral n)
        where ns = map BS.length xs
              n = length ns
    getEx bs = unsafePerformIO $ BS.useAsCString bs $ \p -> do
        n <- fromIntegral <$> (peekByteOff p 0 :: IO Word32)
        ns :: [Word32] <- forM [1..fromIntegral n] $ \i -> peekByteOff p (i * 4)
        return $ snd $ mapAccumL (\bs i -> swap $ BS.splitAt (fromIntegral i) bs) (BS.drop (4 + (n * 4)) bs) ns
instance BinaryEx () where
    putEx () = mempty
    getEx _ = ()
instance BinaryEx String where
    putEx = putEx . UTF8.fromString
    getEx = UTF8.toString
instance BinaryEx (Maybe String) where
    putEx Nothing = mempty
    putEx (Just xs) = putEx $ UTF8.fromString $ '\0' : xs
    getEx = fmap snd . uncons . UTF8.toString
instance BinaryEx [String] where
    putEx = putEx . map UTF8.fromString
    getEx = map UTF8.toString . getEx
instance BinaryEx (String, [String]) where
    putEx (a,bs) = putEx $ a:bs
    getEx x = let a:bs = getEx x in (a,bs)
instance BinaryEx Bool where
    putEx False = Builder 1 $ \ptr i -> pokeByteOff ptr i (0 :: Word8)
    putEx True = mempty
    getEx = BS.null
instance BinaryEx Word8 where
    putEx = putExStorable
    getEx = getExStorable
instance BinaryEx Word16 where
    putEx = putExStorable
    getEx = getExStorable
instance BinaryEx Word32 where
    putEx = putExStorable
    getEx = getExStorable
instance BinaryEx Int where
    putEx = putExStorable
    getEx = getExStorable
instance BinaryEx Float where
    putEx = putExStorable
    getEx = getExStorable
putExStorable :: forall a . Storable a => a -> Builder
putExStorable x = Builder (sizeOf x) $ \p i -> pokeByteOff p i x
getExStorable :: forall a . Storable a => BS.ByteString -> a
getExStorable = \bs -> unsafePerformIO $ BS.useAsCStringLen bs $ \(p, size) ->
        if size /= n then error "size mismatch" else peek (castPtr p)
    where n = sizeOf (undefined :: a)
putExStorableList :: forall a . Storable a => [a] -> Builder
putExStorableList xs = Builder (n * length xs) $ \ptr i ->
    for2M_ [i,i+n..] xs $ \i x -> pokeByteOff ptr i x
    where n = sizeOf (undefined :: a)
getExStorableList :: forall a . Storable a => BS.ByteString -> [a]
getExStorableList = \bs -> unsafePerformIO $ BS.useAsCStringLen bs $ \(p, size) ->
    let (d,m) = size `divMod` n in
    if m /= 0 then error "size mismatch" else forM [0..d-1] $ \i -> peekElemOff (castPtr p) i
    where n = sizeOf (undefined :: a)
putExList :: [Builder] -> Builder
putExList xs = Builder (sum $ map (\b -> sizeBuilder b + 4) xs) $ \p i -> do
    let go _ [] = return ()
        go i (Builder n b:xs) = do
            pokeByteOff p i (fromIntegral n :: Word32)
            b p (i+4)
            go (i+4+n) xs
    go i xs
getExList :: BS.ByteString -> [BS.ByteString]
getExList bs
    | len == 0 = []
    | len >= 4
    , (n :: Word32, bs) <- unsafeBinarySplit bs
    , n <- fromIntegral n
    , (len - 4) >= n
    = BS.unsafeTake n bs : getExList (BS.unsafeDrop n bs)
    | otherwise = error "getList, corrupted binary"
    where len = BS.length bs
putExN :: Builder -> Builder
putExN (Builder n old) = Builder (n+4) $ \p i -> do
    pokeByteOff p i (fromIntegral n :: Word32)
    old p $ i+4
getExN :: BS.ByteString -> (BS.ByteString, BS.ByteString)
getExN bs
    | len >= 4
    , (n :: Word32, bs) <- unsafeBinarySplit bs
    , n <- fromIntegral n
    , (len - 4) >= n
    = (BS.unsafeTake n bs, BS.unsafeDrop n bs)
    | otherwise = error "getList, corrupted binary"
    where len = BS.length bs