{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Capnp.Classes
( IsWord(..)
, ListElem(..)
, MutListElem(..)
, FromPtr(..)
, ToPtr(..)
, FromStruct(..)
, ToStruct(..)
, Allocate(..)
, Marshal(..)
, Cerialize(..)
, Decerialize(..)
, cerializeBasicVec
, cerializeCompositeVec
) where
import Prelude hiding (length)
import Data.Bits
import Data.Int
import Data.ReinterpretCast
import Data.Word
import Control.Monad.Catch (MonadThrow(throwM))
import Data.Foldable (for_)
import Capnp.Bits (Word1 (..))
import Capnp.Errors (Error(SchemaViolationError))
import Capnp.Untyped (Cap, ListOf, Ptr (..), ReadCtx, Struct, messageDefault)
import qualified Capnp.Message as M
import qualified Capnp.Untyped as U
import qualified Data.Vector as V
class IsWord a where
fromWord :: Word64 -> a
toWord :: a -> Word64
class ListElem msg e where
data List msg e
listFromPtr :: U.ReadCtx m msg => msg -> Maybe (U.Ptr msg) -> m (List msg e)
toUntypedList :: List msg e -> U.List msg
length :: List msg e -> Int
index :: U.ReadCtx m msg => Int -> List msg e -> m e
class (ListElem (M.MutMsg s) e) => MutListElem s e where
setIndex :: U.RWCtx m s => e -> Int -> List (M.MutMsg s) e -> m ()
newList :: M.WriteCtx m s => M.MutMsg s -> Int -> m (List (M.MutMsg s) e)
class Allocate s e where
new :: M.WriteCtx m s => M.MutMsg s -> m e
class Decerialize a where
type Cerial msg a
decerialize :: U.ReadCtx m M.ConstMsg => Cerial M.ConstMsg a -> m a
class Decerialize a => Marshal a where
marshalInto :: U.RWCtx m s => Cerial (M.MutMsg s) a -> a -> m ()
class Decerialize a => Cerialize a where
cerialize :: U.RWCtx m s => M.MutMsg s -> a -> m (Cerial (M.MutMsg s) a)
default cerialize :: (U.RWCtx m s, Marshal a, Allocate s (Cerial (M.MutMsg s) a))
=> M.MutMsg s -> a -> m (Cerial (M.MutMsg s) a)
cerialize msg value = do
raw <- new msg
marshalInto raw value
pure raw
class FromPtr msg a where
fromPtr :: ReadCtx m msg => msg -> Maybe (Ptr msg) -> m a
class ToPtr s a where
toPtr :: M.WriteCtx m s => M.MutMsg s -> a -> m (Maybe (Ptr (M.MutMsg s)))
class FromStruct msg a where
fromStruct :: ReadCtx m msg => Struct msg -> m a
class ToStruct msg a where
toStruct :: a -> Struct msg
instance IsWord Bool where
fromWord n = (n .&. 1) == 1
toWord True = 1
toWord False = 0
instance IsWord Word1 where
fromWord = Word1 . fromWord
toWord = toWord . word1ToBool
instance IsWord Int8 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Int16 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Int32 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Int64 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Word8 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Word16 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Word32 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Word64 where
fromWord = fromIntegral
toWord = fromIntegral
instance IsWord Float where
fromWord = wordToFloat . fromIntegral
toWord = fromIntegral . floatToWord
instance IsWord Double where
fromWord = wordToDouble
toWord = doubleToWord
expected :: MonadThrow m => String -> m a
expected msg = throwM $ SchemaViolationError $ "expected " ++ msg
instance FromPtr msg (ListOf msg ()) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.List0 list))) = pure list
fromPtr _ _ = expected "pointer to list with element size 0"
instance ToPtr s (ListOf (M.MutMsg s) ()) where
toPtr _ = pure . Just . PtrList . U.List0
instance FromPtr msg (ListOf msg Word8) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.List8 list))) = pure list
fromPtr _ _ = expected "pointer to list with element size 8"
instance ToPtr s (ListOf (M.MutMsg s) Word8) where
toPtr _ = pure . Just . PtrList . U.List8
instance FromPtr msg (ListOf msg Word16) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.List16 list))) = pure list
fromPtr _ _ = expected "pointer to list with element size 16"
instance ToPtr s (ListOf (M.MutMsg s) Word16) where
toPtr _ = pure . Just . PtrList . U.List16
instance FromPtr msg (ListOf msg Word32) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.List32 list))) = pure list
fromPtr _ _ = expected "pointer to list with element size 32"
instance ToPtr s (ListOf (M.MutMsg s) Word32) where
toPtr _ = pure . Just . PtrList . U.List32
instance FromPtr msg (ListOf msg Word64) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.List64 list))) = pure list
fromPtr _ _ = expected "pointer to list with element size 64"
instance ToPtr s (ListOf (M.MutMsg s) Word64) where
toPtr _ = pure . Just . PtrList . U.List64
instance FromPtr msg (ListOf msg Bool) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.List1 list))) = pure list
fromPtr _ _ = expected "pointer to list with element size 1."
instance ToPtr s (ListOf (M.MutMsg s) Bool) where
toPtr _ = pure . Just . PtrList . U.List1
instance FromPtr msg (Maybe (Ptr msg)) where
fromPtr _ = pure
instance ToPtr s (Maybe (Ptr (M.MutMsg s))) where
toPtr _ = pure
instance FromPtr msg (ListOf msg (Struct msg)) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.ListStruct list))) = pure list
fromPtr _ _ = expected "pointer to list of structs"
instance ToPtr s (ListOf (M.MutMsg s) (Struct (M.MutMsg s))) where
toPtr _ = pure . Just . PtrList . U.ListStruct
instance FromPtr msg (ListOf msg (Maybe (Ptr msg))) where
fromPtr msg Nothing = pure $ messageDefault msg
fromPtr _ (Just (PtrList (U.ListPtr list))) = pure list
fromPtr _ _ = expected "pointer to list of pointers"
instance ToPtr s (ListOf (M.MutMsg s) (Maybe (Ptr (M.MutMsg s)))) where
toPtr _ = pure . Just . PtrList . U.ListPtr
instance ListElem msg e => FromPtr msg (List msg e) where
fromPtr = listFromPtr
instance ListElem (M.MutMsg s) e => ToPtr s (List (M.MutMsg s) e) where
toPtr _ = pure . Just . PtrList . toUntypedList
instance ListElem msg e => ListElem msg (List msg e) where
newtype List msg (List msg e) = NestedList (U.ListOf msg (Maybe (U.Ptr msg)))
listFromPtr msg ptr = NestedList <$> fromPtr msg ptr
toUntypedList (NestedList l) = U.ListPtr l
length (NestedList l) = U.length l
index i (NestedList l) = do
ptr <- U.index i l
fromPtr (U.message l) ptr
instance MutListElem s e => MutListElem s (List (M.MutMsg s) e) where
setIndex e i (NestedList l) = U.setIndex (Just (U.PtrList (toUntypedList e))) i l
newList msg len = NestedList <$> U.allocListPtr msg len
instance FromStruct msg (Struct msg) where
fromStruct = pure
instance ToStruct msg (Struct msg) where
toStruct = id
instance FromPtr msg (Struct msg) where
fromPtr msg Nothing = fromStruct (go msg) where
go :: msg -> Struct msg
go = messageDefault
fromPtr _ (Just (PtrStruct s)) = fromStruct s
fromPtr _ _ = expected "pointer to struct"
instance ToPtr s (Struct (M.MutMsg s)) where
toPtr _ = pure . Just . PtrStruct
instance FromPtr msg (Maybe (Cap msg)) where
fromPtr _ Nothing = pure Nothing
fromPtr _ (Just (PtrCap cap)) = pure (Just cap)
fromPtr _ _ = expected "pointer to capability"
instance ToPtr s (Maybe (Cap (M.MutMsg s))) where
toPtr _ = pure . fmap PtrCap
cerializeBasicVec ::
( U.RWCtx m s
, MutListElem s (Cerial (M.MutMsg s) a)
, Cerialize a
)
=> M.MutMsg s
-> V.Vector a
-> m (List (M.MutMsg s) (Cerial (M.MutMsg s) a))
cerializeBasicVec msg vec = do
list <- newList msg (V.length vec)
for_ [0..V.length vec - 1] $ \i -> do
e <- cerialize msg (vec V.! i)
setIndex e i list
pure list
cerializeCompositeVec ::
( U.RWCtx m s
, MutListElem s (Cerial (M.MutMsg s) a)
, Marshal a
)
=> M.MutMsg s
-> V.Vector a
-> m (List (M.MutMsg s) (Cerial (M.MutMsg s) a))
cerializeCompositeVec msg vec = do
list <- newList msg (V.length vec)
for_ [0..V.length vec - 1] $ \i -> do
targ <- index i list
marshalInto targ (vec V.! i)
pure list
instance
( ListElem M.ConstMsg (Cerial M.ConstMsg a)
, Decerialize a
) => Decerialize (V.Vector a)
where
type Cerial msg (V.Vector a) = List msg (Cerial msg a)
decerialize raw = V.generateM (length raw) (\i -> index i raw >>= decerialize)