{-|
Module      : Z.Data.MessagePack
Description : Fast MessagePack serialization/deserialization
Copyright   : (c) Dong Han, 2019
License     : BSD
Maintainer  : winterland1989@gmail.com
Stability   : experimental
Portability : non-portable

This module provides <https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md MessagePack-RPC> implementation.

@
-- server
import Z.IO.RPC.MessagePack
import Z.IO.Network
import Z.IO
import qualified Z.Data.Text as T

serveRPC (startTCPServer defaultTCPServerConfig) . simpleRouter $
 [ ("foo", CallHandler $ \\ (req :: Int) -> do
     return (req + 1))
 , ("bar", NotifyHandler $ \\ (req :: T.Text) -> do
     printStd (req <> "world"))
 , ("qux", StreamHandler $ \\ (_ :: ()) -> do
    withMVar stdinBuf (pure . sourceFromBuffered))
 ]

-- client
import Z.IO.RPC.MessagePack
import Z.IO.Network
import Z.IO
import qualified Z.Data.Text as T
import qualified Z.Data.Vector as V

withResource (initTCPClient defaultTCPClientConfig) $ \\ uvs -> do
    c <- rpcClient uvs
    -- single call
    call \@Int \@Int c "foo" 1
    -- notify without result
    notify \@T.Text c "bar" "hello"
    -- streaming result
    (_, src) <- callStream c "qux" ()
    runBIO $ src >|> sinkToIO (\\ b -> withMVar stdoutBuf (\\ bo -> do
        writeBuffer bo b
        flushBuffer bo))

@

-}

module Z.IO.RPC.MessagePack where

import           Control.Concurrent
import           Control.Monad
import           Data.Bits
import           Data.Int
import           Data.IORef
import           Z.Data.PrimRef.PrimIORef
import qualified Z.Data.MessagePack.Builder as MB
import qualified Z.Data.MessagePack.Value   as MV
import           Z.Data.MessagePack         (MessagePack)
import qualified Z.Data.MessagePack         as MP
import qualified Z.Data.Parser              as P
import qualified Z.Data.Text                as T
import qualified Z.Data.Vector.FlatIntMap   as FIM
import qualified Z.Data.Vector.FlatMap      as FM
import qualified Z.Data.Vector              as V
import           Z.IO
import           Z.IO.Network

data Client = Client
    { Client -> Counter
_clientSeqRef :: Counter
    , Client -> Counter
_clientPipelineReqNum :: Counter
    , Client -> BufferedInput
_clientBufferedInput :: BufferedInput
    , Client -> BufferedOutput
_clientBufferedOutput :: BufferedOutput
    }

-- | Open a RPC client from input/output device.
rpcClient :: (Input dev, Output dev) => dev -> IO Client
rpcClient :: dev -> IO Client
rpcClient dev
uvs = dev -> dev -> Int -> Int -> IO Client
forall i o.
(Input i, Output o) =>
i -> o -> Int -> Int -> IO Client
rpcClient' dev
uvs dev
uvs Int
V.defaultChunkSize Int
V.defaultChunkSize

-- | Open a RPC client with more control.
rpcClient' :: (Input i, Output o)
              => i
              -> o
              -> Int          -- ^ recv buffer size
              -> Int          -- ^ send buffer size
              -> IO Client
rpcClient' :: i -> o -> Int -> Int -> IO Client
rpcClient' i
i o
o Int
recvBufSiz Int
sendBufSiz = do
    Counter
seqRef <- Int -> IO Counter
newCounter Int
0
    Counter
reqNum <- Int -> IO Counter
newCounter Int
0
    BufferedInput
bi <- Int -> i -> IO BufferedInput
forall i. Input i => Int -> i -> IO BufferedInput
newBufferedInput' Int
recvBufSiz i
i
    BufferedOutput
bo <- Int -> o -> IO BufferedOutput
forall o. Output o => Int -> o -> IO BufferedOutput
newBufferedOutput' Int
sendBufSiz o
o
    Client -> IO Client
forall (m :: * -> *) a. Monad m => a -> m a
return (Counter -> Counter -> BufferedInput -> BufferedOutput -> Client
Client Counter
seqRef Counter
reqNum BufferedInput
bi BufferedOutput
bo)

