-- | Data null-padded to a given length.

{-# LANGUAGE OverloadedStrings #-}

module Binrep.Type.NullPadded where

import Binrep
import Bytezap.Poke qualified as BZ
import Bytezap.Struct qualified as BZ.Struct
import FlatParse.Basic qualified as FP
import Raehik.Compat.FlatParse.Basic.WithLength qualified as FP
import Control.Monad.Combinators ( skipCount )

import Binrep.Util ( tshow )

import Refined
import Refined.Unsafe

import GHC.TypeNats
import Util.TypeNats ( natValInt )

import Data.Typeable ( typeRep )

data NullPad (n :: Natural)

-- | A type which is to be null-padded to a given total length.
--
-- Given some @a :: 'NullPadded' n a@, it is guaranteed that
--
-- @
-- 'blen' a '<=' 'natValInt' \@n
-- @
--
-- thus
--
-- @
-- 'natValInt' \@n '-' 'blen' a '>=' 0
-- @
--
-- That is, the serialized stored data will not be longer than the total length.
--
-- The binrep instances are careful not to construct bytestrings unnecessarily.
type NullPadded n a = Refined (NullPad n) a

instance (BLen a, KnownNat n) => Predicate (NullPad n) a where
    validate :: Proxy (NullPad n) -> a -> Maybe RefineException
validate Proxy (NullPad n)
p a
a
      | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = Maybe RefineException
success
      | Bool
otherwise
          = TypeRep -> Text -> Maybe RefineException
throwRefineOtherException (Proxy (NullPad n) -> TypeRep
forall {k} (proxy :: k -> Type) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep Proxy (NullPad n)
p) (Text -> Maybe RefineException) -> Text -> Maybe RefineException
forall a b. (a -> b) -> a -> b
$
                   Text
"too long: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
len Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" > " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
tshow Int
n
      where
        n :: Int
n = forall (n :: Nat). KnownNat n => Int
natValInt @n
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a

instance IsCBLen (NullPadded n a) where type CBLen (NullPadded n a) = n
deriving via ViaCBLen (NullPadded n a) instance KnownNat n => BLen (NullPadded n a)

instance (BLen a, KnownNat n, PutC a) => PutC (NullPadded n a) where
    putC :: NullPadded n a -> PutterC
putC NullPadded n a
ra = PutterC -> Int -> PutterC -> PutterC
forall s. Poke s -> Int -> Poke s -> Poke s
BZ.Struct.sequencePokes (a -> PutterC
forall a. PutC a => a -> PutterC
putC a
a) Int
len
        (Int -> Word8 -> PutterC
BZ.Struct.replicateByte Int
paddingLen Word8
0x00)
      where
        len :: Int
len = a -> Int
forall a. BLen a => a -> Int
blen a
a
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) x. Refined p x -> x
unrefine NullPadded n a
ra
        paddingLen :: Int
paddingLen = forall (n :: Nat). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        -- ^ refinement guarantees >=0

instance (BLen a, KnownNat n, Put a) => Put (NullPadded n a) where
    put :: NullPadded n a -> Putter
put NullPadded n a
ra = a -> Putter
forall a. Put a => a -> Putter
put a
a Putter -> Putter -> Putter
forall a. Semigroup a => a -> a -> a
<> Int -> Word8 -> Putter
BZ.replicateByte Int
paddingLen Word8
0x00
      where
        a :: a
a = NullPadded n a -> a
forall {k} (p :: k) x. Refined p x -> x
unrefine NullPadded n a
ra
        paddingLen :: Int
paddingLen = forall (n :: Nat). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- a -> Int
forall a. BLen a => a -> Int
blen a
a
        -- ^ refinement guarantees >=0

instance (Get a, KnownNat n) => Get (NullPadded n a) where
    get :: Getter (NullPadded n a)
get = do
        (a
a, Int
len) <- ParserT PureMode E a -> ParserT PureMode E (a, Int)
forall (st :: ZeroBitType) e a.
ParserT st e a -> ParserT st e (a, Int)
FP.parseWithLength ParserT PureMode E a
forall a. Get a => Getter a
get
        let paddingLen :: Int
paddingLen = forall (n :: Nat). KnownNat n => Int
natValInt @n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
        if   Int
paddingLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0
        then EBase -> Getter (NullPadded n a)
forall a. EBase -> Getter a
eBase (EBase -> Getter (NullPadded n a))
-> EBase -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ String -> EBase
EFailNamed String
"TODO used to be EOverlong, cba"
        else do
            Int -> ParserT PureMode E () -> ParserT PureMode E ()
forall (m :: Type -> Type) a. Monad m => Int -> m a -> m ()
skipCount Int
paddingLen (Word8 -> ParserT PureMode E ()
forall (st :: ZeroBitType) e. Word8 -> ParserT st e ()
FP.word8 Word8
0x00)
            NullPadded n a -> Getter (NullPadded n a)
forall a. a -> ParserT PureMode E a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (NullPadded n a -> Getter (NullPadded n a))
-> NullPadded n a -> Getter (NullPadded n a)
forall a b. (a -> b) -> a -> b
$ a -> NullPadded n a
forall {k} x (p :: k). x -> Refined p x
reallyUnsafeRefine a
a