{- |
Module: Capnp.IO
Description: Utilities for reading and writing values to handles.

This module provides utilities for reading and writing values to and
from file 'Handle's.
-}
{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE FlexibleContexts #-}
module Capnp.IO
    ( hGetValue
    , getValue
    , sGetMsg
    , sGetValue
    , hPutValue
    , putValue
    , sPutValue
    , sPutMsg
    , M.hGetMsg
    , M.getMsg
    , M.hPutMsg
    , M.putMsg
    ) where

import Data.Bits

import Control.Exception         (throwIO)
import Control.Monad.Primitive   (RealWorld)
import Control.Monad.Trans.Class (lift)
import Network.Simple.TCP        (Socket, recv, sendLazy)
import System.IO                 (Handle, stdin, stdout)
import System.IO.Error           (eofErrorType, mkIOError)

import qualified Data.ByteString as BS

import Capnp.Bits    (WordCount, wordsToBytes)
import Capnp.Classes
    (Cerialize (..), Decerialize (..), FromStruct (..), ToStruct (..))
import Capnp.Convert        (msgToLBS, valueToLBS)
import Capnp.TraversalLimit (evalLimitT)
import Codec.Capnp          (getRoot, setRoot)
import Data.Mutable         (Thaw (..))

import qualified Capnp.Message as M

-- | @'hGetValue' limit handle@ reads a message from @handle@, returning its root object.
-- @limit@ is used as both a cap on the size of a message which may be read and, for types
-- in the high-level API, the traversal limit when decoding the message.
--
-- It may throw a 'Capnp.Errors.Error' if there is a problem decoding the message,
-- or an 'IOError' raised by the underlying IO libraries.
hGetValue :: FromStruct M.ConstMsg a => Handle -> WordCount -> IO a
hGetValue :: Handle -> WordCount -> IO a
hGetValue Handle
handle WordCount
limit = do
    ConstMsg
msg <- Handle -> WordCount -> IO ConstMsg
M.hGetMsg Handle
handle WordCount
limit
    WordCount -> LimitT IO a -> IO a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (ConstMsg -> LimitT IO a
forall msg a (m :: * -> *).
(FromStruct msg a, ReadCtx m msg) =>
msg -> m a
getRoot ConstMsg
msg)

-- | @'getValue'@ is equivalent to @'hGetValue' 'stdin'@.
getValue :: FromStruct M.ConstMsg a => WordCount -> IO a
getValue :: WordCount -> IO a
getValue = Handle -> WordCount -> IO a
forall a. FromStruct ConstMsg a => Handle -> WordCount -> IO a
hGetValue Handle
stdin

-- | Like 'hGetValue', except that it takes a socket instead of a 'Handle'.
sGetValue :: FromStruct M.ConstMsg a => Socket -> WordCount -> IO a
sGetValue :: Socket -> WordCount -> IO a
sGetValue Socket
socket WordCount
limit = do
    ConstMsg
msg <- Socket -> WordCount -> IO ConstMsg
sGetMsg Socket
socket WordCount
limit
    WordCount -> LimitT IO a -> IO a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (ConstMsg -> LimitT IO a
forall msg a (m :: * -> *).
(FromStruct msg a, ReadCtx m msg) =>
msg -> m a
getRoot ConstMsg
msg)

-- | Like 'hGetMsg', except that it takes a socket instead of a 'Handle'.
sGetMsg :: Socket -> WordCount -> IO M.ConstMsg
sGetMsg :: Socket -> WordCount -> IO ConstMsg
sGetMsg Socket
socket WordCount
limit =
    WordCount -> LimitT IO ConstMsg -> IO ConstMsg
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO ConstMsg -> IO ConstMsg)
-> LimitT IO ConstMsg -> IO ConstMsg
forall a b. (a -> b) -> a -> b
$ LimitT IO Word32
-> (WordCount -> LimitT IO (Segment ConstMsg))
-> LimitT IO ConstMsg
forall (m :: * -> *).
(MonadThrow m, MonadLimit m) =>
m Word32 -> (WordCount -> m (Segment ConstMsg)) -> m ConstMsg
M.readMessage (IO Word32 -> LimitT IO Word32
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift IO Word32
read32) (IO (Segment ConstMsg) -> LimitT IO (Segment ConstMsg)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Segment ConstMsg) -> LimitT IO (Segment ConstMsg))
-> (WordCount -> IO (Segment ConstMsg))
-> WordCount
-> LimitT IO (Segment ConstMsg)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> IO (Segment ConstMsg)
forall msg. Message IO msg => WordCount -> IO (Segment msg)
readSegment)
  where
    read32 :: IO Word32
read32 = do
        ByteString
bytes <- Int -> IO ByteString
recvFull Int
4
        Word32 -> IO Word32
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
0) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL`  Int
0) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
1) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL`  Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
2) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
3) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
24)
    readSegment :: WordCount -> IO (Segment msg)
readSegment !WordCount
words = do
        ByteString
bytes <- Int -> IO ByteString
recvFull (ByteCount -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteCount -> Int) -> ByteCount -> Int
forall a b. (a -> b) -> a -> b
$ WordCount -> ByteCount
wordsToBytes WordCount
words)
        ByteString -> IO (Segment msg)
