{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase       #-}
{-# LANGUAGE TypeFamilies     #-}
-- |
-- Module: Capnp.Rpc.Invoke
-- Description: Invoke remote methods
--
-- Support for invoking 'Server.MethodHandler's
module Capnp.Rpc.Invoke
    (
    -- * Using high-level representations
      invokePurePromise
    , (?)
    , invokePure
    , InvokePureCtx

    -- * Using low level representations
    , invokeRaw
    ) where

import Control.Monad.STM.Class

import Control.Monad.Catch     (MonadThrow)
import Control.Monad.Primitive (PrimMonad, PrimState)

import Capnp.Classes
    ( Cerialize(cerialize)
    , Decerialize(Cerial, decerialize)
    , FromPtr(fromPtr)
    , ToStruct(toStruct)
    )
import Capnp.TraversalLimit (defaultLimit, evalLimitT)
import Data.Mutable         (freeze)

import qualified Capnp.Message     as M
import qualified Capnp.Rpc.Promise as Promise
import qualified Capnp.Rpc.Server  as Server
import qualified Capnp.Untyped     as U

-- | Invoke a method by passing it the low-level representation of its parameter,
-- and a 'Fulfiller' that can be used to supply (the low-level representation of)
-- its return value.
invokeRaw ::
    ( MonadThrow m
    , MonadSTM m
    , PrimMonad m
    , Decerialize r
    , Decerialize p
    , ToStruct M.ConstMsg (Cerial M.ConstMsg p)
    , FromPtr M.ConstMsg (Cerial M.ConstMsg r)
    ) =>
    Server.MethodHandler m p r
    -> Cerial M.ConstMsg p
    -> Promise.Fulfiller (Cerial M.ConstMsg r)
    -> m ()
invokeRaw :: MethodHandler m p r
-> Cerial ConstMsg p -> Fulfiller (Cerial ConstMsg r) -> m ()
invokeRaw MethodHandler m p r
method Cerial ConstMsg p
params Fulfiller (Cerial ConstMsg r)
typedFulfiller = do
    (Promise (Maybe (Ptr ConstMsg))
_, Fulfiller (Maybe (Ptr ConstMsg))
untypedFulfiller) <- STM
  (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
-> m (Promise (Maybe (Ptr ConstMsg)),
      Fulfiller (Maybe (Ptr ConstMsg)))
forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM (STM
   (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
 -> m (Promise (Maybe (Ptr ConstMsg)),
       Fulfiller (Maybe (Ptr ConstMsg))))
-> STM
     (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
-> m (Promise (Maybe (Ptr ConstMsg)),
      Fulfiller (Maybe (Ptr ConstMsg)))
forall a b. (a -> b) -> a -> b
$ (Either Exception (Maybe (Ptr ConstMsg)) -> STM ())
-> STM
     (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
forall (m :: * -> *) a.
MonadSTM m =>
(Either Exception a -> STM ()) -> m (Promise a, Fulfiller a)
Promise.newPromiseWithCallback ((Either Exception (Maybe (Ptr ConstMsg)) -> STM ())
 -> STM
      (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg))))
-> (Either Exception (Maybe (Ptr ConstMsg)) -> STM ())
-> STM
     (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
forall a b. (a -> b) -> a -> b
$ \case
        Left Exception
e -> Fulfiller (Cerial ConstMsg r) -> Exception -> STM ()
forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Exception -> m ()
Promise.breakPromise Fulfiller (Cerial ConstMsg r)
typedFulfiller Exception
e
        Right Maybe (Ptr ConstMsg)
v -> WordCount
-> LimitT STM (Cerial ConstMsg r) -> STM (Cerial ConstMsg r)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit (ConstMsg -> Maybe (Ptr ConstMsg) -> LimitT STM (Cerial ConstMsg r)
forall msg a (m :: * -> *).
(FromPtr msg a, ReadCtx m msg) =>
msg -> Maybe (Ptr msg) -> m a
fromPtr ConstMsg
M.empty Maybe (Ptr ConstMsg)
v) STM (Cerial ConstMsg r) -> (Cerial ConstMsg r -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Fulfiller (Cerial ConstMsg r) -> Cerial ConstMsg r -> STM ()
forall (m :: * -> *) a. MonadSTM m => Fulfiller a -> a -> m ()
Promise.fulfill Fulfiller (Cerial ConstMsg r)
typedFulfiller
    MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
-> Maybe (Ptr ConstMsg) -> Fulfiller (Maybe (Ptr ConstMsg)) -> m ()
forall (m :: * -> *).
MonadSTM m =>
MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
-> Maybe (Ptr ConstMsg) -> Fulfiller (Maybe (Ptr ConstMsg)) -> m ()
Server.invoke
        (MethodHandler m p r
-> MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
forall (m :: * -> *) p r.
MethodHandler m p r
-> MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
Server.toUntypedHandler MethodHandler m p r
method)
        (Ptr ConstMsg -> Maybe (Ptr ConstMsg)
forall a. a -> Maybe a
Just (Struct ConstMsg -> Ptr ConstMsg
forall msg. Struct msg -> Ptr msg
U.PtrStruct (Cerial ConstMsg p -> Struct ConstMsg
forall msg a. ToStruct msg a => a -> Struct msg
toStruct Cerial ConstMsg p
params)))
        Fulfiller (Maybe (Ptr ConstMsg))
untypedFulfiller

-- | Shorthand for class contstraints needed to invoke a method using
-- the high-level API.
type InvokePureCtx m p r =
    ( MonadThrow m
    , MonadSTM m
    , PrimMonad m
    , Decerialize r
    , ToStruct M.ConstMsg (Cerial M.ConstMsg p)
    , ToStruct (M.MutMsg (PrimState m)) (Cerial (M.MutMsg (PrimState m)) p)
    , Cerialize (PrimState m) p
    , FromPtr M.ConstMsg (Cerial M.ConstMsg r)
    )

-- | Like 'invokeRaw', but uses the high-level representations of the data
-- types.
invokePure
    :: InvokePureCtx m p r
    => Server.MethodHandler m p r
    -> p
    -> Promise.Fulfiller r
    -> m ()
invokePure :: MethodHandler m p r -> p -> Fulfiller r -> m ()
invokePure MethodHandler m p r
method p
params Fulfiller r
pureFulfiller = do
    Struct ConstMsg
struct <- WordCount -> LimitT m (Struct ConstMsg) -> m (Struct ConstMsg)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit (LimitT m (Struct ConstMsg) -> m (Struct ConstMsg))
-> LimitT m (Struct ConstMsg) -> m (Struct ConstMsg)
forall a b. (a -> b) -> a -> b
$ do
        MutMsg (PrimState m)
msg <- Maybe WordCount -> LimitT m (MutMsg (PrimState m))
forall (m :: * -> *) s.
WriteCtx m s =>
Maybe WordCount -> m (MutMsg s)
M.newMessage Maybe WordCount
forall a. Maybe a
Nothing
        (Cerial (MutMsg (PrimState m)) p -> Struct (MutMsg (PrimState m))
forall msg a. ToStruct msg a => a -> Struct msg
toStruct (Cerial (MutMsg (PrimState m)) p -> Struct (MutMsg (PrimState m)))
-> LimitT m (Cerial (MutMsg (PrimState m)) p)
-> LimitT m (Struct (MutMsg (PrimState m)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutMsg (PrimState m)
-> p -> LimitT m (Cerial (MutMsg (PrimState m)) p)
forall s a (m :: * -> *).
(Cerialize s a, RWCtx m s) =>
MutMsg s -> a -> m (Cerial (MutMsg s) a)
cerialize MutMsg (PrimState m)
msg p
params) LimitT m (Struct (MutMsg (PrimState m)))
-> (Struct (MutMsg (PrimState m)) -> LimitT m (Struct ConstMsg))
-> LimitT m (Struct ConstMsg)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Struct (MutMsg (PrimState m)) -> LimitT m (Struct ConstMsg)
forall a (m :: * -> *) s.
(Thaw a, PrimMonad m, PrimState m ~ s) =>
Mutable s a -> m a
freeze
    (Promise (Maybe (Ptr ConstMsg))
_, Fulfiller (Maybe (Ptr ConstMsg))
untypedFulfiller) <- STM
  (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
-> m (Promise (Maybe (Ptr ConstMsg)),
      Fulfiller (Maybe (Ptr ConstMsg)))
forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM (STM
   (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
 -> m (Promise (Maybe (Ptr ConstMsg)),
       Fulfiller (Maybe (Ptr ConstMsg))))
-> STM
     (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
-> m (Promise (Maybe (Ptr ConstMsg)),
      Fulfiller (Maybe (Ptr ConstMsg)))
forall a b. (a -> b) -> a -> b
$ (Either Exception (Maybe (Ptr ConstMsg)) -> STM ())
-> STM
     (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
forall (m :: * -> *) a.
MonadSTM m =>
(Either Exception a -> STM ()) -> m (Promise a, Fulfiller a)
Promise.newPromiseWithCallback ((Either Exception (Maybe (Ptr ConstMsg)) -> STM ())
 -> STM
      (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg))))
-> (Either Exception (Maybe (Ptr ConstMsg)) -> STM ())
-> STM
     (Promise (Maybe (Ptr ConstMsg)), Fulfiller (Maybe (Ptr ConstMsg)))
forall a b. (a -> b) -> a -> b
$ \case
        Left Exception
e -> Fulfiller r -> Exception -> STM ()
forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Exception -> m ()
Promise.breakPromise Fulfiller r
pureFulfiller Exception
e
        Right Maybe (Ptr ConstMsg)
v ->
            WordCount -> LimitT STM r -> STM r
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit (ConstMsg -> Maybe (Ptr ConstMsg) -> LimitT STM (Cerial ConstMsg r)
forall msg a (m :: * -> *).
(FromPtr msg a, ReadCtx m msg) =>
msg -> Maybe (Ptr msg) -> m a
fromPtr ConstMsg
M.empty Maybe (Ptr ConstMsg)
v LimitT STM (Cerial ConstMsg r)
-> (Cerial ConstMsg r -> LimitT STM r) -> LimitT STM r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cerial ConstMsg r -> LimitT STM r
forall a (m :: * -> *).
(Decerialize a, ReadCtx m ConstMsg) =>
Cerial ConstMsg a -> m a
decerialize)
            STM r -> (r -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Fulfiller r -> r -> STM ()
forall (m :: * -> *) a. MonadSTM m => Fulfiller a -> a -> m ()
Promise.fulfill Fulfiller r
pureFulfiller
    MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
-> Maybe (Ptr ConstMsg) -> Fulfiller (Maybe (Ptr ConstMsg)) -> m ()
forall (m :: * -> *).
MonadSTM m =>
MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
-> Maybe (Ptr ConstMsg) -> Fulfiller (Maybe (Ptr ConstMsg)) -> m ()
Server.invoke
        (MethodHandler m p r
-> MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
forall (m :: * -> *) p r.
MethodHandler m p r
-> MethodHandler m (Maybe (Ptr ConstMsg)) (Maybe (Ptr ConstMsg))
Server.toUntypedHandler MethodHandler m p r
method)
        (Ptr ConstMsg -> Maybe (Ptr ConstMsg)
forall a. a -> Maybe a
Just (Struct ConstMsg -> Ptr ConstMsg
forall msg. Struct msg -> Ptr msg
U.PtrStruct Struct ConstMsg
struct))
        Fulfiller (Maybe (Ptr ConstMsg))
untypedFulfiller

-- | Like 'invokePure', but returns a promise  instead of accepting a fulfiller.
invokePurePromise
    :: InvokePureCtx m p r
    => Server.MethodHandler m p r
    -> p
    -> m (Promise.Promise r)
invokePurePromise :: MethodHandler m p r -> p -> m (Promise r)
invokePurePromise MethodHandler m p r
method p
params = do
    (Promise r
promise, Fulfiller r
fulfiller) <- m (Promise r, Fulfiller r)
forall (m :: * -> *) a. MonadSTM m => m (Promise a, Fulfiller a)
Promise.newPromise
    MethodHandler m p r -> p -> Fulfiller r -> m ()
forall (m :: * -> *) p r.
InvokePureCtx m p r =>
MethodHandler m p r -> p -> Fulfiller r -> m ()
invokePure MethodHandler m p r
method p
params Fulfiller r
fulfiller
    Promise r -> m (Promise r)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Promise r
promise

-- | Alias for 'invokePurePromise'
(?) :: InvokePureCtx m p r
    => Server.MethodHandler m p r
    -> p
    -> m (Promise.Promise r)
? :: MethodHandler m p r -> p -> m (Promise r)
(?) = MethodHandler m p r -> p -> m (Promise r)
forall (m :: * -> *) p r.
InvokePureCtx m p r =>
MethodHandler m p r -> p -> m (Promise r)
invokePurePromise