{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Capnp.GenHelpers
  ( dataField,
    ptrField,
    groupField,
    voidField,
    readVariant,
    Mutability (..),
    TypeParam,
    newStruct,
    parseEnum,
    encodeEnum,
    getPtrConst,
    BS.ByteString,
    module F,
    module Capnp.Accessors,

    -- * Re-exports from the standard library.
    Proxy (..),
  )
where

import Capnp.Accessors
import qualified Capnp.Basics as NB
import Capnp.Bits
import qualified Capnp.Classes as NC
import Capnp.Constraints (TypeParam)
import Capnp.Convert (bsToRaw)
import Capnp.Fields as F
import Capnp.Message (Mutability (..))
import qualified Capnp.Message as M
import qualified Capnp.Repr as R
import Capnp.TraversalLimit (evalLimitT)
import qualified Capnp.Untyped as U
import Data.Bits
import qualified Data.ByteString as BS
import Data.Functor ((<&>))
import Data.Maybe (fromJust)
import Data.Proxy (Proxy (..))
import Data.Word

dataField ::
  forall b a sz.
  ( R.ReprFor b ~ 'R.Data sz,
    NC.IsWord (R.UntypedData sz)
  ) =>
  BitCount ->
  Word16 ->
  BitCount ->
  Word64 ->
  F.Field 'F.Slot a b
dataField :: forall b a (sz :: DataSz).
(ReprFor b ~ 'Data sz, IsWord (UntypedData sz)) =>
BitCount -> Word16 -> BitCount -> Word64 -> Field 'Slot a b
dataField BitCount
shift Word16
index BitCount
nbits Word64
defaultValue =
  forall (k :: FieldKind) a b. FieldLoc k (ReprFor b) -> Field k a b
F.Field forall a b. (a -> b) -> a -> b
$
    forall (a :: DataSz).
IsWord (UntypedData a) =>
DataFieldLoc a -> FieldLoc 'Slot ('Data a)
F.DataField @sz
      F.DataFieldLoc
        { BitCount
shift :: BitCount
shift :: BitCount
shift,
          Word16
index :: Word16
index :: Word16
index,
          mask :: Word64
mask = ((Word64
1 forall a. Bits a => a -> Int -> a
`shiftL` forall a b. (Integral a, Num b) => a -> b
fromIntegral BitCount
nbits) forall a. Num a => a -> a -> a
- Word64
1) forall a. Bits a => a -> Int -> a
`shiftL` forall a b. (Integral a, Num b) => a -> b
fromIntegral BitCount
shift,
          Word64
defaultValue :: Word64
defaultValue :: Word64
defaultValue
        }

ptrField :: forall a b. R.IsPtr b => Word16 -> F.Field 'F.Slot a b
ptrField :: forall a b. IsPtr b => Word16 -> Field 'Slot a b
ptrField = forall (k :: FieldKind) a b. FieldLoc k (ReprFor b) -> Field k a b
F.Field forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: Maybe PtrRepr).
IsPtrRepr a =>
Word16 -> FieldLoc 'Slot ('Ptr a)
F.PtrField @(R.PtrReprFor (R.ReprFor b))

