{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}

-- |Strict Decoder
module Flat.Decoder.Strict
  ( decodeArrayWith
  , decodeListWith
  , dByteString
  , dLazyByteString
  , dShortByteString
  , dShortByteString_
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
  , dUTF16
#endif
  , dUTF8
  , dInteger
  , dNatural
  , dChar
  , dWord8
  , dWord16
  , dWord32
  , dWord64
  , dWord
  , dInt8
  , dInt16
  , dInt32
  , dInt64
  , dInt
  ) where

import           Data.Bits
import qualified Data.ByteString                as B
import qualified Data.ByteString.Lazy           as L
import qualified Data.ByteString.Short          as SBS
import qualified Data.ByteString.Short.Internal as SBS
import qualified Data.DList                     as DL
import           Flat.Decoder.Prim
import           Flat.Decoder.Types
import           Data.Int
import           Data.Primitive.ByteArray
import qualified Data.Text                      as T
import qualified Data.Text.Encoding             as T

#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
import qualified Data.Text.Array                as TA
import qualified Data.Text.Internal             as T
#endif

import           Data.Word
import           Data.ZigZag
import           GHC.Base                       (unsafeChr)
import           Numeric.Natural

#include "MachDeps.h"

{-# INLINE decodeListWith #-}
decodeListWith :: Get a -> Get [a]
decodeListWith dec = go
  where
    go = do
      b <- dBool
      if b
        then (:) <$> dec <*> go
        else return []

decodeArrayWith :: Get a -> Get [a]
decodeArrayWith dec = DL.toList <$> getAsL_ dec

-- TODO: test if it would it be faster with DList.unfoldr :: (b -> Maybe (a, b)) -> b -> Data.DList.DList a
--  getAsL_ :: Flat a => Get (DL.DList a)
getAsL_ :: Get a -> Get (DL.DList a)
getAsL_ dec = do
  tag <- dWord8
  case tag of
    0 -> return DL.empty
    _ -> do
      h <- gets tag
      t <- getAsL_ dec
      return (DL.append h t)
  where
    gets 0 = return DL.empty
    gets n = DL.cons <$> dec <*> gets (n - 1)

{-# INLINE dNatural #-}
dNatural :: Get Natural
dNatural = dUnsigned

{-# INLINE dInteger #-}
dInteger :: Get Integer
dInteger = zagZig <$> dUnsigned

{-# INLINE dWord #-}
{-# INLINE dInt #-}
dWord :: Get Word
dInt :: Get Int
#if WORD_SIZE_IN_BITS == 64
dWord = (fromIntegral :: Word64 -> Word) <$> dWord64

dInt = (fromIntegral :: Int64 -> Int) <$> dInt64
#elif WORD_SIZE_IN_BITS == 32
dWord = (fromIntegral :: Word32 -> Word) <$> dWord32

dInt = (fromIntegral :: Int32 -> Int) <$> dInt32
#else
#error expected WORD_SIZE_IN_BITS to be 32 or 64
#endif

{-# INLINE dInt8 #-}
dInt8 :: Get Int8
dInt8 = zagZig <$> dWord8

{-# INLINE dInt16 #-}
dInt16 :: Get Int16
dInt16 = zagZig <$> dWord16

{-# INLINE dInt32 #-}
dInt32 :: Get Int32
dInt32 = zagZig <$> dWord32

{-# INLINE dInt64 #-}
dInt64 :: Get Int64
dInt64 = zagZig <$> dWord64

-- {-# INLINE dWord16  #-}
dWord16 :: Get Word16
dWord16 = wordStep 0 (wordStep 7 (lastStep 14)) 0

-- {-# INLINE dWord32  #-}
dWord32 :: Get Word32
dWord32 = wordStep 0 (wordStep 7 (wordStep 14 (wordStep 21 (lastStep 28)))) 0

-- {-# INLINE dWord64  #-}
dWord64 :: Get Word64
dWord64 =
  wordStep
    0
    (wordStep
       7
       (wordStep
          14
          (wordStep
             21
             (wordStep
                28
                (wordStep
                   35
                   (wordStep
                      42
                      (wordStep
                         49
                         (wordStep 56 (wordStep 63 (wordStep 70 (lastStep 77)))))))))))
    0

{-# INLINE dChar #-}
dChar :: Get Char
-- dChar = chr . fromIntegral <$> dWord32
-- Not really faster than the simpler version above
dChar = charStep 0 (charStep 7 (lastCharStep 14)) 0

{-# INLINE charStep #-}
charStep :: Int -> (Int -> Get Char) -> Int -> Get Char
charStep !shl !cont !n = do
  !tw <- fromIntegral <$> dWord8
  let !w = tw .&. 127
  let !v = n .|. (w `shift` shl)
  if tw == w
    then return $ unsafeChr v
    else cont v

{-# INLINE lastCharStep #-}
lastCharStep :: Int -> Int -> Get Char
lastCharStep !shl !n = do
  !tw <- fromIntegral <$> dWord8
  let !w = tw .&. 127
  let !v = n .|. (w `shift` shl)
  if tw == w
    then if v > 0x10FFFF
           then charErr v
           else return $ unsafeChr v
    else charErr v
 where
  charErr v = fail $ concat ["Unexpected extra byte or non unicode char", show v]

{-# INLINE wordStep #-}
wordStep :: (Bits a, Num a) => Int -> (a -> Get a) -> a -> Get a
wordStep shl k n = do
  tw <- fromIntegral <$> dWord8
  let w = tw .&. 127
  let v = n .|. (w `shift` shl)
  if tw == w
    then return v
    --else oneShot k v
    else k v

{-# INLINE lastStep #-}
lastStep :: (FiniteBits b, Show b, Num b) => Int -> b -> Get b
lastStep shl n = do
  tw <- fromIntegral <$> dWord8
  let w = tw .&. 127
  let v = n .|. (w `shift` shl)
  if tw == w
    then if countLeadingZeros w < shl
           then wordErr v
           else return v
    else wordErr v
 where
   wordErr v = fail $ concat ["Unexpected extra byte in unsigned integer", show v]

-- {-# INLINE dUnsigned #-}
dUnsigned :: (Num b, Bits b) => Get b
dUnsigned = do
  (v, shl) <- dUnsigned_ 0 0
  maybe
    (return v)
    (\s ->
       if shl >= s
         then fail "Unexpected extra data in unsigned integer"
         else return v) $
    bitSizeMaybe v

-- {-# INLINE dUnsigned_ #-}
dUnsigned_ :: (Bits t, Num t) => Int -> t -> Get (t, Int)
dUnsigned_ shl n = do
  tw <- dWord8
  let w = tw .&. 127
  let v = n .|. (fromIntegral w `shift` shl)
  if tw == w
    then return (v, shl)
    else dUnsigned_ (shl + 7) v
--encode = encode . blob UTF8Encoding . L.fromStrict . T.encodeUtf8
--decode = T.decodeUtf8 . L.toStrict . (unblob :: BLOB UTF8Encoding -> L.ByteString) <$> decode
#if! defined(ghcjs_HOST_OS) && ! defined (ETA_VERSION)
-- BLOB UTF16Encoding
dUTF16 :: Get T.Text
dUTF16 = do
  _ <- dFiller
  -- Checked decoding
  -- T.decodeUtf16LE <$> dByteString_
  -- Unchecked decoding
  (ByteArray array, lengthInBytes) <- dByteArray_
  return (T.Text (TA.Array array) 0 (lengthInBytes `div` 2))
#endif
dUTF8 :: Get T.Text
dUTF8 = do
  _ <- dFiller
  T.decodeUtf8 <$> dByteString_

dFiller :: Get ()
dFiller = do
  tag <- dBool
  case tag of
    False -> dFiller
    True  -> return ()

dLazyByteString :: Get L.ByteString
dLazyByteString = dFiller >> dLazyByteString_

dShortByteString :: Get SBS.ShortByteString
dShortByteString = dFiller >> dShortByteString_

dShortByteString_ :: Get SBS.ShortByteString
dShortByteString_ = do
  (ByteArray array, _) <- dByteArray_
  return $ SBS.SBS array

dByteString :: Get B.ByteString
dByteString = dFiller >> dByteString_