module Rattletrap.BitGet where

import qualified Control.Monad as Monad
import qualified Data.Binary.Bits.Get as BinaryBits
import qualified Data.Binary.Get as Binary
import qualified Data.Bits as Bits
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as LazyByteString
import qualified Data.Word as Word
import qualified Rattletrap.ByteGet as ByteGet
import qualified Rattletrap.Get as Get
import qualified Rattletrap.Utility.Bytes as Utility

type BitGet = BinaryBits.BitGet

toByteGet :: BitGet a -> ByteGet.ByteGet a
toByteGet :: BitGet a -> ByteGet a
toByteGet = Get a -> ByteGet a
forall a. Get a -> ByteGet a
binaryGetToByteGet (Get a -> ByteGet a)
-> (BitGet a -> Get a) -> BitGet a -> ByteGet a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BitGet a -> Get a
forall a. BitGet a -> Get a
BinaryBits.runBitGet

binaryGetToByteGet :: Binary.Get a -> ByteGet.ByteGet a
binaryGetToByteGet :: Get a -> ByteGet a
binaryGetToByteGet Get a
g = do
  ByteString
s1 <- Get ByteString Identity ByteString
forall (m :: * -> *) s. Applicative m => Get s m s
Get.get
  case Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
forall a.
Get a
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
Binary.runGetOrFail Get a
g (ByteString
 -> Either
      (ByteString, ByteOffset, String) (ByteString, ByteOffset, a))
-> ByteString
-> Either
     (ByteString, ByteOffset, String) (ByteString, ByteOffset, a)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LazyByteString.fromStrict ByteString
s1 of
    Left (ByteString
_, ByteOffset
_, String
x) -> String -> ByteGet a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
x
    Right (ByteString
s2, ByteOffset
_, a
x) -> do
      ByteString -> Get ByteString Identity ()
forall (m :: * -> *) s. Applicative m => s -> Get s m ()
Get.put (ByteString -> Get ByteString Identity ())
-> ByteString -> Get ByteString Identity ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LazyByteString.toStrict ByteString
s2
      a -> ByteGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

fromByteGet :: ByteGet.ByteGet a -> Int -> BitGet a
fromByteGet :: ByteGet a -> Int -> BitGet a
fromByteGet ByteGet a
f Int
n = do
  ByteString
x <- Int -> BitGet ByteString
BinaryBits.getByteString Int
n
  (String -> BitGet a)
-> (a -> BitGet a) -> Either String a -> BitGet a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> BitGet a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail a -> BitGet a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String a -> BitGet a)
-> (ByteString -> Either String a) -> ByteString -> BitGet a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteGet a -> ByteString -> Either String a
forall a. ByteGet a -> ByteString -> Either String a
ByteGet.run ByteGet a
f (ByteString -> BitGet a) -> ByteString -> BitGet a
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Utility.reverseBytes ByteString
x

bits :: Bits.Bits a => Int -> BitGet a
bits :: Int -> BitGet a
bits Int
n =
  (Bool -> a -> a) -> a -> [Bool] -> a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      (\Bool
bit a
x -> let y :: a
y = a -> Int -> a
forall a. Bits a => a -> Int -> a
Bits.shiftL a
x Int
1 in if Bool
bit then a -> Int -> a
forall a. Bits a => a -> Int -> a
Bits.setBit a
y Int
0 else a
y
      )
      a
forall a. Bits a => a
Bits.zeroBits
    ([Bool] -> a) -> BitGet [Bool] -> BitGet a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> BitGet Bool -> BitGet [Bool]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
Monad.replicateM Int
n BitGet Bool
bool

bool :: BitGet Bool
bool :: BitGet Bool
bool = BitGet Bool
BinaryBits.getBool

byteString :: Int -> BitGet ByteString.ByteString
byteString :: Int -> BitGet ByteString
byteString = Int -> BitGet ByteString
BinaryBits.getByteString

word8 :: Int -> BitGet Word.Word8
word8 :: Int -> BitGet Word8
word8 = Int -> BitGet Word8
BinaryBits.getWord8