groupField :: (R.ReprFor b ~ 'R.Ptr ('Just 'R.Struct)) => F.Field 'F.Group a b
groupField :: forall b a. (ReprFor b ~ 'Ptr ('Just 'Struct)) => Field 'Group a b
groupField = forall (k :: FieldKind) a b. FieldLoc k (ReprFor b) -> Field k a b
F.Field FieldLoc 'Group ('Ptr ('Just 'Struct))
F.GroupField

voidField :: (R.ReprFor b ~ 'R.Data 'R.Sz0) => F.Field 'F.Slot a b
voidField :: forall b a. (ReprFor b ~ 'Data 'Sz0) => Field 'Slot a b
voidField = forall (k :: FieldKind) a b. FieldLoc k (ReprFor b) -> Field k a b
F.Field FieldLoc 'Slot ('Data 'Sz0)
F.VoidField

-- | Like 'readField', but accepts a variant. Warning: *DOES NOT CHECK* that the
-- variant is the one that is set. This should only be used by generated code.
readVariant ::
  forall k a b mut m.
  ( R.IsStruct a,
    U.ReadCtx m mut
  ) =>
  F.Variant k a b ->
  R.Raw a mut ->
  m (R.Raw b mut)
readVariant :: forall (k :: FieldKind) a b (mut :: Mutability) (m :: * -> *).
(IsStruct a, ReadCtx m mut) =>
Variant k a b -> Raw a mut -> m (Raw b mut)
readVariant F.Variant {Field k a b
field :: forall (k :: FieldKind) a b. Variant k a b -> Field k a b
field :: Field k a b
field} = forall (k :: FieldKind) a b (mut :: Mutability) (m :: * -> *).
(IsStruct a, ReadCtx m mut) =>
Field k a b -> Raw a mut -> m (Raw b mut)
readField Field k a b
field

newStruct :: forall a m s. (U.RWCtx m s, NC.TypedStruct a) => () -> M.Message ('Mut s) -> m (R.Raw a ('Mut s))
newStruct :: forall a (m :: * -> *) s.
(RWCtx m s, TypedStruct a) =>
() -> Message ('Mut s) -> m (Raw a ('Mut s))
newStruct () Message ('Mut s)
msg = forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (mut :: Mutability).
Raw a mut -> Unwrapped (Untyped (ReprFor a) mut)
R.fromRaw forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *) s.
(Allocate a, RWCtx m s) =>
AllocHint a -> Message ('Mut s) -> m (Raw a ('Mut s))
NC.new @NB.AnyStruct (forall a. TypedStruct a => Word16
NC.numStructWords @a, forall a. TypedStruct a => Word16
NC.numStructPtrs @a) Message ('Mut s)
msg

parseEnum ::
  (R.ReprFor a ~ 'R.Data 'R.Sz16, Enum a, Applicative m) =>
  R.Raw a 'Const ->
  m a
parseEnum :: forall a (m :: * -> *).
(ReprFor a ~ 'Data 'Sz16, Enum a, Applicative m) =>
Raw a 'Const -> m a
parseEnum (R.Raw Unwrapped (Untyped (ReprFor a) 'Const)
n) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. Enum a => Int -> a
toEnum forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Unwrapped (Untyped (ReprFor a) 'Const)
n

encodeEnum ::
  forall a m s.
  (R.ReprFor a ~ 'R.Data 'R.Sz16, Enum a, U.RWCtx m s) =>
  M.Message ('Mut s) ->
  a ->
  m (R.Raw a ('Mut s))
encodeEnum :: forall a (m :: * -> *) s.
(ReprFor a ~ 'Data 'Sz16, Enum a, RWCtx m s) =>
Message ('Mut s) -> a -> m (Raw a ('Mut s))
encodeEnum Message ('Mut s)
_msg a
value = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Enum a => a -> Int
fromEnum @a a
value

-- | Get a pointer from a ByteString, where the root object is a struct with
-- one pointer, which is the pointer we will retrieve. This is only safe for
-- trusted inputs; it reads the message with a traversal limit of 'maxBound'
-- (and so is suseptable to denial of service attacks), and it calls 'error'
-- if decoding is not successful.
--
-- The purpose of this is for defining constants of pointer type from a schema.
getPtrConst :: forall a. R.IsPtr a => BS.ByteString -> R.Raw a 'Const
getPtrConst :: forall a. IsPtr a => ByteString -> Raw a 'Const
getPtrConst ByteString
bytes = forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT forall a. Bounded a => a
maxBound forall a b. (a -> b) -> a -> b
$ do
  R.Raw Unwrapped (Untyped (ReprFor AnyStruct) 'Const)
root <- forall a (m :: * -> *).
(ReadCtx m 'Const, IsStruct a) =>
ByteString -> m (Raw a 'Const)
bsToRaw @NB.AnyStruct ByteString
bytes
  forall (m :: * -> *) (msg :: Mutability).
ReadCtx m msg =>
Int -> Struct msg -> m (Maybe (Ptr msg))
U.getPtr Int
0 Unwrapped (Untyped (ReprFor AnyStruct) 'Const)
root
    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (r :: Maybe PtrRepr) (m :: * -> *) (mut :: Mutability).
(IsPtrRepr r, ReadCtx m mut) =>
Message mut
-> Maybe (Ptr mut) -> m (Unwrapped (Untyped ('Ptr r) mut))
R.fromPtr @(R.PtrReprFor (R.ReprFor a)) (forall (f :: Mutability -> *) (mut :: Mutability).
HasMessage f =>
Unwrapped (f mut) -> Message mut
U.message @U.Struct Unwrapped (Untyped (ReprFor AnyStruct) 'Const)
root)
    forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw