{- |
Module: Capnp.IO
Description: Utilities for reading and writing values to handles.

This module provides utilities for reading and writing values to and
from file 'Handle's.
-}
{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
module Capnp.IO
    ( hGetValue
    , getValue
    , sGetMsg
    , sGetValue
    , hPutValue
    , putValue
    , sPutValue
    , sPutMsg
    , M.hGetMsg
    , M.getMsg
    , M.hPutMsg
    , M.putMsg

    , hGetParsed
    , sGetParsed
    , getParsed
    , hPutParsed
    , sPutParsed
    , putParsed
    , hGetRaw
    , getRaw
    , sGetRaw
    ) where

import Data.Bits

import Control.Exception         (throwIO)
import Control.Monad.Primitive   (RealWorld)
import Control.Monad.Trans.Class (lift)
import Network.Simple.TCP        (Socket, recv, sendLazy)
import System.IO                 (Handle, stdin, stdout)
import System.IO.Error           (eofErrorType, mkIOError)

import qualified Data.ByteString         as BS
import qualified Data.ByteString.Builder as BB

import Capnp.Bits           (WordCount, wordsToBytes)
import Capnp.Classes
    (Cerialize(..), Decerialize(..), FromStruct(..), ToStruct(..))
import Capnp.Convert
    (msgToLBS, msgToParsed, msgToRaw, parsedToBuilder, parsedToLBS, valueToLBS)
import Capnp.Message        (Mutability(..))
import Capnp.New.Classes    (Parse)
import Capnp.TraversalLimit (evalLimitT)
import Codec.Capnp          (getRoot, setRoot)
import Data.Mutable         (Thaw(..))

import qualified Capnp.Message as M
import qualified Capnp.Repr    as R

-- | @'hGetValue' limit handle@ reads a message from @handle@, returning its root object.
-- @limit@ is used as both a cap on the size of a message which may be read and, for types
-- in the high-level API, the traversal limit when decoding the message.
--
-- It may throw a 'Capnp.Errors.Error' if there is a problem decoding the message,
-- or an 'IOError' raised by the underlying IO libraries.
hGetValue :: FromStruct 'Const a => Handle -> WordCount -> IO a
hGetValue :: Handle -> WordCount -> IO a
hGetValue Handle
handle WordCount
limit = do
    Message 'Const
msg <- Handle -> WordCount -> IO (Message 'Const)
M.hGetMsg Handle
handle WordCount
limit
    WordCount -> LimitT IO a -> IO a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (Message 'Const -> LimitT IO a
forall (mut :: Mutability) a (m :: * -> *).
(FromStruct mut a, ReadCtx m mut) =>
Message mut -> m a
getRoot Message 'Const
msg)

-- | @'getValue'@ is equivalent to @'hGetValue' 'stdin'@.
getValue :: FromStruct 'Const a => WordCount -> IO a
getValue :: WordCount -> IO a
getValue = Handle -> WordCount -> IO a
forall a. FromStruct 'Const a => Handle -> WordCount -> IO a
hGetValue Handle
stdin

-- | Like 'hGetValue', except that it takes a socket instead of a 'Handle'.
sGetValue :: FromStruct 'Const a => Socket -> WordCount -> IO a
sGetValue :: Socket -> WordCount -> IO a
sGetValue Socket
socket WordCount
limit = do
    Message 'Const
msg <- Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit
    WordCount -> LimitT IO a -> IO a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (Message 'Const -> LimitT IO a
forall (mut :: Mutability) a (m :: * -> *).
(FromStruct mut a, ReadCtx m mut) =>
Message mut -> m a
getRoot Message 'Const
msg)

-- | Like 'hGetMsg', except that it takes a socket instead of a 'Handle'.
sGetMsg :: Socket -> WordCount -> IO (M.Message 'Const)
sGetMsg :: Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit =
    WordCount -> LimitT IO (Message 'Const) -> IO (Message 'Const)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO (Message 'Const) -> IO (Message 'Const))
-> LimitT IO (Message 'Const) -> IO (Message 'Const)
forall a b. (a -> b) -> a -> b
$ LimitT IO Word32
-> (WordCount -> LimitT IO (Segment 'Const))
-> LimitT IO (Message 'Const)
forall (m :: * -> *).
(MonadThrow m, MonadLimit m) =>
m Word32 -> (WordCount -> m (Segment 'Const)) -> m (Message 'Const)
M.readMessage (IO Word32 -> LimitT IO Word32
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift IO Word32
read32) (IO (Segment 'Const) -> LimitT IO (Segment 'Const)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Segment 'Const) -> LimitT IO (Segment 'Const))
-> (WordCount -> IO (Segment 'Const))
-> WordCount
-> LimitT IO (Segment 'Const)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> IO (Segment 'Const)
readSegment)
  where
    read32 :: IO Word32
read32 = do
        ByteString
bytes <- Int -> IO ByteString
recvFull Int
4
        Word32 -> IO Word32
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
0) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL`  Int
0) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
1) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL`  Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
2) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
3) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
24)
    readSegment :: WordCount -> IO (Segment 'Const)
readSegment !WordCount
words =
        ByteString -> Segment 'Const
M.fromByteString (ByteString -> Segment 'Const)
-> IO ByteString -> IO (Segment 'Const)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
recvFull (ByteCount -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteCount -> Int) -> ByteCount -> Int
forall a b. (a -> b) -> a -> b
$ WordCount -> ByteCount
wordsToBytes WordCount
words)

    -- | Like recv, but (1) never returns less than `count` bytes, (2)
    -- uses `socket`, rather than taking the socket as an argument, and (3)
    -- throws an EOF exception when the connection is closed.
    recvFull :: Int -> IO BS.ByteString
    recvFull :: Int -> IO ByteString
recvFull !Int
count = do
        Maybe ByteString
maybeBytes <- Socket -> Int -> IO (Maybe ByteString)
forall (m :: * -> *).
MonadIO m =>
Socket -> Int -> m (Maybe ByteString)
recv Socket
socket Int
count
        case Maybe ByteString
maybeBytes of
            Maybe ByteString
Nothing ->
                IOError -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ByteString) -> IOError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ IOErrorType -> String -> Maybe Handle -> Maybe String -> IOError
mkIOError IOErrorType
eofErrorType String
"Remote socket closed" Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing
            Just ByteString
bytes
                | ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
count ->
                    ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bytes
                | Bool
otherwise ->
                    (ByteString
bytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
recvFull (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bytes)

-- | @'hPutValue' handle value@ writes @value@ to handle, as the root object of
-- a message. If it throws an exception, it will be an 'IOError' raised by the
-- underlying IO libraries.
hPutValue :: (Cerialize RealWorld a, ToStruct ('Mut RealWorld) (Cerial ('Mut RealWorld) a))
    => Handle -> a -> IO ()
hPutValue :: Handle -> a -> IO ()
hPutValue Handle
handle a
value = do
    Message ('Mut RealWorld)
msg <- Maybe WordCount -> IO (Message ('Mut RealWorld))
forall (m :: * -> *) s.
WriteCtx m s =>
Maybe WordCount -> m (Message ('Mut s))
M.newMessage Maybe WordCount
forall a. Maybe a
Nothing
    Cerial ('Mut RealWorld) a
root <- WordCount
-> LimitT IO (Cerial ('Mut RealWorld) a)
-> IO (Cerial ('Mut RealWorld) a)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO (Cerial ('Mut RealWorld) a)
 -> IO (Cerial ('Mut RealWorld) a))
-> LimitT IO (Cerial ('Mut RealWorld) a)
-> IO (Cerial ('Mut RealWorld) a)
forall a b. (a -> b) -> a -> b
$ Message ('Mut RealWorld)
-> a -> LimitT IO (Cerial ('Mut RealWorld) a)
forall s a (m :: * -> *).
(Cerialize s a, RWCtx m s) =>
Message ('Mut s) -> a -> m (Cerial ('Mut s) a)
cerialize Message ('Mut RealWorld)
msg a
value
    Cerial ('Mut RealWorld) a -> IO ()
forall s a (m :: * -> *).
(ToStruct ('Mut s) a, WriteCtx m s) =>
a -> m ()
setRoot Cerial ('Mut RealWorld) a
root
    Message 'Const
constMsg <- Mutable RealWorld (Message 'Const) -> IO (Message 'Const)
forall a (m :: * -> *) s.
(Thaw a, PrimMonad m, PrimState m ~ s) =>
Mutable s a -> m a
freeze Mutable RealWorld (Message 'Const)
Message ('Mut RealWorld)
msg
    Handle -> Message 'Const -> IO ()
M.hPutMsg Handle
handle Message 'Const
constMsg

-- | 'putValue' is equivalent to @'hPutValue' 'stdin'@
putValue :: (Cerialize RealWorld a, ToStruct ('Mut RealWorld) (Cerial ('Mut RealWorld) a))
    => a -> IO ()
putValue :: a -> IO ()
putValue = Handle -> a -> IO ()
forall a.
(Cerialize RealWorld a,
 ToStruct ('Mut RealWorld) (Cerial ('Mut RealWorld) a)) =>
Handle -> a -> IO ()
hPutValue Handle
stdout

-- | Like 'hPutMsg', except that it takes a 'Socket' instead of a 'Handle'.
sPutMsg :: Socket -> M.Message 'Const -> IO ()
sPutMsg :: Socket -> Message 'Const -> IO ()
sPutMsg Socket
socket = Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket (ByteString -> IO ())
-> (Message 'Const -> ByteString) -> Message 'Const -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message 'Const -> ByteString
msgToLBS

-- | Like 'hPutValue', except that it takes a 'Socket' instead of a 'Handle'.
sPutValue :: (Cerialize RealWorld a, ToStruct ('Mut RealWorld) (Cerial ('Mut RealWorld) a))
    => Socket -> a -> IO ()
sPutValue :: Socket -> a -> IO ()
sPutValue Socket
socket a
value = do
    ByteString
lbs <- WordCount -> LimitT IO ByteString -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO ByteString -> IO ByteString)
-> LimitT IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ a -> LimitT IO ByteString
forall (m :: * -> *) s a.
(MonadLimit m, WriteCtx m s, Cerialize s a,
 ToStruct ('Mut s) (Cerial ('Mut s) a)) =>
a -> m ByteString
valueToLBS a
value
    Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket ByteString
lbs

-- | Read a struct from the handle in its parsed form, using the supplied
-- read limit.
hGetParsed :: forall a pa. (R.IsStruct a, Parse a pa) => Handle -> WordCount -> IO pa
hGetParsed :: Handle -> WordCount -> IO pa
hGetParsed Handle
handle WordCount
limit = do
    Message 'Const
msg <- Handle -> WordCount -> IO (Message 'Const)
M.hGetMsg Handle
handle WordCount
limit
    WordCount -> LimitT IO pa -> IO pa
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO pa -> IO pa) -> LimitT IO pa -> IO pa
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO pa
forall a (m :: * -> *) pa.
(ReadCtx m 'Const, IsStruct a, Parse a pa) =>
Message 'Const -> m pa
msgToParsed @a Message 'Const
msg

-- | Read a struct from the socket in its parsed form, using the supplied
-- read limit.
sGetParsed :: forall a pa. (R.IsStruct a, Parse a pa) => Socket -> WordCount -> IO pa
sGetParsed :: Socket -> WordCount -> IO pa
sGetParsed Socket
socket WordCount
limit = do
    Message 'Const
msg <- Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit
    WordCount -> LimitT IO pa -> IO pa
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO pa -> IO pa) -> LimitT IO pa -> IO pa
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO pa
forall a (m :: * -> *) pa.
(ReadCtx m 'Const, IsStruct a, Parse a pa) =>
Message 'Const -> m pa
msgToParsed @a Message 'Const
msg

-- | Read a struct from stdin in its parsed form, using the supplied
-- read limit.
getParsed :: (R.IsStruct a, Parse a pa) => WordCount -> IO pa
getParsed :: WordCount -> IO pa
getParsed = Handle -> WordCount -> IO pa
forall a pa.
(IsStruct a, Parse a pa) =>
Handle -> WordCount -> IO pa
hGetParsed Handle
stdin

-- | Write the parsed form of a struct to the handle
hPutParsed :: (R.IsStruct a, Parse a pa) => Handle -> pa -> IO ()
hPutParsed :: Handle -> pa -> IO ()
hPutParsed Handle
h pa
value = do
    Builder
bb <- WordCount -> LimitT IO Builder -> IO Builder
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO Builder -> IO Builder)
-> LimitT IO Builder -> IO Builder
forall a b. (a -> b) -> a -> b
$ pa -> LimitT IO Builder
forall a (m :: * -> *) pa s.
(RWCtx m s, IsStruct a, Parse a pa) =>
pa -> m Builder
parsedToBuilder pa
value
    Handle -> Builder -> IO ()
BB.hPutBuilder Handle
h Builder
bb

-- | Write the parsed form of a struct to stdout
putParsed :: (R.IsStruct a, Parse a pa) => pa -> IO ()
putParsed :: pa -> IO ()
putParsed = Handle -> pa -> IO ()
forall a pa. (IsStruct a, Parse a pa) => Handle -> pa -> IO ()
hPutParsed Handle
stdout

-- | Write the parsed form of a struct to the socket.
sPutParsed :: (R.IsStruct a, Parse a pa) => Socket -> pa -> IO ()
sPutParsed :: Socket -> pa -> IO ()
sPutParsed Socket
socket pa
value  = do
    ByteString
lbs <- WordCount -> LimitT IO ByteString -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO ByteString -> IO ByteString)
-> LimitT IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ pa -> LimitT IO ByteString
forall a (m :: * -> *) pa s.
(RWCtx m s, IsStruct a, Parse a pa) =>
pa -> m ByteString
parsedToLBS pa
value
    Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket ByteString
lbs

-- | Read a struct from the handle using the supplied read limit,
-- and return its root pointer.
hGetRaw :: R.IsStruct a => Handle -> WordCount -> IO (R.Raw 'Const a)
hGetRaw :: Handle -> WordCount -> IO (Raw 'Const a)
hGetRaw Handle
h WordCount
limit = do
    Message 'Const
msg <- Handle -> WordCount -> IO (Message 'Const)
M.hGetMsg Handle
h WordCount
limit
    WordCount -> LimitT IO (Raw 'Const a) -> IO (Raw 'Const a)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO (Raw 'Const a) -> IO (Raw 'Const a))
-> LimitT IO (Raw 'Const a) -> IO (Raw 'Const a)
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO (Raw 'Const a)
forall a (m :: * -> *) (mut :: Mutability).
(ReadCtx m mut, IsStruct a) =>
Message mut -> m (Raw mut a)
msgToRaw Message 'Const
msg

-- | Read a struct from stdin using the supplied read limit,
-- and return its root pointer.
getRaw :: R.IsStruct a => WordCount -> IO (R.Raw 'Const a)
getRaw :: WordCount -> IO (Raw 'Const a)
getRaw = Handle -> WordCount -> IO (Raw 'Const a)
forall a. IsStruct a => Handle -> WordCount -> IO (Raw 'Const a)
hGetRaw Handle
stdin

-- | Read a struct from the socket using the supplied read limit,
-- and return its root pointer.
sGetRaw :: R.IsStruct a => Socket -> WordCount -> IO (R.Raw 'Const a)
sGetRaw :: Socket -> WordCount -> IO (Raw 'Const a)
sGetRaw Socket
socket WordCount
limit = do
    Message 'Const
msg <- Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit
    WordCount -> LimitT IO (Raw 'Const a) -> IO (Raw 'Const a)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO (Raw 'Const a) -> IO (Raw 'Const a))
-> LimitT IO (Raw 'Const a) -> IO (Raw 'Const a)
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO (Raw 'Const a)
forall a (m :: * -> *) (mut :: Mutability).
(ReadCtx m mut, IsStruct a) =>
Message mut -> m (Raw mut a)
msgToRaw Message 'Const
msg