{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Capnp.IO
( hGetValue
, getValue
, sGetMsg
, sGetValue
, hPutValue
, putValue
, sPutValue
, sPutMsg
, M.hGetMsg
, M.getMsg
, M.hPutMsg
, M.putMsg
) where
import Data.Bits
import Control.Exception (throwIO)
import Control.Monad.Primitive (RealWorld)
import Control.Monad.Trans.Class (lift)
import Network.Simple.TCP (Socket, recv, sendLazy)
import System.IO (Handle, stdin, stdout)
import System.IO.Error (eofErrorType, mkIOError)
import qualified Data.ByteString as BS
import Capnp.Bits (WordCount, wordsToBytes)
import Capnp.Classes
(Cerialize (..), Decerialize (..), FromStruct (..), ToStruct (..))
import Capnp.Convert (msgToLBS, valueToLBS)
import Capnp.TraversalLimit (evalLimitT)
import Codec.Capnp (getRoot, setRoot)
import Data.Mutable (Thaw (..))
import qualified Capnp.Message as M
hGetValue :: FromStruct M.ConstMsg a => Handle -> WordCount -> IO a
hGetValue :: Handle -> WordCount -> IO a
hGetValue Handle
handle WordCount
limit = do
ConstMsg
msg <- Handle -> WordCount -> IO ConstMsg
M.hGetMsg Handle
handle WordCount
limit
WordCount -> LimitT IO a -> IO a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (ConstMsg -> LimitT IO a
forall msg a (m :: * -> *).
(FromStruct msg a, ReadCtx m msg) =>
msg -> m a
getRoot ConstMsg
msg)
getValue :: FromStruct M.ConstMsg a => WordCount -> IO a
getValue :: WordCount -> IO a
getValue = Handle -> WordCount -> IO a
forall a. FromStruct ConstMsg a => Handle -> WordCount -> IO a
hGetValue Handle
stdin
sGetValue :: FromStruct M.ConstMsg a => Socket -> WordCount -> IO a
sGetValue :: Socket -> WordCount -> IO a
sGetValue Socket
socket WordCount
limit = do
ConstMsg
msg <- Socket -> WordCount -> IO ConstMsg
sGetMsg Socket
socket WordCount
limit
WordCount -> LimitT IO a -> IO a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (ConstMsg -> LimitT IO a
forall msg a (m :: * -> *).
(FromStruct msg a, ReadCtx m msg) =>
msg -> m a
getRoot ConstMsg
msg)
sGetMsg :: Socket -> WordCount -> IO M.ConstMsg
sGetMsg :: Socket -> WordCount -> IO ConstMsg
sGetMsg Socket
socket WordCount
limit =
WordCount -> LimitT IO ConstMsg -> IO ConstMsg
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO ConstMsg -> IO ConstMsg)
-> LimitT IO ConstMsg -> IO ConstMsg
forall a b. (a -> b) -> a -> b
$ LimitT IO Word32
-> (WordCount -> LimitT IO (Segment ConstMsg))
-> LimitT IO ConstMsg
forall (m :: * -> *).
(MonadThrow m, MonadLimit m) =>
m Word32 -> (WordCount -> m (Segment ConstMsg)) -> m ConstMsg
M.readMessage (IO Word32 -> LimitT IO Word32
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift IO Word32
read32) (IO (Segment ConstMsg) -> LimitT IO (Segment ConstMsg)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Segment ConstMsg) -> LimitT IO (Segment ConstMsg))
-> (WordCount -> IO (Segment ConstMsg))
-> WordCount
-> LimitT IO (Segment ConstMsg)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> IO (Segment ConstMsg)
forall msg. Message IO msg => WordCount -> IO (Segment msg)
readSegment)
where
read32 :: IO Word32
read32 = do
ByteString
bytes <- Int -> IO ByteString
recvFull Int
4
Word32 -> IO Word32
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$
(Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
0) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
0) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
(Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
1) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
(Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
2) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
(Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
3) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
24)
readSegment :: WordCount -> IO (Segment msg)
readSegment !WordCount
words = do
ByteString
bytes <- Int -> IO ByteString
recvFull (ByteCount -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteCount -> Int) -> ByteCount -> Int
forall a b. (a -> b) -> a -> b
$ WordCount -> ByteCount
wordsToBytes WordCount
words)
ByteString -> IO (Segment msg)
forall (m :: * -> *) msg.
Message m msg =>
ByteString -> m (Segment msg)
M.fromByteString ByteString
bytes
recvFull :: Int -> IO BS.ByteString
recvFull :: Int -> IO ByteString
recvFull !Int
count = do
Maybe ByteString
maybeBytes <- Socket -> Int -> IO (Maybe ByteString)
forall (m :: * -> *).
MonadIO m =>
Socket -> Int -> m (Maybe ByteString)
recv Socket
socket Int
count
case Maybe ByteString
maybeBytes of
Maybe ByteString
Nothing ->
IOError -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ByteString) -> IOError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ IOErrorType -> String -> Maybe Handle -> Maybe String -> IOError
mkIOError IOErrorType
eofErrorType String
"Remote socket closed" Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing
Just ByteString
bytes
| ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
count ->
ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bytes
| Bool
otherwise ->
(ByteString
bytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
recvFull (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bytes)
hPutValue :: (Cerialize RealWorld a, ToStruct (M.MutMsg RealWorld) (Cerial (M.MutMsg RealWorld) a))
=> Handle -> a -> IO ()
hPutValue :: Handle -> a -> IO ()
hPutValue Handle
handle a
value = do
MutMsg RealWorld
msg <- Maybe WordCount -> IO (MutMsg RealWorld)
forall (m :: * -> *) s.
WriteCtx m s =>
Maybe WordCount -> m (MutMsg s)
M.newMessage Maybe WordCount
forall a. Maybe a
Nothing
Cerial (MutMsg RealWorld) a
root <- WordCount
-> LimitT IO (Cerial (MutMsg RealWorld) a)
-> IO (Cerial (MutMsg RealWorld) a)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO (Cerial (MutMsg RealWorld) a)
-> IO (Cerial (MutMsg RealWorld) a))
-> LimitT IO (Cerial (MutMsg RealWorld) a)
-> IO (Cerial (MutMsg RealWorld) a)
forall a b. (a -> b) -> a -> b
$ MutMsg RealWorld -> a -> LimitT IO (Cerial (MutMsg RealWorld) a)
forall s a (m :: * -> *).
(Cerialize s a, RWCtx m s) =>
MutMsg s -> a -> m (Cerial (MutMsg s) a)
cerialize MutMsg RealWorld
msg a
value
Cerial (MutMsg RealWorld) a -> IO ()
forall s a (m :: * -> *).
(ToStruct (MutMsg s) a, WriteCtx m s) =>
a -> m ()
setRoot Cerial (MutMsg RealWorld) a
root
ConstMsg
constMsg <- Mutable RealWorld ConstMsg -> IO ConstMsg
forall a (m :: * -> *) s.
(Thaw a, PrimMonad m, PrimState m ~ s) =>
Mutable s a -> m a
freeze Mutable RealWorld ConstMsg
MutMsg RealWorld
msg
Handle -> ConstMsg -> IO ()
M.hPutMsg Handle
handle ConstMsg
constMsg
putValue :: (Cerialize RealWorld a, ToStruct (M.MutMsg RealWorld) (Cerial (M.MutMsg RealWorld) a))
=> a -> IO ()
putValue :: a -> IO ()
putValue = Handle -> a -> IO ()
forall a.
(Cerialize RealWorld a,
ToStruct (MutMsg RealWorld) (Cerial (MutMsg RealWorld) a)) =>
Handle -> a -> IO ()
hPutValue Handle
stdout
sPutMsg :: Socket -> M.ConstMsg -> IO ()
sPutMsg :: Socket -> ConstMsg -> IO ()
sPutMsg Socket
socket = Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket (ByteString -> IO ())
-> (ConstMsg -> ByteString) -> ConstMsg -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstMsg -> ByteString
msgToLBS
sPutValue :: (Cerialize RealWorld a, ToStruct (M.MutMsg RealWorld) (Cerial (M.MutMsg RealWorld) a))
=> Socket -> a -> IO ()
sPutValue :: Socket -> a -> IO ()
sPutValue Socket
socket a
value = do
ByteString
lbs <- WordCount -> LimitT IO ByteString -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO ByteString -> IO ByteString)
-> LimitT IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ a -> LimitT IO ByteString
forall (m :: * -> *) s a.
(MonadLimit m, WriteCtx m s, Cerialize s a,
ToStruct (MutMsg s) (Cerial (MutMsg s) a)) =>
a -> m ByteString
valueToLBS a
value
Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket ByteString
lbs