{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Capnp.Rpc.Server
  ( CallHandler,
    MethodHandler,
    UntypedMethodHandler,
    CallInfo (..),
    ServerOps (..),
    Export (..),
    exportToServerOps,
    findMethod,
    SomeServer (..),
    runServer,
    castHandler,

    -- * Helpers for writing method handlers
    handleParsed,
    handleRaw,
    methodUnimplemented,
    toUntypedMethodHandler,

    -- * Internals; exposed only for use by generated code.
    MethodHandlerTree (..),
  )
where

import qualified Capnp.Basics as B
import qualified Capnp.Classes as C
import Capnp.Convert (parsedToRaw)
import Capnp.Message (Mutability (..))
import qualified Capnp.Repr as R
-- import Capnp.Repr.Methods (Client (..))
import Capnp.Rpc.Errors
  ( eFailed,
    eMethodUnimplemented,
    wrapException,
  )
import Capnp.Rpc.Promise
  ( Fulfiller,
    breakPromise,
    fulfill,
  )
import Capnp.TraversalLimit (defaultLimit, evalLimitT)
import qualified Capnp.Untyped as U
import Control.Concurrent.STM (atomically)
import Control.Exception.Safe (withException)
import Data.Function ((&))
import Data.Functor.Contravariant (contramap)
import Data.Kind (Constraint, Type)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy (..))
import Data.Typeable (Typeable)
import qualified Data.Vector as V
import Data.Word
import GHC.Prim (coerce)
import Internal.BuildPure (createPure)
import qualified Internal.TCloseQ as TCloseQ

