{-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeFamilies #-} {-| Module: Capnp.Untyped Description: Utilities for reading capnproto messages with no schema. The types and functions in this module know about things like structs and lists, but are not schema aware. Each of the data types exported by this module is parametrized over a Message type (see "Capnp.Message"), used as the underlying storage. -} module Capnp.Untyped ( Ptr(..), List(..), Struct, ListOf, Cap , dataSection, ptrSection , getData, getPtr , setData, setPtr , copyStruct , getClient , get, index, length , setIndex , take , rootPtr , setRoot , rawBytes , ReadCtx , RWCtx , HasMessage(..), MessageDefault(..) , allocStruct , allocCompositeList , allocList0 , allocList1 , allocList8 , allocList16 , allocList32 , allocList64 , allocListPtr , appendCap , TraverseMsg(..) ) where import Prelude hiding (length, take) import Data.Bits import Data.Word import Control.Monad (forM_) import Control.Monad.Catch (MonadThrow(throwM)) import qualified Data.ByteString as BS import Capnp.Address (OffsetError(..), WordAddr(..), pointerFrom) import Capnp.Bits ( BitCount(..) , ByteCount(..) , Word1(..) , WordCount(..) , bitsToBytesCeil , bytesToWordsCeil , replaceBits , wordsToBytes ) import Capnp.Pointer (ElementSize(..)) import Capnp.TraversalLimit (MonadLimit(invoice)) import Data.Mutable (Thaw(..)) import qualified Capnp.Errors as E import qualified Capnp.Message as M import qualified Capnp.Pointer as P -- | Type (constraint) synonym for the constraints needed for most read -- operations. type ReadCtx m msg = (M.Message m msg, MonadThrow m, MonadLimit m) -- | Synonym for ReadCtx + WriteCtx type RWCtx m s = (ReadCtx m (M.MutMsg s), M.WriteCtx m s) -- | A an absolute pointer to a value (of arbitrary type) in a message. -- Note that there is no variant for far pointers, which don't make sense -- with absolute addressing. data Ptr msg = PtrCap (Cap msg) | PtrList (List msg) | PtrStruct (Struct msg) -- | A list of values (of arbitrary type) in a message. data List msg = List0 (ListOf msg ()) | List1 (ListOf msg Bool) | List8 (ListOf msg Word8) | List16 (ListOf msg Word16) | List32 (ListOf msg Word32) | List64 (ListOf msg Word64) | ListPtr (ListOf msg (Maybe (Ptr msg))) | ListStruct (ListOf msg (Struct msg)) -- | A "normal" (non-composite) list. data NormalList msg = NormalList { nMsg :: msg , nAddr :: WordAddr , nLen :: Int } -- | A list of values of type 'a' in a message. data ListOf msg a where ListOfVoid :: msg -> !Int -- number of elements -> ListOf msg () ListOfStruct :: Struct msg -- First element. data/ptr sizes are the same for -- all elements. -> !Int -- Number of elements -> ListOf msg (Struct msg) ListOfBool :: !(NormalList msg) -> ListOf msg Bool ListOfWord8 :: !(NormalList msg) -> ListOf msg Word8 ListOfWord16 :: !(NormalList msg) -> ListOf msg Word16 ListOfWord32 :: !(NormalList msg) -> ListOf msg Word32 ListOfWord64 :: !(NormalList msg) -> ListOf msg Word64 ListOfPtr :: !(NormalList msg) -> ListOf msg (Maybe (Ptr msg)) -- | A Capability in a message. data Cap msg = Cap msg !Word32 -- | A struct value in a message. data Struct msg = Struct msg !WordAddr -- Start of struct !Word16 -- Data section size. !Word16 -- Pointer section size. -- | 'TraverseMsg' is basically 'Traversable' from the prelude, but -- the intent is that rather than conceptually being a "container", -- the instance is a value backed by a message, and the point of the -- type class is to be able to apply transformations to the underlying -- message. -- -- We don't just use 'Traversable' for this because while algebraically -- it makes sense, it would be very unintuitive to e.g. have the -- 'Traversable' instance for 'List' not traverse over the *elements* -- of the list. class TraverseMsg f where tMsg :: Applicative m => (msgA -> m msgB) -> f msgA -> m (f msgB) instance TraverseMsg Ptr where tMsg f = \case PtrCap cap -> PtrCap <$> tMsg f cap PtrList l -> PtrList <$> tMsg f l PtrStruct s -> PtrStruct <$> tMsg f s instance TraverseMsg Cap where tMsg f (Cap msg n) = Cap <$> f msg <*> pure n instance TraverseMsg Struct where tMsg f (Struct msg addr dataSz ptrSz) = Struct <$> f msg <*> pure addr <*> pure dataSz <*> pure ptrSz instance TraverseMsg List where tMsg f = \case List0 l -> List0 . unflip <$> tMsg f (FlipList l) List1 l -> List1 . unflip <$> tMsg f (FlipList l) List8 l -> List8 . unflip <$> tMsg f (FlipList l) List16 l -> List16 . unflip <$> tMsg f (FlipList l) List32 l -> List32 . unflip <$> tMsg f (FlipList l) List64 l -> List64 . unflip <$> tMsg f (FlipList l) ListPtr l -> ListPtr . unflipP <$> tMsg f (FlipListP l) ListStruct l -> ListStruct . unflipS <$> tMsg f (FlipListS l) instance TraverseMsg NormalList where tMsg f NormalList{..} = do msg <- f nMsg pure NormalList { nMsg = msg, .. } ------------------------------------------------------------------------------- -- newtype wrappers for the purpose of implementing 'TraverseMsg'; these adjust -- the shape of 'ListOf' so that we can define an instance. We need a couple -- different wrappers depending on the shape of the element type. ------------------------------------------------------------------------------- -- 'FlipList' wraps a @ListOf msg a@ where 'a' is of kind @*@. newtype FlipList a msg = FlipList { unflip :: ListOf msg a } -- 'FlipListS' wraps a @ListOf msg (Struct msg)@. We can't use 'FlipList' for -- our instances, because we need both instances of the 'msg' parameter to stay -- equal. newtype FlipListS msg = FlipListS { unflipS :: ListOf msg (Struct msg) } -- 'FlipListP' wraps a @ListOf msg (Maybe (Ptr msg))@. Pointers can't use -- 'FlipList' for the same reason as structs. newtype FlipListP msg = FlipListP { unflipP :: ListOf msg (Maybe (Ptr msg)) } ------------------------------------------------------------------------------- -- 'TraverseMsg' instances for 'FlipList' ------------------------------------------------------------------------------- instance TraverseMsg (FlipList ()) where tMsg f (FlipList (ListOfVoid msg len)) = FlipList <$> (ListOfVoid <$> f msg <*> pure len) instance TraverseMsg (FlipList Bool) where tMsg f (FlipList (ListOfBool nlist)) = FlipList . ListOfBool <$> tMsg f nlist instance TraverseMsg (FlipList Word8) where tMsg f (FlipList (ListOfWord8 nlist)) = FlipList . ListOfWord8 <$> tMsg f nlist instance TraverseMsg (FlipList Word16) where tMsg f (FlipList (ListOfWord16 nlist)) = FlipList . ListOfWord16 <$> tMsg f nlist instance TraverseMsg (FlipList Word32) where tMsg f (FlipList (ListOfWord32 nlist)) = FlipList . ListOfWord32 <$> tMsg f nlist instance TraverseMsg (FlipList Word64) where tMsg f (FlipList (ListOfWord64 nlist)) = FlipList . ListOfWord64 <$> tMsg f nlist ------------------------------------------------------------------------------- -- 'TraverseMsg' instances for struct and pointer lists. ------------------------------------------------------------------------------- instance TraverseMsg FlipListP where tMsg f (FlipListP (ListOfPtr nlist)) = FlipListP . ListOfPtr <$> tMsg f nlist instance TraverseMsg FlipListS where tMsg f (FlipListS (ListOfStruct tag size)) = FlipListS <$> (ListOfStruct <$> tMsg f tag <*> pure size) -- helpers for applying tMsg to a @ListOf@. tFlip :: (TraverseMsg (FlipList a), Applicative m) => (msgA -> m msg) -> ListOf msgA a -> m (ListOf msg a) tFlipS :: Applicative m => (msgA -> m msg) -> ListOf msgA (Struct msgA) -> m (ListOf msg (Struct msg)) tFlipP :: Applicative m => (msgA -> m msg) -> ListOf msgA (Maybe (Ptr msgA)) -> m (ListOf msg (Maybe (Ptr msg))) tFlip f list = unflip <$> tMsg f (FlipList list) tFlipS f list = unflipS <$> tMsg f (FlipListS list) tFlipP f list = unflipP <$> tMsg f (FlipListP list) ------------------------------------------------------------------------------- -- Boilerplate 'Thaw' instances. -- -- These all just call the underlying methods on the message, using 'TraverseMsg'. ------------------------------------------------------------------------------- instance Thaw a => Thaw (Maybe a) where type Mutable s (Maybe a) = Maybe (Mutable s a) thaw = traverse thaw freeze = traverse freeze unsafeThaw = traverse unsafeThaw unsafeFreeze = traverse unsafeFreeze instance Thaw msg => Thaw (Ptr msg) where type Mutable s (Ptr msg) = Ptr (Mutable s msg) thaw = tMsg thaw freeze = tMsg freeze unsafeThaw = tMsg unsafeThaw unsafeFreeze = tMsg unsafeFreeze instance Thaw msg => Thaw (List msg) where type Mutable s (List msg) = List (Mutable s msg) thaw = tMsg thaw freeze = tMsg freeze unsafeThaw = tMsg unsafeThaw unsafeFreeze = tMsg unsafeFreeze instance Thaw msg => Thaw (NormalList msg) where type Mutable s (NormalList msg) = NormalList (Mutable s msg) thaw = tMsg thaw freeze = tMsg freeze unsafeThaw = tMsg unsafeThaw unsafeFreeze = tMsg unsafeFreeze instance Thaw msg => Thaw (ListOf msg ()) where type Mutable s (ListOf msg ()) = ListOf (Mutable s msg) () thaw = tFlip thaw freeze = tFlip freeze unsafeThaw = tFlip unsafeThaw unsafeFreeze = tFlip unsafeFreeze instance Thaw msg => Thaw (ListOf msg Bool) where type Mutable s (ListOf msg Bool) = ListOf (Mutable s msg) Bool thaw = tFlip thaw freeze = tFlip freeze unsafeThaw = tFlip unsafeThaw unsafeFreeze = tFlip unsafeFreeze instance Thaw msg => Thaw (ListOf msg Word8) where type Mutable s (ListOf msg Word8) = ListOf (Mutable s msg) Word8 thaw = tFlip thaw freeze = tFlip freeze unsafeThaw = tFlip unsafeThaw unsafeFreeze = tFlip unsafeFreeze instance Thaw msg => Thaw (ListOf msg Word16) where type Mutable s (ListOf msg Word16) = ListOf (Mutable s msg) Word16 thaw = tFlip thaw freeze = tFlip freeze unsafeThaw = tFlip unsafeThaw unsafeFreeze = tFlip unsafeFreeze instance Thaw msg => Thaw (ListOf msg Word32) where type Mutable s (ListOf msg Word32) = ListOf (Mutable s msg) Word32 thaw = tFlip thaw freeze = tFlip freeze unsafeThaw = tFlip unsafeThaw unsafeFreeze = tFlip unsafeFreeze instance Thaw msg => Thaw (ListOf msg Word64) where type Mutable s (ListOf msg Word64) = ListOf (Mutable s msg) Word64 thaw = tFlip thaw freeze = tFlip freeze unsafeThaw = tFlip unsafeThaw unsafeFreeze = tFlip unsafeFreeze instance Thaw msg => Thaw (ListOf msg (Struct msg)) where type Mutable s (ListOf msg (Struct msg)) = ListOf (Mutable s msg) (Struct (Mutable s msg)) thaw = tFlipS thaw freeze = tFlipS freeze unsafeThaw = tFlipS unsafeThaw unsafeFreeze = tFlipS unsafeFreeze instance Thaw msg => Thaw (ListOf msg (Maybe (Ptr msg))) where type Mutable s (ListOf msg (Maybe (Ptr msg))) = ListOf (Mutable s msg) (Maybe (Ptr (Mutable s msg))) thaw = tFlipP thaw freeze = tFlipP freeze unsafeThaw = tFlipP unsafeThaw unsafeFreeze = tFlipP unsafeFreeze instance Thaw msg => Thaw (Struct msg) where type Mutable s (Struct msg) = Struct (Mutable s msg) thaw = tMsg thaw freeze = tMsg freeze unsafeThaw = tMsg unsafeThaw unsafeFreeze = tMsg unsafeFreeze ------------------------------------------------------------------------------- -- | Types @a@ whose storage is owned by a message.. class HasMessage a where -- | The type of the messages containing @a@s. type InMessage a -- | Get the message in which the @a@ is stored. message :: a -> InMessage a -- | Types which have a "default" value, but require a message -- to construct it. -- -- The default is usually conceptually zero-size. This is mostly useful -- for generated code, so that it can use standard decoding techniques -- on default values. class HasMessage a => MessageDefault a where messageDefault :: InMessage a -> a instance HasMessage (Ptr msg) where type InMessage (Ptr msg) = msg message (PtrCap cap) = message cap message (PtrList list) = message list message (PtrStruct struct) = message struct instance HasMessage (Cap msg) where type InMessage (Cap msg) = msg message (Cap msg _) = msg instance HasMessage (Struct msg) where type InMessage (Struct msg) = msg message (Struct msg _ _ _) = msg instance MessageDefault (Struct msg) where messageDefault msg = Struct msg (WordAt 0 0) 0 0 instance HasMessage (List msg) where type InMessage (List msg) = msg message (List0 list) = message list message (List1 list) = message list message (List8 list) = message list message (List16 list) = message list message (List32 list) = message list message (List64 list) = message list message (ListPtr list) = message list message (ListStruct list) = message list instance HasMessage (ListOf msg a) where type InMessage (ListOf msg a) = msg message (ListOfVoid msg _) = msg message (ListOfStruct tag _) = message tag message (ListOfBool list) = message list message (ListOfWord8 list) = message list message (ListOfWord16 list) = message list message (ListOfWord32 list) = message list message (ListOfWord64 list) = message list message (ListOfPtr list) = message list instance MessageDefault (ListOf msg ()) where messageDefault msg = ListOfVoid msg 0 instance MessageDefault (ListOf msg (Struct msg)) where messageDefault msg = ListOfStruct (messageDefault msg) 0 instance MessageDefault (ListOf msg Bool) where messageDefault msg = ListOfBool (messageDefault msg) instance MessageDefault (ListOf msg Word8) where messageDefault msg = ListOfWord8 (messageDefault msg) instance MessageDefault (ListOf msg Word16) where messageDefault msg = ListOfWord16 (messageDefault msg) instance MessageDefault (ListOf msg Word32) where messageDefault msg = ListOfWord32 (messageDefault msg) instance MessageDefault (ListOf msg Word64) where messageDefault msg = ListOfWord64 (messageDefault msg) instance MessageDefault (ListOf msg (Maybe (Ptr msg))) where messageDefault msg = ListOfPtr (messageDefault msg) instance HasMessage (NormalList msg) where type InMessage (NormalList msg) = msg message = nMsg instance MessageDefault (NormalList msg) where messageDefault msg = NormalList msg (WordAt 0 0) 0 getClient :: ReadCtx m msg => Cap msg -> m M.Client getClient (Cap msg idx) = M.getCap msg (fromIntegral idx) -- | @get msg addr@ returns the Ptr stored at @addr@ in @msg@. -- Deducts 1 from the quota for each word read (which may be multiple in the -- case of far pointers). get :: ReadCtx m msg => msg -> WordAddr -> m (Maybe (Ptr msg)) get msg addr = do word <- getWord msg addr case P.parsePtr word of Nothing -> return Nothing Just p -> case p of P.CapPtr cap -> return $ Just $ PtrCap (Cap msg cap) P.StructPtr off dataSz ptrSz -> return $ Just $ PtrStruct $ Struct msg (resolveOffset addr off) dataSz ptrSz P.ListPtr off eltSpec -> Just <$> getList (resolveOffset addr off) eltSpec P.FarPtr twoWords offset segment -> do let addr' = WordAt { wordIndex = fromIntegral offset , segIndex = fromIntegral segment } if not twoWords then get msg addr' else do landingPad <- getWord msg addr' case P.parsePtr landingPad of Just (P.FarPtr False off seg) -> do tagWord <- getWord msg addr' { wordIndex = wordIndex addr' + 1 } let finalAddr = WordAt { wordIndex = fromIntegral off , segIndex = fromIntegral seg } case P.parsePtr tagWord of Just (P.StructPtr 0 dataSz ptrSz) -> return $ Just $ PtrStruct $ Struct msg finalAddr dataSz ptrSz Just (P.ListPtr 0 eltSpec) -> Just <$> getList finalAddr eltSpec -- TODO: I'm not sure whether far pointers to caps are -- legal; it's clear how they would work, but I don't -- see a use, and the spec is unclear. Should check -- how the reference implementation does this, copy -- that, and submit a patch to the spec. Just (P.CapPtr cap) -> return $ Just $ PtrCap (Cap msg cap) ptr -> throwM $ E.InvalidDataError $ "The tag word of a far pointer's " ++ "2-word landing pad should be an intra " ++ "segment pointer with offset 0, but " ++ "we read " ++ show ptr ptr -> throwM $ E.InvalidDataError $ "The first word of a far pointer's 2-word " ++ "landing pad should be another far pointer " ++ "(with a one-word landing pad), but we read " ++ show ptr where getWord msg addr = invoice 1 *> M.getWord msg addr resolveOffset addr@WordAt{..} off = addr { wordIndex = wordIndex + fromIntegral off + 1 } getList addr@WordAt{..} eltSpec = PtrList <$> case eltSpec of P.EltNormal sz len -> pure $ case sz of Sz0 -> List0 (ListOfVoid msg (fromIntegral len)) Sz1 -> List1 (ListOfBool nlist) Sz8 -> List8 (ListOfWord8 nlist) Sz16 -> List16 (ListOfWord16 nlist) Sz32 -> List32 (ListOfWord32 nlist) Sz64 -> List64 (ListOfWord64 nlist) SzPtr -> ListPtr (ListOfPtr nlist) where nlist = NormalList msg addr (fromIntegral len) P.EltComposite _ -> do tagWord <- getWord msg addr case P.parsePtr' tagWord of P.StructPtr numElts dataSz ptrSz -> pure $ ListStruct $ ListOfStruct (Struct msg addr { wordIndex = wordIndex + 1 } dataSz ptrSz) (fromIntegral numElts) tag -> throwM $ E.InvalidDataError $ "Composite list tag was not a struct-" ++ "formatted word: " ++ show tag -- | Return the EltSpec needed for a pointer to the given list. listEltSpec :: List msg -> P.EltSpec listEltSpec (ListStruct list@(ListOfStruct (Struct _ _ dataSz ptrSz) _)) = P.EltComposite $ fromIntegral (length list) * (fromIntegral dataSz + fromIntegral ptrSz) listEltSpec (List0 list) = P.EltNormal Sz0 $ fromIntegral (length list) listEltSpec (List1 list) = P.EltNormal Sz1 $ fromIntegral (length list) listEltSpec (List8 list) = P.EltNormal Sz8 $ fromIntegral (length list) listEltSpec (List16 list) = P.EltNormal Sz16 $ fromIntegral (length list) listEltSpec (List32 list) = P.EltNormal Sz32 $ fromIntegral (length list) listEltSpec (List64 list) = P.EltNormal Sz64 $ fromIntegral (length list) listEltSpec (ListPtr list) = P.EltNormal SzPtr $ fromIntegral (length list) -- | Return the starting address of the list. listAddr :: List msg -> WordAddr listAddr (ListStruct (ListOfStruct (Struct _ addr _ _) _)) = -- addr is the address of the first element of the list, but -- composite lists start with a tag word: addr { wordIndex = wordIndex addr - 1 } listAddr (List0 _) = WordAt { segIndex = 0, wordIndex = 1 } listAddr (List1 (ListOfBool NormalList{nAddr})) = nAddr listAddr (List8 (ListOfWord8 NormalList{nAddr})) = nAddr listAddr (List16 (ListOfWord16 NormalList{nAddr})) = nAddr listAddr (List32 (ListOfWord32 NormalList{nAddr})) = nAddr listAddr (List64 (ListOfWord64 NormalList{nAddr})) = nAddr listAddr (ListPtr (ListOfPtr NormalList{nAddr})) = nAddr -- | Return the address of the pointer's target. It is illegal to call this on -- a pointer which targets a capability. ptrAddr :: Ptr msg -> WordAddr ptrAddr (PtrCap _) = error "ptrAddr called on a capability pointer." ptrAddr (PtrStruct (Struct _ addr _ _)) = addr ptrAddr (PtrList list) = listAddr list -- | @'setIndex value i list@ Set the @i@th element of @list@ to @value@. setIndex :: RWCtx m s => a -> Int -> ListOf (M.MutMsg s) a -> m () setIndex _ i list | length list <= i = throwM E.BoundsError { E.index = i, E.maxIndex = length list } setIndex value i list = case list of ListOfVoid _ _ -> pure () ListOfBool nlist -> setNIndex nlist 64 (Word1 value) ListOfWord8 nlist -> setNIndex nlist 8 value ListOfWord16 nlist -> setNIndex nlist 4 value ListOfWord32 nlist -> setNIndex nlist 2 value ListOfWord64 nlist -> setNIndex nlist 1 value ListOfPtr nlist -> case value of Just p | message p /= message list -> do newPtr <- copyPtr (message list) value setIndex newPtr i list Nothing -> setNIndex nlist 1 (P.serializePtr Nothing) Just (PtrCap (Cap _ cap)) -> setNIndex nlist 1 (P.serializePtr (Just (P.CapPtr cap))) Just p@(PtrList ptrList) -> setPtrIndex nlist p $ P.ListPtr 0 (listEltSpec ptrList) Just p@(PtrStruct (Struct _ _ dataSz ptrSz)) -> setPtrIndex nlist p $ P.StructPtr 0 dataSz ptrSz list@(ListOfStruct _ _) -> do dest <- index i list copyStruct dest value where setNIndex :: (ReadCtx m (M.MutMsg s), M.WriteCtx m s, Bounded a, Integral a) => NormalList (M.MutMsg s) -> Int -> a -> m () setNIndex NormalList{nAddr=nAddr@WordAt{..},..} eltsPerWord value = do let wordAddr = nAddr { wordIndex = wordIndex + WordCount (i `div` eltsPerWord) } word <- M.getWord nMsg wordAddr let shift = (i `mod` eltsPerWord) * (64 `div` eltsPerWord) M.setWord nMsg wordAddr $ replaceBits value word shift setPtrIndex :: (ReadCtx m (M.MutMsg s), M.WriteCtx m s) => NormalList (M.MutMsg s) -> Ptr (M.MutMsg s) -> P.Ptr -> m () setPtrIndex NormalList{..} absPtr relPtr = let srcAddr = nAddr { wordIndex = wordIndex nAddr + WordCount i } in setPointerTo nMsg srcAddr (ptrAddr absPtr) relPtr -- | @'setPointerTo' msg srcAddr dstAddr relPtr@ sets the word at @srcAddr@ in @msg@ to a -- pointer like @relPtr@, but pointing to @dstAddr@. @relPtr@ should not be a far pointer. -- If the two addresses are in different segments, a landing pad will be allocated and -- @dstAddr@ will contain a far pointer. setPointerTo :: M.WriteCtx m s => M.MutMsg s -> WordAddr -> WordAddr -> P.Ptr -> m () setPointerTo msg srcAddr dstAddr relPtr = case pointerFrom srcAddr dstAddr relPtr of Right absPtr -> M.setWord msg srcAddr (P.serializePtr $ Just absPtr) Left OutOfRange -> error "BUG: segment is too large to set the pointer." Left DifferentSegments -> do -- We need a far pointer; allocate a landing pad in the target segment, -- set it to point to the final destination, an then set the source pointer -- pointer to point to the landing pad. let WordAt{segIndex} = dstAddr landingPadAddr <- M.allocInSeg msg segIndex 1 case pointerFrom landingPadAddr dstAddr relPtr of Right landingPad -> do M.setWord msg landingPadAddr (P.serializePtr $ Just landingPad) let WordAt{segIndex,wordIndex} = landingPadAddr M.setWord msg srcAddr $ P.serializePtr $ Just $ P.FarPtr False (fromIntegral wordIndex) (fromIntegral segIndex) Left DifferentSegments -> error "BUG: allocated a landing pad in the wrong segment!" Left OutOfRange -> error "BUG: segment is too large to set the pointer." copyCap :: RWCtx m s => M.MutMsg s -> Cap (M.MutMsg s) -> m (Cap (M.MutMsg s)) copyCap dest cap = getClient cap >>= appendCap dest copyPtr :: RWCtx m s => M.MutMsg s -> Maybe (Ptr (M.MutMsg s)) -> m (Maybe (Ptr (M.MutMsg s))) copyPtr _ Nothing = pure Nothing copyPtr dest (Just (PtrCap cap)) = Just . PtrCap <$> copyCap dest cap copyPtr dest (Just (PtrList src)) = Just . PtrList <$> copyList dest src copyPtr dest (Just (PtrStruct src)) = Just . PtrStruct <$> do destStruct <- allocStruct dest (fromIntegral $ length (dataSection src)) (fromIntegral $ length (ptrSection src)) copyStruct destStruct src pure destStruct copyList :: RWCtx m s => M.MutMsg s -> List (M.MutMsg s) -> m (List (M.MutMsg s)) copyList dest src = case src of List0 src -> List0 <$> allocList0 dest (length src) List1 src -> List1 <$> copyNewListOf dest src allocList1 List8 src -> List8 <$> copyNewListOf dest src allocList8 List16 src -> List16 <$> copyNewListOf dest src allocList16 List32 src -> List32 <$> copyNewListOf dest src allocList32 List64 src -> List64 <$> copyNewListOf dest src allocList64 ListPtr src -> ListPtr <$> copyNewListOf dest src allocListPtr ListStruct src -> ListStruct <$> do destList <- allocCompositeList dest (structListDataCount src) (structListPtrCount src) (length src) copyListOf destList src pure destList copyNewListOf :: RWCtx m s => M.MutMsg s -> ListOf (M.MutMsg s) a -> (M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) a)) -> m (ListOf (M.MutMsg s) a) copyNewListOf destMsg src new = do dest <- new destMsg (length src) copyListOf dest src pure dest copyListOf :: RWCtx m s => ListOf (M.MutMsg s) a -> ListOf (M.MutMsg s) a -> m () copyListOf dest src = forM_ [0..length src - 1] $ \i -> do value <- index i src setIndex value i dest -- | @'copyStruct' dest src@ copies the source struct to the destination struct. copyStruct :: RWCtx m s => Struct (M.MutMsg s) -> Struct (M.MutMsg s) -> m () copyStruct dest src = do -- We copy both the data and pointer sections from src to dest, -- padding the tail of the destination section with zeros/null -- pointers as necessary. If the destination section is -- smaller than the source section, this will raise a BoundsError. -- -- TODO: possible enhancement: allow the destination section to be -- smaller than the source section if and only if the tail of the -- source section is all zeros (default values). copySection (dataSection dest) (dataSection src) 0 copySection (ptrSection dest) (ptrSection src) Nothing where copySection dest src pad = do -- Copy the source section to the destination section: copyListOf dest src -- Pad the remainder with zeros/default values: forM_ [length src..length dest - 1] $ \i -> setIndex pad i dest -- | @index i list@ returns the ith element in @list@. Deducts 1 from the quota index :: ReadCtx m msg => Int -> ListOf msg a -> m a index i list = invoice 1 >> index' list where index' :: ReadCtx m msg => ListOf msg a -> m a index' (ListOfVoid _ len) | i < len = pure () | otherwise = throwM E.BoundsError { E.index = i, E.maxIndex = len - 1 } index' (ListOfStruct (Struct msg addr@WordAt{..} dataSz ptrSz) len) | i < len = do let offset = WordCount $ i * (fromIntegral dataSz + fromIntegral ptrSz) let addr' = addr { wordIndex = wordIndex + offset } return $ Struct msg addr' dataSz ptrSz | otherwise = throwM E.BoundsError { E.index = i, E.maxIndex = len - 1} index' (ListOfBool nlist) = do Word1 val <- indexNList nlist 64 pure val index' (ListOfWord8 nlist) = indexNList nlist 8 index' (ListOfWord16 nlist) = indexNList nlist 4 index' (ListOfWord32 nlist) = indexNList nlist 2 index' (ListOfWord64 (NormalList msg addr@WordAt{..} len)) | i < len = M.getWord msg addr { wordIndex = wordIndex + WordCount i } | otherwise = throwM E.BoundsError { E.index = i, E.maxIndex = len - 1} index' (ListOfPtr (NormalList msg addr@WordAt{..} len)) | i < len = get msg addr { wordIndex = wordIndex + WordCount i } | otherwise = throwM E.BoundsError { E.index = i, E.maxIndex = len - 1} indexNList :: (ReadCtx m msg, Integral a) => NormalList msg -> Int -> m a indexNList (NormalList msg addr@WordAt{..} len) eltsPerWord | i < len = do let wordIndex' = wordIndex + WordCount (i `div` eltsPerWord) word <- M.getWord msg addr { wordIndex = wordIndex' } let shift = (i `mod` eltsPerWord) * (64 `div` eltsPerWord) pure $ fromIntegral $ word `shiftR` shift | otherwise = throwM E.BoundsError { E.index = i, E.maxIndex = len - 1 } -- | Returns the length of a list length :: ListOf msg a -> Int length (ListOfVoid _ len) = len length (ListOfStruct _ len) = len length (ListOfBool nlist) = nLen nlist length (ListOfWord8 nlist) = nLen nlist length (ListOfWord16 nlist) = nLen nlist length (ListOfWord32 nlist) = nLen nlist length (ListOfWord64 nlist) = nLen nlist length (ListOfPtr nlist) = nLen nlist -- | Return a prefix of the list, of the given length. take :: MonadThrow m => Int -> ListOf msg a -> m (ListOf msg a) take count list | length list < count = throwM E.BoundsError { E.index = count, E.maxIndex = length list - 1 } | otherwise = pure $ go list where go (ListOfVoid msg _) = ListOfVoid msg count go (ListOfStruct tag _) = ListOfStruct tag count go (ListOfBool nlist) = ListOfBool $ nTake nlist go (ListOfWord8 nlist) = ListOfWord8 $ nTake nlist go (ListOfWord16 nlist) = ListOfWord16 $ nTake nlist go (ListOfWord32 nlist) = ListOfWord32 $ nTake nlist go (ListOfWord64 nlist) = ListOfWord64 $ nTake nlist go (ListOfPtr nlist) = ListOfPtr $ nTake nlist nTake :: NormalList msg -> NormalList msg nTake NormalList{..} = NormalList { nLen = count, .. } -- | The data section of a struct, as a list of Word64 dataSection :: Struct msg -> ListOf msg Word64 dataSection (Struct msg addr dataSz _) = ListOfWord64 $ NormalList msg addr (fromIntegral dataSz) -- | The pointer section of a struct, as a list of Ptr ptrSection :: Struct msg -> ListOf msg (Maybe (Ptr msg)) ptrSection (Struct msg addr@WordAt{..} dataSz ptrSz) = ListOfPtr $ NormalList msg addr { wordIndex = wordIndex + fromIntegral dataSz } (fromIntegral ptrSz) -- | Get the size (in words) of a struct's data section. structDataCount :: Struct msg -> Word16 structDataCount = fromIntegral . length . dataSection -- | Get the size of a struct's pointer section. structPtrCount :: Struct msg -> Word16 structPtrCount = fromIntegral . length . ptrSection -- | Get the size (in words) of the data sections in a composite list. structListDataCount :: ListOf msg (Struct msg) -> Word16 structListDataCount (ListOfStruct s _) = structDataCount s -- | Get the size of the pointer sections in a composite list. structListPtrCount :: ListOf msg (Struct msg) -> Word16 structListPtrCount (ListOfStruct s _) = structPtrCount s -- | @'getData' i struct@ gets the @i@th word from the struct's data section, -- returning 0 if it is absent. getData :: ReadCtx m msg => Int -> Struct msg -> m Word64 getData i struct | length (dataSection struct) <= i = 0 <$ invoice 1 | otherwise = index i (dataSection struct) -- | @'getPtr' i struct@ gets the @i@th word from the struct's pointer section, -- returning Nothing if it is absent. getPtr :: ReadCtx m msg => Int -> Struct msg -> m (Maybe (Ptr msg)) getPtr i struct | length (ptrSection struct) <= i = Nothing <$ invoice 1 | otherwise = index i (ptrSection struct) -- | @'setData' value i struct@ sets the @i@th word in the struct's data section -- to @value@. setData :: (ReadCtx m (M.MutMsg s), M.WriteCtx m s) => Word64 -> Int -> Struct (M.MutMsg s) -> m () setData value i = setIndex value i . dataSection -- | @'setData' value i struct@ sets the @i@th pointer in the struct's pointer -- section to @value@. setPtr :: (ReadCtx m (M.MutMsg s), M.WriteCtx m s) => Maybe (Ptr (M.MutMsg s)) -> Int -> Struct (M.MutMsg s) -> m () setPtr value i = setIndex value i . ptrSection -- | 'rawBytes' returns the raw bytes corresponding to the list. rawBytes :: ReadCtx m msg => ListOf msg Word8 -> m BS.ByteString rawBytes (ListOfWord8 (NormalList msg WordAt{..} len)) = do invoice $ fromIntegral $ bytesToWordsCeil (ByteCount len) bytes <- M.getSegment msg segIndex >>= M.toByteString let ByteCount byteOffset = wordsToBytes wordIndex pure $ BS.take len $ BS.drop byteOffset bytes -- | Returns the root pointer of a message. rootPtr :: ReadCtx m msg => msg -> m (Struct msg) rootPtr msg = do root <- get msg (WordAt 0 0) case root of Just (PtrStruct struct) -> pure struct Nothing -> pure (messageDefault msg) _ -> throwM $ E.SchemaViolationError "Unexpected root type; expected struct." -- | Make the given struct the root object of its message. setRoot :: M.WriteCtx m s => Struct (M.MutMsg s) -> m () setRoot (Struct msg addr dataSz ptrSz) = setPointerTo msg (WordAt 0 0) addr (P.StructPtr 0 dataSz ptrSz) -- | Allocate a struct in the message. allocStruct :: M.WriteCtx m s => M.MutMsg s -> Word16 -> Word16 -> m (Struct (M.MutMsg s)) allocStruct msg dataSz ptrSz = do let totalSz = fromIntegral dataSz + fromIntegral ptrSz addr <- M.alloc msg totalSz pure $ Struct msg addr dataSz ptrSz -- | Allocate a composite list. allocCompositeList :: M.WriteCtx m s => M.MutMsg s -- ^ The message to allocate in. -> Word16 -- ^ The size of the data sections -> Word16 -- ^ The size of the pointer sections -> Int -- ^ The length of the list in elements. -> m (ListOf (M.MutMsg s) (Struct (M.MutMsg s))) allocCompositeList msg dataSz ptrSz len = do let eltSize = fromIntegral dataSz + fromIntegral ptrSz addr <- M.alloc msg (WordCount $ len * eltSize + 1) -- + 1 for the tag word. M.setWord msg addr $ P.serializePtr' $ P.StructPtr (fromIntegral len) dataSz ptrSz let firstStruct = Struct msg addr { wordIndex = wordIndex addr + 1 } dataSz ptrSz pure $ ListOfStruct firstStruct len -- | Allocate a list of capnproto @Void@ values. allocList0 :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) ()) -- | Allocate a list of booleans allocList1 :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) Bool) -- | Allocate a list of 8-bit values. allocList8 :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) Word8) -- | Allocate a list of 16-bit values. allocList16 :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) Word16) -- | Allocate a list of 32-bit values. allocList32 :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) Word32) -- | Allocate a list of 64-bit words. allocList64 :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) Word64) -- | Allocate a list of pointers. allocListPtr :: M.WriteCtx m s => M.MutMsg s -> Int -> m (ListOf (M.MutMsg s) (Maybe (Ptr (M.MutMsg s)))) allocList0 msg len = pure $ ListOfVoid msg len allocList1 msg len = ListOfBool <$> allocNormalList 1 msg len allocList8 msg len = ListOfWord8 <$> allocNormalList 8 msg len allocList16 msg len = ListOfWord16 <$> allocNormalList 16 msg len allocList32 msg len = ListOfWord32 <$> allocNormalList 32 msg len allocList64 msg len = ListOfWord64 <$> allocNormalList 64 msg len allocListPtr msg len = ListOfPtr <$> allocNormalList 64 msg len -- | Allocate a NormalList allocNormalList :: M.WriteCtx m s => Int -- ^ The number of elements per 64-bit word -> M.MutMsg s -- ^ The message to allocate in -> Int -- ^ The number of bits per element -> m (NormalList (M.MutMsg s)) allocNormalList bitsPerElt msg len = do -- round 'len' up to the nearest word boundary. let totalBits = BitCount (len * bitsPerElt) totalWords = bytesToWordsCeil $ bitsToBytesCeil totalBits addr <- M.alloc msg totalWords pure NormalList { nMsg = msg , nAddr = addr , nLen = len } appendCap :: M.WriteCtx m s => M.MutMsg s -> M.Client -> m (Cap (M.MutMsg s)) appendCap msg client = do i <- M.appendCap msg client pure $ Cap msg (fromIntegral i)