{-# LANGUAGE BangPatterns, CPP #-}
module Data.Binary.Bits.Get
(
BitGet
, runBitGet
, getBool
, getWord8
, getWord16be
, getWord32be
, getWord64be
, Block
, block
, bool
, word8
, word16be
, word32be
, word64be
, byteString
, Data.Binary.Bits.Get.getByteString
, Data.Binary.Bits.Get.getLazyByteString
, Data.Binary.Bits.Get.isEmpty
) where
import qualified Control.Monad.Fail as Fail
import Data.Binary.Get as B ( Get, getLazyByteString, isEmpty )
import Data.Binary.Get.Internal as B ( get, put, ensureN )
import Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe
import Data.Bits
import Data.Word
import Control.Applicative
import Prelude as P
data S = S {-# UNPACK #-} !ByteString
{-# UNPACK #-} !Int
deriving (Show)
data Block a = Block Int (S -> a)
instance Functor Block where
fmap f (Block i p) = Block i (\s -> f (p s))
instance Applicative Block where
pure a = Block 0 (\_ -> a)
(Block i p) <*> (Block j q) = Block (i+j) (\s -> p s $ q (incS i s))
(Block i _) *> (Block j q) = Block (i+j) (q . incS i)
(Block i p) <* (Block j _) = Block (i+j) p
block :: Block a -> BitGet a
block (Block i p) = do
ensureBits i
s <- getState
putState $! (incS i s)
return $! p s
incS :: Int -> S -> S
incS o (S bs n) =
let !o' = (n+o)
!d = o' `shiftR` 3
!n' = o' .&. make_mask 3
in S (unsafeDrop d bs) n'
make_mask :: (Bits a, Num a) => Int -> a
make_mask n = (1 `shiftL` fromIntegral n) - 1
{-# SPECIALIZE make_mask :: Int -> Int #-}
{-# SPECIALIZE make_mask :: Int -> Word #-}
{-# SPECIALIZE make_mask :: Int -> Word8 #-}
{-# SPECIALIZE make_mask :: Int -> Word16 #-}
{-# SPECIALIZE make_mask :: Int -> Word32 #-}
{-# SPECIALIZE make_mask :: Int -> Word64 #-}
bit_offset :: Int -> Int
bit_offset n = make_mask 3 .&. n
byte_offset :: Int -> Int
byte_offset n = n `shiftR` 3
readBool :: S -> Bool
readBool (S bs n) = testBit (unsafeHead bs) (7-n)
{-# INLINE readWord8 #-}
readWord8 :: Int -> S -> Word8
readWord8 n (S bs o)
| n == 0 = 0
| n <= 8 - o = let w = unsafeHead bs
m = make_mask n
w' = (w `shiftr_w8` (8 - o - n)) .&. m
in w'
| n <= 8 = let w = (fromIntegral (unsafeHead bs) `shiftl_w16` 8) .|.
(fromIntegral (unsafeIndex bs 1))
m = make_mask n
w' = (w `shiftr_w16` (16 - o - n)) .&. m
in fromIntegral w'
| otherwise = error "readWord8: tried to read more than 8 bits"
{-# INLINE readWord16be #-}
readWord16be :: Int -> S -> Word16
readWord16be n s@(S bs o)
| n <= 8 = fromIntegral (readWord8 n s)
| o == 0 && n == 16 = let msb = fromIntegral (unsafeHead bs)
lsb = fromIntegral (unsafeIndex bs 1)
w = (msb `shiftl_w16` 8) .|. lsb
in w
| o == 0 = let msb = fromIntegral (unsafeHead bs)
lsb = fromIntegral (unsafeIndex bs 1)
w = (msb `shiftl_w16` (n-8)) .|. (lsb `shiftr_w16` (16-n))
in w
| n <= 16 = readWithOffset s shiftl_w16 shiftr_w16 n
| otherwise = error "readWord16be: tried to read more than 16 bits"
{-# INLINE readWord32be #-}
readWord32be :: Int -> S -> Word32
readWord32be n s@(S _ o)
| n <= 8 = fromIntegral (readWord8 n s)
| n <= 16 = fromIntegral (readWord16be n s)
| o == 0 = readWithoutOffset s shiftl_w32 shiftr_w32 n
| n <= 32 = readWithOffset s shiftl_w32 shiftr_w32 n
| otherwise = error "readWord32be: tried to read more than 32 bits"
{-# INLINE readWord64be #-}
readWord64be :: Int -> S -> Word64
readWord64be n s@(S _ o)
| n <= 8 = fromIntegral (readWord8 n s)
| n <= 16 = fromIntegral (readWord16be n s)
| o == 0 = readWithoutOffset s shiftl_w64 shiftr_w64 n
| n <= 64 = readWithOffset s shiftl_w64 shiftr_w64 n
| otherwise = error "readWord64be: tried to read more than 64 bits"
readByteString :: Int -> S -> ByteString
readByteString n s@(S bs o)
| o == 0 = unsafeTake n bs
| otherwise = B.pack (P.map (readWord8 8) (P.take n (iterate (incS 8) s)))
readWithoutOffset :: (Bits a, Num a)
=> S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithoutOffset (S bs o) shifterL shifterR n
| o /= 0 = error "readWithoutOffset: there is an offset"
| bit_offset n == 0 && byte_offset n <= 4 =
let segs = byte_offset n
bn 0 = fromIntegral (unsafeHead bs)
bn x = (bn (x-1) `shifterL` 8) .|. fromIntegral (unsafeIndex bs x)
in bn (segs-1)
| n <= 64 = let segs = byte_offset n
o' = bit_offset (n - 8 + o)
bn 0 = fromIntegral (unsafeHead bs)
bn x = (bn (x-1) `shifterL` 8) .|. fromIntegral (unsafeIndex bs x)
msegs = bn (segs-1) `shifterL` o'
lst = (fromIntegral (unsafeIndex bs segs)) `shifterR` (8 - o')
w = msegs .|. lst
in w
| otherwise = error "readWithoutOffset: tried to read more than 64 bits"
readWithOffset :: (Bits a, Num a)
=> S -> (a -> Int -> a) -> (a -> Int -> a) -> Int -> a
readWithOffset (S bs o) shifterL shifterR n
| n <= 64 = let bits_in_msb = 8 - o
(n',top) = (n - bits_in_msb
, (fromIntegral (unsafeHead bs) .&. make_mask bits_in_msb) `shifterL` n')
segs = byte_offset n'
bn 0 = 0
bn x = (bn (x-1) `shifterL` 8) .|. fromIntegral (unsafeIndex bs x)
o' = bit_offset n'
mseg = bn segs `shifterL` o'
lst | o' > 0 = (fromIntegral (unsafeIndex bs (segs + 1))) `shifterR` (8 - o')
| otherwise = 0
w = top .|. mseg .|. lst
in w
| otherwise = error "readWithOffset: tried to read more than 64 bits"
newtype BitGet a = B { runState :: S -> Get (S,a) }
instance Monad BitGet where
return = pure
(B f) >>= g = B $ \s -> do (s',a) <- f s
runState (g a) s'
#if !MIN_VERSION_GLASGOW_HASKELL(8, 8, 1, 0)
fail = Fail.fail
#endif
instance Fail.MonadFail BitGet where
fail str = B $ \(S inp n) -> putBackState inp n >> fail str
instance Functor BitGet where
fmap f m = m >>= \a -> return (f a)
instance Applicative BitGet where
pure x = B $ \s -> return (s,x)
fm <*> m = fm >>= \f -> m >>= \v -> return (f v)
runBitGet :: BitGet a -> Get a
runBitGet bg = do
s <- mkInitState
((S str' n),a) <- runState bg s
putBackState str' n
return a
mkInitState :: Get S
mkInitState = do
str <- get
put B.empty
return (S str 0)
putBackState :: B.ByteString -> Int -> Get ()
putBackState bs n = do
remaining <- get
put (B.drop (if n==0 then 0 else 1) bs `B.append` remaining)
getState :: BitGet S
getState = B $ \s -> return (s,s)
putState :: S -> BitGet ()
putState s = B $ \_ -> return (s,())
ensureBits :: Int -> BitGet ()
ensureBits n = do
(S bs o) <- getState
if n <= (B.length bs * 8 - o)
then return ()
else do let currentBits = B.length bs * 8 - o
let byteCount = (n - currentBits + 7) `div` 8
B $ \_ -> do B.ensureN byteCount
bs' <- B.get
put B.empty
return (S (bs`append`bs') o, ())
getBool :: BitGet Bool
getBool = block bool
getWord8 :: Int -> BitGet Word8
getWord8 n = block (word8 n)
getWord16be :: Int -> BitGet Word16
getWord16be n = block (word16be n)
getWord32be :: Int -> BitGet Word32
getWord32be n = block (word32be n)
getWord64be :: Int -> BitGet Word64
getWord64be n = block (word64be n)
getByteString :: Int -> BitGet ByteString
getByteString n = block (byteString n)
getLazyByteString :: Int -> BitGet L.ByteString
getLazyByteString n = do
(S _ o) <- getState
case o of
0 -> B $ \ (S bs o') -> do
putBackState bs o'
lbs <- B.getLazyByteString (fromIntegral n)
return (S B.empty 0, lbs)
_ -> L.fromChunks . (:[]) <$> Data.Binary.Bits.Get.getByteString n
isEmpty :: BitGet Bool
isEmpty = B $ \ (S bs o) -> if B.null bs
then B.isEmpty >>= \e -> return (S bs o, e)
else return (S bs o, False)
bool :: Block Bool
bool = Block 1 readBool
word8 :: Int -> Block Word8
word8 n = Block n (readWord8 n)
word16be :: Int -> Block Word16
word16be n = Block n (readWord16be n)
word32be :: Int -> Block Word32
word32be n = Block n (readWord32be n)
word64be :: Int -> Block Word64
word64be n = Block n (readWord64be n)
byteString :: Int -> Block ByteString
byteString n | n > 0 = Block (n*8) (readByteString n)
| otherwise = Block 0 (\_ -> B.empty)
shiftl_w16 :: Word16 -> Int -> Word16
shiftl_w32 :: Word32 -> Int -> Word32
shiftl_w64 :: Word64 -> Int -> Word64
shiftr_w8 :: Word8 -> Int -> Word8
shiftr_w16 :: Word16 -> Int -> Word16
shiftr_w32 :: Word32 -> Int -> Word32
shiftr_w64 :: Word64 -> Int -> Word64
shiftl_w16 = unsafeShiftL
shiftl_w32 = unsafeShiftL
shiftl_w64 = unsafeShiftL
shiftr_w8 = unsafeShiftR
shiftr_w16 = unsafeShiftR
shiftr_w32 = unsafeShiftR
shiftr_w64 = unsafeShiftR