-- | A 'CallInfo' contains information about a method call.
data CallInfo = CallInfo
  { -- | The id of the interface whose method is being called.
    CallInfo -> Word64
interfaceId :: !Word64,
    -- | The method id of the method being called.
    CallInfo -> Word16
methodId :: !Word16,
    -- | The arguments to the method call.
    CallInfo -> Maybe (Ptr 'Const)
arguments :: Maybe (U.Ptr 'Const),
    -- | A 'Fulfiller' which accepts the method's return value.
    CallInfo -> Fulfiller (Maybe (Ptr 'Const))
response :: Fulfiller (Maybe (U.Ptr 'Const))
  }

-- | The operations necessary to receive and handle method calls, i.e.
-- to implement an object.
data ServerOps = ServerOps
  { -- | Handle a method call; takes the interface and method id and returns
    -- a handler for the specific method.
    ServerOps -> Word64 -> Word16 -> UntypedMethodHandler
handleCall :: Word64 -> Word16 -> UntypedMethodHandler,
    -- | Handle shutting-down the receiver; this is called when the last
    -- reference to the capability is dropped.
    ServerOps -> IO ()
handleStop :: IO (),
    -- | used to unwrap the server when reflecting on a local client.
    ServerOps -> forall a. Typeable a => Maybe a
handleCast :: forall a. Typeable a => Maybe a
  }

-- | A handler for arbitrary RPC calls. Maps (interfaceId, methodId) pairs to
-- 'UntypedMethodHandler's.
type CallHandler = M.Map Word64 (V.Vector UntypedMethodHandler)

-- | Type alias for a handler for a particular rpc method.
type MethodHandler p r =
  R.Raw p 'Const ->
  Fulfiller (R.Raw r 'Const) ->
  IO ()

castHandler ::
  forall p q r s.
  (R.ReprFor p ~ R.ReprFor q, R.ReprFor r ~ R.ReprFor s) =>
  MethodHandler p r ->
  MethodHandler q s
castHandler :: forall p q r s.
(ReprFor p ~ ReprFor q, ReprFor r ~ ReprFor s) =>
MethodHandler p r -> MethodHandler q s
castHandler = coerce :: forall a b. Coercible a b => a -> b
coerce

-- | Type alias for a handler for an untyped RPC method.
type UntypedMethodHandler = MethodHandler (Maybe B.AnyPointer) (Maybe B.AnyPointer)

-- | Base class for things that can act as capnproto servers.
class SomeServer a where
  -- | Called when the last live reference to a server is dropped.
  shutdown :: a -> IO ()
  shutdown a
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  -- | Try to extract a value of a given type. The default implementation
  -- always fails (returns 'Nothing'). If an instance chooses to implement
  -- this, it will be possible to use "reflection" on clients that point
  -- at local servers to dynamically unwrap the server value. A typical
  -- implementation will just call Typeable's @cast@ method, but this
  -- needn't be the case -- a server may wish to allow local peers to
  -- unwrap some value that is not exactly the data the server has access
  -- to.
  unwrap :: Typeable b => a -> Maybe b
  unwrap a
_ = forall a. Maybe a
Nothing

-- | Generated interface types have instances of 'Export', which allows a server
-- for that interface to be exported as a 'Client'.
class (R.IsCap i, C.HasTypeId i) => Export i where
  -- | The constraint needed for a server to implement an interface;
  -- if @'Server' i s@ is satisfied, @s@ is a server for interface @i@.
  -- The code generator generates a type class for each interface, and
  -- this will aways be an alias for that type class.
  type Server i :: Type -> Constraint

  -- | Convert the server to a 'MethodHandlerTree' populated with appropriate
  -- 'MethodHandler's for the interface. This is really only exported for use
  -- by generated code; users of the library will generally prefer to use
  -- 'export'.
  methodHandlerTree :: Server i s => Proxy i -> s -> MethodHandlerTree

-- NB: the proxy helps disambiguate types; for some reason TypeApplications
-- doesn't seem to be enough in the face of a type alias of kind 'Constraint'.
-- the inconsistency is a bit ugly, but this method isn't intended to called
-- by users directly, only by generated code and our helper in this module,
-- so it's less of a big deal.

-- | Lazily computed tree of the method handlers exposed by an interface. Only
-- of interest to generated code.
data MethodHandlerTree = MethodHandlerTree
  { -- | type id for the primary interface
    MethodHandlerTree -> Word64
mhtId :: Word64,
    -- | method handlers for methods of the primary interface.
    MethodHandlerTree -> [UntypedMethodHandler]
mhtHandlers :: [UntypedMethodHandler],
    -- | Trees for parent interfaces. In the case of diamond dependencies,
    -- there may be duplicates, which are eliminated by 'mhtToCallHandler'.
    MethodHandlerTree -> [MethodHandlerTree]
mhtParents :: [MethodHandlerTree]
  }

mhtToCallHandler :: MethodHandlerTree -> CallHandler
mhtToCallHandler :: MethodHandlerTree -> CallHandler
mhtToCallHandler = CallHandler -> [MethodHandlerTree] -> CallHandler
go forall k a. Map k a
M.empty forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure
  where
    go :: CallHandler -> [MethodHandlerTree] -> CallHandler
go CallHandler
accum [] = CallHandler
accum
    go CallHandler
accum (MethodHandlerTree
t : [MethodHandlerTree]
ts)
      | MethodHandlerTree -> Word64
mhtId MethodHandlerTree
t forall k a. Ord k => k -> Map k a -> Bool
`M.member` CallHandler
accum = CallHandler -> [MethodHandlerTree] -> CallHandler
go CallHandler
accum [MethodHandlerTree]
ts -- dedup diamond dependencies
      | Bool
otherwise =
          CallHandler -> [MethodHandlerTree] -> CallHandler
go (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (MethodHandlerTree -> Word64
mhtId MethodHandlerTree
t) (forall a. [a] -> Vector a
V.fromList (MethodHandlerTree -> [UntypedMethodHandler]
mhtHandlers MethodHandlerTree
t)) CallHandler
accum) (MethodHandlerTree -> [MethodHandlerTree]
mhtParents MethodHandlerTree
t forall a. [a] -> [a] -> [a]
++ [MethodHandlerTree]
ts)

-- | Look up a particlar 'MethodHandler' in the 'CallHandler'.
findMethod :: Word64 -> Word16 -> CallHandler -> Maybe UntypedMethodHandler
findMethod :: Word64 -> Word16 -> CallHandler -> Maybe UntypedMethodHandler
findMethod Word64
interfaceId Word16
methodId CallHandler
handler = do
  Vector UntypedMethodHandler
iface <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Word64
interfaceId CallHandler
handler
  Vector UntypedMethodHandler
iface forall a. Vector a -> Int -> Maybe a
V.!? forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
methodId

-- | Convert a typed method handler to an untyped one. Mostly intended for
-- use by generated code.
toUntypedMethodHandler ::
  forall p r.
  (R.IsStruct p, R.IsStruct r) =>
  MethodHandler p r ->
  UntypedMethodHandler
toUntypedMethodHandler :: forall p r.
(IsStruct p, IsStruct r) =>
MethodHandler p r -> UntypedMethodHandler
toUntypedMethodHandler MethodHandler p r
h =
  \case
    R.Raw (Just (U.PtrStruct Struct 'Const
param)) -> \Fulfiller (Raw (Maybe AnyPointer) 'Const)
ret ->
      MethodHandler p r
h
        (forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw Struct 'Const
param)
        ( forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap
            (\(R.Raw Unwrapped (Untyped (ReprFor r) 'Const)
s) -> forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (mut :: Mutability). Struct mut -> Ptr mut
U.PtrStruct Unwrapped (Untyped (ReprFor r) 'Const)
s)
            Fulfiller (Raw (Maybe AnyPointer) 'Const)
ret
        )
    Raw (Maybe AnyPointer) 'Const
_ ->
      \Fulfiller (Raw (Maybe AnyPointer) 'Const)
ret -> forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Parsed Exception -> m ()
breakPromise Fulfiller (Raw (Maybe AnyPointer) 'Const)
ret (Text -> Parsed Exception
eFailed Text
"Parameter was not a struct")

someServerToServerOps :: SomeServer a => a -> CallHandler -> ServerOps
someServerToServerOps :: forall a. SomeServer a => a -> CallHandler -> ServerOps
someServerToServerOps a
srv CallHandler
callHandler =
  ServerOps
    { $sel:handleStop:ServerOps :: IO ()
handleStop = forall a. SomeServer a => a -> IO ()
shutdown a
srv,
      $sel:handleCast:ServerOps :: forall a. Typeable a => Maybe a
handleCast = forall a b. (SomeServer a, Typeable b) => a -> Maybe b
unwrap a
srv,
      $sel:handleCall:ServerOps :: Word64 -> Word16 -> UntypedMethodHandler
handleCall = \Word64
interfaceId Word16
methodId ->
        Word64 -> Word16 -> CallHandler -> Maybe UntypedMethodHandler
findMethod Word64
interfaceId Word16
methodId CallHandler
callHandler
          forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a -> a
fromMaybe forall p r. MethodHandler p r
methodUnimplemented
    }

exportToServerOps :: forall i s. (Export i, Server i s, SomeServer s) => Proxy i -> s -> ServerOps
exportToServerOps :: forall i s.
(Export i, Server i s, SomeServer s) =>
Proxy i -> s -> ServerOps
exportToServerOps Proxy i
proxy s
srv =
  MethodHandlerTree -> CallHandler
mhtToCallHandler (forall i s.
(Export i, Server i s) =>
Proxy i -> s -> MethodHandlerTree
methodHandlerTree Proxy i
proxy s
srv)
    forall a b. a -> (a -> b) -> b
& forall a. SomeServer a => a -> CallHandler -> ServerOps
someServerToServerOps s
srv

-- Helpers for writing method handlers

-- | Handle a method, working with the parsed form of parameters and
-- results.
handleParsed ::
  ( C.Parse p pp,
    R.IsStruct p,
    C.Parse r pr,
    R.IsStruct r
  ) =>
  (pp -> IO pr) ->
  MethodHandler p r
handleParsed :: forall p pp r pr.
(Parse p pp, IsStruct p, Parse r pr, IsStruct r) =>
(pp -> IO pr) -> MethodHandler p r
handleParsed pp -> IO pr
handler Raw p 'Const
param = forall a b. (Fulfiller a -> IO b) -> Fulfiller a -> IO b
propagateExceptions forall a b. (a -> b) -> a -> b
$ \Fulfiller (Raw r 'Const)
f -> do
  pp
p <- forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
defaultLimit forall a b. (a -> b) -> a -> b
$ forall t p (m :: * -> *).
(Parse t p, ReadCtx m 'Const) =>
Raw t 'Const -> m p
C.parse Raw p 'Const
param
  pr
r <- pp -> IO pr
handler pp
p
  -- TODO: Figure out how to add an instance of Thaw for
  -- Raw so we can skip the (un)wrapping here.
  Struct 'Const
struct <- forall (m :: * -> *) (f :: Mutability -> *).
(MonadThrow m, MaybeMutable f) =>
WordCount -> (forall s. PureBuilder s (f ('Mut s))) -> m (f 'Const)
createPure forall a. Bounded a => a
maxBound forall a b. (a -> b) -> a -> b
$ forall a (mut :: Mutability).
Raw a mut -> Unwrapped (Untyped (ReprFor a) mut)
R.fromRaw forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *) pa s.
(RWCtx m s, IsStruct a, Parse a pa) =>
pa -> m (Raw a ('Mut s))
parsedToRaw pr
r
  forall (m :: * -> *) a. MonadSTM m => Fulfiller a -> a -> m ()
fulfill Fulfiller (Raw r 'Const)
f (forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw Struct 'Const
struct)

-- | Handle a method, working with the raw (unparsed) form of
-- parameters and results.
handleRaw ::
  (R.IsStruct p, R.IsStruct r) =>
  (R.Raw p 'Const -> IO (R.Raw r 'Const)) ->
  MethodHandler p r
handleRaw :: forall p r.
(IsStruct p, IsStruct r) =>
(Raw p 'Const -> IO (Raw r 'Const)) -> MethodHandler p r
handleRaw Raw p 'Const -> IO (Raw r 'Const)
handler Raw p 'Const
param = forall a b. (Fulfiller a -> IO b) -> Fulfiller a -> IO b
propagateExceptions forall a b. (a -> b) -> a -> b
$ \Fulfiller (Raw r 'Const)
f ->
  Raw p 'Const -> IO (Raw r 'Const)
handler Raw p 'Const
param forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. MonadSTM m => Fulfiller a -> a -> m ()
fulfill Fulfiller (Raw r 'Const)
f

-- Helper for handle*; breaks the promise if the handler throws.
propagateExceptions :: (Fulfiller a -> IO b) -> Fulfiller a -> IO b
propagateExceptions :: forall a b. (Fulfiller a -> IO b) -> Fulfiller a -> IO b
propagateExceptions Fulfiller a -> IO b
h Fulfiller a
f =
  Fulfiller a -> IO b
h Fulfiller a
f forall (m :: * -> *) e a b.
(MonadMask m, Exception e) =>
m a -> (e -> m b) -> m a
`withException` (forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Parsed Exception -> m ()
breakPromise Fulfiller a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> SomeException -> Parsed Exception
wrapException Bool
False)

-- | 'MethodHandler' that always throws unimplemented.
methodUnimplemented :: MethodHandler p r
methodUnimplemented :: forall p r. MethodHandler p r
methodUnimplemented Raw p 'Const
_ Fulfiller (Raw r 'Const)
f = forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Parsed Exception -> m ()
breakPromise Fulfiller (Raw r 'Const)
f Parsed Exception
eMethodUnimplemented

-- | Handle incoming messages for a given object.
--
-- Accepts a queue of messages to handle, and 'ServerOps' used to handle them.
-- returns when it receives a 'Stop' message.
runServer :: TCloseQ.Q CallInfo -> ServerOps -> IO ()
runServer :: Q CallInfo -> ServerOps -> IO ()
runServer Q CallInfo
q ServerOps
ops = IO ()
go
  where
    go :: IO ()
go =
      forall a. STM a -> IO a
atomically (forall a. Q a -> STM (Maybe a)
TCloseQ.read Q CallInfo
q) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe CallInfo
Nothing ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Just CallInfo {Word64
interfaceId :: Word64
$sel:interfaceId:CallInfo :: CallInfo -> Word64
interfaceId, Word16
methodId :: Word16
$sel:methodId:CallInfo :: CallInfo -> Word16
methodId, Maybe (Ptr 'Const)
arguments :: Maybe (Ptr 'Const)
$sel:arguments:CallInfo :: CallInfo -> Maybe (Ptr 'Const)
arguments, Fulfiller (Maybe (Ptr 'Const))
response :: Fulfiller (Maybe (Ptr 'Const))
$sel:response:CallInfo :: CallInfo -> Fulfiller (Maybe (Ptr 'Const))
response} ->
          do
            ServerOps -> Word64 -> Word16 -> UntypedMethodHandler
handleCall ServerOps
ops Word64
interfaceId Word16
methodId (forall a (mut :: Mutability).
Unwrapped (Untyped (ReprFor a) mut) -> Raw a mut
R.Raw Maybe (Ptr 'Const)
arguments) (coerce :: forall a b. Coercible a b => a -> b
coerce Fulfiller (Maybe (Ptr 'Const))
response)
            IO ()
go