{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Capnp.Message (
hPutMsg
, hGetMsg
, putMsg
, getMsg
, maxSegmentSize
, maxSegments
, encode
, decode
, Message(..)
, empty
, ConstMsg
, getSegment
, getWord
, MutMsg
, newMessage
, alloc
, allocInSeg
, newSegment
, setWord
, setSegment
, WriteCtx(..)
) where
import Prelude hiding (read)
import Data.Bits (shiftL)
import Control.Monad (void, when, (>=>))
import Control.Monad.Catch (MonadThrow(..))
import Control.Monad.Primitive (PrimMonad, PrimState)
import Control.Monad.State (evalStateT, get, put)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Writer (execWriterT, tell)
import Data.Bytes.Get (getWord32le, runGetS)
import Data.ByteString.Internal (ByteString(..))
import Data.Either (fromRight)
import Data.Primitive (MutVar, newMutVar, readMutVar, writeMutVar)
import Data.Word (Word32, Word64)
import System.Endian (fromLE64, toLE64)
import System.IO (Handle, stdin, stdout)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BB
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Storable.Mutable as SMV
import Data.Capnp.Address (WordAddr(..))
import Data.Capnp.Bits (WordCount(..), hi, lo)
import Data.Capnp.Errors (Error(..))
import Data.Capnp.TraversalLimit (LimitT, MonadLimit(invoice), evalLimitT)
import Data.Mutable (Mutable(..))
import Internal.Util (checkIndex)
maxSegmentSize :: Int
maxSegmentSize = 1 `shiftL` 28
maxSegments :: Int
maxSegments = 1024
class Monad m => Message m msg where
data Segment msg
numSegs :: msg -> m Int
internalGetSeg :: msg -> Int -> m (Segment msg)
numWords :: Segment msg -> m Int
slice :: Int -> Int -> Segment msg -> m (Segment msg)
read :: Segment msg -> Int -> m Word64
fromByteString :: ByteString -> m (Segment msg)
toByteString :: Segment msg -> m ByteString
getSegment :: (MonadThrow m, Message m msg) => msg -> Int -> m (Segment msg)
getSegment msg i = do
checkIndex i =<< numSegs msg
internalGetSeg msg i
getWord :: (MonadThrow m, Message m msg) => msg -> WordAddr -> m Word64
getWord msg WordAt{wordIndex=wordIndex@(WordCount i), segIndex} = do
seg <- getSegment msg segIndex
checkIndex i =<< numWords seg
seg `read` i
setSegment :: (WriteCtx m s, MonadThrow m) => MutMsg s -> Int -> Segment (MutMsg s) -> m ()
setSegment msg i seg = do
checkIndex i =<< numSegs msg
internalSetSeg msg i seg
setWord :: (WriteCtx m s, MonadThrow m) => MutMsg s -> WordAddr -> Word64 -> m ()
setWord msg WordAt{wordIndex=WordCount i, segIndex} val = do
seg <- getSegment msg segIndex
checkIndex i =<< numWords seg
write seg i val
newtype ConstMsg = ConstMsg (V.Vector (Segment ConstMsg))
instance Monad m => Message m ConstMsg where
newtype Segment ConstMsg = ConstSegment { constSegToVec :: SV.Vector Word64 }
numSegs (ConstMsg vec) = pure $ V.length vec
internalGetSeg (ConstMsg vec) i = vec `V.indexM` i
numWords (ConstSegment vec) = pure $ SV.length vec
slice start len (ConstSegment vec) = pure $ ConstSegment (SV.slice start len vec)
read (ConstSegment vec) i = fromLE64 <$> vec `SV.indexM` i
fromByteString (PS fptr offset len) =
pure $ ConstSegment (SV.unsafeCast $ SV.unsafeFromForeignPtr fptr offset len)
toByteString (ConstSegment vec) = pure $ PS fptr offset len where
(fptr, offset, len) = SV.unsafeToForeignPtr (SV.unsafeCast vec)
decode :: MonadThrow m => ByteString -> m ConstMsg
decode bytes = fromByteString bytes >>= decodeSeg
encode :: MonadThrow m => ConstMsg -> m BB.Builder
encode msg = execWriterT $ writeMessage
msg
(tell . BB.word32LE)
(toByteString >=> tell . BB.byteString)
decodeSeg :: MonadThrow m => Segment ConstMsg -> m ConstMsg
decodeSeg seg = do
len <- numWords seg
flip evalStateT (Nothing, 0) $ evalLimitT len $
readMessage read32 readSegment
where
read32 = do
(cur, idx) <- get
case cur of
Just n -> do
put (Nothing, idx)
return n
Nothing -> do
word <- lift $ lift $ read seg idx
put (Just $ hi word, idx + 1)
return (lo word)
readSegment (WordCount len) = do
(cur, idx) <- get
put (cur, idx + len)
lift $ lift $ slice idx len seg
readMessage :: (MonadThrow m, MonadLimit m) => m Word32 -> (WordCount -> m (Segment ConstMsg)) -> m ConstMsg
readMessage read32 readSegment = do
invoice 1
numSegs' <- read32
let numSegs = numSegs' + 1
invoice (fromIntegral numSegs `div` 2)
segSizes <- V.replicateM (fromIntegral numSegs) read32
when (numSegs `mod` 2 == 0) $ void read32
V.mapM_ (invoice . fromIntegral) segSizes
ConstMsg <$> V.mapM (readSegment . fromIntegral) segSizes
writeMessage :: MonadThrow m => ConstMsg -> (Word32 -> m ()) -> (Segment ConstMsg -> m ()) -> m ()
writeMessage (ConstMsg segs) write32 writeSegment = do
let numSegs = V.length segs
write32 (fromIntegral numSegs - 1)
V.forM_ segs $ \seg -> write32 =<< fromIntegral <$> numWords seg
when (numSegs `mod` 2 == 0) $ write32 0
V.forM_ segs writeSegment
hPutMsg :: Handle -> ConstMsg -> IO ()
hPutMsg handle msg = encode msg >>= BB.hPutBuilder handle
putMsg :: ConstMsg -> IO ()
putMsg = hPutMsg stdout
hGetMsg :: Handle -> Int -> IO ConstMsg
hGetMsg handle size =
evalLimitT size $ readMessage read32 readSegment
where
read32 :: LimitT IO Word32
read32 = lift $ do
bytes <- BS.hGet handle 4
pure $ fromRight (error "impossible") (runGetS getWord32le bytes)
readSegment n = lift $ BS.hGet handle (fromIntegral n * 8) >>= fromByteString
getMsg :: Int -> IO ConstMsg
getMsg = hGetMsg stdin
data MutMsg s = MutMsg
{ mutMsgSegs :: MutVar s (MV.MVector s (Segment (MutMsg s)))
, mutMsgLen :: MutVar s Int
}
type WriteCtx m s = (PrimMonad m, s ~ PrimState m, MonadThrow m)
instance (PrimMonad m, s ~ PrimState m) => Message m (MutMsg s) where
data Segment (MutMsg s) = MutSegment
{ mutSegVec :: !(SMV.MVector s Word64)
, mutSegLen :: !Int
}
numWords MutSegment{mutSegLen} = pure mutSegLen
slice start len MutSegment{mutSegVec,mutSegLen} =
pure MutSegment
{ mutSegVec = SMV.slice start len mutSegVec
, mutSegLen = len
}
read MutSegment{mutSegVec} i = fromLE64 <$> SMV.read mutSegVec i
fromByteString bytes = do
vec <- constSegToVec <$> fromByteString bytes
mvec <- SV.thaw vec
pure MutSegment
{ mutSegVec = mvec
, mutSegLen = SV.length vec
}
toByteString mseg = do
seg <- freeze mseg
toByteString (seg :: Segment ConstMsg)
numSegs = readMutVar . mutMsgLen
internalGetSeg MutMsg{mutMsgSegs} i = do
segs <- readMutVar mutMsgSegs
MV.read segs i
internalSetSeg :: WriteCtx m s => MutMsg s -> Int -> Segment (MutMsg s) -> m ()
internalSetSeg MutMsg{mutMsgSegs} segIndex seg = do
segs <- readMutVar mutMsgSegs
MV.write segs segIndex seg
write :: WriteCtx m s => Segment (MutMsg s) -> Int -> Word64 -> m ()
write MutSegment{mutSegVec} i val =
SMV.write mutSegVec i (toLE64 val)
grow :: WriteCtx m s => Segment (MutMsg s) -> Int -> m (Segment (MutMsg s))
grow MutSegment{mutSegVec} amount = do
when (maxSegmentSize - amount < SMV.length mutSegVec) $
throwM SizeError
newVec <- SMV.grow mutSegVec amount
pure MutSegment
{ mutSegVec = newVec
, mutSegLen = SMV.length newVec
}
newSegment :: WriteCtx m s => MutMsg s -> Int -> m (Int, Segment (MutMsg s))
newSegment msg@MutMsg{mutMsgSegs,mutMsgLen} sizeHint = do
newSegVec <- SMV.new sizeHint
segIndex <- numSegs msg
when (segIndex >= maxSegments) $
throwM SizeError
segs <- readMutVar mutMsgSegs
when (MV.length segs == segIndex) $ do
MV.grow segs segIndex >>= writeMutVar mutMsgSegs
writeMutVar mutMsgLen (segIndex * 2)
let newSeg = MutSegment
{ mutSegVec = newSegVec
, mutSegLen = 0
}
setSegment msg segIndex newSeg
pure (segIndex, newSeg)
allocInSeg :: WriteCtx m s => MutMsg s -> Int -> WordCount -> m WordAddr
allocInSeg msg segIndex (WordCount size) = do
oldSeg@MutSegment{mutSegLen} <- getSegment msg segIndex
let ret = WordAt { segIndex, wordIndex = WordCount mutSegLen }
newSeg <- grow oldSeg size
setSegment msg segIndex newSeg
pure ret
alloc :: WriteCtx m s => MutMsg s -> WordCount -> m WordAddr
alloc msg size = do
segIndex <- pred <$> numSegs msg
allocInSeg msg segIndex size
empty :: ConstMsg
empty = ConstMsg $ V.fromList [ ConstSegment $ SV.fromList [0] ]
newMessage :: WriteCtx m s => m (MutMsg s)
newMessage = thaw empty
instance Thaw (Segment ConstMsg) where
type Mutable s (Segment ConstMsg) = Segment (MutMsg s)
thaw = thawSeg SV.thaw
unsafeThaw = thawSeg SV.unsafeThaw
freeze = freezeSeg SV.freeze
unsafeFreeze = freezeSeg SV.unsafeFreeze
thawSeg thaw (ConstSegment vec) = do
mvec <- thaw vec
pure MutSegment
{ mutSegVec = mvec
, mutSegLen = SV.length vec
}
freezeSeg freeze seg@MutSegment{mutSegLen} = do
MutSegment{mutSegVec} <- slice 0 mutSegLen seg
ConstSegment <$> freeze mutSegVec
instance Thaw ConstMsg where
type Mutable s ConstMsg = MutMsg s
thaw = thawMsg thaw
unsafeThaw = thawMsg unsafeThaw
freeze = freezeMsg freeze
unsafeFreeze = freezeMsg unsafeFreeze
thawMsg thaw (ConstMsg vec) = do
segments <- V.mapM thaw vec >>= V.unsafeThaw
MutMsg
<$> newMutVar segments
<*> newMutVar (MV.length segments)
freezeMsg freeze msg@MutMsg{mutMsgLen} = do
len <- readMutVar mutMsgLen
ConstMsg <$> V.generateM len (internalGetSeg msg >=> freeze)