{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}
module Flat.Decoder.Strict
  ( decodeArrayWith
  , decodeListWith
  , dByteString
  , dLazyByteString
  , dShortByteString
  , dShortByteString_
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
  , dUTF16
#endif
  , dUTF8
  , dInteger
  , dNatural
  , dChar
  , dWord8
  , dWord16
  , dWord32
  , dWord64
  , dWord
  , dInt8
  , dInt16
  , dInt32
  , dInt64
  , dInt
  ) where
import           Data.Bits
import qualified Data.ByteString                as B
import qualified Data.ByteString.Lazy           as L
import qualified Data.ByteString.Short          as SBS
import qualified Data.ByteString.Short.Internal as SBS
import qualified Data.DList                     as DL
import           Flat.Decoder.Prim
import           Flat.Decoder.Types
import           Data.Int
import           Data.Primitive.ByteArray
import qualified Data.Text                      as T
import qualified Data.Text.Encoding             as T
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
import qualified Data.Text.Array                as TA
import qualified Data.Text.Internal             as T
#endif
import           Data.Word
import           Data.ZigZag
import           GHC.Base                       (unsafeChr)
import           Numeric.Natural
#include "MachDeps.h"
{-# INLINE decodeListWith #-}
decodeListWith :: Get a -> Get [a]
decodeListWith dec = go
  where
    go = do
      b <- dBool
      if b
        then (:) <$> dec <*> go
        else return []
decodeArrayWith :: Get a -> Get [a]
decodeArrayWith dec = DL.toList <$> getAsL_ dec
getAsL_ :: Get a -> Get (DL.DList a)
getAsL_ dec = do
  tag <- dWord8
  case tag of
    0 -> return DL.empty
    _ -> do
      h <- gets tag
      t <- getAsL_ dec
      return (DL.append h t)
  where
    gets 0 = return DL.empty
    gets n = DL.cons <$> dec <*> gets (n - 1)
{-# INLINE dNatural #-}
dNatural :: Get Natural
dNatural = dUnsigned
{-# INLINE dInteger #-}
dInteger :: Get Integer
dInteger = zagZig <$> dUnsigned
{-# INLINE dWord #-}
{-# INLINE dInt #-}
dWord :: Get Word
dInt :: Get Int
#if WORD_SIZE_IN_BITS == 64
dWord = (fromIntegral :: Word64 -> Word) <$> dWord64
dInt = (fromIntegral :: Int64 -> Int) <$> dInt64
#elif WORD_SIZE_IN_BITS == 32
dWord = (fromIntegral :: Word32 -> Word) <$> dWord32
dInt = (fromIntegral :: Int32 -> Int) <$> dInt32
#else
#error expected WORD_SIZE_IN_BITS to be 32 or 64
#endif
{-# INLINE dInt8 #-}
dInt8 :: Get Int8
dInt8 = zagZig <$> dWord8
{-# INLINE dInt16 #-}
dInt16 :: Get Int16
dInt16 = zagZig <$> dWord16
{-# INLINE dInt32 #-}
dInt32 :: Get Int32
dInt32 = zagZig <$> dWord32
{-# INLINE dInt64 #-}
dInt64 :: Get Int64
dInt64 = zagZig <$> dWord64
dWord16 :: Get Word16
dWord16 = wordStep 0 (wordStep 7 (lastStep 14)) 0
dWord32 :: Get Word32
dWord32 = wordStep 0 (wordStep 7 (wordStep 14 (wordStep 21 (lastStep 28)))) 0
dWord64 :: Get Word64
dWord64 =
  wordStep
    0
    (wordStep
       7
       (wordStep
          14
          (wordStep
             21
             (wordStep
                28
                (wordStep
                   35
                   (wordStep
                      42
                      (wordStep
                         49
                         (wordStep 56 (wordStep 63 (wordStep 70 (lastStep 77)))))))))))
    0
{-# INLINE dChar #-}
dChar :: Get Char
dChar = charStep 0 (charStep 7 (lastCharStep 14)) 0
{-# INLINE charStep #-}
charStep :: Int -> (Int -> Get Char) -> Int -> Get Char
charStep !shl !cont !n = do
  !tw <- fromIntegral <$> dWord8
  let !w = tw .&. 127
  let !v = n .|. (w `shift` shl)
  if tw == w
    then return $ unsafeChr v
    else cont v
{-# INLINE lastCharStep #-}
lastCharStep :: Int -> Int -> Get Char
lastCharStep !shl !n = do
  !tw <- fromIntegral <$> dWord8
  let !w = tw .&. 127
  let !v = n .|. (w `shift` shl)
  if tw == w
    then if v > 0x10FFFF
           then charErr v
           else return $ unsafeChr v
    else charErr v
 where
  charErr v = fail $ concat ["Unexpected extra byte or non unicode char", show v]
{-# INLINE wordStep #-}
wordStep :: (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep shl k n = do
  tw <- fromIntegral <$> dWord8
  let w = tw .&. 127
  let v = n .|. (w `shift` shl)
  if tw == w
    then return v
    
    else k v
{-# INLINE lastStep #-}
lastStep :: (FiniteBits b, Show b, Num b) => Int -> b -> Get b
lastStep shl n = do
  tw <- fromIntegral <$> dWord8
  let w = tw .&. 127
  let v = n .|. (w `shift` shl)
  if tw == w
    then if countLeadingZeros w < shl
           then wordErr v
           else return v
    else wordErr v
 where
   wordErr v = fail $ concat ["Unexpected extra byte in unsigned integer", show v]
dUnsigned :: (Num b, Bits b) => Get b
dUnsigned = do
  (v, shl) <- dUnsigned_ 0 0
  maybe
    (return v)
    (\s ->
       if shl >= s
         then fail "Unexpected extra data in unsigned integer"
         else return v) $
    bitSizeMaybe v
dUnsigned_ :: (Bits t, Num t) => Int -> t -> Get (t, Int)
dUnsigned_ shl n = do
  tw <- dWord8
  let w = tw .&. 127
  let v = n .|. (fromIntegral w `shift` shl)
  if tw == w
    then return (v, shl)
    else dUnsigned_ (shl + 7) v
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
dUTF16 :: Get T.Text
dUTF16 = do
  _ <- dFiller
  
  
  
  (ByteArray array, lengthInBytes) <- dByteArray_
  return (T.Text (TA.Array array) 0 (lengthInBytes `div` 2))
#endif
dUTF8 :: Get T.Text
dUTF8 = do
  _ <- dFiller
  T.decodeUtf8 <$> dByteString_
dFiller :: Get ()
dFiller = do
  tag <- dBool
  case tag of
    False -> dFiller
    True  -> return ()
dLazyByteString :: Get L.ByteString
dLazyByteString = dFiller >> dLazyByteString_
dShortByteString :: Get SBS.ShortByteString
dShortByteString = dFiller >> dShortByteString_
dShortByteString_ :: Get SBS.ShortByteString
dShortByteString_ = do
  (ByteArray array, _) <- dByteArray_
  return $ SBS.SBS array
dByteString :: Get B.ByteString
dByteString = dFiller >> dByteString_