{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Capnp.Message (
hPutMsg
, hGetMsg
, putMsg
, getMsg
, maxSegmentSize
, maxSegments
, encode
, decode
, Message(..)
, empty
, ConstMsg
, getSegment
, getWord
, MutMsg
, newMessage
, alloc
, allocInSeg
, newSegment
, setWord
, setSegment
, WriteCtx(..)
) where
import Prelude hiding (read)
import Data.Bits (shiftL)
import Control.Monad (void, when, (>=>))
import Control.Monad.Catch (MonadThrow(..))
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.State (evalStateT, get, put)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Writer (execWriterT, tell)
import Data.Bytes.Get (getWord32le, runGetS)
import Data.ByteString.Internal (ByteString(..))
import Data.Either (fromRight)
import Data.Maybe (fromJust)
import Data.Primitive (MutVar, newMutVar, readMutVar, writeMutVar)
import Data.Word (Word32, Word64)
import System.Endian (fromLE64, toLE64)
import System.IO (Handle, stdin, stdout)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BB
import qualified Data.Vector as V
import qualified Data.Vector.Generic.Mutable as GMV
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Storable.Mutable as SMV
import Data.Capnp.Address (WordAddr(..))
import Data.Capnp.Bits (WordCount(..), hi, lo)
import Data.Capnp.TraversalLimit (LimitT, MonadLimit(invoice), evalLimitT)
import Data.Mutable (Mutable(..))
import Internal.AppendVec (AppendVec)
import Internal.Util (checkIndex)
import qualified Internal.AppendVec as AppendVec
maxSegmentSize :: Int
maxSegmentSize = 1 `shiftL` 28
maxSegments :: Int
maxSegments = 1024
class Monad m => Message m msg where
data Segment msg
numSegs :: msg -> m Int
internalGetSeg :: msg -> Int -> m (Segment msg)
numWords :: Segment msg -> m Int
slice :: Int -> Int -> Segment msg -> m (Segment msg)
read :: Segment msg -> Int -> m Word64
fromByteString :: ByteString -> m (Segment msg)
toByteString :: Segment msg -> m ByteString
getSegment :: (MonadThrow m, Message m msg) => msg -> Int -> m (Segment msg)
getSegment msg i = do
checkIndex i =<< numSegs msg
internalGetSeg msg i
getWord :: (MonadThrow m, Message m msg) => msg -> WordAddr -> m Word64
getWord msg WordAt{wordIndex=wordIndex@(WordCount i), segIndex} = do
seg <- getSegment msg segIndex
checkIndex i =<< numWords seg
seg `read` i
setSegment :: (WriteCtx m s, MonadThrow m) => MutMsg s -> Int -> Segment (MutMsg s) -> m ()
setSegment msg i seg = do
checkIndex i =<< numSegs msg
internalSetSeg msg i seg
setWord :: (WriteCtx m s, MonadThrow m) => MutMsg s -> WordAddr -> Word64 -> m ()
setWord msg WordAt{wordIndex=WordCount i, segIndex} val = do
seg <- getSegment msg segIndex
checkIndex i =<< numWords seg
write seg i val
newtype ConstMsg = ConstMsg (V.Vector (Segment ConstMsg))
instance Monad m => Message m ConstMsg where
newtype Segment ConstMsg = ConstSegment { constSegToVec :: SV.Vector Word64 }
numSegs (ConstMsg vec) = pure $ V.length vec
internalGetSeg (ConstMsg vec) i = vec `V.indexM` i
numWords (ConstSegment vec) = pure $ SV.length vec
slice start len (ConstSegment vec) = pure $ ConstSegment (SV.slice start len vec)
read (ConstSegment vec) i = fromLE64 <$> vec `SV.indexM` i
fromByteString (PS fptr offset len) =
pure $ ConstSegment (SV.unsafeCast $ SV.unsafeFromForeignPtr fptr offset len)
toByteString (ConstSegment vec) = pure $ PS fptr offset len where
(fptr, offset, len) = SV.unsafeToForeignPtr (SV.unsafeCast vec)
decode :: MonadThrow m => ByteString -> m ConstMsg
decode bytes = fromByteString bytes >>= decodeSeg
encode :: Monad m => ConstMsg -> m BB.Builder
encode msg =
pure $ fromJust $ execWriterT $ writeMessage
msg
(tell . BB.word32LE)
(toByteString >=> tell . BB.byteString)
decodeSeg :: MonadThrow m => Segment ConstMsg -> m ConstMsg
decodeSeg seg = do
len <- numWords seg
flip evalStateT (Nothing, 0) $ evalLimitT len $
readMessage read32 readSegment
where
read32 = do
(cur, idx) <- get
case cur of
Just n -> do
put (Nothing, idx)
return n
Nothing -> do
word <- lift $ lift $ read seg idx
put (Just $ hi word, idx + 1)
return (lo word)
readSegment (WordCount len) = do
(cur, idx) <- get
put (cur, idx + len)
lift $ lift $ slice idx len seg
readMessage :: (MonadThrow m, MonadLimit m) => m Word32 -> (WordCount -> m (Segment ConstMsg)) -> m ConstMsg
readMessage read32 readSegment = do
invoice 1
numSegs' <- read32
let numSegs = numSegs' + 1
invoice (fromIntegral numSegs `div` 2)
segSizes <- V.replicateM (fromIntegral numSegs) read32
when (numSegs `mod` 2 == 0) $ void read32
V.mapM_ (invoice . fromIntegral) segSizes
ConstMsg <$> V.mapM (readSegment . fromIntegral) segSizes
writeMessage :: MonadThrow m => ConstMsg -> (Word32 -> m ()) -> (Segment ConstMsg -> m ()) -> m ()
writeMessage (ConstMsg segs) write32 writeSegment = do
let numSegs = V.length segs
write32 (fromIntegral numSegs - 1)
V.forM_ segs $ \seg -> write32 =<< fromIntegral <$> numWords seg
when (numSegs `mod` 2 == 0) $ write32 0
V.forM_ segs writeSegment
hPutMsg :: Handle -> ConstMsg -> IO ()
hPutMsg handle msg = encode msg >>= BB.hPutBuilder handle
putMsg :: ConstMsg -> IO ()
putMsg = hPutMsg stdout
hGetMsg :: Handle -> Int -> IO ConstMsg
hGetMsg handle size =
evalLimitT size $ readMessage read32 readSegment
where
read32 :: LimitT IO Word32
read32 = lift $ do
bytes <- BS.hGet handle 4
pure $ fromRight (error "impossible") (runGetS getWord32le bytes)
readSegment n = lift $ BS.hGet handle (fromIntegral n * 8) >>= fromByteString
getMsg :: Int -> IO ConstMsg
getMsg = hGetMsg stdin
newtype MutMsg s = MutMsg (MutVar s (AppendVec MV.MVector s (Segment (MutMsg s))))
deriving(Eq)
type WriteCtx m s = (PrimMonad m, s ~ PrimState m, MonadThrow m)
instance (PrimMonad m, s ~ PrimState m) => Message m (MutMsg s) where
newtype Segment (MutMsg s) = MutSegment (AppendVec SMV.MVector s Word64)
numWords (MutSegment mseg) = pure $ GMV.length (AppendVec.getVector mseg)
slice start len (MutSegment mseg) =
pure $ MutSegment $ AppendVec.fromVector $
SMV.slice start len (AppendVec.getVector mseg)
read (MutSegment mseg) i = fromLE64 <$> SMV.read (AppendVec.getVector mseg) i
fromByteString bytes = do
vec <- constSegToVec <$> fromByteString bytes
MutSegment . AppendVec.fromVector <$> SV.thaw vec
toByteString mseg = do
seg <- freeze mseg
toByteString (seg :: Segment ConstMsg)
numSegs (MutMsg segVar) = GMV.length . AppendVec.getVector <$> readMutVar segVar
internalGetSeg (MutMsg segVar) i = do
segs <- AppendVec.getVector <$> readMutVar segVar
MV.read segs i
internalSetSeg :: WriteCtx m s => MutMsg s -> Int -> Segment (MutMsg s) -> m ()
internalSetSeg (MutMsg segVar) segIndex seg = do
segs <- AppendVec.getVector <$> readMutVar segVar
MV.write segs segIndex seg
write :: WriteCtx m s => Segment (MutMsg s) -> Int -> Word64 -> m ()
write (MutSegment seg) i val =
SMV.write (AppendVec.getVector seg) i (toLE64 val)
grow :: WriteCtx m s => Segment (MutMsg s) -> Int -> m (Segment (MutMsg s))
grow (MutSegment vec) amount =
MutSegment <$> AppendVec.grow vec amount maxSegmentSize
newSegment :: WriteCtx m s => MutMsg s -> Int -> m (Int, Segment (MutMsg s))
newSegment msg@(MutMsg segVar) sizeHint = do
segIndex <- numSegs msg
segs <- readMutVar segVar
segs <- AppendVec.grow segs 1 maxSegments
writeMutVar segVar segs
newSeg <- MutSegment . AppendVec.makeEmpty <$> SMV.new sizeHint
setSegment msg segIndex newSeg
pure (segIndex, newSeg)
allocInSeg :: WriteCtx m s => MutMsg s -> Int -> WordCount -> m WordAddr
allocInSeg msg segIndex (WordCount size) = do
oldSeg@(MutSegment vec) <- getSegment msg segIndex
let ret = WordAt
{ segIndex
, wordIndex = WordCount $ GMV.length $ AppendVec.getVector vec
}
newSeg <- grow oldSeg size
setSegment msg segIndex newSeg
pure ret
alloc :: WriteCtx m s => MutMsg s -> WordCount -> m WordAddr
alloc msg size = do
segIndex <- pred <$> numSegs msg
allocInSeg msg segIndex size
empty :: ConstMsg
empty = ConstMsg $ V.fromList [ ConstSegment $ SV.fromList [0] ]
newMessage :: WriteCtx m s => m (MutMsg s)
newMessage = thaw empty
instance Thaw (Segment ConstMsg) where
type Mutable s (Segment ConstMsg) = Segment (MutMsg s)
thaw = thawSeg thaw
unsafeThaw = thawSeg unsafeThaw
freeze = freezeSeg freeze
unsafeFreeze = freezeSeg unsafeFreeze
thawSeg thaw (ConstSegment vec) =
MutSegment <$> thaw (AppendVec.FrozenAppendVec vec)
freezeSeg freeze (MutSegment mvec) =
ConstSegment . AppendVec.getFrozenVector <$> freeze mvec
instance Thaw ConstMsg where
type Mutable s ConstMsg = MutMsg s
thaw = thawMsg thaw
unsafeThaw = thawMsg unsafeThaw
freeze = freezeMsg freeze
unsafeFreeze = freezeMsg unsafeFreeze
thawMsg thaw (ConstMsg vec) = do
segments <- V.mapM thaw vec >>= V.unsafeThaw
MutMsg <$> newMutVar (AppendVec.fromVector segments)
freezeMsg freeze msg = do
len <- numSegs msg
ConstMsg <$> V.generateM len (internalGetSeg msg >=> freeze)