{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase       #-}
{-# LANGUAGE TypeFamilies     #-}
-- |
-- Module: Capnp.Rpc.Invoke
-- Description: Invoke remote methods
--
-- Support for invoking 'Server.MethodHandler's
module Capnp.Rpc.Invoke
    (
    -- * Using high-level representations
      invokePurePromise
    , (?)
    , invokePure
    , InvokePureCtx

    -- * Using low level representations
    , invokeRaw
    ) where

import Control.Monad.STM.Class

import Control.Monad.Catch     (MonadThrow)
import Control.Monad.Primitive (PrimMonad, PrimState)

import Capnp.Classes
    ( Cerialize(cerialize)
    , Decerialize(Cerial, decerialize)
    , FromPtr(fromPtr)
    , ToStruct(toStruct)
    )
import Capnp.Message        (Mutability(..))
import Capnp.TraversalLimit (defaultLimit, evalLimitT)
import Data.Mutable         (freeze)

import qualified Capnp.Message     as M
import qualified Capnp.Rpc.Promise as Promise
import qualified Capnp.Rpc.Server  as Server
import qualified Capnp.Untyped     as U

-- | Invoke a method by passing it the low-level representation of its parameter,
-- and a 'Fulfiller' that can be used to supply (the low-level representation of)
-- its return value.
invokeRaw ::
    ( MonadThrow m
    , MonadSTM m
    , PrimMonad m
    , Decerialize r
    , Decerialize p
    , ToStruct 'Const (Cerial 'Const p)
    , FromPtr 'Const (Cerial 'Const r)
    ) =>
    Server.MethodHandler m p r
    -> Cerial 'Const p
    -> Promise.Fulfiller (Cerial 'Const r)
    -> m ()
invokeRaw :: MethodHandler m p r
-> Cerial 'Const p -> Fulfiller (Cerial 'Const r) -> m ()
invokeRaw MethodHandler m p r
method Cerial 'Const p
params Fulfiller (Cerial 'Const r)
typedFulfiller = do
    (Promise (Maybe (Ptr 'Const))
_, Fulfiller (Maybe (Ptr 'Const))
untypedFulfiller) <- STM (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
-> m (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM (STM (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
 -> m (Promise (Maybe (Ptr 'Const)),
       Fulfiller (Maybe (Ptr 'Const))))
-> STM
     (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
-> m (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall a b. (a -> b) -> a -> b
$ (Either Exception (Maybe (Ptr 'Const)) -> STM ())
-> STM
     (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall (m :: * -> *) a.
MonadSTM m =>
(Either Exception a -> STM ()) -> m (Promise a, Fulfiller a)
Promise.newPromiseWithCallback ((Either Exception (Maybe (Ptr 'Const)) -> STM ())
 -> STM
      (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const))))
-> (Either Exception (Maybe (Ptr 'Const)) -> STM ())
-> STM
     (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall a b. (a -> b) -> a -> b
$ \case
        Left Exception
e -> Fulfiller (Cerial 'Const r) -> Exception -> STM ()
forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Exception -> m ()
Promise.breakPromise Fulfiller (Cerial 'Const r)
typedFulfiller Exception
e
        Right Maybe (Ptr 'Const)
v -> WordCount -> LimitT STM (Cerial 'Const r) -> STM (Cerial 'Const r)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit (Message 'Const
-> Maybe (Ptr 'Const) -> LimitT STM (Cerial 'Const r)
forall (mut :: Mutability) a (m :: * -> *).
(FromPtr mut a, ReadCtx m mut) =>
Message mut -> Maybe (Ptr mut) -> m a
fromPtr Message 'Const
M.empty Maybe (Ptr 'Const)
v) STM (Cerial 'Const r) -> (Cerial 'Const r -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Fulfiller (Cerial 'Const r) -> Cerial 'Const r -> STM ()
forall (m :: * -> *) a. MonadSTM m => Fulfiller a -> a -> m ()
Promise.fulfill Fulfiller (Cerial 'Const r)
typedFulfiller
    MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> m ()
forall (m :: * -> *).
MonadSTM m =>
MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> m ()
Server.invoke
        (MethodHandler m p r
-> MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
forall (m :: * -> *) p r.
MethodHandler m p r
-> MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
Server.toUntypedHandler MethodHandler m p r
method)
        (Ptr 'Const -> Maybe (Ptr 'Const)
forall a. a -> Maybe a
Just (Struct 'Const -> Ptr 'Const
forall (mut :: Mutability). Struct mut -> Ptr mut
U.PtrStruct (Cerial 'Const p -> Struct 'Const
forall (mut :: Mutability) a. ToStruct mut a => a -> Struct mut
toStruct Cerial 'Const p
params)))
        Fulfiller (Maybe (Ptr 'Const))
untypedFulfiller

-- | Shorthand for class contstraints needed to invoke a method using
-- the high-level API.
type InvokePureCtx m p r =
    ( MonadThrow m
    , MonadSTM m
    , PrimMonad m
    , Decerialize r
    , ToStruct 'Const (Cerial 'Const p)
    , ToStruct ('Mut (PrimState m)) (Cerial ('Mut (PrimState m)) p)
    , Cerialize (PrimState m) p
    , FromPtr 'Const (Cerial 'Const r)
    )

-- | Like 'invokeRaw', but uses the high-level representations of the data
-- types.
invokePure
    :: InvokePureCtx m p r
    => Server.MethodHandler m p r
    -> p
    -> Promise.Fulfiller r
    -> m ()
invokePure :: MethodHandler m p r -> p -> Fulfiller r -> m ()
invokePure MethodHandler m p r
method p
params Fulfiller r
pureFulfiller = do
    Struct 'Const
struct <- WordCount -> LimitT m (Struct 'Const) -> m (Struct 'Const)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit (LimitT m (Struct 'Const) -> m (Struct 'Const))
-> LimitT m (Struct 'Const) -> m (Struct 'Const)
forall a b. (a -> b) -> a -> b
$ do
        Message ('Mut (PrimState m))
msg <- Maybe WordCount -> LimitT m (Message ('Mut (PrimState m)))
forall (m :: * -> *) s.
WriteCtx m s =>
Maybe WordCount -> m (Message ('Mut s))
M.newMessage Maybe WordCount
forall a. Maybe a
Nothing
        Message ('Mut (PrimState m))
-> p -> LimitT m (Cerial ('Mut (PrimState m)) p)
forall s a (m :: * -> *).
(Cerialize s a, RWCtx m s) =>
Message ('Mut s) -> a -> m (Cerial ('Mut s) a)
cerialize Message ('Mut (PrimState m))
msg p
params LimitT m (Cerial ('Mut (PrimState m)) p)
-> (Cerial ('Mut (PrimState m)) p -> LimitT m (Struct 'Const))
-> LimitT m (Struct 'Const)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Struct ('Mut (PrimState m)) -> LimitT m (Struct 'Const)
forall a (m :: * -> *) s.
(Thaw a, PrimMonad m, PrimState m ~ s) =>
Mutable s a -> m a
freeze (Struct ('Mut (PrimState m)) -> LimitT m (Struct 'Const))
-> (Cerial ('Mut (PrimState m)) p -> Struct ('Mut (PrimState m)))
-> Cerial ('Mut (PrimState m)) p
-> LimitT m (Struct 'Const)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cerial ('Mut (PrimState m)) p -> Struct ('Mut (PrimState m))
forall (mut :: Mutability) a. ToStruct mut a => a -> Struct mut
toStruct
    (Promise (Maybe (Ptr 'Const))
_, Fulfiller (Maybe (Ptr 'Const))
untypedFulfiller) <- STM (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
-> m (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM (STM (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
 -> m (Promise (Maybe (Ptr 'Const)),
       Fulfiller (Maybe (Ptr 'Const))))
-> STM
     (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
-> m (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall a b. (a -> b) -> a -> b
$ (Either Exception (Maybe (Ptr 'Const)) -> STM ())
-> STM
     (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall (m :: * -> *) a.
MonadSTM m =>
(Either Exception a -> STM ()) -> m (Promise a, Fulfiller a)
Promise.newPromiseWithCallback ((Either Exception (Maybe (Ptr 'Const)) -> STM ())
 -> STM
      (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const))))
-> (Either Exception (Maybe (Ptr 'Const)) -> STM ())
-> STM
     (Promise (Maybe (Ptr 'Const)), Fulfiller (Maybe (Ptr 'Const)))
forall a b. (a -> b) -> a -> b
$ \case
        Left Exception
e -> Fulfiller r -> Exception -> STM ()
forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Exception -> m ()
Promise.breakPromise Fulfiller r
pureFulfiller Exception
e
        Right Maybe (Ptr 'Const)
v ->
            WordCount -> LimitT STM r -> STM r
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit (Message 'Const
-> Maybe (Ptr 'Const) -> LimitT STM (Cerial 'Const r)
forall (mut :: Mutability) a (m :: * -> *).
(FromPtr mut a, ReadCtx m mut) =>
Message mut -> Maybe (Ptr mut) -> m a
fromPtr Message 'Const
M.empty Maybe (Ptr 'Const)
v LimitT STM (Cerial 'Const r)
-> (Cerial 'Const r -> LimitT STM r) -> LimitT STM r
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Cerial 'Const r -> LimitT STM r
forall a (m :: * -> *).
(Decerialize a, ReadCtx m 'Const) =>
Cerial 'Const a -> m a
decerialize)
            STM r -> (r -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Fulfiller r -> r -> STM ()
forall (m :: * -> *) a. MonadSTM m => Fulfiller a -> a -> m ()
Promise.fulfill Fulfiller r
pureFulfiller
    MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> m ()
forall (m :: * -> *).
MonadSTM m =>
MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> m ()
Server.invoke
        (MethodHandler m p r
-> MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
forall (m :: * -> *) p r.
MethodHandler m p r
-> MethodHandler m (Maybe (Ptr 'Const)) (Maybe (Ptr 'Const))
Server.toUntypedHandler MethodHandler m p r
method)
        (Ptr 'Const -> Maybe (Ptr 'Const)
forall a. a -> Maybe a
Just (Struct 'Const -> Ptr 'Const
forall (mut :: Mutability). Struct mut -> Ptr mut
U.PtrStruct Struct 'Const
struct))
        Fulfiller (Maybe (Ptr 'Const))
untypedFulfiller

-- | Like 'invokePure', but returns a promise  instead of accepting a fulfiller.
invokePurePromise
    :: InvokePureCtx m p r
    => Server.MethodHandler m p r
    -> p
    -> m (Promise.Promise r)
invokePurePromise :: MethodHandler m p r -> p -> m (Promise r)
invokePurePromise MethodHandler m p r
method p
params = do
    (Promise r
promise, Fulfiller r
fulfiller) <- m (Promise r, Fulfiller r)
forall (m :: * -> *) a. MonadSTM m => m (Promise a, Fulfiller a)
Promise.newPromise
    MethodHandler m p r -> p -> Fulfiller r -> m ()
forall (m :: * -> *) p r.
InvokePureCtx m p r =>
MethodHandler m p r -> p -> Fulfiller r -> m ()
invokePure MethodHandler m p r
method p
params Fulfiller r
fulfiller
    Promise r -> m (Promise r)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Promise r
promise

-- | Alias for 'invokePurePromise'
(?) :: InvokePureCtx m p r
    => Server.MethodHandler m p r
    -> p
    -> m (Promise.Promise r)
? :: MethodHandler m p r -> p -> m (Promise r)
(?) = MethodHandler m p r -> p -> m (Promise r)
forall (m :: * -> *) p r.
InvokePureCtx m p r =>
MethodHandler m p r -> p -> m (Promise r)
invokePurePromise