{-# LANGUAGE BangPatterns              #-}
{-# LANGUAGE CPP                       #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RecordWildCards           #-}
{-# LANGUAGE ScopedTypeVariables       #-}

-- |Strict encoder
module Flat.Encoder.Strict where

import           Control.Monad        (when)
import qualified Data.ByteString      as B
import qualified Data.ByteString.Lazy as L
import           Data.Foldable
import           Flat.Encoder.Prim
import qualified Flat.Encoder.Size    as S
import           Flat.Encoder.Types
import           Flat.Memory
import           Flat.Types

-- import           Data.Semigroup
-- import           Data.Semigroup          (Semigroup (..))

#if !MIN_VERSION_base(4,11,0)
import           Data.Semigroup       (Semigroup (..))
#endif

#ifdef ETA_VERSION
-- import Data.Function(trampoline)
import           GHC.IO               (trampolineIO)
trampolineEncoding :: Encoding -> Encoding
trampolineEncoding (Encoding op) = Encoding (\s -> trampolineIO (op s))
#else

-- trampolineIO = id
#endif

-- |Strict encoder
strictEncoder :: NumBits -> Encoding -> B.ByteString
strictEncoder :: Int -> Encoding -> ByteString
strictEncoder Int
numBits Encoding
enc =
  let (ByteString
bs,Int
numBitsUsed) = Int -> Encoding -> (ByteString, Int)
strictEncoderPartial Int
numBits Encoding
enc
      bitsInLastByte :: Int
bitsInLastByte = Int
numBitsUsed forall a. Integral a => a -> a -> a
`mod` Int
8
  in if Int
bitsInLastByte forall a. Eq a => a -> a -> Bool
/=Int
0
      then forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unwords [[Char]
"encoder: did not end on byte boundary, bits used in last byte=",forall a. Show a => a -> [Char]
show  Int
bitsInLastByte]
      else ByteString
bs

numEncodedBits :: Int -> Encoding -> NumBits
numEncodedBits :: Int -> Encoding -> Int
numEncodedBits Int
numBits Encoding
enc =forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ Int -> Encoding -> (ByteString, Int)
strictEncoderPartial Int
numBits Encoding
enc

strictEncoderPartial ::
  Int                        -- ^ the maximum size in bits of the encoding
  -> Encoding                -- ^ the encoder
  -> (B.ByteString, NumBits) -- ^ the encoded bytestring + the actual number of encoded bits
strictEncoderPartial :: Int -> Encoding -> (ByteString, Int)
strictEncoderPartial Int
numBits (Encoding Prim
op)
  = let bufSize :: Int
bufSize = Int -> Int
S.bitsToBytes Int
numBits
    in forall a. Int -> (Ptr Word8 -> IO (Int, a)) -> (ByteString, a)
unsafeCreateUptoN' Int
bufSize forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
        S{Int
Word8
Ptr Word8
usedBits :: S -> Int
currByte :: S -> Word8
nextPtr :: S -> Ptr Word8
usedBits :: Int
currByte :: Word8
nextPtr :: Ptr Word8
..} <- Prim
op (Ptr Word8 -> Word8 -> Int -> S
S Ptr Word8
ptr Word8
0 Int
0)
        let numBitsUsed :: Int
numBitsUsed = Ptr Word8
nextPtr forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
ptr forall a. Num a => a -> a -> a
* Int
8 forall a. Num a => a -> a -> a
+ Int
usedBits
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
numBitsUsed forall a. Ord a => a -> a -> Bool
> Int
numBits) forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unwords [[Char]
"encoder: size mismatch, expected <=",forall a. Show a => a -> [Char]
show Int
numBits,[Char]
"actual=",forall a. Show a => a -> [Char]
show Int
numBitsUsed,[Char]
"bits"]
        forall (m :: * -> *) a. Monad m => a -> m a
return (Ptr Word8
nextPtr forall a b. Ptr a -> Ptr b -> Int
`minusPtr` Ptr Word8
ptr,Int
numBitsUsed)

newtype Encoding =
  Encoding
    { Encoding -> Prim
run :: Prim
    }

instance Show Encoding where
  show :: Encoding -> [Char]
show Encoding
_ = [Char]
"Encoding"

instance Semigroup Encoding where
  {-# INLINE (<>) #-}
  <> :: Encoding -> Encoding -> Encoding
(<>) = Encoding -> Encoding -> Encoding
encodingAppend

instance Monoid Encoding where
  {-# INLINE mempty #-}
  mempty :: Encoding
mempty = Prim -> Encoding
Encoding forall (m :: * -> *) a. Monad m => a -> m a
return

#if !(MIN_VERSION_base(4,11,0))
  {-# INLINE mappend #-}
  mappend = encodingAppend
#endif

  {-# INLINE mconcat #-}
  mconcat :: [Encoding] -> Encoding
mconcat = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Monoid a => a -> a -> a
mappend forall a. Monoid a => a
mempty

{-# INLINE encodingAppend #-}
encodingAppend :: Encoding -> Encoding -> Encoding
encodingAppend :: Encoding -> Encoding -> Encoding
encodingAppend (Encoding Prim
f) (Encoding Prim
g) = Prim -> Encoding
Encoding Prim
m
    where
      m :: Prim
m s :: S
s@(S !Ptr Word8
_ !Word8
_ !Int
_) = do
        !S
s1 <- Prim
f S
s
        Prim
g S
s1

-- PROB: GHC 8.02 won't always apply the rules leading to poor execution times (e.g. with lists)
-- TODO: check with newest GHC versions
{-# RULES
"encodersSN" forall h t . encodersS (h : t) =
             h `mappend` encodersS t
"encodersS0" encodersS [] = mempty
 #-}

{-# NOINLINE encodersS #-}
encodersS :: [Encoding] -> Encoding
-- Without the explicit parameter the rules won't fire!
encodersS :: [Encoding] -> Encoding
encodersS [Encoding]
ws = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall a. Monoid a => a -> a -> a
mappend forall a. Monoid a => a
mempty [Encoding]
ws

sizeListWith :: (Foldable t1, Num t2) => (t3 -> t2 -> t2) -> t1 t3 -> t2 -> t2
sizeListWith :: forall (t1 :: * -> *) t2 t3.
(Foldable t1, Num t2) =>
(t3 -> t2 -> t2) -> t1 t3 -> t2 -> t2
sizeListWith t3 -> t2 -> t2
size t1 t3
l t2
sz = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\t2
s t3
e -> t3 -> t2 -> t2
size t3
e (t2
s forall a. Num a => a -> a -> a
+ t2
1)) (t2
sz forall a. Num a => a -> a -> a
+ t2
1) t1 t3
l
{-# INLINE sizeListWith #-}

-- encodersS ws = error $ unwords ["encodersS CALLED",show ws]
{-# INLINE encodeListWith #-}
-- |Encode as a List
encodeListWith :: (t -> Encoding) -> [t] -> Encoding
encodeListWith :: forall t. (t -> Encoding) -> [t] -> Encoding
encodeListWith t -> Encoding
enc = [t] -> Encoding
go
  where
    go :: [t] -> Encoding
go []     = Encoding
eFalse
    go (t
x:[t]
xs) = Encoding
eTrue forall a. Semigroup a => a -> a -> a
<> t -> Encoding
enc t
x forall a. Semigroup a => a -> a -> a
<> [t] -> Encoding
go [t]
xs

-- {-# INLINE encodeList #-}
-- encodeList :: (Foldable t, Flat a) => t a -> Encoding
-- encodeList l = F.foldl' (\acc a -> acc <> eTrue <> encode a) mempty l <> eFalse
-- {-# INLINE encodeList2 #-}
-- encodeList2 :: (Foldable t, Flat a) => t a -> Encoding
-- encodeList2 l = foldr (\a acc -> eTrue <> encode a <> acc) mempty l <> eFalse
{-# INLINE encodeArrayWith #-}
-- |Encode as Array
encodeArrayWith :: (t -> Encoding) -> [t] -> Encoding
encodeArrayWith :: forall t. (t -> Encoding) -> [t] -> Encoding
encodeArrayWith t -> Encoding
_ [] = Word8 -> Encoding
eWord8 Word8
0
encodeArrayWith t -> Encoding
f [t]
ws = Prim -> Encoding
Encoding forall a b. (a -> b) -> a -> b
$ [t] -> Prim
go [t]
ws
  where
    go :: [t] -> Prim
go [t]
l S
s = do
      -- write a placeholder for the number of elements in current block
      S
s' <- Word8 -> Prim
eWord8F Word8
0 S
s
      (Word8
n, S
sn, [t]
l) <- forall {t}. (Eq t, Num t) => [t] -> t -> S -> IO (t, S, [t])
gol [t]
l Word8
0 S
s'
      -- update actual number of elements
      S
s'' <- Word8 -> S -> Prim
updateWord8 Word8
n S
s S
sn
      if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [t]
l
        then Word8 -> Prim
eWord8F Word8
0 S
s''
        else [t] -> Prim
go [t]
l S
s''
    -- encode up to 255 elements and returns (numberOfWrittenElements,elementsLeftToWrite,currentState)
    gol :: [t] -> t -> S -> IO (t, S, [t])
gol [] !t
n !S
s = forall (m :: * -> *) a. Monad m => a -> m a
return (t
n, S
s, [])
    gol l :: [t]
l@(t
x:[t]
xs) !t
n !S
s
      | t
n forall a. Eq a => a -> a -> Bool
== t
255 = forall (m :: * -> *) a. Monad m => a -> m a
return (t
255, S
s, [t]
l)
      | Bool
otherwise = Encoding -> Prim
run (t -> Encoding
f t
x) S
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [t] -> t -> S -> IO (t, S, [t])
gol [t]
xs (t
n forall a. Num a => a -> a -> a
+ t
1)

-- Encoding primitives
{-# INLINE eChar #-}
{-# INLINE eUTF8 #-}
{-# INLINE eNatural #-}
{-# INLINE eFloat #-}
{-# INLINE eDouble #-}
{-# INLINE eInteger #-}
{-# INLINE eInt64 #-}
{-# INLINE eInt32 #-}
{-# INLINE eInt16 #-}
{-# INLINE eInt8 #-}
{-# INLINE eInt #-}
{-# INLINE eWord64 #-}
{-# INLINE eWord32 #-}
{-# INLINE eWord16 #-}
{-# INLINE eWord8 #-}
{-# INLINE eWord #-}
{-# INLINE eBits #-}
{-# INLINE eFiller #-}
{-# INLINE eBool #-}
{-# INLINE eTrue #-}
{-# INLINE eFalse #-}
eChar :: Char -> Encoding
eChar :: Char -> Encoding
eChar = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Prim
eCharF

#if! defined (ETA_VERSION)
{-# INLINE eUTF16 #-}
eUTF16 :: Text -> Encoding
eUTF16 :: Text -> Encoding
eUTF16 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Prim
eUTF16F
#endif

eUTF8 :: Text -> Encoding
eUTF8 :: Text -> Encoding
eUTF8 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Prim
eUTF8F

eBytes :: B.ByteString -> Encoding
eBytes :: ByteString -> Encoding
eBytes = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Prim
eBytesF

eLazyBytes :: L.ByteString -> Encoding
eLazyBytes :: ByteString -> Encoding
eLazyBytes = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Prim
eLazyBytesF

eShortBytes :: ShortByteString -> Encoding
eShortBytes :: ShortByteString -> Encoding
eShortBytes = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShortByteString -> Prim
eShortBytesF

eNatural :: Natural -> Encoding
eNatural :: Natural -> Encoding
eNatural = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Prim
eNaturalF

eFloat :: Float -> Encoding
eFloat :: Float -> Encoding
eFloat = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Float -> Prim
eFloatF

eDouble :: Double -> Encoding
eDouble :: Double -> Encoding
eDouble = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Prim
eDoubleF

eInteger :: Integer -> Encoding
eInteger :: Integer -> Encoding
eInteger = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Prim
eIntegerF

eInt64 :: Int64 -> Encoding
eInt64 :: Int64 -> Encoding
eInt64 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Prim
eInt64F

eInt32 :: Int32 -> Encoding
eInt32 :: Int32 -> Encoding
eInt32 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> Prim
eInt32F

eInt16 :: Int16 -> Encoding
eInt16 :: Int16 -> Encoding
eInt16 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int16 -> Prim
eInt16F

eInt8 :: Int8 -> Encoding
eInt8 :: Int8 -> Encoding
eInt8 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int8 -> Prim
eInt8F

eInt :: Int -> Encoding
eInt :: Int -> Encoding
eInt = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Prim
eIntF

eWord64 :: Word64 -> Encoding
eWord64 :: Word64 -> Encoding
eWord64 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Prim
eWord64F

eWord32 :: Word32 -> Encoding
eWord32 :: Word32 -> Encoding
eWord32 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Prim
eWord32F

eWord16 :: Word16 -> Encoding
eWord16 :: Word16 -> Encoding
eWord16 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Prim
eWord16F

eWord8 :: Word8 -> Encoding
eWord8 :: Word8 -> Encoding
eWord8 = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Prim
eWord8F

eWord :: Word -> Encoding
eWord :: Word -> Encoding
eWord = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word -> Prim
eWordF

eBits16 :: NumBits -> Word16 -> Encoding
eBits16 :: Int -> Word16 -> Encoding
eBits16 Int
n Word16
f = Prim -> Encoding
Encoding forall a b. (a -> b) -> a -> b
$ Int -> Word16 -> Prim
eBits16F Int
n Word16
f

eBits :: NumBits -> Word8 -> Encoding
eBits :: Int -> Word8 -> Encoding
eBits Int
n Word8
f = Prim -> Encoding
Encoding forall a b. (a -> b) -> a -> b
$ Int -> Word8 -> Prim
eBitsF Int
n Word8
f

eFiller :: Encoding
eFiller :: Encoding
eFiller = Prim -> Encoding
Encoding Prim
eFillerF

eBool :: Bool -> Encoding
eBool :: Bool -> Encoding
eBool = Prim -> Encoding
Encoding forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Prim
eBoolF

eTrue :: Encoding
eTrue :: Encoding
eTrue = Prim -> Encoding
Encoding Prim
eTrueF

eFalse :: Encoding
eFalse :: Encoding
eFalse = Prim -> Encoding
Encoding Prim
eFalseF

-- Size Primitives
-- Variable size
{-# INLINE vsize #-}
vsize :: (t -> NumBits) -> t -> NumBits -> NumBits
vsize :: forall t. (t -> Int) -> t -> Int -> Int
vsize !t -> Int
f !t
t !Int
n = t -> Int
f t
t forall a. Num a => a -> a -> a
+ Int
n

-- Constant size
{-# INLINE csize #-}
csize :: NumBits -> t -> NumBits -> NumBits
csize :: forall t. Int -> t -> Int -> Int
csize !Int
n t
_ !Int
s = Int
n forall a. Num a => a -> a -> a
+ Int
s

sChar :: Size Char
sChar :: Size Char
sChar = forall t. (t -> Int) -> t -> Int -> Int
vsize Char -> Int
S.sChar

sInt64 :: Size Int64
sInt64 :: Size Int64
sInt64 = forall t. (t -> Int) -> t -> Int -> Int
vsize Int64 -> Int
S.sInt64

sInt32 :: Size Int32
sInt32 :: Size Int32
sInt32 = forall t. (t -> Int) -> t -> Int -> Int
vsize Int32 -> Int
S.sInt32

sInt16 :: Size Int16
sInt16 :: Size Int16
sInt16 = forall t. (t -> Int) -> t -> Int -> Int
vsize Int16 -> Int
S.sInt16

sInt8 :: Size Int8
sInt8 :: Size Int8
sInt8 = forall t. Int -> t -> Int -> Int
csize Int
S.sInt8

sInt :: Size Int
sInt :: Size Int
sInt = forall t. (t -> Int) -> t -> Int -> Int
vsize Int -> Int
S.sInt

sWord64 :: Size Word64
sWord64 :: Size Word64
sWord64 = forall t. (t -> Int) -> t -> Int -> Int
vsize Word64 -> Int
S.sWord64

sWord32 :: Size Word32
sWord32 :: Size Word32
sWord32 = forall t. (t -> Int) -> t -> Int -> Int
vsize Word32 -> Int
S.sWord32

sWord16 :: Size Word16
sWord16 :: Size Word16
sWord16 = forall t. (t -> Int) -> t -> Int -> Int
vsize Word16 -> Int
S.sWord16

sWord8 :: Size Word8
sWord8 :: Size Word8
sWord8 = forall t. Int -> t -> Int -> Int
csize Int
S.sWord8

sWord :: Size Word
sWord :: Size Word
sWord = forall t. (t -> Int) -> t -> Int -> Int
vsize Word -> Int
S.sWord

sFloat :: Size Float
sFloat :: Size Float
sFloat = forall t. Int -> t -> Int -> Int
csize Int
S.sFloat

sDouble :: Size Double
sDouble :: Size Double
sDouble = forall t. Int -> t -> Int -> Int
csize Int
S.sDouble

sBytes :: Size B.ByteString
sBytes :: Size ByteString
sBytes = forall t. (t -> Int) -> t -> Int -> Int
vsize ByteString -> Int
S.sBytes

sLazyBytes :: Size L.ByteString
sLazyBytes :: Size ByteString
sLazyBytes = forall t. (t -> Int) -> t -> Int -> Int
vsize ByteString -> Int
S.sLazyBytes

sShortBytes :: Size ShortByteString
sShortBytes :: Size ShortByteString
sShortBytes = forall t. (t -> Int) -> t -> Int -> Int
vsize ShortByteString -> Int
S.sShortBytes

sNatural :: Size Natural
sNatural :: Size Natural
sNatural = forall t. (t -> Int) -> t -> Int -> Int
vsize Natural -> Int
S.sNatural

sInteger :: Size Integer
sInteger :: Size Integer
sInteger = forall t. (t -> Int) -> t -> Int -> Int
vsize Integer -> Int
S.sInteger

sUTF8Max :: Size Text
sUTF8Max :: Size Text
sUTF8Max = forall t. (t -> Int) -> t -> Int -> Int
vsize Text -> Int
S.sUTF8Max

sUTF16 :: Size Text
sUTF16 :: Size Text
sUTF16 = forall t. (t -> Int) -> t -> Int -> Int
vsize Text -> Int
S.sUTF16Max

sFillerMax :: Size a
sFillerMax :: forall a. Size a
sFillerMax = forall t. Int -> t -> Int -> Int
csize Int
S.sFillerMax

sBool :: Size Bool
sBool :: Size Bool
sBool = forall t. Int -> t -> Int -> Int
csize Int
S.sBool