-- | Send a single RPC call and get result.
call:: (MessagePack req, MessagePack res, HasCallStack) => Client -> T.Text -> req -> IO res
call :: Client -> Text -> req -> IO res
call Client
cli Text
name req
req = do
    Int
msgid <- Client -> Text -> req -> IO Int
forall req.
(HasCallStack, MessagePack req) =>
Client -> Text -> req -> IO Int
callPipeline Client
cli Text
name req
req
    Int -> PipelineResult -> IO res
forall res.
(HasCallStack, MessagePack res) =>
Int -> PipelineResult -> IO res
fetchPipeline Int
msgid (PipelineResult -> IO res) -> IO PipelineResult -> IO res
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HasCallStack => Client -> IO PipelineResult
Client -> IO PipelineResult
execPipeline Client
cli

-- | Send a single notification RPC call without getting result.
notify :: (MessagePack req, HasCallStack)=> Client -> T.Text -> req -> IO ()
notify :: Client -> Text -> req -> IO ()
notify c :: Client
c@(Client Counter
_ Counter
_ BufferedInput
_ BufferedOutput
bo) Text
name req
req = Client -> Text -> req -> IO ()
forall req.
(HasCallStack, MessagePack req) =>
Client -> Text -> req -> IO ()
notifyPipeline Client
c Text
name req
req IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo

type PipelineId = Int
type PipelineResult = FIM.FlatIntMap MV.Value

-- | Make a call inside a pipeline, which will be sent in batch when `execPipeline`.
--
-- @
--  ...
--  fooId <- callPipeline client "foo" $ ...
--  barId <- callPipeline client "bar" $ ...
--  notifyPipeline client "qux" $ ...
--
--  r <- execPipeline client
--
--  fooResult <- fetchPipeline fooId r
--  barResult <- fetchPipeline barId r
-- @
--
callPipeline :: HasCallStack => MessagePack req => Client -> T.Text -> req -> IO PipelineId
callPipeline :: Client -> Text -> req -> IO Int
callPipeline (Client Counter
seqRef Counter
reqNum BufferedInput
_ BufferedOutput
bo) Text
name req
req = do
    Int
x <- Counter -> IO Int
forall a. Prim a => PrimIORef a -> IO a
readPrimIORef Counter
reqNum
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (-Int
1)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RPCException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (CallStack -> RPCException
RPCStreamUnconsumed CallStack
HasCallStack => CallStack
callStack)
    Counter -> Int -> IO ()
forall a. Prim a => PrimIORef a -> a -> IO ()
writePrimIORef Counter
reqNum (Int
xInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    Int
msgid <- Counter -> IO Int
forall a. Prim a => PrimIORef a -> IO a
readPrimIORef Counter
seqRef
    Counter -> Int -> IO ()
forall a. Prim a => PrimIORef a -> a -> IO ()
writePrimIORef Counter
seqRef (Int
msgidInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    let !msgid' :: Int
msgid' = Int
msgid Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
0xFFFFFFFF  -- shrink to unsiged 32bits
    BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int -> Builder ()
MB.arrayHeader Int
4
        Int64 -> Builder ()
MB.int Int64
0                        -- type request
        Int64 -> Builder ()
MB.int (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
msgid')    -- msgid
        Text -> Builder ()
MB.str Text
name                     -- method name
        req -> Builder ()
forall a. MessagePack a => a -> Builder ()
MP.encodeMessagePack req
req        -- param
    Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
msgid'

-- | Make a notify inside a pipeline, which will be sent in batch when `execPipeline`.
--
-- Notify calls doesn't affect execution's result.
notifyPipeline :: HasCallStack => MessagePack req => Client -> T.Text -> req -> IO ()
notifyPipeline :: Client -> Text -> req -> IO ()
notifyPipeline (Client Counter
_ Counter
reqNum BufferedInput
_ BufferedOutput
bo) Text
name req
req = do
    Int
x <- Counter -> IO Int
forall a. Prim a => PrimIORef a -> IO a
readPrimIORef Counter
reqNum
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (-Int
1)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RPCException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (CallStack -> RPCException
RPCStreamUnconsumed CallStack
HasCallStack => CallStack
callStack)
    BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int -> Builder ()
MB.arrayHeader Int
3
        Int64 -> Builder ()
MB.int Int64
2                        -- type notification
        Text -> Builder ()
MB.str Text
name                     -- method name
        req -> Builder ()
forall a. MessagePack a => a -> Builder ()
MP.encodeMessagePack req
req        -- param

-- | Exception thrown when remote endpoint return errors.
data RPCException
    = RPCStreamUnconsumed CallStack
    | RPCException MV.Value CallStack
  deriving Int -> RPCException -> ShowS
[RPCException] -> ShowS
RPCException -> String
(Int -> RPCException -> ShowS)
-> (RPCException -> String)
-> ([RPCException] -> ShowS)
-> Show RPCException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RPCException] -> ShowS
$cshowList :: [RPCException] -> ShowS
show :: RPCException -> String
$cshow :: RPCException -> String
showsPrec :: Int -> RPCException -> ShowS
$cshowsPrec :: Int -> RPCException -> ShowS
Show
instance Exception RPCException

-- | Sent request in batch and get result in a map identified by 'PipelineId'.
execPipeline :: HasCallStack => Client -> IO PipelineResult
execPipeline :: Client -> IO PipelineResult
execPipeline (Client Counter
_ Counter
reqNum BufferedInput
bi BufferedOutput
bo) = do
    HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
    Int
x <- Counter -> IO Int
forall a. Prim a => PrimIORef a -> IO a
readPrimIORef Counter
reqNum
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (-Int
1)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RPCException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (CallStack -> RPCException
RPCStreamUnconsumed CallStack
HasCallStack => CallStack
callStack)
    Counter -> Int -> IO ()
forall a. Prim a => PrimIORef a -> a -> IO ()
writePrimIORef Counter
reqNum Int
0
    Int -> [IPair Value] -> PipelineResult
forall v. Int -> [IPair v] -> FlatIntMap v
FIM.packN Int
x ([IPair Value] -> PipelineResult)
-> IO [IPair Value] -> IO PipelineResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (IPair Value) -> IO [IPair Value]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
x (do
        (Int64
msgid, Value
err, Value
v) <- Parser (Int64, Value, Value)
-> BufferedInput -> IO (Int64, Value, Value)
forall a. HasCallStack => Parser a -> BufferedInput -> IO a
readParser (do
            Word8
tag <- Parser Word8
P.anyWord8
            Bool -> Parser () -> Parser ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
tag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x94) (Text -> Parser ()
forall a. Text -> Parser a
P.fail' (Text -> Parser ()) -> Text -> Parser ()
forall a b. (a -> b) -> a -> b
$ Text
"wrong response tag: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Word8 -> Text
forall a. Print a => a -> Text
T.toText Word8
tag)
            !Value
typ <- Parser Value
MV.value
            !Value
seq <- Parser Value
MV.value
            !Value
err <- Parser Value
MV.value
            !Value
v <- Parser Value
MV.value
            case Value
typ of
                MV.Int Int64
1 -> case Value
seq of
                    MV.Int Int64
msgid | Int64
msgid Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0 Bool -> Bool -> Bool
&& Int64
msgid Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0xFFFFFFFF ->
                        (Int64, Value, Value) -> Parser (Int64, Value, Value)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64
msgid, Value
err, Value
v)
                    Value
_ -> Text -> Parser (Int64, Value, Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Int64, Value, Value))
-> Text -> Parser (Int64, Value, Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong msgid: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
seq
                Value
_ -> Text -> Parser (Int64, Value, Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Int64, Value, Value))
-> Text -> Parser (Int64, Value, Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong response type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
typ
            ) BufferedInput
bi
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Value
err Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
/= Value
MV.Nil) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RPCException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (Value -> CallStack -> RPCException
RPCException Value
err CallStack
HasCallStack => CallStack
callStack)
        IPair Value -> IO (IPair Value)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Value -> IPair Value
forall a. Int -> a -> IPair a
V.IPair (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
msgid) Value
v))

-- | Use the `PipelineId` returned when `callPipeline` to fetch call's result.
fetchPipeline :: HasCallStack => MessagePack res => PipelineId -> PipelineResult -> IO res
fetchPipeline :: Int -> PipelineResult -> IO res
fetchPipeline Int
msgid PipelineResult
r = do
    Text -> Either ConvertError res -> IO res
forall e a. (HasCallStack, Print e) => Text -> Either e a -> IO a
unwrap Text
"EPARSE" (Either ConvertError res -> IO res)
-> (Value -> Either ConvertError res) -> Value -> IO res
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Value -> Either ConvertError res
forall a. MessagePack a => Value -> Either ConvertError a
MP.convertValue (Value -> IO res) -> IO Value -> IO res
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
        Text -> Text -> Maybe Value -> IO Value
forall a. HasCallStack => Text -> Text -> Maybe a -> IO a
unwrap' Text
"ENOMSG" (Text
"missing message in response: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Print a => a -> Text
T.toText Int
msgid)
            (Int -> PipelineResult -> Maybe Value
forall v. Int -> FlatIntMap v -> Maybe v
FIM.lookup Int
msgid PipelineResult
r)

-- | Call a stream method, no other `call` or `notify` should be sent until
-- returned stream is consumed completely.
--
-- This is implemented by extend MessagePack-RPC protocol by adding following new message types:
--
-- @
-- -- start stream request
-- [typ 0x04, name, param]
--
-- -- stop stream request
-- [typ 0x05]
--
-- -- each stream response
-- [typ 0x06, err, value]
--
-- -- stream response end
-- [typ 0x07]
-- @
--
-- The return tuple is a pair of a stop action and a `Source`, to terminate stream early, call the
-- stop action. Please continue consuming until EOF reached,
-- otherwise the state of the `Client` will be incorrect.
callStream :: (MessagePack req, MessagePack res, HasCallStack) => Client -> T.Text -> req -> IO (IO (), Source res)
callStream :: Client -> Text -> req -> IO (IO (), Source res)
callStream (Client Counter
seqRef Counter
reqNum BufferedInput
bi BufferedOutput
bo) Text
name req
req = do
    Int
x <- Counter -> IO Int
forall a. Prim a => PrimIORef a -> IO a
readPrimIORef Counter
reqNum
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (-Int
1)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RPCException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (CallStack -> RPCException
RPCStreamUnconsumed CallStack
HasCallStack => CallStack
callStack)
    Counter -> Int -> IO ()
forall a. Prim a => PrimIORef a -> a -> IO ()
writePrimIORef Counter
reqNum (-Int
1)
    BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int -> Builder ()
MB.arrayHeader Int
3
        Int64 -> Builder ()
MB.int Int64
4                        -- type request
        Text -> Builder ()
MB.str Text
name                     -- method name
        req -> Builder ()
forall a. MessagePack a => a -> Builder ()
MP.encodeMessagePack req
req        -- param
    HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
    (IO (), Source res) -> IO (IO (), Source res)
forall (m :: * -> *) a. Monad m => a -> m a
return (IO ()
sendEOF, IO (Maybe res) -> Source res
forall a. HasCallStack => IO (Maybe a) -> Source a
sourceFromIO (IO (Maybe res) -> Source res) -> IO (Maybe res) -> Source res
forall a b. (a -> b) -> a -> b
$ do
        Maybe (Maybe (Value, Value))
res <- BIO Void (Maybe (Value, Value))
-> IO (Maybe (Maybe (Value, Value)))
forall inp out. BIO inp out -> IO (Maybe out)
pull (Parser (Maybe (Value, Value))
-> BufferedInput -> BIO Void (Maybe (Value, Value))
forall a. HasCallStack => Parser a -> BufferedInput -> Source a
sourceParserFromBuffered (do
            Word8
tag <- Parser Word8
P.anyWord8
            -- stream stop
            case Word8
tag of
                Word8
0x91 -> do
                    !Value
typ <- Parser Value
MV.value
                    Bool -> Parser () -> Parser ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Value
typ Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64 -> Value
MV.Int Int64
7) (Parser () -> Parser ()) -> Parser () -> Parser ()
forall a b. (a -> b) -> a -> b
$
                        Text -> Parser ()
forall a. Text -> Parser a
P.fail' (Text -> Parser ()) -> Text -> Parser ()
forall a b. (a -> b) -> a -> b
$ Text
"wrong response type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
typ
                    Maybe (Value, Value) -> Parser (Maybe (Value, Value))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Value, Value)
forall a. Maybe a
Nothing
                Word8
0x93 -> do
                    !Value
typ <- Parser Value
MV.value
                    !Value
err <- Parser Value
MV.value
                    !Value
v <- Parser Value
MV.value
                    Bool -> Parser () -> Parser ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Value
typ Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64 -> Value
MV.Int Int64
6) (Parser () -> Parser ()) -> Parser () -> Parser ()
forall a b. (a -> b) -> a -> b
$
                        Text -> Parser ()
forall a. Text -> Parser a
P.fail' (Text -> Parser ()) -> Text -> Parser ()
forall a b. (a -> b) -> a -> b
$ Text
"wrong response type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
typ
                    Maybe (Value, Value) -> Parser (Maybe (Value, Value))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Value, Value) -> Maybe (Value, Value)
forall a. a -> Maybe a
Just (Value
err, Value
v))
                Word8
_ -> Text -> Parser (Maybe (Value, Value))
forall a. Text -> Parser a
P.fail' (Text -> Parser (Maybe (Value, Value)))
-> Text -> Parser (Maybe (Value, Value))
forall a b. (a -> b) -> a -> b
$ Text
"wrong response tag: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Word8 -> Text
forall a. Print a => a -> Text
T.toText Word8
tag
            ) BufferedInput
bi)

        -- we take tcp disconnect as eof too
        case (Maybe (Maybe (Value, Value)) -> Maybe (Value, Value)
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join Maybe (Maybe (Value, Value))
res) of
            Just (Value
err, Value
v) -> do
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Value
err Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
/= Value
MV.Nil) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RPCException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (Value -> CallStack -> RPCException
RPCException Value
err CallStack
HasCallStack => CallStack
callStack)
                Text -> Either ConvertError (Maybe res) -> IO (Maybe res)
forall e a. (HasCallStack, Print e) => Text -> Either e a -> IO a
unwrap Text
"EPARSE" (Value -> Either ConvertError (Maybe res)
forall a. MessagePack a => Value -> Either ConvertError a
MP.convertValue Value
v)
            Maybe (Value, Value)
_ -> do
                Counter -> Int -> IO ()
forall a. Prim a => PrimIORef a -> a -> IO ()
writePrimIORef Counter
reqNum Int
0
                Maybe res -> IO (Maybe res)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe res
forall a. Maybe a
Nothing
        )
  where
    sendEOF :: IO ()
sendEOF = do
        BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            Int -> Builder ()
MB.arrayHeader Int
1
            Int64 -> Builder ()
MB.int Int64
5
        HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo

--------------------------------------------------------------------------------

type ServerLoop = (UVStream -> IO ()) -> IO ()
type ServerService = T.Text -> Maybe ServerHandler
data ServerHandler where
    CallHandler :: (MessagePack req, MessagePack res) => (req -> IO res) -> ServerHandler
    NotifyHandler :: MessagePack req => (req -> IO ()) -> ServerHandler
    StreamHandler :: (MessagePack req, MessagePack res) => (req -> IO (Source res)) -> ServerHandler

-- | Simple router using `FlatMap`, lookup name in /O(log(N))/.
--
-- @
-- import Z.IO.PRC.MessagePack
-- import Z.IO.Network
-- import Z.IO
--
-- serveRPC (startTCPServer defaultTCPServerConfig) . simpleRouter $
--  [ ("foo", CallHandler $ \\ req -> do
--      ... )
--  , ("bar", CallHandler $ \\ req -> do
--      ... )
--  ]
--
-- @
simpleRouter :: [(T.Text, ServerHandler)] -> ServerService
simpleRouter :: [(Text, ServerHandler)] -> ServerService
simpleRouter [(Text, ServerHandler)]
handles Text
name = Text -> FlatMap Text ServerHandler -> Maybe ServerHandler
forall k v. Ord k => k -> FlatMap k v -> Maybe v
FM.lookup Text
name FlatMap Text ServerHandler
handleMap
  where
    handleMap :: FlatMap Text ServerHandler
handleMap = [(Text, ServerHandler)] -> FlatMap Text ServerHandler
forall k v. Ord k => [(k, v)] -> FlatMap k v
FM.packR [(Text, ServerHandler)]
handles

-- | Serve a RPC service.
serveRPC :: ServerLoop -> ServerService -> IO ()
serveRPC :: ServerLoop -> ServerService -> IO ()
serveRPC ServerLoop
serve = ServerLoop -> Int -> Int -> ServerService -> IO ()
serveRPC' ServerLoop
serve Int
V.defaultChunkSize Int
V.defaultChunkSize

data Request a
    = Notify (T.Text, a)
    | Call (Int64, T.Text, a)
    | StreamStart (T.Text, a)
  deriving Int -> Request a -> ShowS
[Request a] -> ShowS
Request a -> String
(Int -> Request a -> ShowS)
-> (Request a -> String)
-> ([Request a] -> ShowS)
-> Show (Request a)
forall a. Show a => Int -> Request a -> ShowS
forall a. Show a => [Request a] -> ShowS
forall a. Show a => Request a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Request a] -> ShowS
$cshowList :: forall a. Show a => [Request a] -> ShowS
show :: Request a -> String
$cshow :: forall a. Show a => Request a -> String
showsPrec :: Int -> Request a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Request a -> ShowS
Show

-- | Serve a RPC service with more control.
serveRPC' :: ServerLoop
          -> Int          -- ^ recv buffer size
          -> Int          -- ^ send buffer size
          -> ServerService -> IO ()
serveRPC' :: ServerLoop -> Int -> Int -> ServerService -> IO ()
serveRPC' ServerLoop
serve Int
recvBufSiz Int
sendBufSiz ServerService
handle = ServerLoop
serve ServerLoop -> ServerLoop
forall a b. (a -> b) -> a -> b
$ \ UVStream
uvs -> do
    BufferedInput
bi <- Int -> UVStream -> IO BufferedInput
forall i. Input i => Int -> i -> IO BufferedInput
newBufferedInput' Int
recvBufSiz UVStream
uvs
    BufferedOutput
bo <- Int -> UVStream -> IO BufferedOutput
forall o. Output o => Int -> o -> IO BufferedOutput
newBufferedOutput' Int
sendBufSiz UVStream
uvs
    BufferedInput -> BufferedOutput -> IO ()
loop BufferedInput
bi BufferedOutput
bo
  where
    loop :: BufferedInput -> BufferedOutput -> IO ()
loop BufferedInput
bi BufferedOutput
bo = do
        Maybe (Request Value)
req <- BIO Void (Request Value) -> IO (Maybe (Request Value))
forall inp out. BIO inp out -> IO (Maybe out)
pull (Parser (Request Value) -> BufferedInput -> BIO Void (Request Value)
forall a. HasCallStack => Parser a -> BufferedInput -> Source a
sourceParserFromBuffered (do
            Word8
tag <- Parser Word8
P.anyWord8
            case Word8
tag of
                -- notify or stream start
                Word8
0x93 -> do
                    !Value
typ <- Parser Value
MV.value
                    !Value
name <- Parser Value
MV.value
                    !Value
v <- Parser Value
MV.value
                    case Value
typ of
                        MV.Int Int64
2 -> case Value
name of
                            MV.Str Text
name' -> Request Value -> Parser (Request Value)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Text, Value) -> Request Value
forall a. (Text, a) -> Request a
Notify (Text
name', Value
v))
                            Value
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong RPC name: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
name
                        MV.Int Int64
4 -> case Value
name of
                            MV.Str Text
name' -> Request Value -> Parser (Request Value)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Text, Value) -> Request Value
forall a. (Text, a) -> Request a
StreamStart (Text
name', Value
v))
                            Value
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong RPC name: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
name
                        Value
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong request type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
typ
                -- call
                Word8
0x94 -> do
                    !Value
typ <- Parser Value
MV.value
                    !Value
seq <- Parser Value
MV.value
                    !Value
name <- Parser Value
MV.value
                    !Value
v <- Parser Value
MV.value
                    case Value
typ of
                        MV.Int Int64
0 -> case Value
seq of
                            MV.Int Int64
