{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

module Data.LLVM.BitCode.GetBits (
    GetBits
  , runGetBits
  , fixed, align32bits
  , bytestring
  , label
  , isolate
  , try
  , skip
  ) where

import           Data.LLVM.BitCode.BitString

import           Control.Applicative ( Alternative(..) )
import           Control.Monad ( MonadPlus(..) )
import           Data.Bits ( shiftR, shiftL, (.&.), (.|.) )
import           Data.ByteString ( ByteString )
import qualified Data.ByteString as BS
import           GHC.Exts
import           GHC.Word

#if !MIN_VERSION_base(4,13,0)
import           Control.Monad.Fail ( MonadFail )
import qualified Control.Monad.Fail
#endif

-- Bit-level Parsing -----------------------------------------------------------

newtype GetBits a =
  GetBits { forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits :: BitPosition -> BS.ByteString
                      -> (# BitsGetter a, BitPosition #)
          }

type BitPosition = (# Int#, Int# #)  -- (# current bit pos, maximum bit pos #)

type BitsGetter a = Either String a -- Left is fail


-- | Run a @GetBits@ action, returning its value, and the number of bits offset
-- into the next byte of the stream.
runGetBits :: GetBits a -> ByteString -> Either String a
runGetBits :: forall a. GetBits a -> ByteString -> Either String a
runGetBits GetBits a
m ByteString
bs =
  let !startPos# :: BitPosition
startPos# = (# Int#
0#, NumBits -> Int#
bitCount# (NumBits -> Int#) -> NumBits -> Int#
forall a b. (a -> b) -> a -> b
$ NumBytes -> NumBits
bytesToBits (NumBytes -> NumBits) -> NumBytes -> NumBits
forall a b. (a -> b) -> a -> b
$ Int -> NumBytes
Bytes' (Int -> NumBytes) -> Int -> NumBytes
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
bs #)
      !(# BitsGetter a
g, BitPosition
_ #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
m BitPosition
startPos# ByteString
bs
  in BitsGetter a
g


instance Functor GetBits where
  {-# INLINE fmap #-}
  fmap :: forall a b. (a -> b) -> GetBits a -> GetBits b
fmap a -> b
f GetBits a
m = (BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
-> GetBits b
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
 -> GetBits b)
-> (BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
-> GetBits b
forall a b. (a -> b) -> a -> b
$
    \ !BitPosition
pos# ByteString
inp -> let !(# BitsGetter a
b, BitPosition
n# #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
m BitPosition
pos# ByteString
inp
                   in (# a -> b
f (a -> b) -> BitsGetter a -> BitsGetter b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BitsGetter a
b, BitPosition
n# #)

instance Applicative GetBits where
  {-# INLINE pure #-}
  pure :: forall a. a -> GetBits a
pure a
x = (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
 -> GetBits a)
-> (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a b. (a -> b) -> a -> b
$ \ !BitPosition
pos# ByteString
_ -> (# a -> BitsGetter a
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x, BitPosition
pos# #)

  {-# INLINE (<*>) #-}
  GetBits (a -> b)
f <*> :: forall a b. GetBits (a -> b) -> GetBits a -> GetBits b
<*> GetBits a
x =
    (BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
-> GetBits b
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
 -> GetBits b)
-> (BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
-> GetBits b
forall a b. (a -> b) -> a -> b
$ \ !BitPosition
pos# ByteString
inp ->
                let !(# BitsGetter (a -> b)
g, BitPosition
n# #) = GetBits (a -> b)
-> BitPosition
-> ByteString
-> (# BitsGetter (a -> b), BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits (a -> b)
f BitPosition
pos# ByteString
inp
                in case BitsGetter (a -> b)
g of
                     Right a -> b
g' ->
                       let !(# BitsGetter a
y, BitPosition
m# #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
x BitPosition
n# ByteString
inp
                       in case BitsGetter a
y of
                            Right a
y' -> (# b -> BitsGetter b
forall a b. b -> Either a b
Right (b -> BitsGetter b) -> b -> BitsGetter b
forall a b. (a -> b) -> a -> b
$ a -> b
g' a
y', BitPosition
m# #)
                            Left String
e -> (# String -> BitsGetter b
forall a b. a -> Either a b
Left String
e, BitPosition
m# #)
                     Left String
e -> (# String -> BitsGetter b
forall a b. a -> Either a b
Left String
e, BitPosition
n# #)

instance Monad GetBits where
  {-# INLINE return #-}
  return :: forall a. a -> GetBits a
return = a -> GetBits a
forall a. a -> GetBits a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

  {-# INLINE (>>=) #-}
  GetBits a
m >>= :: forall a b. GetBits a -> (a -> GetBits b) -> GetBits b
>>= a -> GetBits b
f = (BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
-> GetBits b
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
 -> GetBits b)
-> (BitPosition -> ByteString -> (# BitsGetter b, BitPosition #))
-> GetBits b
forall a b. (a -> b) -> a -> b
$ \ !BitPosition
pos# ByteString
inp ->
                        let !(# BitsGetter a
g, BitPosition
n# #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
m BitPosition
pos# ByteString
inp
                            !(# BitsGetter b
gr, BitPosition
nr# #) = case BitsGetter a
g of
                                               Left String
e -> (# String -> BitsGetter b
forall a b. a -> Either a b
Left String
e, BitPosition
n# #)
                                               Right a
a -> GetBits b
-> BitPosition -> ByteString -> (# BitsGetter b, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits (a -> GetBits b
f a
a) BitPosition
n# ByteString
inp
                        in (# BitsGetter b
gr, BitPosition
nr# #)

#if !MIN_VERSION_base(4,13,0)
  {-# INLINE fail #-}
  fail e = GetBits $ \ p _ -> (# Left e, p #)
#endif

instance MonadFail GetBits where
  {-# INLINE fail #-}
  fail :: forall a. String -> GetBits a
fail String
e = (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
 -> GetBits a)
-> (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a b. (a -> b) -> a -> b
$ \ BitPosition
p ByteString
_ -> (# String -> BitsGetter a
forall a b. a -> Either a b
Left String
e, BitPosition
p #)

instance Alternative GetBits where
  {-# INLINE empty #-}
  empty :: forall a. GetBits a
empty = (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
 -> GetBits a)
-> (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a b. (a -> b) -> a -> b
$ \ BitPosition
p ByteString
_ -> (# String -> BitsGetter a
forall a b. a -> Either a b
Left String
"GetBits is empty!", BitPosition
p #)

  {-# INLINE (<|>) #-}
  GetBits a
a <|> :: forall a. GetBits a -> GetBits a -> GetBits a
<|> GetBits a
b = (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits
            ((BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
 -> GetBits a)
-> (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a b. (a -> b) -> a -> b
$ \ !BitPosition
pos# ByteString
inp ->
                let !r :: (# BitsGetter a, BitPosition #)
r@(# BitsGetter a
g, BitPosition
_ #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
a BitPosition
pos# ByteString
inp
                in case BitsGetter a
g of
                     Right a
_ -> (# BitsGetter a, BitPosition #)
r
                     Left String
_ -> GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
b BitPosition
pos# ByteString
inp

instance MonadPlus GetBits where
  {-# INLINE mzero #-}
  mzero :: forall a. GetBits a
mzero = GetBits a
forall a. GetBits a
forall (f :: * -> *) a. Alternative f => f a
empty

  {-# INLINE mplus #-}
  mplus :: forall a. GetBits a -> GetBits a -> GetBits a
mplus = GetBits a -> GetBits a -> GetBits a
forall a. GetBits a -> GetBits a -> GetBits a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>)


-- | Extracts an Integer value of the specified number of bits from a ByteString,
-- starting at the indicated bit position (fails with Left if the range to
-- extract is not valid... i.e. > bitLimit).  Returns the Integer value along
-- with the bit position following the extraction.

-- There are two implementations: one builds an integer value by shifting bits,
-- then shifts and masks the result to get the final value.  The other uses
-- unlifted values and avoids the final shift by being smarter about individual
-- compositions.  Their functionality should be identical, but it may be easier
-- to debug the first.

_extractFromByteString' :: NumBits {-^ the last bit accessible in the ByteString -}
                        -> NumBits {-^ the bit to start extraction at -}
                        -> NumBits {-^ the number of bits to extract -}
                        -> ByteString {-^ the ByteString to extract from -}
                        -> Either String (Int, NumBits)
_extractFromByteString' :: NumBits
-> NumBits -> NumBits -> ByteString -> Either String (Int, NumBits)
_extractFromByteString' NumBits
bitLimit NumBits
startBit NumBits
numBits ByteString
bs =
  let Bytes' Int
s8 = (NumBytes, NumBits) -> NumBytes
forall a b. (a, b) -> a
fst (NumBits -> (NumBytes, NumBits)
bitsToBytes NumBits
startBit)
      Bytes' Int
r8 = (NumBytes, NumBits) -> NumBytes
forall a b. (a, b) -> a
fst (NumBits -> (NumBytes, NumBits)
bitsToBytes NumBits
numBits)
      rcnt :: Int
rcnt = Int
r8 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 -- 2 == pre-shift overflow byte on either side

      -- Extract the relevant bits from the ByteCode, with padding to byte
      -- boundaries into ws.
      ws :: ByteString
ws = Int -> ByteString -> ByteString
BS.take Int
rcnt (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop Int
s8 ByteString
bs

      -- Combine the extracted bytes into an Integer value in wi.
      wi :: Int
wi = (Word8 -> Int -> Int) -> Int -> ByteString -> Int
forall a. (Word8 -> a -> a) -> a -> ByteString -> a
BS.foldr (\Word8
w Int
a -> Int
a Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` Int
8 Int -> Int -> Int
forall a. Bits a => a -> a -> a
.|. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w) (Int
0::Int) ByteString
ws

      -- Mask is 0-bit based set of bits wanted in the result
      mask :: Int
mask = ((Int
1::Int) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftL` NumBits -> Int
bitCount NumBits
numBits) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

      -- Shift the desired value down to byte alignment and then discard any
      -- excess high bits.
      vi :: Int
vi = Int
wi Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` (NumBits -> Int
bitCount NumBits
startBit Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
7) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
mask

      updPos :: NumBits
updPos = NumBits -> NumBits -> NumBits
addBitCounts NumBits
startBit NumBits
numBits
  in if NumBits
updPos NumBits -> NumBits -> Bool
forall a. Ord a => a -> a -> Bool
> NumBits
bitLimit
     then String -> Either String (Int, NumBits)
forall a b. a -> Either a b
Left (String
"Attempt to read bits past limit (newPos="
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> NumBits -> String
forall a. Show a => a -> String
show NumBits
updPos String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
", limit=" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> NumBits -> String
forall a. Show a => a -> String
show NumBits
bitLimit String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
")"
               )
     else (Int, NumBits) -> Either String (Int, NumBits)
forall a b. b -> Either a b
Right (Int
vi, NumBits
updPos)

extractFromByteString :: Int# {-^ the last bit accessible in the ByteString -}
                      -> Int# {-^ the bit to start extraction at -}
                      -> Int# {-^ the number of bits to extract -}
                      -> ByteString {-^ the ByteString to extract from -}
                      -> Either String (() -> (# Int#, Int# #))
extractFromByteString :: Int#
-> Int# -> Int# -> ByteString -> Either String (() -> BitPosition)
extractFromByteString !Int#
bitLim# !Int#
sBit# !Int#
nbits# ByteString
bs =
     if Int# -> Bool
isTrue# ((Int#
1# Int# -> Int# -> Int#
`uncheckedIShiftL#` (Int#
nbits#)) Int# -> Int# -> Int#
/=# Int#
0#)
        -- (nbits# -# 1#) above would allow 64-bit value extraction, but this
        -- function cannot actually support a size of 64, because Int# is signed,
        -- so it doesn't properly use the high bit in numeric operations.  This
        -- seems to be OK at this point because LLVM bitcode does not attempt to
        -- encode actual 64-bit values.
     then
       let !updPos# :: Int#
updPos# = Int#
sBit# Int# -> Int# -> Int#
+# Int#
nbits#
       in if Int# -> Bool
isTrue# (Int#
updPos# Int# -> Int# -> Int#
<=# Int#
bitLim#)
          then
            let !s8# :: Int#
s8# = Int#
sBit# Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
3#
                !hop# :: Int#
hop# = Int#
sBit# Int# -> Int# -> Int#
`andI#` Int#
7#
                !r8# :: Int#
r8# = ((Int#
hop# Int# -> Int# -> Int#
+# Int#
nbits# Int# -> Int# -> Int#
+# Int#
7#) Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
3#)
                !mask# :: Int#
mask# = (Int#
1# Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
nbits#) Int# -> Int# -> Int#
-# Int#
1#
                -- Here, s8# is the size in 8-bit bytes, hop# is the number of
                -- bits shifted from the byte boundary, r8# is the rounded number
                -- of bytes actually needed to retrieve to get the value to
                -- account for shifting, and mask# is the mask for the final
                -- target set of bits after shifting.
#if MIN_VERSION_base(4,16,0)
                word8ToInt :: Word8# -> Int#
word8ToInt !Word8#
w8# = Word# -> Int#
word2Int# (Word8# -> Word#
word8ToWord# Word8#
w8#)
#else
                -- technically #if !MIN_VERSION_ghc_prim(0,8,0), for GHC 9.2, but
                -- since ghc_prim isn't a direct dependency and is re-exported
                -- from base, this define needs to reference the base version.
                word8ToInt = word2Int#
#endif
                -- getB# gets a value from a byte starting at bit0 of the byte
                getB# :: Int# -> Int#
                getB# :: Int# -> Int#
getB# !Int#
i# =
                  case Int#
i# of
                    Int#
0# -> let !(W8# Word8#
w#) = ByteString
bs HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
`BS.index` (Int# -> Int
I# Int#
s8#)
                          in Word8# -> Int#
word8ToInt Word8#
w#
                    Int#
_ -> let !(W8# Word8#
w#) = (ByteString
bs HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
`BS.index` (Int# -> Int
I# (Int#
s8# Int# -> Int# -> Int#
+# Int#
i#)))
                         in (Word8# -> Int#
word8ToInt Word8#
w#) Int# -> Int# -> Int#
`uncheckedIShiftL#` (Int#
8# Int# -> Int# -> Int#
*# Int#
i#)
                -- getSB# gets a value from a byte shifting from a non-zero start
                -- bit within the byte.
                getSB# :: Int# -> Int#
                getSB# :: Int# -> Int#
getSB# !Int#
i# =
                  case Int#
i# of
                    Int#
0# -> let !(W8# Word8#
w#) = ByteString
bs HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
`BS.index` (Int# -> Int
I# Int#
s8#)
                          in (Word8# -> Int#
word8ToInt Word8#
w#) Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
hop#
                    Int#
_  -> let !(W8# Word8#
w#) = ByteString
bs HasCallStack => ByteString -> Int -> Word8
ByteString -> Int -> Word8
`BS.index` (Int# -> Int
I# (Int#
s8# Int# -> Int# -> Int#
+# Int#
i#))
                              !shft# :: Int#
shft# = (Int#
8# Int# -> Int# -> Int#
*# Int#
i#) Int# -> Int# -> Int#
-# Int#
hop#
                          in (Word8# -> Int#
word8ToInt Word8#
w#) Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
shft#
                !vi# :: Int#
vi# = Int#
mask# Int# -> Int# -> Int#
`andI#`
                       (case Int#
hop# of
                          Int#
0# -> case Int#
r8# of
                                  Int#
1# -> Int# -> Int#
getB# Int#
0#
                                  Int#
2# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1#
                                  Int#
3# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
2#
                                  Int#
4# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
3#
                                  Int#
5# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
4#
                                  Int#
6# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
5#
                                  Int#
7# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
5# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
6#
                                  Int#
8# -> Int# -> Int#
getB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
5# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getB# Int#
6# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getB# Int#
7#
                                  -- This is the catch-all loop for other sizes
                                  -- not addressed above.
                                  Int#
_ -> let join :: Word8 -> Int -> Int
join !(W8# Word8#
w#) !(I# Int#
a#) =
                                             Int# -> Int
I# ((Int#
a# Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
8#)
                                                 Int# -> Int# -> Int#
`orI#` (Word8# -> Int#
word8ToInt Word8#
w#))
                                           bs' :: ByteString
bs' = Int -> ByteString -> ByteString
BS.take (Int# -> Int
I# (Int#
r8# Int# -> Int# -> Int#
+# Int#
2#))
                                                 (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop (Int# -> Int
I# Int#
s8#) ByteString
bs
                                           !(I# Int#
v#) = (Word8 -> Int -> Int) -> Int -> ByteString -> Int
forall a. (Word8 -> a -> a) -> a -> ByteString -> a
BS.foldr Word8 -> Int -> Int
join (Int
0::Int) ByteString
bs'
                                       in Int#
mask# Int# -> Int# -> Int#
`andI#` (Int#
v# Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
hop#)
                          Int#
_ -> case Int#
r8# of
                                 Int#
1# -> Int# -> Int#
getSB# Int#
0#
                                 Int#
2# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1#
                                 Int#
3# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2#
                                 Int#
4# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3#
                                 Int#
5# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
4#
                                 Int#
6# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
5#
                                 Int#
7# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
5# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
6#
                                 Int#
8# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
5# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
6# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
7#
                                 -- n.b. these are hand-unrolled cases for common
                                 -- sizes this is called for.
                                 Int#
9# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
5# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
6# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
7# Int# -> Int# -> Int#
`orI#`
                                       Int# -> Int#
getSB# Int#
8#
                                 Int#
18# -> Int# -> Int#
getSB# Int#
0# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
1# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
2# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
3# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
4# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
5# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
6# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
7# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
8# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
9# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
10# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
11# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
12# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
13# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
14# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
15# Int# -> Int# -> Int#
`orI#`
                                        Int# -> Int#
getSB# Int#
16# Int# -> Int# -> Int#
`orI#` Int# -> Int#
getSB# Int#
17#
                                 -- This is the catch-all loop for other sizes
                                 -- not addressed above.
                                 Int#
_ -> let join :: Word8 -> Int -> Int
join !(W8# Word8#
w#) !(I# Int#
a#) =
                                            Int# -> Int
I# ((Int#
a# Int# -> Int# -> Int#
`uncheckedIShiftL#` Int#
8#)
                                                Int# -> Int# -> Int#
`orI#` (Word8# -> Int#
word8ToInt Word8#
w#))
                                          bs' :: ByteString
bs' = Int -> ByteString -> ByteString
BS.take (Int# -> Int
I# (Int#
r8# Int# -> Int# -> Int#
+# Int#
2#))
                                                (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop (Int# -> Int
I# Int#
s8#) ByteString
bs
                                          !(I# Int#
v#) = (Word8 -> Int -> Int) -> Int -> ByteString -> Int
forall a. (Word8 -> a -> a) -> a -> ByteString -> a
BS.foldr Word8 -> Int -> Int
join (Int
0::Int) ByteString
bs'
                                      in Int#
mask# Int# -> Int# -> Int#
`andI#` (Int#
v# Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
hop#)
                       )
            in (() -> BitPosition) -> Either String (() -> BitPosition)
forall a b. b -> Either a b
Right ((() -> BitPosition) -> Either String (() -> BitPosition))
-> (() -> BitPosition) -> Either String (() -> BitPosition)
forall a b. (a -> b) -> a -> b
$ \()
_ -> (# Int#
vi#, Int#
updPos# #)
          else String -> Either String (() -> BitPosition)
forall a b. a -> Either a b
Left String
"Attempt to read bits past limit"
     else
       -- BitString stores an Int, but number of extracted bits is larger than
       -- an Int can represent.
       String -> Either String (() -> BitPosition)
forall a b. a -> Either a b
Left String
"Attempt to extracted large value"


-- Basic Interface -------------------------------------------------------------

-- | Read zeros up to an alignment of 32-bits.
align32bits :: GetBits ()
align32bits :: GetBits ()
align32bits  = (BitPosition -> ByteString -> (# BitsGetter (), BitPosition #))
-> GetBits ()
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter (), BitPosition #))
 -> GetBits ())
-> (BitPosition -> ByteString -> (# BitsGetter (), BitPosition #))
-> GetBits ()
forall a b. (a -> b) -> a -> b
$ \ !BitPosition
pos# ByteString
inp ->
  let !(# Int#
curBit#, Int#
ttlBits# #) = BitPosition
pos#
      !s32# :: Int#
s32# = Int#
curBit# Int# -> Int# -> Int#
`andI#` Int#
31#
      !r32# :: Int#
r32# = Int#
32# Int# -> Int# -> Int#
-# Int#
s32#  -- num bits to reach next 32-bit boundary
      nonZero :: String
nonZero = String
"alignments @" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (Int# -> Int
I# Int#
curBit#)
                String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" not zeroes up to 32-bit boundary"
  in if Int# -> Bool
isTrue# (Int#
s32# Int# -> Int# -> Int#
==# Int#
0#)
     then (# () -> BitsGetter ()
forall a b. b -> Either a b
Right (), BitPosition
pos# #)
     else case Int#
-> Int# -> Int# -> ByteString -> Either String (() -> BitPosition)
extractFromByteString Int#
ttlBits# Int#
curBit# Int#
r32# ByteString
inp of
            Right () -> BitPosition
getRes ->
              let !(# Int#
vi#, Int#
newPos# #) = () -> BitPosition
getRes ()
              in if Int# -> Bool
isTrue# (Int#
vi# Int# -> Int# -> Int#
==# Int#
0#)
                 then (# () -> BitsGetter ()
forall a b. b -> Either a b
Right (), (# Int#
newPos#, Int#
ttlBits# #) #)
                 else (# String -> BitsGetter ()
forall a b. a -> Either a b
Left String
nonZero, BitPosition
pos# #)
            Left String
e -> (# String -> BitsGetter ()
forall a b. a -> Either a b
Left String
e, BitPosition
pos# #)


-- | Read out n bits as a @BitString@.
fixed :: NumBits -> GetBits BitString
fixed :: NumBits -> GetBits BitString
fixed !(Bits' (I# Int#
n#)) = (BitPosition
 -> ByteString -> (# BitsGetter BitString, BitPosition #))
-> GetBits BitString
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits
  ((BitPosition
  -> ByteString -> (# BitsGetter BitString, BitPosition #))
 -> GetBits BitString)
-> (BitPosition
    -> ByteString -> (# BitsGetter BitString, BitPosition #))
-> GetBits BitString
forall a b. (a -> b) -> a -> b
$ \ !s :: BitPosition
s@(# Int#
cur#, Int#
lim# #) ->
      \ByteString
inp ->
        case Int#
-> Int# -> Int# -> ByteString -> Either String (() -> BitPosition)
extractFromByteString Int#
lim# Int#
cur# Int#
n# ByteString
inp of
          Right () -> BitPosition
getRes ->
            let !(# Int#
v#, Int#
p# #) = () -> BitPosition
getRes ()
            in (# BitString -> BitsGetter BitString
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BitString -> BitsGetter BitString)
-> BitString -> BitsGetter BitString
forall a b. (a -> b) -> a -> b
$ NumBits -> Int -> BitString
toBitString (Int -> NumBits
Bits' (Int# -> Int
I# Int#
n#)) (Int# -> Int
I# Int#
v#)
               , (# Int#
p#, Int#
lim# #)
               #)
          Left String
e -> (# String -> BitsGetter BitString
forall a b. a -> Either a b
Left String
e, BitPosition
s #)


-- | Read out n bytes as a @ByteString@, aligning to a 32-bit boundary before and
-- after.
bytestring :: NumBytes -> GetBits ByteString
bytestring :: NumBytes -> GetBits ByteString
bytestring n :: NumBytes
n@(Bytes' Int
nbytes) = do
  GetBits ()
align32bits
  ByteString
r <- (BitPosition
 -> ByteString -> (# BitsGetter ByteString, BitPosition #))
-> GetBits ByteString
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits
       ((BitPosition
  -> ByteString -> (# BitsGetter ByteString, BitPosition #))
 -> GetBits ByteString)
-> (BitPosition
    -> ByteString -> (# BitsGetter ByteString, BitPosition #))
-> GetBits ByteString
forall a b. (a -> b) -> a -> b
$ \ !(# Int#
pos#, Int#
lim# #) ->
           \ByteString
inp ->
             let !sbyte# :: Int#
sbyte# = Int#
pos# Int# -> Int# -> Int#
`uncheckedIShiftRL#` Int#
3# -- known to be aligned
                 !endAt# :: Int#
endAt# = Int#
pos# Int# -> Int# -> Int#
+# NumBits -> Int#
bitCount# (NumBytes -> NumBits
bytesToBits NumBytes
n)
                 !end# :: BitPosition
end# = (# Int#
endAt#, Int#
lim# #)
                 err :: String
err = String
"Sub-bytestring attempted beyond end of input bytestring"
             in if Int# -> Bool
isTrue# (Int#
endAt# Int# -> Int# -> Int#
<=# Int#
lim#)
                then (# ByteString -> BitsGetter ByteString
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> BitsGetter ByteString)
-> ByteString -> BitsGetter ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.take Int
nbytes (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
BS.drop (Int# -> Int
I# Int#
sbyte#) ByteString
inp, BitPosition
end# #)
                else (# String -> BitsGetter ByteString
forall a b. a -> Either a b
Left String
err, BitPosition
end# #)
  GetBits ()
align32bits
  ByteString -> GetBits ByteString
forall a. a -> GetBits a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
r


-- | Add a label to the error tag stack.
label :: String -> GetBits a -> GetBits a
label :: forall a. String -> GetBits a -> GetBits a
label String
l GetBits a
m = (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
 -> GetBits a)
-> (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a b. (a -> b) -> a -> b
$ \ !BitPosition
pos# ByteString
inp ->
                        let !(# BitsGetter a
j, BitPosition
n# #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
m BitPosition
pos# ByteString
inp
                        in case BitsGetter a
j of
                             Left String
e -> (# String -> BitsGetter a
forall a b. a -> Either a b
Left (String -> BitsGetter a) -> String -> BitsGetter a
forall a b. (a -> b) -> a -> b
$ String
e String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"\n  " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
l, BitPosition
n# #)
                             Right a
r -> (# a -> BitsGetter a
forall a b. b -> Either a b
Right a
r, BitPosition
n# #)


-- | Isolate input to a sub-span of the specified byte length.
isolate :: NumBytes -> GetBits a -> GetBits a
isolate :: forall a. NumBytes -> GetBits a -> GetBits a
isolate NumBytes
ws GetBits a
m =
  (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
 -> GetBits a)
-> (BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
forall a b. (a -> b) -> a -> b
$ \ !(# Int#
pos#, Int#
lim# #) ->
              \ByteString
inp ->
                let !l# :: Int#
l# = Int#
pos# Int# -> Int# -> Int#
+# NumBits -> Int#
bitCount# (NumBytes -> NumBits
bytesToBits NumBytes
ws)
                    !(# BitsGetter a
r, (# Int#
x#, Int#
_ #) #) = GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
forall a.
GetBits a
-> BitPosition -> ByteString -> (# BitsGetter a, BitPosition #)
unGetBits GetBits a
m (# Int#
pos#, Int#
l# #) ByteString
inp
                in (# BitsGetter a
r, (# Int#
x#, Int#
lim# #) #)


-- | Try to parse something, returning Nothing when it fails.

try :: GetBits a -> GetBits (Maybe a)
try :: forall a. GetBits a -> GetBits (Maybe a)
try GetBits a
m = (a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> GetBits a -> GetBits (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GetBits a
m) GetBits (Maybe a) -> GetBits (Maybe a) -> GetBits (Maybe a)
forall a. GetBits a -> GetBits a -> GetBits a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Maybe a -> GetBits (Maybe a)
forall a. a -> GetBits a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing


-- | Skips the specified number of bits

skip :: NumBits -> GetBits ()
skip :: NumBits -> GetBits ()
skip !(Bits' (I# Int#
n#)) =
  (BitPosition -> ByteString -> (# BitsGetter (), BitPosition #))
-> GetBits ()
forall a.
(BitPosition -> ByteString -> (# BitsGetter a, BitPosition #))
-> GetBits a
GetBits ((BitPosition -> ByteString -> (# BitsGetter (), BitPosition #))
 -> GetBits ())
-> (BitPosition -> ByteString -> (# BitsGetter (), BitPosition #))
-> GetBits ()
forall a b. (a -> b) -> a -> b
$ \ !(# Int#
cur#, Int#
lim# #) ->
              let !newLoc# :: Int#
newLoc# = Int#
cur# Int# -> Int# -> Int#
+# Int#
n#
                  !newPos# :: BitPosition
newPos# = (# Int#
newLoc#, Int#
lim# #)
              in if Int# -> Bool
isTrue# (Int#
newLoc# Int# -> Int# -> Int#
># Int#
lim#)
                 then \ByteString
_ -> (# String -> BitsGetter ()
forall a b. a -> Either a b
Left String
"skipped past end of bytestring"
                            , BitPosition
newPos#
                              #)
                 else \ByteString
_ -> (# () -> BitsGetter ()
forall a b. b -> Either a b
Right (), BitPosition
newPos# #)