{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE OverloadedStrings          #-}
{-# LANGUAGE StrictData                 #-}
{-# LANGUAGE TypeFamilies               #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Network.MessagePack.Client.Internal where

import           Control.Applicative               (Applicative)
import           Control.Monad                     (when)
import           Control.Monad.Catch               (MonadCatch, MonadThrow,
                                                    throwM)
import qualified Control.Monad.State.Strict        as CMS
import           Control.Monad.Validate            (runValidate)
import qualified Data.Binary                       as Binary
import qualified Data.ByteString                   as S
import           Data.Conduit                      (ConduitT, SealedConduitT,
                                                    Void, runConduit, ($$++),
                                                    (.|))
import qualified Data.Conduit.Binary               as CB
import           Data.Conduit.Serialization.Binary (sinkGet)
import           Data.MessagePack                  (MessagePack, Object,
                                                    defaultConfig, fromObject,
                                                    fromObjectWith)
import           Data.Monoid                       ((<>))
import           Data.Text                         (Text)
import qualified Data.Text                         as T
import qualified Network.MessagePack.Types.Result  as R

import           Network.MessagePack.Interface     (IsClientType (..), Returns,
                                                    ReturnsM)
import           Network.MessagePack.Types.Client
import           Network.MessagePack.Types.Error
import           Network.MessagePack.Types.Spec


-- | RPC connection type
data Connection m = Connection
  { Connection m -> SealedConduitT () ByteString m ()
connSource :: !(SealedConduitT () S.ByteString m ())
  , Connection m -> ConduitT ByteString Void m ()
connSink   :: !(ConduitT S.ByteString Void m ())
  , Connection m -> Int
connMsgId  :: !Int
  , Connection m -> [Text]
connMths   :: ![Text]
  }


newtype ClientT m a
  = ClientT { ClientT m a -> StateT (Connection m) m a
runClientT :: CMS.StateT (Connection m) m a }
  deriving (a -> ClientT m b -> ClientT m a
(a -> b) -> ClientT m a -> ClientT m b
(forall a b. (a -> b) -> ClientT m a -> ClientT m b)
-> (forall a b. a -> ClientT m b -> ClientT m a)
-> Functor (ClientT m)
forall a b. a -> ClientT m b -> ClientT m a
forall a b. (a -> b) -> ClientT m a -> ClientT m b
forall (m :: * -> *) a b.
Functor m =>
a -> ClientT m b -> ClientT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> ClientT m a -> ClientT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ClientT m b -> ClientT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> ClientT m b -> ClientT m a
fmap :: (a -> b) -> ClientT m a -> ClientT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> ClientT m a -> ClientT m b
Functor, Functor (ClientT m)
a -> ClientT m a
Functor (ClientT m)
-> (forall a. a -> ClientT m a)
-> (forall a b. ClientT m (a -> b) -> ClientT m a -> ClientT m b)
-> (forall a b c.
    (a -> b -> c) -> ClientT m a -> ClientT m b -> ClientT m c)
-> (forall a b. ClientT m a -> ClientT m b -> ClientT m b)
-> (forall a b. ClientT m a -> ClientT m b -> ClientT m a)
-> Applicative (ClientT m)
ClientT m a -> ClientT m b -> ClientT m b
ClientT m a -> ClientT m b -> ClientT m a
ClientT m (a -> b) -> ClientT m a -> ClientT m b
(a -> b -> c) -> ClientT m a -> ClientT m b -> ClientT m c
forall a. a -> ClientT m a
forall a b. ClientT m a -> ClientT m b -> ClientT m a
forall a b. ClientT m a -> ClientT m b -> ClientT m b
forall a b. ClientT m (a -> b) -> ClientT m a -> ClientT m b
forall a b c.
(a -> b -> c) -> ClientT m a -> ClientT m b -> ClientT m c
forall (m :: * -> *). Monad m => Functor (ClientT m)
forall (m :: * -> *) a. Monad m => a -> ClientT m a
forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> ClientT m b -> ClientT m a
forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> ClientT m b -> ClientT m b
forall (m :: * -> *) a b.
Monad m =>
ClientT m (a -> b) -> ClientT m a -> ClientT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> ClientT m a -> ClientT m b -> ClientT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ClientT m a -> ClientT m b -> ClientT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> ClientT m b -> ClientT m a
*> :: ClientT m a -> ClientT m b -> ClientT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> ClientT m b -> ClientT m b
liftA2 :: (a -> b -> c) -> ClientT m a -> ClientT m b -> ClientT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> ClientT m a -> ClientT m b -> ClientT m c
<*> :: ClientT m (a -> b) -> ClientT m a -> ClientT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
ClientT m (a -> b) -> ClientT m a -> ClientT m b
pure :: a -> ClientT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> ClientT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (ClientT m)
Applicative, Applicative (ClientT m)
a -> ClientT m a
Applicative (ClientT m)
-> (forall a b. ClientT m a -> (a -> ClientT m b) -> ClientT m b)
-> (forall a b. ClientT m a -> ClientT m b -> ClientT m b)
-> (forall a. a -> ClientT m a)
-> Monad (ClientT m)
ClientT m a -> (a -> ClientT m b) -> ClientT m b
ClientT m a -> ClientT m b -> ClientT m b
forall a. a -> ClientT m a
forall a b. ClientT m a -> ClientT m b -> ClientT m b
forall a b. ClientT m a -> (a -> ClientT m b) -> ClientT m b
forall (m :: * -> *). Monad m => Applicative (ClientT m)
forall (m :: * -> *) a. Monad m => a -> ClientT m a
forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> ClientT m b -> ClientT m b
forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> (a -> ClientT m b) -> ClientT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ClientT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> ClientT m a
>> :: ClientT m a -> ClientT m b -> ClientT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> ClientT m b -> ClientT m b
>>= :: ClientT m a -> (a -> ClientT m b) -> ClientT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
ClientT m a -> (a -> ClientT m b) -> ClientT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (ClientT m)
Monad, Monad (ClientT m)
Monad (ClientT m)
-> (forall a. IO a -> ClientT m a) -> MonadIO (ClientT m)
IO a -> ClientT m a
forall a. IO a -> ClientT m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (ClientT m)
forall (m :: * -> *) a. MonadIO m => IO a -> ClientT m a
liftIO :: IO a -> ClientT m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> ClientT m a
$cp1MonadIO :: forall (m :: * -> *). MonadIO m => Monad (ClientT m)
CMS.MonadIO, Monad (ClientT m)
e -> ClientT m a
Monad (ClientT m)
-> (forall e a. Exception e => e -> ClientT m a)
-> MonadThrow (ClientT m)
forall e a. Exception e => e -> ClientT m a
forall (m :: * -> *).
Monad m -> (forall e a. Exception e => e -> m a) -> MonadThrow m
forall (m :: * -> *). MonadThrow m => Monad (ClientT m)
forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> ClientT m a
throwM :: e -> ClientT m a
$cthrowM :: forall (m :: * -> *) e a.
(MonadThrow m, Exception e) =>
e -> ClientT m a
$cp1MonadThrow :: forall (m :: * -> *). MonadThrow m => Monad (ClientT m)
MonadThrow, MonadThrow (ClientT m)
MonadThrow (ClientT m)
-> (forall e a.
    Exception e =>
    ClientT m a -> (e -> ClientT m a) -> ClientT m a)
-> MonadCatch (ClientT m)
ClientT m a -> (e -> ClientT m a) -> ClientT m a
forall e a.
Exception e =>
ClientT m a -> (e -> ClientT m a) -> ClientT m a
forall (m :: * -> *).
MonadThrow m
-> (forall e a. Exception e => m a -> (e -> m a) -> m a)
-> MonadCatch m
forall (m :: * -> *). MonadCatch m => MonadThrow (ClientT m)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
ClientT m a -> (e -> ClientT m a) -> ClientT m a
catch :: ClientT m a -> (e -> ClientT m a) -> ClientT m a
$ccatch :: forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
ClientT m a -> (e -> ClientT m a) -> ClientT m a
$cp1MonadCatch :: forall (m :: * -> *). MonadCatch m => MonadThrow (ClientT m)
MonadCatch)

type Client a = ClientT IO a

instance IsClientType m (Returns r) where
  type ClientType m (Returns r) = ClientT m r

instance IsClientType m (ReturnsM io r) where
  type ClientType m (ReturnsM io r) = ClientT m r


instance (CMS.MonadIO m, MonadThrow m, MessagePack o)
    => RpcType (ClientT m o) where
  rpcc :: Text -> [Object] -> ClientT m o
rpcc Text
name [Object]
args = do
    Object
res <- Text -> [Object] -> ClientT m Object
forall (m :: * -> *).
(MonadThrow m, MonadIO m) =>
Text -> [Object] -> ClientT m Object
rpcCall Text
name ([Object] -> [Object]
forall a. [a] -> [a]
reverse [Object]
args)
    case Validate DecodeError o -> Either DecodeError o
forall e a. Validate e a -> Either e a
runValidate (Validate DecodeError o -> Either DecodeError o)
-> Validate DecodeError o -> Either DecodeError o
forall a b. (a -> b) -> a -> b
$ Config -> Object -> Validate DecodeError o
forall a (m :: * -> *).
(MessagePack a, Applicative m, Monad m,
 MonadValidate DecodeError m) =>
Config -> Object -> m a
fromObjectWith Config
defaultConfig Object
res of
      Right o
ok  ->
        o -> ClientT m o
forall (m :: * -> *) a. Monad m => a -> m a
return o
ok
      Left DecodeError
err ->
        RpcError -> ClientT m o
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (RpcError -> ClientT m o) -> RpcError -> ClientT m o
forall a b. (a -> b) -> a -> b
$ Text -> Object -> RpcError
ResultTypeError (String -> Text
T.pack (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ DecodeError -> String
forall a. Show a => a -> String
show DecodeError
err) Object
res


rpcCall :: (MonadThrow m, CMS.MonadIO m) => Text -> [Object] -> ClientT m Object
rpcCall :: Text -> [Object] -> ClientT m Object
rpcCall Text
methodName [Object]
args = StateT (Connection m) m Object -> ClientT m Object
forall (m :: * -> *) a. StateT (Connection m) m a -> ClientT m a
ClientT (StateT (Connection m) m Object -> ClientT m Object)
-> StateT (Connection m) m Object -> ClientT m Object
forall a b. (a -> b) -> a -> b
$ do
  Connection m
conn <- StateT (Connection m) m (Connection m)
forall s (m :: * -> *). MonadState s m => m s
CMS.get
  let msgid :: Int
msgid = Connection m -> Int
forall (m :: * -> *). Connection m -> Int
connMsgId Connection m
conn

  (SealedConduitT () ByteString m ()
rsrc', Object
res) <- m (SealedConduitT () ByteString m (), Object)
-> StateT
     (Connection m) m (SealedConduitT () ByteString m (), Object)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
CMS.lift (m (SealedConduitT () ByteString m (), Object)
 -> StateT
      (Connection m) m (SealedConduitT () ByteString m (), Object))
-> m (SealedConduitT () ByteString m (), Object)
-> StateT
     (Connection m) m (SealedConduitT () ByteString m (), Object)
forall a b. (a -> b) -> a -> b
$ do
    let req :: ByteString
req = [Text] -> Request Text -> ByteString
forall mth.
(Eq mth, MessagePack mth) =>
[mth] -> Request mth -> ByteString
packRequest (Connection m -> [Text]
forall (m :: * -> *). Connection m -> [Text]
connMths Connection m
conn) (Int
0, Int
msgid, Text
methodName, [Object]
args)
    ConduitT () Void m () -> m ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void m () -> m ()) -> ConduitT () Void m () -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ConduitT () ByteString m ()
forall (m :: * -> *) i.
Monad m =>
ByteString -> ConduitT i ByteString m ()
CB.sourceLbs ByteString
req ConduitT () ByteString m ()
-> ConduitM ByteString Void m () -> ConduitT () Void m ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| Connection m -> ConduitM ByteString Void m ()
forall (m :: * -> *). Connection m -> ConduitT ByteString Void m ()
connSink Connection m
conn
    Connection m -> SealedConduitT () ByteString m ()
forall (m :: * -> *).
Connection m -> SealedConduitT () ByteString m ()
connSource Connection m
conn SealedConduitT () ByteString m ()
-> Sink ByteString m Object
-> m (SealedConduitT () ByteString m (), Object)
forall (m :: * -> *) a b.
Monad m =>
SealedConduitT () a m ()
-> Sink a m b -> m (SealedConduitT () a m (), b)
$$++ Get Object -> Sink ByteString m Object
forall (m :: * -> *) b z.
MonadThrow m =>
Get b -> ConduitT ByteString z m b
sinkGet Get Object
forall t. Binary t => Get t
Binary.get

  Connection m -> StateT (Connection m) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
CMS.put Connection m
conn
    { connSource :: SealedConduitT () ByteString m ()
connSource = SealedConduitT () ByteString m ()
rsrc'
    , connMsgId :: Int
connMsgId  = Int
msgid Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    }

  case Object -> Either DecodeError Response
unpackResponse Object
res of
    Left DecodeError
err -> RpcError -> StateT (Connection m) m Object
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (RpcError -> StateT (Connection m) m Object)
-> RpcError -> StateT (Connection m) m Object
forall a b. (a -> b) -> a -> b
$ Text -> RpcError
ProtocolError (Text -> RpcError) -> Text -> RpcError
forall a b. (a -> b) -> a -> b
$ Text
"invalid response data: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (DecodeError -> String
forall a. Show a => a -> String
show DecodeError
err)
    Right (Int
rtype, Int
rmsgid, Object
rerror, Object
rresult) -> do
      Bool -> StateT (Connection m) m () -> StateT (Connection m) m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rtype Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1) (StateT (Connection m) m () -> StateT (Connection m) m ())
-> StateT (Connection m) m () -> StateT (Connection m) m ()
forall a b. (a -> b) -> a -> b
$
        RpcError -> StateT (Connection m) m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (RpcError -> StateT (Connection m) m ())
-> RpcError -> StateT (Connection m) m ()
forall a b. (a -> b) -> a -> b
$ Text -> RpcError
ProtocolError (Text -> RpcError) -> Text -> RpcError
forall a b. (a -> b) -> a -> b
$
          Text
"invalid response type (expect 1, but got " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
rtype) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"): " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Object -> String
forall a. Show a => a -> String
show Object
res)

      Bool -> StateT (Connection m) m () -> StateT (Connection m) m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
