-- | Data null-padded to a given length.

{- TODO
Null padding using the underlying type's instances doesn't necessarily work.
'ByteString's must parse until the end of the string.
Or maybe that's correct, and we must use null terminated bytestrings with null
padding...? Huh.

...well, doing that fixes my issue. And thinking about it, I imagine that's how
C does it (you're still going to be wanting to deal with cstrings regardless of
null padding). Cool!!

OK, all good. But because of that, I should provide a convenience wrapper to put
nullpad+nullterm together.
-}

{-# LANGUAGE OverloadedStrings #-}

module Binrep.Type.NullPadded where

import Binrep
import Bytezap.Poke qualified as BZ
import Bytezap.Struct qualified as BZ.Struct
import FlatParse.Basic qualified as FP
import Raehik.Compat.FlatParse.Basic.WithLength qualified as FP
import Control.Monad.Combinators ( skipCount )

import Binrep.Util ( tshow )

import Refined
import Refined.Unsafe

import GHC.TypeNats
import Util.TypeNats ( natValInt )

import Data.Typeable ( typeRep )

import Bytezap.Parser.Struct qualified as BZG
import GHC.Exts ( Int(I#) )

data NullPad (n :: Natural)

{- | A type which is to be null-padded to a given total length.

Given some @a :: 'NullPadded' n a@, it is guaranteed that

@
'blen' a '<=' 'natValInt' \@n
@

thus

@
'natValInt' \@n '-' 'blen' a '>=' 0
@

That is, the serialized stored data will not be longer than the total length.
-}
type NullPadded n a = Refined (NullPad n) a

instance IsCBLen (NullPadded n a) where type CBLen (NullPadded n a) = n
deriving via ViaCBLen (NullPadded n a) instance KnownNat n => BLen (NullPadded n a)

-- | Assert that term will fit.
instance (BLen a, KnownNat n) => Predicate (NullPad n) a where
    validate :: Proxy (NullPad n) -> a -> Maybe RefineException
validate Proxy (NullPad n)
p a
a
      | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = Maybe RefineException
success
      | Bool
otherwise
          = TypeRep -> Text -> Maybe RefineException
throwRefineOtherException (Proxy (NullPad n) -> TypeRep
forall {k} (proxy :: k -> Type) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep Proxy (NullPad n)
p) (Text -> Maybe RefineException) -> Text -> Maybe RefineException
forall a b. (a -> b) -> a -> b
$
                   Text
"too long: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
len Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" > " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
n
      where
        n :: Int
n = forall (n :: Natural). KnownNat n => Int
natValInt @n
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a

instance (BLen a, KnownNat n, PutC a) => PutC (NullPadded n a) where
    putC :: NullPadded n a -> PutterC
putC NullPadded n a
ra = PutterC -> Int -> PutterC -> PutterC
forall s. Poke s -> Int -> Poke s -> Poke s
BZ.Struct.sequencePokes (a -> PutterC
forall a. PutC a => a -> PutterC
putC a
a) Int
len
        (Int -> Word8 -> PutterC
BZ.Struct.replicateByte Int
paddingLen Word8
0x00)
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) x. Refined p x -> x
unrefine NullPadded n a
ra
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a
        paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        -- ^ refinement guarantees >=0

instance (BLen a, KnownNat n, Put a) => Put (NullPadded n a) where
    put :: NullPadded n a -> Putter
put NullPadded n a
ra = a -> Putter
forall a. Put a => a -> Putter
put a
a Putter -> Putter -> Putter
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> Putter
BZ.replicateByte Int
paddingLen Word8
0x00
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) x. Refined p x -> x
unrefine NullPadded n a
ra
        paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- a -> Int
forall a. BLen a => a -> Int
blen a
a
        -- ^ refinement guarantees >=0

-- | Run a @Getter a@ isolated to @n@ bytes.
instance (KnownNat n, Get a) => GetC (NullPadded n a) where
    getC :: GetterC (NullPadded n a)
getC = ParserT PureMode E a
-> Int#
-> (a -> Int# -> GetterC (NullPadded n a))
-> GetterC (NullPadded n a)
forall (st :: ZeroBitType) e a r.
ParserT st e a
-> Int# -> (a -> Int# -> ParserT st e r) -> ParserT st e r
fpToBz ParserT PureMode E a
forall a. Get a => Getter a
get Int#
len# ((a -> Int# -> GetterC (NullPadded n a))
 -> GetterC (NullPadded n a))
-> (a -> Int# -> GetterC (NullPadded n a))
-> GetterC (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ \a
a Int#
_unconsumed# ->
        -- TODO consume nulls lol
        NullPadded n a -> GetterC (NullPadded n a)
forall a (st :: ZeroBitType) e. a -> ParserT st e a
BZG.constParse (NullPadded n a -> GetterC (NullPadded n a))
-> NullPadded n a -> GetterC (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} x (p :: k). x -> Refined p x
reallyUnsafeRefine a
a
      where
        !(I# Int#
len#) = forall (n :: Natural). KnownNat n => Int
natValInt @n

instance (Get a, KnownNat n) => Get (NullPadded n a) where
    get :: Getter (NullPadded n a)
get = do
        (a
a, Int
len) <- ParserT PureMode E a -> ParserT PureMode E (a, Int)
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e (a, Int)
FP.parseWithLength ParserT PureMode E a
forall a. Get a => Getter a
get
        let paddingLen :: Int
paddingLen = forall (n :: Natural). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        if   Int
paddingLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
        then EBase -> Getter (NullPadded n a)
forall a. EBase -> Getter a
eBase (EBase -> Getter (NullPadded n a))
-> EBase -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ String -> EBase
EFailNamed String
"TODO used to be EOverlong, cba"
        else do
            Int -> ParserT PureMode E () -> ParserT PureMode E ()
forall (m :: Type -> Type) a. Monad m => Int -> m a -> m ()
skipCount Int
paddingLen (Word8 -> ParserT PureMode E ()
forall (st :: ZeroBitType) e. Word8 -> ParserT st e ()
FP.word8 Word8
0x00)
            NullPadded n a -> Getter (NullPadded n a)
forall a. a -> ParserT PureMode E a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (NullPadded n a -> Getter (NullPadded n a))
-> NullPadded n a -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} x (p :: k). x -> Refined p x
reallyUnsafeRefine a
a