msgid | Int64
msgid Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
0 Bool -> Bool -> Bool
&& Int64
msgid Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0xFFFFFFFF -> case Value
name of
                                MV.Str Text
name' -> Request Value -> Parser (Request Value)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int64, Text, Value) -> Request Value
forall a. (Int64, Text, a) -> Request a
Call (Int64
msgid, Text
name', Value
v))
                                Value
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong RPC name: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
name
                            Value
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong msgid: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
seq
                        Value
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong request type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
typ
                Word8
_ -> Text -> Parser (Request Value)
forall a. Text -> Parser a
P.fail' (Text -> Parser (Request Value)) -> Text -> Parser (Request Value)
forall a b. (a -> b) -> a -> b
$ Text
"wrong request tag: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Word8 -> Text
forall a. Print a => a -> Text
T.toText Word8
tag
            ) BufferedInput
bi)
        Maybe (Request Value) -> IO ()
forall a. Show a => a -> IO ()
print Maybe (Request Value)
req
        case Maybe (Request Value)
req of
            Just (Notify (Text
name, Value
v)) -> do
                case ServerService
handle Text
name of
                    Just (NotifyHandler req -> IO ()
f) -> do
                        req -> IO ()
f (req -> IO ()) -> IO req -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> Either ConvertError req -> IO req
forall e a. (HasCallStack, Print e) => Text -> Either e a -> IO a
unwrap Text
"EPARSE" (Value -> Either ConvertError req
forall a. MessagePack a => Value -> Either ConvertError a
MP.convertValue Value
v)
                    Maybe ServerHandler
_ -> Text -> Text -> IO ()
forall a. HasCallStack => Text -> Text -> IO a
throwOtherError Text
"ENOTFOUND" Text
"notification method not found"
                BufferedInput -> BufferedOutput -> IO ()
loop BufferedInput
bi BufferedOutput
bo
            Just (Call (Int64
msgid, Text
name, Value
v)) -> do
                case ServerService
handle Text
name of
                    Just (CallHandler req -> IO res
f) -> do
                        Either SomeException res
res <- IO res -> IO (Either SomeException res)
forall e a. Exception e => IO a -> IO (Either e a)
try (req -> IO res
f (req -> IO res) -> IO req -> IO res
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> Either ConvertError req -> IO req
forall e a. (HasCallStack, Print e) => Text -> Either e a -> IO a
unwrap Text
"EPARSE" (Value -> Either ConvertError req
forall a. MessagePack a => Value -> Either ConvertError a
MP.convertValue Value
v))
                        BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                            Int -> Builder ()
MB.arrayHeader Int
4
                            Int64 -> Builder ()
MB.int Int64
1                        -- type response
                            Int64 -> Builder ()
MB.int (Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
msgid)     -- msgid
                            case Either SomeException res
res of
                                Left SomeException
e -> do
                                    Text -> Builder ()
MB.str (String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show (SomeException
e :: SomeException))
                                    Builder ()
MB.nil
                                Right res
res -> do
                                    Builder ()
MB.nil
                                    res -> Builder ()
forall a. MessagePack a => a -> Builder ()
MP.encodeMessagePack res
res
                        HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
                    Maybe ServerHandler
_ -> do
                        BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                            Int -> Builder ()
MB.arrayHeader Int
4
                            Int64 -> Builder ()
MB.int Int64
1                        -- type response
                            Int64 -> Builder ()
MB.int (Int64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
msgid)     -- msgid
                            Text -> Builder ()
MB.str (Text -> Builder ()) -> Text -> Builder ()
forall a b. (a -> b) -> a -> b
$ Text
"request method: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" not found"
                            Builder ()
MB.nil
                        HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
                BufferedInput -> BufferedOutput -> IO ()
loop BufferedInput
bi BufferedOutput
bo
            Just (StreamStart (Text
name, Value
v)) -> do
                IORef Bool
eofRef <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False
                -- fork new thread to get stream end notification
                IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
                    BIO Void () -> IO (Maybe ())
