{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Flat.Decoder.Prim (
dBool,
dWord8,
dBE8,
dBE16,
dBE32,
dBE64,
dBEBits8,
dBEBits16,
dBEBits32,
dBEBits64,
dropBits,
dFloat,
dDouble,
getChunksInfo,
dByteString_,
dLazyByteString_,
dByteArray_,
) where
import Control.Monad
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.Flat.Decoder.Types
import Data.Flat.Memory
import Data.FloatCast
import Data.Word
import Foreign
import System.Endian
{-# INLINE ensureBits #-}
ensureBits :: Ptr Word8 -> S -> Int -> IO ()
ensureBits endPtr s n = when ((endPtr `minusPtr` currPtr s) * 8 - usedBits s < n) $ notEnoughSpace endPtr s
{-# INLINE dropBits #-}
dropBits :: Int -> Get ()
dropBits n
| n > 0 = Get $ \endPtr s -> do
ensureBits endPtr s n
return $ GetResult (dropBits_ s n) ()
| n == 0 = return ()
| otherwise = error $ unwords ["dropBits",show n]
dropBits_ :: S -> Int -> S
dropBits_ s n = let (bytes,bits) = (n+usedBits s) `divMod` 8
in S {currPtr=currPtr s `plusPtr` bytes,usedBits=bits}
{-# INLINE dBool #-}
dBool :: Get Bool
dBool = Get $ \endPtr s ->
if currPtr s >= endPtr
then notEnoughSpace endPtr s
else do
!w <- peek (currPtr s)
let !b = 0 /= (w .&. (128 `shiftR` usedBits s))
let !s' = if usedBits s == 7
then s { currPtr = currPtr s `plusPtr` 1, usedBits = 0 }
else s { usedBits = usedBits s + 1 }
return $ GetResult s' b
{-# INLINE dBEBits8 #-}
dBEBits8 :: Int -> Get Word8
dBEBits8 n = Get $ \endPtr s -> do
ensureBits endPtr s n
take8 s n
{-# INLINE dBEBits16 #-}
dBEBits16 :: Int -> Get Word16
dBEBits16 n = Get $ \endPtr s -> do
ensureBits endPtr s n
takeN n s
{-# INLINE dBEBits32 #-}
dBEBits32 :: Int -> Get Word32
dBEBits32 n = Get $ \endPtr s -> do
ensureBits endPtr s n
takeN n s
{-# INLINE dBEBits64 #-}
dBEBits64 :: Int -> Get Word64
dBEBits64 n = Get $ \endPtr s -> do
ensureBits endPtr s n
takeN n s
{-# INLINE take8 #-}
take8 :: S -> Int -> IO (GetResult Word8)
take8 s n = GetResult (dropBits_ s n) <$> read8 s n
{-# INLINE read8 #-}
read8 :: S -> Int -> IO Word8
read8 s n | n >=0 && n <=8 =
if n <= 8 - usedBits s
then do
w <- peek (currPtr s)
return $ (w `unsafeShiftL` usedBits s) `unsafeShiftR` (8 - n)
else do
w::Word16 <- toBE16 <$> peek (castPtr $ currPtr s)
return $ fromIntegral $ (w `unsafeShiftL` usedBits s) `unsafeShiftR` (16 - n)
| otherwise = error $ unwords ["read8: cannot read",show n,"bits"]
{-# INLINE takeN #-}
takeN :: (Num a, Bits a) => Int -> S -> IO (GetResult a)
takeN n s = read s 0 (n - (n `min` 8)) n
where
read s r sh n | n <=0 = return $ GetResult s r
| otherwise = do
let m = n `min` 8
GetResult s' b <- take8 s m
read s' (r .|. (fromIntegral b `unsafeShiftL` sh)) ((sh-8) `max` 0) (n-8)
dWord8 :: Get Word8
dWord8 = dBE8
{-# INLINE dBE8 #-}
dBE8 :: Get Word8
dBE8 = Get $ \endPtr s -> do
ensureBits endPtr s 8
!w1 <- peek (currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!w2 <- peek (currPtr s `plusPtr` 1)
return $ (w1 `unsafeShiftL` usedBits s) .|. (w2 `unsafeShiftR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 1}) w
{-# INLINE dBE16 #-}
dBE16 :: Get Word16
dBE16 = Get $ \endPtr s -> do
ensureBits endPtr s 16
!w1 <- toBE16 <$> peek (castPtr $ currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!(w2::Word8) <- peek (currPtr s `plusPtr` 2)
return $ w1 `unsafeShiftL` usedBits s .|. fromIntegral (w2 `unsafeShiftR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 2}) w
{-# INLINE dBE32 #-}
dBE32 :: Get Word32
dBE32 = Get $ \endPtr s -> do
ensureBits endPtr s 32
!w1 <- toBE32 <$> peek (castPtr $ currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!(w2::Word8) <- peek (currPtr s `plusPtr` 4)
return $ w1 `unsafeShiftL` usedBits s .|. fromIntegral (w2 `unsafeShiftR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 4}) w
{-# INLINE dBE64 #-}
dBE64 :: Get Word64
dBE64 = Get $ \endPtr s -> do
ensureBits endPtr s 64
!w1 <- toBE64 <$> peek (castPtr $ currPtr s)
!w <- if usedBits s == 0
then return w1
else do
!(w2::Word8) <- peek (currPtr s `plusPtr` 8)
return $ w1 `unsafeShiftL` usedBits s .|. fromIntegral (w2 `unsafeShiftR` (8-usedBits s))
return $ GetResult (s {currPtr=currPtr s `plusPtr` 8}) w
{-# INLINE dFloat #-}
dFloat :: Get Float
dFloat = wordToFloat <$> dBE32
{-# INLINE dDouble #-}
dDouble :: Get Double
dDouble = wordToDouble <$> dBE64
dLazyByteString_ :: Get L.ByteString
dLazyByteString_ = L.fromStrict <$> dByteString_
dByteString_ :: Get B.ByteString
dByteString_ = chunksToByteString <$> getChunksInfo
dByteArray_ :: Get (ByteArray,Int)
dByteArray_ = chunksToByteArray <$> getChunksInfo
getChunksInfo :: Get (Ptr Word8, [Int])
getChunksInfo = Get $ \endPtr s -> do
let getChunks srcPtr l = do
ensureBits endPtr s 8
n <- fromIntegral <$> peek srcPtr
if n==0
then return (srcPtr `plusPtr` 1,l [])
else do
ensureBits endPtr s ((n+1)*8)
getChunks (srcPtr `plusPtr` (n+1)) (l . (n:))
when (usedBits s /=0) $ badEncoding endPtr s "usedBits /= 0"
(currPtr',ns) <- getChunks (currPtr s) id
return $ GetResult (s {currPtr=currPtr'}) (currPtr s `plusPtr` 1,ns)