-- | Put bits easily.

module Data.Binary.Bits.Put
  ( BitPut
  , runBitPut
  , joinPut

          -- * Data types
          -- ** Bool
  , putBool

          -- ** Words
  , putWord8
  , putWord16be
  , putWord32be
  , putWord64be

          -- ** ByteString
  , putByteString
  ) where

import Data.Bits ((.&.), (.|.))

import qualified Data.Binary.Builder as B
import qualified Data.Binary.Put as Put
import qualified Data.Bits as Bits
import qualified Data.ByteString as ByteString
import qualified Data.Word as Word

newtype BitPut a = BitPut
  { forall a. BitPut a -> S -> PairS a
run :: S -> PairS a
  }

data PairS a = PairS a {-# UNPACK #-} !S

data S = S !B.Builder !Word.Word8 !Int

-- | Put a 1 bit 'Bool'.
putBool :: Bool -> BitPut ()
putBool :: Bool -> BitPut ()
putBool Bool
b = Int -> Word8 -> BitPut ()
putWord8 Int
1 (if Bool
b then Word8
0xff else Word8
0x00)

-- | makeMask 3 = 00000111
makeMask :: (Bits.Bits a, Num a) => Int -> a
makeMask :: forall a. (Bits a, Num a) => Int -> a
makeMask Int
n = (a
1 forall a. Bits a => a -> Int -> a
`Bits.shiftL` forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) forall a. Num a => a -> a -> a
- a
1
{-# SPECIALIZE makeMask :: Int -> Int #-}
{-# SPECIALIZE makeMask :: Int -> Word #-}
{-# SPECIALIZE makeMask :: Int -> Word.Word8 #-}
{-# SPECIALIZE makeMask :: Int -> Word.Word16 #-}
{-# SPECIALIZE makeMask :: Int -> Word.Word32 #-}
{-# SPECIALIZE makeMask :: Int -> Word.Word64 #-}

-- | Put the @n@ lower bits of a 'Word8'.
putWord8 :: Int -> Word.Word8 -> BitPut ()
putWord8 :: Int -> Word8 -> BitPut ()
putWord8 Int
n Word8
w = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \S
s ->
  forall a. a -> S -> PairS a
PairS ()
    forall a b. (a -> b) -> a -> b
$ let w' :: Word8
w' = forall a. (Bits a, Num a) => Int -> a
makeMask Int
n forall a. Bits a => a -> a -> a
.&. Word8
w
      in
        case S
s of
                    -- a whole word8, no offset
          (S Builder
b Word8
t Int
o)
            | Int
n forall a. Eq a => a -> a -> Bool
== Int
8 Bool -> Bool -> Bool
&& Int
o forall a. Eq a => a -> a -> Bool
== Int
0
            -> S -> S
flush forall a b. (a -> b) -> a -> b
$ Builder -> Word8 -> Int -> S
S Builder
b Word8
w Int
n
            |
                      -- less than a word8, will fit in the current word8
              Int
n forall a. Ord a => a -> a -> Bool
<= Int
8 forall a. Num a => a -> a -> a
- Int
o
            -> S -> S
flush forall a b. (a -> b) -> a -> b
$ Builder -> Word8 -> Int -> S
S Builder
b (Word8
t forall a. Bits a => a -> a -> a
.|. (Word8
w' forall a. Bits a => a -> Int -> a
`Bits.shiftL` (Int
8 forall a. Num a => a -> a -> a
- Int
n forall a. Num a => a -> a -> a
- Int
o))) (Int
o forall a. Num a => a -> a -> a
+ Int
n)
            |
                      -- will finish this word8, and spill into the next one
              Bool
otherwise
            -> S -> S
flush
              forall a b. (a -> b) -> a -> b
$ let
                  o' :: Int
o' = Int
o forall a. Num a => a -> a -> a
+ Int
n forall a. Num a => a -> a -> a
- Int
8
                  b' :: Word8
b' = Word8
t forall a. Bits a => a -> a -> a
.|. (Word8
w' forall a. Bits a => a -> Int -> a
`Bits.shiftR` Int
o')
                  t' :: Word8
t' = Word8
w forall a. Bits a => a -> Int -> a
`Bits.shiftL` (Int
8 forall a. Num a => a -> a -> a
- Int
o')
                in Builder -> Word8 -> Int -> S
S (Builder
b forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b') Word8
t' Int
o'

-- | Put the @n@ lower bits of a 'Word16'.
putWord16be :: Int -> Word.Word16 -> BitPut ()
putWord16be :: Int -> Word16 -> BitPut ()
putWord16be Int
n Word16
w
  | Int
n forall a. Ord a => a -> a -> Bool
<= Int
8 = Int -> Word8 -> BitPut ()
putWord8 Int
n (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
w)
  | Bool
otherwise = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \S
s ->
    forall a. a -> S -> PairS a
PairS ()
      forall a b. (a -> b) -> a -> b
$ let w' :: Word16
w' = forall a. (Bits a, Num a) => Int -> a
makeMask Int
n forall a. Bits a => a -> a -> a
.&. Word16
w
        in
          case S
s of
          -- as n>=9, it's too big to fit into one single byte
          -- it'll either use 2 or 3 bytes
                                     -- it'll fit in 2 bytes
            (S Builder
b Word8
t Int
o)
              | Int
o forall a. Num a => a -> a -> a
+ Int
n forall a. Ord a => a -> a -> Bool
<= Int
16
              -> S -> S
flush
                forall a b. (a -> b) -> a -> b
$ let
                    o' :: Int
o' = Int
o forall a. Num a => a -> a -> a
+ Int
n forall a. Num a => a -> a -> a
- Int
8
                    b' :: Word8
b' = Word8
t forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w' forall a. Bits a => a -> Int -> a
`Bits.shiftR` Int
o')
                    t' :: Word8
t' = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w forall a. Bits a => a -> Int -> a
`Bits.shiftL` (Int
8 forall a. Num a => a -> a -> a
- Int
o'))
                  in Builder -> Word8 -> Int -> S
S (Builder
b forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b') Word8
t' Int
o'
              |
                                     -- 3 bytes required
                Bool
otherwise
              -> S -> S
flush
                forall a b. (a -> b) -> a -> b
$ let
                    o' :: Int
o' = Int
o forall a. Num a => a -> a -> a
+ Int
n forall a. Num a => a -> a -> a
- Int
16
                    b' :: Word8
b' = Word8
t forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w' forall a. Bits a => a -> Int -> a
`Bits.shiftR` (Int
o' forall a. Num a => a -> a -> a
+ Int
8))
                    b'' :: Word8
b'' = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word16
w forall a. Bits a => a -> Int -> a
`Bits.shiftR` Int
o') forall a. Bits a => a -> a -> a
.&. Word16
0xff)
                    t' :: Word8
t' = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
w forall a. Bits a => a -> Int -> a
`Bits.shiftL` (Int
8 forall a. Num a => a -> a -> a
- Int
o'))
                  in Builder -> Word8 -> Int -> S
S
                    (Builder
b forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b' forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
b'')
                    Word8
t'
                    Int
o'

-- | Put the @n@ lower bits of a 'Word32'.
putWord32be :: Int -> Word.Word32 -> BitPut ()
putWord32be :: Int -> Word32 -> BitPut ()
putWord32be Int
n Word32
w
  | Int
n forall a. Ord a => a -> a -> Bool
<= Int
16 = Int -> Word16 -> BitPut ()
putWord16be Int
n (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
w)
  | Bool
otherwise = do
    Int -> Word32 -> BitPut ()
putWord32be (Int
n forall a. Num a => a -> a -> a
- Int
16) (Word32
w forall a. Bits a => a -> Int -> a
`Bits.shiftR` Int
16)
    Int -> Word32 -> BitPut ()
putWord32be Int
16 (Word32
w forall a. Bits a => a -> a -> a
.&. Word32
0x0000ffff)

-- | Put the @n@ lower bits of a 'Word64'.
putWord64be :: Int -> Word.Word64 -> BitPut ()
putWord64be :: Int -> Word64 -> BitPut ()
putWord64be Int
n Word64
w
  | Int
n forall a. Ord a => a -> a -> Bool
<= Int
32 = Int -> Word32 -> BitPut ()
putWord32be Int
n (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
w)
  | Bool
otherwise = do
    Int -> Word64 -> BitPut ()
putWord64be (Int
n forall a. Num a => a -> a -> a
- Int
32) (Word64
w forall a. Bits a => a -> Int -> a
`Bits.shiftR` Int
32)
    Int -> Word64 -> BitPut ()
putWord64be Int
32 (Word64
w forall a. Bits a => a -> a -> a
.&. Word64
0xffffffff)

-- | Put a 'ByteString'.
putByteString :: ByteString.ByteString -> BitPut ()
putByteString :: ByteString -> BitPut ()
putByteString ByteString
bs = do
  Bool
offset <- BitPut Bool
hasOffset
  if Bool
offset
    then forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> Word8 -> BitPut ()
putWord8 Int
8) (ByteString -> [Word8]
ByteString.unpack ByteString
bs) -- naive
    else Put -> BitPut ()
joinPut (ByteString -> Put
Put.putByteString ByteString
bs)
  where hasOffset :: BitPut Bool
hasOffset = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \s :: S
s@(S Builder
_ Word8
_ Int
o) -> forall a. a -> S -> PairS a
PairS (Int
o forall a. Eq a => a -> a -> Bool
/= Int
0) S
s

-- | Run a 'Put' inside 'BitPut'. Any partially written bytes will be flushed
-- before 'Put' executes to ensure byte alignment.
joinPut :: Put.Put -> BitPut ()
joinPut :: Put -> BitPut ()
joinPut Put
m = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \S
s0 ->
  forall a. a -> S -> PairS a
PairS ()
    forall a b. (a -> b) -> a -> b
$ let
        (S Builder
b0 Word8
_ Int
_) = S -> S
flushIncomplete S
s0
        b :: Builder
b = forall a. PutM a -> Builder
Put.execPut Put
m
      in Builder -> Word8 -> Int -> S
S (Builder
b0 forall a. Monoid a => a -> a -> a
`mappend` Builder
b) Word8
0 Int
0

flush :: S -> S
flush :: S -> S
flush s :: S
s@(S Builder
b Word8
w Int
o)
  | Int
o forall a. Ord a => a -> a -> Bool
> Int
8 = forall a. HasCallStack => [Char] -> a
error [Char]
"flush: offset > 8"
  | Int
o forall a. Eq a => a -> a -> Bool
== Int
8 = Builder -> Word8 -> Int -> S
S (Builder
b forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
w) Word8
0 Int
0
  | Bool
otherwise = S
s

flushIncomplete :: S -> S
flushIncomplete :: S -> S
flushIncomplete s :: S
s@(S Builder
b Word8
w Int
o)
  | Int
o forall a. Eq a => a -> a -> Bool
== Int
0 = S
s
  | Bool
otherwise = Builder -> Word8 -> Int -> S
S (Builder
b forall a. Monoid a => a -> a -> a
`mappend` Word8 -> Builder
B.singleton Word8
w) Word8
0 Int
0

-- | Run the 'BitPut' monad inside 'Put'.
runBitPut :: BitPut () -> Put.Put
runBitPut :: BitPut () -> Put
runBitPut BitPut ()
m = Builder -> Put
Put.putBuilder Builder
b
 where
  PairS ()
_ S
s = forall a. BitPut a -> S -> PairS a
run BitPut ()
m (Builder -> Word8 -> Int -> S
S forall a. Monoid a => a
mempty Word8
0 Int
0)
  (S Builder
b Word8
_ Int
_) = S -> S
flushIncomplete S
s

instance Functor BitPut where
  fmap :: forall a b. (a -> b) -> BitPut a -> BitPut b
fmap a -> b
f (BitPut S -> PairS a
k) = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \S
s -> let PairS a
x S
s' = S -> PairS a
k S
s in forall a. a -> S -> PairS a
PairS (a -> b
f a
x) S
s'

instance Applicative BitPut where
  pure :: forall a. a -> BitPut a
pure a
a = forall a. (S -> PairS a) -> BitPut a
BitPut (forall a. a -> S -> PairS a
PairS a
a)
  (BitPut S -> PairS (a -> b)
f) <*> :: forall a b. BitPut (a -> b) -> BitPut a -> BitPut b
<*> (BitPut S -> PairS a
g) = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \S
s ->
    let
      PairS a -> b
a S
s' = S -> PairS (a -> b)
f S
s
      PairS a
b S
s'' = S -> PairS a
g S
s'
    in forall a. a -> S -> PairS a
PairS (a -> b
a a
b) S
s''

instance Monad BitPut where
  BitPut a
m >>= :: forall a b. BitPut a -> (a -> BitPut b) -> BitPut b
>>= a -> BitPut b
k = forall a. (S -> PairS a) -> BitPut a
BitPut forall a b. (a -> b) -> a -> b
$ \S
s ->
    let
      PairS a
a S
s' = forall a. BitPut a -> S -> PairS a
run BitPut a
m S
s
      PairS b
b S
s'' = forall a. BitPut a -> S -> PairS a
run (a -> BitPut b
k a
a) S
s'
    in forall a. a -> S -> PairS a
PairS b
b S
s''