forall inp out. BIO inp out -> IO (Maybe out)
pull (Parser () -> BufferedInput -> BIO Void ()
forall a. HasCallStack => Parser a -> BufferedInput -> Source a
sourceParserFromBuffered (do
                        Word8
tag <- Parser Word8
P.anyWord8
                        -- stream stop
                        Bool -> Parser () -> Parser ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
tag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0x91) (Parser () -> Parser ()) -> Parser () -> Parser ()
forall a b. (a -> b) -> a -> b
$
                            Text -> Parser ()
forall a. Text -> Parser a
P.fail' (Text -> Parser ()) -> Text -> Parser ()
forall a b. (a -> b) -> a -> b
$ Text
"wrong request tag: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Word8 -> Text
forall a. Print a => a -> Text
T.toText Word8
tag
                        !Value
typ <- Parser Value
MV.value
                        Bool -> Parser () -> Parser ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Value
typ Value -> Value -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64 -> Value
MV.Int Int64
5) (Parser () -> Parser ()) -> Parser () -> Parser ()
forall a b. (a -> b) -> a -> b
$
                            Text -> Parser ()
forall a. Text -> Parser a
P.fail' (Text -> Parser ()) -> Text -> Parser ()
forall a b. (a -> b) -> a -> b
$ Text
"wrong request type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Value -> Text
forall a. Print a => a -> Text
T.toText Value
typ
                        ) BufferedInput
bi)
                    IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef IORef Bool
eofRef Bool
True

                case ServerService
handle Text
name of
                    Just (StreamHandler req -> IO (Source res)
f) -> do
                        Source res
src <- req -> IO (Source res)
f (req -> IO (Source res)) -> IO req -> IO (Source res)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Text -> Either ConvertError req -> IO req
forall e a. (HasCallStack, Print e) => Text -> Either e a -> IO a
unwrap Text
"EPARSE" (Value -> Either ConvertError req
forall a. MessagePack a => Value -> Either ConvertError a
MP.convertValue Value
v)
                        IORef Bool -> Source res -> BufferedOutput -> IO ()
forall a inp.
MessagePack a =>
IORef Bool -> BIO inp a -> BufferedOutput -> IO ()
loopSend IORef Bool
eofRef Source res
src BufferedOutput
bo
                    Maybe ServerHandler
_ -> do
                        BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                            Int -> Builder ()
MB.arrayHeader Int
3
                            Int64 -> Builder ()
MB.int Int64
6                        -- type response
                            Text -> Builder ()
MB.str (Text -> Builder ()) -> Text -> Builder ()
forall a b. (a -> b) -> a -> b
$ Text
"request method: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" not found"
                            Builder ()
MB.nil
                        HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
                BufferedInput -> BufferedOutput -> IO ()
loop BufferedInput
bi BufferedOutput
bo

            Maybe (Request Value)
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    loopSend :: IORef Bool -> BIO inp a -> BufferedOutput -> IO ()
loopSend IORef Bool
eofRef BIO inp a
src BufferedOutput
bo = do
        Bool
eof <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
eofRef
        if Bool
eof
        then do
            BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Int -> Builder ()
MB.arrayHeader Int
1
                Int64 -> Builder ()
MB.int Int64
7                        -- type response
            HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
        else do
            Maybe a
r <- BIO inp a -> IO (Maybe a)
forall inp out. BIO inp out -> IO (Maybe out)
pull BIO inp a
src
            case Maybe a
r of
                Just a
r' -> do
                    BufferedOutput -> Builder () -> IO ()
forall a. HasCallStack => BufferedOutput -> Builder a -> IO ()
writeBuilder BufferedOutput
bo (Builder () -> IO ()) -> Builder () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                        Int -> Builder ()
MB.arrayHeader Int
3
                        Int64 -> Builder ()
MB.int Int64
6                        -- type response
                        Builder ()
MB.nil
                        a -> Builder ()
forall a. MessagePack a => a -> Builder ()
MP.encodeMessagePack a
r'
                    HasCallStack => BufferedOutput -> IO ()
BufferedOutput -> IO ()
flushBuffer BufferedOutput
bo
                Maybe a
_ -> IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
atomicWriteIORef IORef Bool
eofRef Bool
True
            IORef Bool -> BIO inp a -> BufferedOutput -> IO ()
loopSend IORef Bool
eofRef BIO inp a
src BufferedOutput
bo