{- | Bytestring parsers.

Module dependency complications prevent us from placing these in
"FlatParse.Stateful.Base".
-}

module FlatParse.Stateful.Bytes
  ( bytes, bytesUnsafe
  ) where

import FlatParse.Stateful.Parser
import FlatParse.Stateful.Base ( withEnsure )
import FlatParse.Stateful.Integers ( word8Unsafe, word16Unsafe, word32Unsafe, word64Unsafe )
import qualified FlatParse.Common.Assorted as Common
import Language.Haskell.TH
import GHC.Exts

-- | Read a sequence of bytes. This is a template function, you can use it as
--   @$(bytes [3, 4, 5])@, for example, and the splice has type @Parser e
--   ()@. For a non-TH variant see 'FlatParse.Stateful.byteString'.
bytes :: [Word] -> Q Exp
bytes :: [Word] -> Q Exp
bytes [Word]
bs = do
  let !len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word]
bs
  [| withEnsure len $(bytesUnsafe bs) |]

-- | Template function, creates a @Parser e ()@ which unsafely parses a given
--   sequence of bytes.
--
-- The caller must guarantee that the input has enough bytes.
bytesUnsafe :: [Word] -> Q Exp
bytesUnsafe :: [Word] -> Q Exp
bytesUnsafe [Word]
bytes = do
  let !([Word]
leading, [Word]
w8s) = [Word] -> ([Word], [Word])
Common.splitBytes [Word]
bytes
      !scanw8s :: Q Exp
scanw8s        = forall {m :: * -> *} {t}. (Quote m, Lift t) => [t] -> m Exp
go [Word]
w8s where
                         go :: [t] -> m Exp
go (t
w8:[] ) = [| word64Unsafe w8 |]
                         go (t
w8:[t]
w8s) = [| word64Unsafe w8 >> $(go w8s) |]
                         go []       = [| pure () |]
  case [Word]
w8s of
    [] -> forall {m :: * -> *}. Quote m => [Word] -> m Exp
go [Word]
leading
          where
            go :: [Word] -> m Exp
go (Word
a:Word
b:Word
c:Word
d:[]) = let !w :: Word
w = [Word] -> Word
Common.packBytes [Word
a, Word
b, Word
c, Word
d] in [| word32Unsafe w |]
            go (Word
a:Word
b:Word
c:Word
d:[Word]
ws) = let !w :: Word
w = [Word] -> Word
Common.packBytes [Word
a, Word
b, Word
c, Word
d] in [| word32Unsafe w >> $(go ws) |]
            go (Word
a:Word
b:[])     = let !w :: Word
w = [Word] -> Word
Common.packBytes [Word
a, Word
b]       in [| word16Unsafe w |]
            go (Word
a:Word
b:[Word]
ws)     = let !w :: Word
w = [Word] -> Word
Common.packBytes [Word
a, Word
b]       in [| word16Unsafe w >> $(go ws) |]
            go (Word
a:[])       = [| word8Unsafe a |]
            go []           = [| pure () |]
    [Word]
_  -> case [Word]
leading of

      []              -> Q Exp
scanw8s
      [Word
a]             -> [| word8Unsafe a >> $scanw8s |]
      ws :: [Word]
ws@[Word
a, Word
b]       -> let !w :: Word
w = [Word] -> Word
Common.packBytes [Word]
ws in [| word16Unsafe w >> $scanw8s |]
      ws :: [Word]
ws@[Word
a, Word
b, Word
c, Word
d] -> let !w :: Word
w = [Word] -> Word
Common.packBytes [Word]
ws in [| word32Unsafe w >> $scanw8s |]
      [Word]
ws              -> let !w :: Word
w = [Word] -> Word
Common.packBytes [Word]
ws
                             !l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word]
ws
                         in [| scanPartial64# l w >> $scanw8s |]

scanPartial64# :: Int -> Word -> ParserT st r e ()
scanPartial64# :: forall (st :: ZeroBitType) r e. Int -> Word -> ParserT st r e ()
scanPartial64# (I# Int#
len) (W# Word#
w) = forall (st :: ZeroBitType) r e a.
(ForeignPtrContents
 -> r -> Addr# -> Addr# -> Int# -> st -> Res# st e a)
-> ParserT st r e a
ParserT \ForeignPtrContents
fp !r
r Addr#
eob Addr#
s Int#
n st
st ->
  case Addr# -> Int# -> Word#
indexWordOffAddr# Addr#
s Int#
0# of
    Word#
w' -> case Int# -> Int# -> Int#
uncheckedIShiftL# (Int#
8# Int# -> Int# -> Int#
-# Int#
len) Int#
3# of
      Int#
sh -> case Word# -> Int# -> Word#
uncheckedShiftL# Word#
w' Int#
sh of
        Word#
w' -> case Word# -> Int# -> Word#
uncheckedShiftRL# Word#
w' Int#
sh of
          Word#
w' -> case Word# -> Word# -> Int#
eqWord# Word#
w Word#
w' of
            Int#
1# -> forall (st :: ZeroBitType) a e.
st -> a -> Addr# -> Int# -> Res# st e a
OK#   st
st () (Addr# -> Int# -> Addr#
plusAddr# Addr#
s Int#
len) Int#
n
            Int#
_  -> forall (st :: ZeroBitType) e a. st -> Res# st e a
Fail# st
st