forall (m :: * -> *) msg.
Message m msg =>
ByteString -> m (Segment msg)
M.fromByteString ByteString
bytes

    -- | Like recv, but (1) never returns less than `count` bytes, (2)
    -- uses `socket`, rather than taking the socket as an argument, and (3)
    -- throws an EOF exception when the connection is closed.
    recvFull :: Int -> IO BS.ByteString
    recvFull :: Int -> IO ByteString
recvFull !Int
count = do
        Maybe ByteString
maybeBytes <- Socket -> Int -> IO (Maybe ByteString)
forall (m :: * -> *).
MonadIO m =>
Socket -> Int -> m (Maybe ByteString)
recv Socket
socket Int
count
        case Maybe ByteString
maybeBytes of
            Maybe ByteString
Nothing ->
                IOError -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ByteString) -> IOError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ IOErrorType -> String -> Maybe Handle -> Maybe String -> IOError
mkIOError IOErrorType
eofErrorType String
"Remote socket closed" Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing
            Just ByteString
bytes
                | ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
count ->
                    ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bytes
                | Bool
otherwise ->
                    (ByteString
bytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
recvFull (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bytes)

-- | @'hPutValue' handle value@ writes @value@ to handle, as the root object of
-- a message. If it throws an exception, it will be an 'IOError' raised by the
-- underlying IO libraries.
hPutValue :: (Cerialize RealWorld a, ToStruct (M.MutMsg RealWorld) (Cerial (M.MutMsg RealWorld) a))
    => Handle -> a -> IO ()
hPutValue :: Handle -> a -> IO ()
hPutValue Handle
handle a
value = do
    MutMsg RealWorld
msg <- Maybe WordCount -> IO (MutMsg RealWorld)
forall (m :: * -> *) s.
WriteCtx m s =>
Maybe WordCount -> m (MutMsg s)
M.newMessage Maybe WordCount
forall a. Maybe a
Nothing
    Cerial (MutMsg RealWorld) a
root <- WordCount
-> LimitT IO (Cerial (MutMsg RealWorld) a)
-> IO (Cerial (MutMsg RealWorld) a)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO (Cerial (MutMsg RealWorld) a)
 -> IO (Cerial (MutMsg RealWorld) a))
-> LimitT IO (Cerial (MutMsg RealWorld) a)
-> IO (Cerial (MutMsg RealWorld) a)
forall a b. (a -> b) -> a -> b
$ MutMsg RealWorld -> a -> LimitT IO (Cerial (MutMsg RealWorld) a)
forall s a (m :: * -> *).
(Cerialize s a, RWCtx m s) =>
MutMsg s -> a -> m (Cerial (MutMsg s) a)
cerialize MutMsg RealWorld
msg a
value
    Cerial (MutMsg RealWorld) a -> IO ()
forall s a (m :: * -> *).
(ToStruct (MutMsg s) a, WriteCtx m s) =>
a -> m ()
setRoot Cerial (MutMsg RealWorld) a
root
    ConstMsg
constMsg <- Mutable RealWorld ConstMsg -> IO ConstMsg
forall a (m :: * -> *) s.
(Thaw a, PrimMonad m, PrimState m ~ s) =>
Mutable s a -> m a
freeze Mutable RealWorld ConstMsg
MutMsg RealWorld
msg
    Handle -> ConstMsg -> IO ()
M.hPutMsg Handle
handle ConstMsg
constMsg

-- | 'putValue' is equivalent to @'hPutValue' 'stdin'@
putValue :: (Cerialize RealWorld a, ToStruct (M.MutMsg RealWorld) (Cerial (M.MutMsg RealWorld) a))
    => a -> IO ()
putValue :: a -> IO ()
putValue = Handle -> a -> IO ()
forall a.
(Cerialize RealWorld a,
 ToStruct (MutMsg RealWorld) (Cerial (MutMsg RealWorld) a)) =>
Handle -> a -> IO ()
hPutValue Handle
stdout

-- | Like 'hPutMsg', except that it takes a 'Socket' instead of a 'Handle'.
sPutMsg :: Socket -> M.ConstMsg -> IO ()
sPutMsg :: Socket -> ConstMsg -> IO ()
sPutMsg Socket
socket = Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket (ByteString -> IO ())
-> (ConstMsg -> ByteString) -> ConstMsg -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstMsg -> ByteString
msgToLBS

-- | Like 'hPutValue', except that it takes a 'Socket' instead of a 'Handle'.
sPutValue :: (Cerialize RealWorld a, ToStruct (M.MutMsg RealWorld) (Cerial (M.MutMsg RealWorld) a))
    => Socket -> a -> IO ()
sPutValue :: Socket -> a -> IO ()
sPutValue Socket
socket a
value = do
    ByteString
lbs <- WordCount -> LimitT IO ByteString -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO ByteString -> IO ByteString)
-> LimitT IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ a -> LimitT IO ByteString
forall (m :: * -> *) s a.
(MonadLimit m, WriteCtx m s, Cerialize s a,
 ToStruct (MutMsg s) (Cerial (MutMsg s) a)) =>
a -> m ByteString
valueToLBS a
value
    Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket ByteString
lbs