rmsgid Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
msgid) (StateT (Connection m) m () -> StateT (Connection m) m ())
-> StateT (Connection m) m () -> StateT (Connection m) m ()
forall a b. (a -> b) -> a -> b
$
        RpcError -> StateT (Connection m) m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (RpcError -> StateT (Connection m) m ())
-> RpcError -> StateT (Connection m) m ()
forall a b. (a -> b) -> a -> b
$ Text -> RpcError
ProtocolError (Text -> RpcError) -> Text -> RpcError
forall a b. (a -> b) -> a -> b
$
          Text
"message id mismatch: expect " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
msgid) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but got " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
rmsgid)

      case Object -> Maybe ()
forall (m :: * -> *) a.
(MonadFail m, MessagePack a) =>
Object -> m a
fromObject Object
rerror of
        Maybe ()
Nothing -> RpcError -> StateT (Connection m) m Object
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (RpcError -> StateT (Connection m) m Object)
-> RpcError -> StateT (Connection m) m Object
forall a b. (a -> b) -> a -> b
$ Object -> RpcError
RemoteError Object
rerror
        Just () -> Object -> StateT (Connection m) m Object
forall (m :: * -> *) a. Monad m => a -> m a
return Object
rresult


setMethodList :: Monad m => [Text] -> ClientT m ()
setMethodList :: [Text] -> ClientT m ()
setMethodList [Text]
mths = StateT (Connection m) m () -> ClientT m ()
forall (m :: * -> *) a. StateT (Connection m) m a -> ClientT m a
ClientT (StateT (Connection m) m () -> ClientT m ())
-> StateT (Connection m) m () -> ClientT m ()
forall a b. (a -> b) -> a -> b
$ do
  Connection m
conn <- StateT (Connection m) m (Connection m)
forall s (m :: * -> *). MonadState s m => m s
CMS.get
  Connection m -> StateT (Connection m) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
CMS.put Connection m
conn { connMths :: [Text]
connMths = [Text]
mths }