{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Capnp.Rpc.Untyped
(
ConnConfig(..)
, handleConn
, Client
, call
, nullClient
, newPromiseClient
, IsClient(..)
, Pipeline
, walkPipelinePtr
, pipelineClient
, export
, clientMethodHandler
, unwrapServer
, RpcError(..)
, R.Exception(..)
, R.Exception'Type(..)
) where
import Control.Concurrent.STM
import Control.Monad.STM.Class
import Data.Word
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (concurrently_, race_)
import Control.Concurrent.MVar (MVar, newEmptyMVar)
import Control.Exception.Safe (Exception, bracket, throwIO, try)
import Control.Monad (forever, void, when)
import Data.Default (Default(def))
import Data.Foldable (for_, toList, traverse_)
import Data.Hashable (Hashable, hash, hashWithSalt)
import Data.Maybe (catMaybes)
import Data.String (fromString)
import Data.Text (Text)
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import Supervisors (Supervisor, superviseSTM, withSupervisor)
import System.Mem.StableName (StableName, hashStableName, makeStableName)
import System.Timeout (timeout)
import qualified Data.Vector as V
import qualified Focus
import qualified ListT
import qualified StmContainers.Map as M
import Capnp.Classes (cerialize, decerialize)
import Capnp.Convert (msgToValue, valueToMsg)
import Capnp.Message (ConstMsg)
import Capnp.Rpc.Errors
( eDisconnected
, eFailed
, eMethodUnimplemented
, eUnimplemented
, wrapException
)
import Capnp.Rpc.Promise
(Fulfiller, breakOrFulfill, breakPromise, fulfill, newCallback)
import Capnp.Rpc.Transport (Transport(recvMsg, sendMsg))
import Capnp.TraversalLimit (defaultLimit, evalLimitT)
import Internal.BuildPure (createPure)
import Internal.Rc (Rc)
import Internal.SnocList (SnocList)
import qualified Capnp.Gen.Capnp.Rpc.Pure as R
import qualified Capnp.Message as Message
import qualified Capnp.Rpc.Server as Server
import qualified Capnp.Untyped as UntypedRaw
import qualified Capnp.Untyped.Pure as Untyped
import qualified Internal.Finalizer as Fin
import qualified Internal.Rc as Rc
import qualified Internal.SnocList as SnocList
import qualified Internal.TCloseQ as TCloseQ
type MPtr = Maybe Untyped.Ptr
type RawMPtr = Maybe (UntypedRaw.Ptr ConstMsg)
data RpcError
= ReceivedAbort R.Exception
| SentAbort R.Exception
deriving(Show, Eq, Generic)
instance Exception RpcError
newtype EmbargoId = EmbargoId { embargoWord :: Word32 } deriving(Eq, Hashable)
newtype QAId = QAId { qaWord :: Word32 } deriving(Eq, Hashable)
newtype IEId = IEId { ieWord :: Word32 } deriving(Eq, Hashable)
instance Show QAId where
show = show . qaWord
instance Show IEId where
show = show . ieWord
data Conn = Conn
{ stableName :: StableName (MVar ())
, debugMode :: !Bool
, liveState :: TVar LiveState
}
data LiveState
= Live Conn'
| Dead
data Conn' = Conn'
{ sendQ :: TBQueue ConstMsg
, recvQ :: TBQueue ConstMsg
, supervisor :: Supervisor
, questionIdPool :: IdPool
, exportIdPool :: IdPool
, questions :: M.Map QAId EntryQA
, answers :: M.Map QAId EntryQA
, exports :: M.Map IEId EntryE
, imports :: M.Map IEId EntryI
, embargos :: M.Map EmbargoId (Fulfiller ())
, pendingCallbacks :: TQueue (IO ())
, bootstrap :: Maybe Client
}
instance Eq Conn where
x == y = stableName x == stableName y
instance Hashable Conn where
hash Conn{stableName} = hashStableName stableName
hashWithSalt _ = hash
data ConnConfig = ConnConfig
{ maxQuestions :: !Word32
, maxExports :: !Word32
, debugMode :: !Bool
, getBootstrap :: Supervisor -> STM (Maybe Client)
, withBootstrap :: Maybe (Supervisor -> Client -> IO ())
}
instance Default ConnConfig where
def = ConnConfig
{ maxQuestions = 128
, maxExports = 8192
, debugMode = False
, getBootstrap = \_ -> pure Nothing
, withBootstrap = Nothing
}
queueIO :: Conn' -> IO () -> STM ()
queueIO Conn'{pendingCallbacks} = writeTQueue pendingCallbacks
queueSTM :: Conn' -> STM () -> STM ()
queueSTM conn = queueIO conn . atomically
mapQueueSTM :: Conn' -> SnocList (a -> STM ()) -> a -> STM ()
mapQueueSTM conn fs x = traverse_ (\f -> queueSTM conn (f x)) fs
newQuestion :: Conn' -> STM QAId
newQuestion = fmap QAId . newId . questionIdPool
freeQuestion :: Conn' -> QAId -> STM ()
freeQuestion conn = freeId (questionIdPool conn) . qaWord
newExport :: Conn' -> STM IEId
newExport = fmap IEId . newId . exportIdPool
freeExport :: Conn' -> IEId -> STM ()
freeExport conn = freeId (exportIdPool conn) . ieWord
newEmbargo :: Conn' -> STM EmbargoId
newEmbargo = fmap EmbargoId . newId . questionIdPool
freeEmbargo :: Conn' -> EmbargoId -> STM ()
freeEmbargo conn = freeId (exportIdPool conn) . embargoWord
handleConn :: Transport -> ConnConfig -> IO ()
handleConn
transport
cfg@ConnConfig
{ maxQuestions
, maxExports
, withBootstrap
, debugMode
}
= withSupervisor $ \sup ->
bracket
(newConn sup)
stopConn
runConn
where
newConn sup = do
stableName <- makeStableName =<< newEmptyMVar
atomically $ do
bootstrap <- getBootstrap cfg sup
questionIdPool <- newIdPool maxQuestions
exportIdPool <- newIdPool maxExports
sendQ <- newTBQueue $ fromIntegral maxQuestions
recvQ <- newTBQueue $ fromIntegral maxQuestions
questions <- M.new
answers <- M.new
exports <- M.new
imports <- M.new
embargos <- M.new
pendingCallbacks <- newTQueue
let conn' = Conn'
{ supervisor = sup
, questionIdPool
, exportIdPool
, recvQ
, sendQ
, questions
, answers
, exports
, imports
, embargos
, pendingCallbacks
, bootstrap
}
liveState <- newTVar (Live conn')
let conn = Conn
{ stableName
, debugMode
, liveState
}
pure (conn, conn')
runConn (conn, conn') = do
result <- try $
( coordinator conn
`concurrently_` sendLoop transport conn'
`concurrently_` recvLoop transport conn'
`concurrently_` callbacksLoop conn'
) `race_`
useBootstrap conn conn'
case result of
Left (SentAbort e) -> do
rawMsg <- createPure maxBound $ valueToMsg $ R.Message'abort e
void $ timeout 1000000 $ sendMsg transport rawMsg
throwIO $ SentAbort e
Left e ->
throwIO e
Right _ ->
pure ()
stopConn
( conn@Conn{liveState}
, conn'@Conn'{questions, exports, embargos}
) = do
atomically $ do
let walk table = flip ListT.traverse_ (M.listT table)
case bootstrap conn' of
Just (Client (Just client')) -> dropConnExport conn client'
_ -> pure ()
walk exports $ \(_, EntryE{client}) ->
dropConnExport conn client
walk questions $ \(QAId qid, entry) ->
let raiseDisconnected onReturn =
mapQueueSTM conn' onReturn $ R.Return
{ answerId = qid
, releaseParamCaps = False
, union' = R.Return'exception eDisconnected
}
in case entry of
NewQA{onReturn} -> raiseDisconnected onReturn
HaveFinish{onReturn} -> raiseDisconnected onReturn
_ -> pure ()
walk embargos $ \(_, fulfiller) ->
breakPromise fulfiller eDisconnected
writeTVar liveState Dead
flushCallbacks conn'
useBootstrap conn conn' = case withBootstrap of
Nothing ->
forever $ threadDelay maxBound
Just f ->
atomically (requestBootstrap conn) >>= f (supervisor conn')
newtype IdPool = IdPool (TVar [Word32])
newIdPool :: Word32 -> STM IdPool
newIdPool size = IdPool <$> newTVar [0..size-1]
newId :: IdPool -> STM Word32
newId (IdPool pool) = readTVar pool >>= \case
[] -> retry
(id:ids) -> do
writeTVar pool $! ids
pure id
freeId :: IdPool -> Word32 -> STM ()
freeId (IdPool pool) id = modifyTVar' pool (id:)
data EntryQA
= NewQA
{ onFinish :: SnocList (R.Finish -> STM ())
, onReturn :: SnocList (R.Return -> STM ())
}
| HaveReturn
{ returnMsg :: R.Return
, onFinish :: SnocList (R.Finish -> STM ())
}
| HaveFinish
{ finishMsg :: R.Finish
, onReturn :: SnocList (R.Return -> STM ())
}
data EntryI = EntryI
{ localRc :: Rc ()
, remoteRc :: !Word32
, proxies :: ExportMap
, promiseState :: Maybe
( TVar PromiseState
, TmpDest
)
}
data EntryE = EntryE
{ client :: Client'
, refCount :: !Word32
}
class IsClient a where
toClient :: a -> Client
fromClient :: Client -> a
instance IsClient Client where
toClient = id
fromClient = id
instance Show Client where
show (Client Nothing) = "nullClient"
show _ = "({- capability; not statically representable -})"
newtype Client =
Client (Maybe Client')
deriving(Eq)
data Client'
= LocalClient
{ exportMap :: ExportMap
, qCall :: Rc (Server.CallInfo -> STM ())
, finalizerKey :: Fin.Cell ()
, unwrapper :: forall a. Typeable a => Maybe a
}
| PromiseClient
{ pState :: TVar PromiseState
, exportMap :: ExportMap
, origTarget :: TmpDest
}
| ImportClient (Fin.Cell ImportRef)
data Pipeline = Pipeline
{ state :: TVar PipelineState
, steps :: SnocList Word16
}
data PipelineState
= PendingRemotePipeline
{ answerId :: !QAId
, clientMap :: M.Map (SnocList Word16) Client
, conn :: Conn
}
| PendingLocalPipeline (SnocList (Fulfiller MPtr))
| ReadyPipeline (Either R.Exception MPtr)
walkPipelinePtr :: Pipeline -> Word16 -> Pipeline
walkPipelinePtr p@Pipeline{steps} step =
p { steps = SnocList.snoc steps step }
pipelineClient :: MonadSTM m => Pipeline -> m Client
pipelineClient Pipeline{state, steps} = liftSTM $ do
readTVar state >>= \case
PendingRemotePipeline{answerId, clientMap, conn} -> do
maybeClient <- M.lookup steps clientMap
case maybeClient of
Nothing -> do
client <- promisedAnswerClient
conn
PromisedAnswer { answerId, transform = steps }
M.insert client steps clientMap
pure client
Just client ->
pure client
PendingLocalPipeline subscribers -> do
(ret, retFulfiller) <- newPromiseClient
ptrFulfiller <- newCallback $ \r -> do
writeTVar state (ReadyPipeline r)
breakOrFulfill retFulfiller (r >>= followPtrs (toList steps) >>= ptrClient)
writeTVar state $ PendingLocalPipeline $ SnocList.snoc subscribers ptrFulfiller
pure ret
ReadyPipeline r ->
case r >>= followPtrs (toList steps) >>= ptrClient of
Right v -> pure v
Left e -> do
(p, f) <- newPromiseClient
breakPromise f e
pure p
promisedAnswerClient :: Conn -> PromisedAnswer -> STM Client
promisedAnswerClient conn answer@PromisedAnswer{answerId, transform} = do
let tmpDest = RemoteDest AnswerDest { conn, answer }
pState <- newTVar Pending { tmpDest }
exportMap <- ExportMap <$> M.new
let client = Client $ Just PromiseClient
{ pState
, exportMap
, origTarget = tmpDest
}
readTVar (liveState conn) >>= \case
Dead ->
resolveClientExn tmpDest (writeTVar pState) eDisconnected
Live conn'@Conn'{questions} ->
subscribeReturn "questions" conn' questions answerId $
resolveClientReturn tmpDest (writeTVar pState) conn' (toList transform)
pure client
data PromiseState
= Ready
{ target :: Client
}
| Embargo
{ callBuffer :: TQueue Server.CallInfo
}
| Pending
{ tmpDest :: TmpDest
}
| Error R.Exception
data TmpDest
= LocalDest LocalDest
| RemoteDest RemoteDest
newtype LocalDest
= LocalBuffer { callBuffer :: TQueue Server.CallInfo }
data RemoteDest
= AnswerDest
{ conn :: Conn
, answer :: PromisedAnswer
}
| ImportDest (Fin.Cell ImportRef)
data ImportRef = ImportRef
{ conn :: Conn
, importId :: !IEId
, proxies :: ExportMap
}
instance Eq ImportRef where
ImportRef { conn=cx, importId=ix } == ImportRef { conn=cy, importId=iy } =
cx == cy && ix == iy
instance Eq Client' where
LocalClient { qCall = x } == LocalClient { qCall = y } =
x == y
PromiseClient { pState = x } == PromiseClient { pState = y } =
x == y
ImportClient x == ImportClient y =
x == y
_ == _ =
False
newtype ExportMap = ExportMap (M.Map Conn IEId)
data MsgTarget
= ImportTgt !IEId
| AnswerTgt PromisedAnswer
data PromisedAnswer = PromisedAnswer
{ answerId :: !QAId
, transform :: SnocList Word16
}
call :: MonadSTM m => Server.CallInfo -> Client -> m Pipeline
call Server.CallInfo { response } (Client Nothing) = liftSTM $ do
breakPromise response eMethodUnimplemented
state <- newTVar $ ReadyPipeline (Left eMethodUnimplemented)
pure Pipeline{state, steps = mempty}
call info@Server.CallInfo { response } (Client (Just client')) = liftSTM $ do
(localPipeline, response') <- makeLocalPipeline response
let info' = info { Server.response = response' }
case client' of
LocalClient { qCall } -> do
Rc.get qCall >>= \case
Just q -> do
q info'
Nothing ->
breakPromise response' eDisconnected
pure localPipeline
PromiseClient { pState } -> readTVar pState >>= \case
Ready { target } ->
call info target
Embargo { callBuffer } -> do
writeTQueue callBuffer info'
pure localPipeline
Pending { tmpDest } -> case tmpDest of
LocalDest LocalBuffer { callBuffer } -> do
writeTQueue callBuffer info'
pure localPipeline
RemoteDest AnswerDest { conn, answer } ->
callRemote conn info $ AnswerTgt answer
RemoteDest (ImportDest cell) -> do
ImportRef { conn, importId } <- Fin.get cell
callRemote conn info (ImportTgt importId)
Error exn -> do
breakPromise response' exn
pure localPipeline
ImportClient cell -> do
ImportRef { conn, importId } <- Fin.get cell
callRemote conn info (ImportTgt importId)
makeLocalPipeline :: Fulfiller RawMPtr -> STM (Pipeline, Fulfiller RawMPtr)
makeLocalPipeline f = do
state <- newTVar $ PendingLocalPipeline mempty
f' <- newCallback $ \r -> do
s <- readTVar state
case s of
PendingLocalPipeline fs -> do
pr <- case r of
Left e -> pure (Left e)
Right v -> Right <$> evalLimitT defaultLimit (decerialize v)
writeTVar state (ReadyPipeline pr)
breakOrFulfill f r
traverse_ (`breakOrFulfill` pr) fs
_ ->
error "impossible"
pure (Pipeline{state, steps = mempty}, f')
callRemote :: Conn -> Server.CallInfo -> MsgTarget -> STM Pipeline
callRemote
conn
Server.CallInfo{ interfaceId, methodId, arguments, response }
target = do
conn'@Conn'{questions} <- getLive conn
qid <- newQuestion conn'
payload@R.Payload{capTable} <- makeOutgoingPayload conn arguments
sendPureMsg conn' $ R.Message'call def
{ R.questionId = qaWord qid
, R.target = marshalMsgTarget target
, R.params = payload
, R.interfaceId = interfaceId
, R.methodId = methodId
}
let paramCaps = catMaybes $ flip map (V.toList capTable) $ \R.CapDescriptor{union'} -> case union' of
R.CapDescriptor'senderHosted eid -> Just (IEId eid)
R.CapDescriptor'senderPromise eid -> Just (IEId eid)
_ -> Nothing
clientMap <- M.new
rp <- newTVar PendingRemotePipeline
{ answerId = qid
, clientMap
, conn
}
response' <- newCallback $ \r -> do
breakOrFulfill response r
case r of
Left e -> writeTVar rp $ ReadyPipeline (Left e)
Right v -> do
content <- evalLimitT defaultLimit (decerialize v)
writeTVar rp $ ReadyPipeline (Right content)
M.insert
NewQA
{ onReturn = SnocList.singleton $ cbCallReturn paramCaps conn response'
, onFinish = SnocList.empty
}
qid
questions
pure Pipeline { state = rp, steps = mempty }
cbCallReturn :: [IEId] -> Conn -> Fulfiller RawMPtr -> R.Return -> STM ()
cbCallReturn
paramCaps
conn
response
R.Return{ answerId, union', releaseParamCaps } = do
conn'@Conn'{answers} <- getLive conn
when releaseParamCaps $
traverse_ (releaseExport conn 1) paramCaps
case union' of
R.Return'exception exn ->
breakPromise response exn
R.Return'results R.Payload{ content } -> do
rawPtr <- createPure defaultLimit $ do
msg <- Message.newMessage Nothing
cerialize msg content
fulfill response rawPtr
R.Return'canceled ->
breakPromise response $ eFailed "Canceled"
R.Return'resultsSentElsewhere ->
abortConn conn' $ eFailed $ mconcat
[ "Received Return.resultsSentElswhere for a call "
, "with sendResultsTo = caller."
]
R.Return'takeFromOtherQuestion (QAId -> qid) ->
subscribeReturn "answer" conn' answers qid $
cbCallReturn [] conn response
R.Return'acceptFromThirdParty _ ->
abortConn conn' $ eUnimplemented
"This vat does not support level 3."
R.Return'unknown' ordinal ->
abortConn conn' $ eUnimplemented $
"Unknown return variant #" <> fromString (show ordinal)
queueSTM conn' $ finishQuestion conn' def
{ R.questionId = answerId
, R.releaseResultCaps = False
}
marshalMsgTarget :: MsgTarget -> R.MessageTarget
marshalMsgTarget = \case
ImportTgt importId ->
R.MessageTarget'importedCap (ieWord importId)
AnswerTgt tgt ->
R.MessageTarget'promisedAnswer $ marshalPromisedAnswer tgt
marshalPromisedAnswer :: PromisedAnswer -> R.PromisedAnswer
marshalPromisedAnswer PromisedAnswer{ answerId, transform } =
R.PromisedAnswer
{ R.questionId = qaWord answerId
, R.transform =
V.fromList $
map R.PromisedAnswer'Op'getPointerField $
toList transform
}
unmarshalPromisedAnswer :: R.PromisedAnswer -> Either R.Exception PromisedAnswer
unmarshalPromisedAnswer R.PromisedAnswer { questionId, transform } = do
idxes <- unmarshalOps (toList transform)
pure PromisedAnswer
{ answerId = QAId questionId
, transform = SnocList.fromList idxes
}
unmarshalOps :: [R.PromisedAnswer'Op] -> Either R.Exception [Word16]
unmarshalOps [] = Right []
unmarshalOps (R.PromisedAnswer'Op'noop:ops) =
unmarshalOps ops
unmarshalOps (R.PromisedAnswer'Op'getPointerField i:ops) =
(i:) <$> unmarshalOps ops
unmarshalOps (R.PromisedAnswer'Op'unknown' tag:_) =
Left $ eFailed $ "Unknown PromisedAnswer.Op: " <> fromString (show tag)
nullClient :: Client
nullClient = Client Nothing
newPromiseClient :: (MonadSTM m, IsClient c) => m (c, Fulfiller c)
newPromiseClient = liftSTM $ do
callBuffer <- newTQueue
let tmpDest = LocalDest LocalBuffer { callBuffer }
pState <- newTVar Pending { tmpDest }
exportMap <- ExportMap <$> M.new
f <- newCallback $ \case
Left e -> resolveClientExn tmpDest (writeTVar pState) e
Right v -> resolveClientClient tmpDest (writeTVar pState) (toClient v)
let p = Client $ Just $ PromiseClient
{ pState
, exportMap
, origTarget = tmpDest
}
pure (fromClient p, f)
unwrapServer :: (IsClient c, Typeable a) => c -> Maybe a
unwrapServer c = case toClient c of
Client (Just LocalClient { unwrapper }) -> unwrapper
_ -> Nothing
export :: MonadSTM m => Supervisor -> Server.ServerOps IO -> m Client
export sup ops = liftSTM $ do
q <- TCloseQ.new
qCall <- Rc.new (TCloseQ.write q) (TCloseQ.close q)
exportMap <- ExportMap <$> M.new
finalizerKey <- Fin.newCell ()
let client' = LocalClient
{ qCall
, exportMap
, finalizerKey
, unwrapper = Server.handleCast ops
}
superviseSTM sup $ do
Fin.addFinalizer finalizerKey $ Rc.release qCall
Server.runServer q ops
pure $ Client (Just client')
clientMethodHandler :: Word64 -> Word16 -> Client -> Server.MethodHandler IO p r
clientMethodHandler interfaceId methodId client =
Server.fromUntypedHandler $ Server.untypedHandler $
\arguments response -> atomically $ void $ call Server.CallInfo{..} client
callbacksLoop :: Conn' -> IO ()
callbacksLoop Conn'{pendingCallbacks} = forever $ do
cbs <- atomically $ flushTQueue pendingCallbacks >>= \case
[] -> retry
cbs -> pure cbs
sequence_ cbs
flushCallbacks :: Conn' -> IO ()
flushCallbacks Conn'{pendingCallbacks} =
atomically (flushTQueue pendingCallbacks) >>= sequence_
sendLoop :: Transport -> Conn' -> IO ()
sendLoop transport Conn'{sendQ} =
forever $ atomically (readTBQueue sendQ) >>= sendMsg transport
recvLoop :: Transport -> Conn' -> IO ()
recvLoop transport Conn'{recvQ} =
forever $ recvMsg transport >>= atomically . writeTBQueue recvQ
coordinator :: Conn -> IO ()
coordinator conn@Conn{debugMode} = forever $ atomically $ do
conn'@Conn'{recvQ} <- getLive conn
msg <- (readTBQueue recvQ >>= parseWithCaps conn)
`catchSTM`
(abortConn conn' . wrapException debugMode)
case msg of
R.Message'abort exn ->
handleAbortMsg conn exn
R.Message'unimplemented oldMsg ->
handleUnimplementedMsg conn oldMsg
R.Message'bootstrap bs ->
handleBootstrapMsg conn bs
R.Message'call call ->
handleCallMsg conn call
R.Message'return ret ->
handleReturnMsg conn ret
R.Message'finish finish ->
handleFinishMsg conn finish
R.Message'resolve res ->
handleResolveMsg conn res
R.Message'release release ->
handleReleaseMsg conn release
R.Message'disembargo disembargo ->
handleDisembargoMsg conn disembargo
_ ->
sendPureMsg conn' $ R.Message'unimplemented msg
parseWithCaps :: Conn -> ConstMsg -> STM R.Message
parseWithCaps conn msg = do
pureMsg <- msgToValue msg
case pureMsg of
R.Message'call R.Call{params=R.Payload{capTable}} ->
fixCapTable capTable conn msg >>= msgToValue
R.Message'return R.Return{union'=R.Return'results R.Payload{capTable}} ->
fixCapTable capTable conn msg >>= msgToValue
_ ->
pure pureMsg
handleAbortMsg :: Conn -> R.Exception -> STM ()
handleAbortMsg _ exn =
throwSTM (ReceivedAbort exn)
handleUnimplementedMsg :: Conn -> R.Message -> STM ()
handleUnimplementedMsg conn msg = getLive conn >>= \conn' -> case msg of
R.Message'unimplemented _ ->
pure ()
R.Message'abort _ ->
abortConn conn' $ eFailed $
"Your vat sent an 'unimplemented' message for an abort message " <>
"that its remote peer never sent. This is likely a bug in your " <>
"capnproto library."
_ ->
abortConn conn' $
eFailed "Received unimplemented response for required message."
handleBootstrapMsg :: Conn -> R.Bootstrap -> STM ()
handleBootstrapMsg conn R.Bootstrap{ questionId } = getLive conn >>= \conn' -> do
ret <- case bootstrap conn' of
Nothing ->
pure $ R.Return
{ R.answerId = questionId
, R.releaseParamCaps = True
, R.union' =
R.Return'exception $
eFailed "No bootstrap interface for this connection."
}
Just client -> do
capDesc <- emitCap conn client
pure $ R.Return
{ R.answerId = questionId
, R.releaseParamCaps = True
, R.union' =
R.Return'results R.Payload
{ content = Just (Untyped.PtrCap client)
, capTable = V.singleton (def :: R.CapDescriptor) { R.union' = capDesc }
}
}
M.focus
(Focus.alterM $ insertBootstrap conn' ret)
(QAId questionId)
(answers conn')
sendPureMsg conn' $ R.Message'return ret
where
insertBootstrap _ ret Nothing =
pure $ Just HaveReturn
{ returnMsg = ret
, onFinish = SnocList.fromList
[ \R.Finish{releaseResultCaps} ->
case ret of
R.Return
{ union' = R.Return'results R.Payload
{ capTable = (V.toList -> [ R.CapDescriptor { union' = R.CapDescriptor'receiverHosted (IEId -> eid) } ])
}
} ->
when releaseResultCaps $
releaseExport conn 1 eid
_ ->
pure ()
]
}
insertBootstrap conn' _ (Just _) =
abortConn conn' $ eFailed "Duplicate question ID"
handleCallMsg :: Conn -> R.Call -> STM ()
handleCallMsg
conn
R.Call
{ questionId
, target
, interfaceId
, methodId
, params=R.Payload{content, capTable}
}
= getLive conn >>= \conn'@Conn'{exports, answers} -> do
insertNewAbort
"answer"
conn'
(QAId questionId)
NewQA
{ onReturn = SnocList.empty
, onFinish = SnocList.fromList
[ \R.Finish{releaseResultCaps} ->
when releaseResultCaps $
for_ capTable $ \R.CapDescriptor{union'} -> case union' of
R.CapDescriptor'receiverHosted (IEId -> importId) ->
releaseExport conn 1 importId
_ ->
pure ()
]
}
answers
callParams <- createPure defaultLimit $ do
msg <- Message.newMessage Nothing
cerialize msg content
fulfiller <- newCallback $ \case
Left e ->
returnAnswer conn' def
{ R.answerId = questionId
, R.releaseParamCaps = False
, R.union' = R.Return'exception e
}
Right v -> do
content <- evalLimitT defaultLimit (decerialize v)
capTable <- genSendableCapTable conn content
returnAnswer conn' def
{ R.answerId = questionId
, R.releaseParamCaps = False
, R.union' = R.Return'results def
{ R.content = content
, R.capTable = capTable
}
}
let callInfo = Server.CallInfo
{ interfaceId
, methodId
, arguments = callParams
, response = fulfiller
}
case target of
R.MessageTarget'importedCap exportId ->
lookupAbort "export" conn' exports (IEId exportId) $
\EntryE{client} -> void $ call callInfo $ Client $ Just client
R.MessageTarget'promisedAnswer R.PromisedAnswer { questionId = targetQid, transform } ->
let onReturn ret@R.Return{union'} =
case union' of
R.Return'exception _ ->
returnAnswer conn' ret { R.answerId = questionId }
R.Return'canceled ->
returnAnswer conn' ret { R.answerId = questionId }
R.Return'results R.Payload{content} ->
void $ transformClient transform content conn' >>= call callInfo
R.Return'resultsSentElsewhere ->
abortConn conn' $ eFailed $
"Tried to call a method on a promised answer that " <>
"returned resultsSentElsewhere"
R.Return'takeFromOtherQuestion otherQid ->
subscribeReturn "answer" conn' answers (QAId otherQid) onReturn
R.Return'acceptFromThirdParty _ ->
error "BUG: our implementation unexpectedly used a level 3 feature"
R.Return'unknown' tag ->
error $
"BUG: our implemented unexpectedly returned unknown " ++
"result variant #" ++ show tag
in
subscribeReturn "answer" conn' answers (QAId targetQid) onReturn
R.MessageTarget'unknown' ordinal ->
abortConn conn' $ eUnimplemented $
"Unknown MessageTarget ordinal #" <> fromString (show ordinal)
transformClient :: V.Vector R.PromisedAnswer'Op -> MPtr -> Conn' -> STM Client
transformClient transform ptr conn =
case unmarshalOps (V.toList transform) >>= flip followPtrs ptr >>= ptrClient of
Left e ->
abortConn conn e
Right client ->
pure client
ptrClient :: MPtr -> Either R.Exception Client
ptrClient Nothing = Right nullClient
ptrClient (Just (Untyped.PtrCap client)) = Right client
ptrClient (Just _) = Left $ eFailed "Tried to call method on non-capability."
followPtrs :: [Word16] -> MPtr -> Either R.Exception MPtr
followPtrs [] ptr =
Right ptr
followPtrs (_:_) Nothing =
Right Nothing
followPtrs (i:is) (Just (Untyped.PtrStruct (Untyped.Struct _ ptrs))) =
followPtrs is (Untyped.sliceIndex (fromIntegral i) ptrs)
followPtrs (_:_) (Just _) =
Left (eFailed "Tried to access pointer field of non-struct.")
handleReturnMsg :: Conn -> R.Return -> STM ()
handleReturnMsg conn ret = getLive conn >>= \conn'@Conn'{questions} ->
updateQAReturn conn' questions "question" ret
handleFinishMsg :: Conn -> R.Finish -> STM ()
handleFinishMsg conn finish = getLive conn >>= \conn'@Conn'{answers} ->
updateQAFinish conn' answers "answer" finish
handleResolveMsg :: Conn -> R.Resolve -> STM ()
handleResolveMsg conn R.Resolve{promiseId, union'} =
getLive conn >>= \conn'@Conn'{imports} -> do
entry <- M.lookup (IEId promiseId) imports
case entry of
Nothing ->
case union' of
R.Resolve'cap R.CapDescriptor{union' = R.CapDescriptor'receiverHosted importId} ->
sendPureMsg conn' $ R.Message'release def
{ R.id = importId
, R.referenceCount = 1
}
_ -> pure ()
Just EntryI{ promiseState = Nothing } ->
abortConn conn' $ eFailed $ mconcat
[ "Received a resolve message for export id #", fromString (show promiseId)
, ", but that capability is not a promise!"
]
Just EntryI { promiseState = Just (tvar, tmpDest) } ->
case union' of
R.Resolve'cap R.CapDescriptor{union' = cap} -> do
client <- acceptCap conn cap
resolveClientClient tmpDest (writeTVar tvar) client
R.Resolve'exception exn ->
resolveClientExn tmpDest (writeTVar tvar) exn
R.Resolve'unknown' tag ->
abortConn conn' $ eUnimplemented $ mconcat
[ "Resolve variant #"
, fromString (show tag)
, " not understood"
]
handleReleaseMsg :: Conn -> R.Release -> STM ()
handleReleaseMsg
conn
R.Release
{ id=(IEId -> eid)
, referenceCount=refCountDiff
} =
releaseExport conn refCountDiff eid
releaseExport :: Conn -> Word32 -> IEId -> STM ()
releaseExport conn refCountDiff eid =
getLive conn >>= \conn'@Conn'{exports} ->
lookupAbort "export" conn' exports eid $
\EntryE{client, refCount=oldRefCount} ->
case compare oldRefCount refCountDiff of
LT ->
abortConn conn' $ eFailed $
"Received release for export with referenceCount " <>
"greater than our recorded total ref count."
EQ ->
dropConnExport conn client
GT ->
M.insert
EntryE
{ client
, refCount = oldRefCount - refCountDiff
}
eid
exports
handleDisembargoMsg :: Conn -> R.Disembargo -> STM ()
handleDisembargoMsg conn d = getLive conn >>= go d
where
go
R.Disembargo { context=R.Disembargo'context'receiverLoopback (EmbargoId -> eid) }
conn'@Conn'{embargos}
= do
result <- M.lookup eid embargos
case result of
Nothing ->
abortConn conn' $ eFailed $
"No such embargo: " <> fromString (show $ embargoWord eid)
Just fulfiller -> do
queueSTM conn' (fulfill fulfiller ())
M.delete eid embargos
freeEmbargo conn' eid
go
R.Disembargo{ target, context=R.Disembargo'context'senderLoopback embargoId }
conn'@Conn'{exports, answers}
= case target of
R.MessageTarget'importedCap exportId ->
lookupAbort "export" conn' exports (IEId exportId) $ \EntryE{ client } ->
disembargoPromise client
R.MessageTarget'promisedAnswer R.PromisedAnswer{ questionId, transform } ->
lookupAbort "answer" conn' answers (QAId questionId) $ \case
HaveReturn { returnMsg=R.Return{union'=R.Return'results R.Payload{content} } } ->
transformClient transform content conn' >>= \case
Client (Just client') -> disembargoClient client'
Client Nothing -> abortDisembargo "targets a null capability"
_ ->
abortDisembargo $
"does not target an answer which has resolved to a value hosted by"
<> " the sender."
R.MessageTarget'unknown' ordinal ->
abortConn conn' $ eUnimplemented $
"Unknown MessageTarget ordinal #" <> fromString (show ordinal)
where
disembargoPromise PromiseClient{ pState } = readTVar pState >>= \case
Ready (Client (Just client)) ->
disembargoClient client
Ready (Client Nothing) ->
abortDisembargo "targets a promise which resolved to null."
_ ->
abortDisembargo "targets a promise which has not resolved."
disembargoPromise _ =
abortDisembargo "targets something that is not a promise."
disembargoClient (ImportClient cell) = do
client <- Fin.get cell
case client of
ImportRef {conn=targetConn, importId}
| conn == targetConn ->
sendPureMsg conn' $ R.Message'disembargo R.Disembargo
{ context = R.Disembargo'context'receiverLoopback embargoId
, target = R.MessageTarget'importedCap (ieWord importId)
}
_ ->
abortDisembargoClient
disembargoClient _ = abortDisembargoClient
abortDisembargoClient =
abortDisembargo $
"targets a promise which has not resolved to a capability"
<> " hosted by the sender."
abortDisembargo info =
abortConn conn' $ eFailed $ mconcat
[ "Disembargo #"
, fromString (show embargoId)
, " with context = senderLoopback "
, info
]
go d conn' =
sendPureMsg conn' $ R.Message'unimplemented $ R.Message'disembargo d
fixCapTable :: V.Vector R.CapDescriptor -> Conn -> ConstMsg -> STM ConstMsg
fixCapTable capDescs conn msg = do
clients <- traverse (\R.CapDescriptor{union'} -> acceptCap conn union') capDescs
pure $ Message.withCapTable clients msg
lookupAbort
:: (Eq k, Hashable k, Show k)
=> Text -> Conn' -> M.Map k v -> k -> (v -> STM a) -> STM a
lookupAbort keyTypeName conn m key f = do
result <- M.lookup key m
case result of
Just val ->
f val
Nothing ->
abortConn conn $ eFailed $ mconcat
[ "No such "
, keyTypeName
, ": "
, fromString (show key)
]
insertNewAbort :: (Eq k, Hashable k) => Text -> Conn' -> k -> v -> M.Map k v -> STM ()
insertNewAbort keyTypeName conn key value =
M.focus
(Focus.alterM $ \case
Just _ ->
abortConn conn $ eFailed $
"duplicate entry in " <> keyTypeName <> " table."
Nothing ->
pure (Just value)
)
key
genSendableCapTable :: Conn -> MPtr -> STM (V.Vector R.CapDescriptor)
genSendableCapTable conn ptr = do
rawPtr <- createPure defaultLimit $ do
msg <- Message.newMessage Nothing
cerialize msg ptr
genSendableCapTableRaw conn rawPtr
genSendableCapTableRaw
:: Conn
-> Maybe (UntypedRaw.Ptr ConstMsg)
-> STM (V.Vector R.CapDescriptor)
genSendableCapTableRaw _ Nothing = pure V.empty
genSendableCapTableRaw conn (Just ptr) =
traverse
(\c -> do
union' <- emitCap conn c
pure (def :: R.CapDescriptor) { R.union' = union' }
)
(Message.getCapTable (UntypedRaw.message ptr))
makeOutgoingPayload :: Conn -> RawMPtr -> STM R.Payload
makeOutgoingPayload conn rawContent = do
capTable <- genSendableCapTableRaw conn rawContent
content <- evalLimitT defaultLimit (decerialize rawContent)
pure R.Payload { content, capTable }
sendPureMsg :: Conn' -> R.Message -> STM ()
sendPureMsg Conn'{sendQ} msg =
createPure maxBound (valueToMsg msg) >>= writeTBQueue sendQ
finishQuestion :: Conn' -> R.Finish -> STM ()
finishQuestion conn@Conn'{questions} finish@R.Finish{questionId} = do
subscribeReturn "question" conn questions (QAId questionId) $ \_ ->
freeQuestion conn (QAId questionId)
sendPureMsg conn $ R.Message'finish finish
updateQAFinish conn questions "question" finish
returnAnswer :: Conn' -> R.Return -> STM ()
returnAnswer conn@Conn'{answers} ret = do
sendPureMsg conn $ R.Message'return ret
updateQAReturn conn answers "answer" ret
updateQAReturn :: Conn' -> M.Map QAId EntryQA -> Text -> R.Return -> STM ()
updateQAReturn conn table tableName ret@R.Return{answerId} =
lookupAbort tableName conn table (QAId answerId) $ \case
NewQA{onFinish, onReturn} -> do
mapQueueSTM conn onReturn ret
M.insert
HaveReturn
{ returnMsg = ret
, onFinish
}
(QAId answerId)
table
HaveFinish{onReturn} -> do
mapQueueSTM conn onReturn ret
M.delete (QAId answerId) table
HaveReturn{} ->
abortConn conn $ eFailed $
"Duplicate return message for " <> tableName <> " #"
<> fromString (show answerId)
updateQAFinish :: Conn' -> M.Map QAId EntryQA -> Text -> R.Finish -> STM ()
updateQAFinish conn table tableName finish@R.Finish{questionId} =
lookupAbort tableName conn table (QAId questionId) $ \case
NewQA{onFinish, onReturn} -> do
mapQueueSTM conn onFinish finish
M.insert
HaveFinish
{ finishMsg = finish
, onReturn
}
(QAId questionId)
table
HaveReturn{onFinish} -> do
mapQueueSTM conn onFinish finish
M.delete (QAId questionId) table
HaveFinish{} ->
abortConn conn $ eFailed $
"Duplicate finish message for " <> tableName <> " #"
<> fromString (show questionId)
subscribeReturn :: Text -> Conn' -> M.Map QAId EntryQA -> QAId -> (R.Return -> STM ()) -> STM ()
subscribeReturn tableName conn table qaId onRet =
lookupAbort tableName conn table qaId $ \qa -> do
new <- go qa
M.insert new qaId table
where
go = \case
NewQA{onFinish, onReturn} ->
pure NewQA
{ onFinish
, onReturn = SnocList.snoc onReturn onRet
}
HaveFinish{finishMsg, onReturn} ->
pure HaveFinish
{ finishMsg
, onReturn = SnocList.snoc onReturn onRet
}
val@HaveReturn{returnMsg} -> do
queueSTM conn (onRet returnMsg)
pure val
abortConn :: Conn' -> R.Exception -> STM a
abortConn _ e = throwSTM (SentAbort e)
getLive :: Conn -> STM Conn'
getLive Conn{liveState} = readTVar liveState >>= \case
Live conn' -> pure conn'
Dead -> throwSTM eDisconnected
whenLive :: Conn -> (Conn' -> STM ()) -> STM ()
whenLive Conn{liveState} f = readTVar liveState >>= \case
Live conn' -> f conn'
Dead -> pure ()
requestBootstrap :: Conn -> STM Client
requestBootstrap conn@Conn{liveState} = readTVar liveState >>= \case
Dead ->
pure nullClient
Live conn'@Conn'{questions} -> do
qid <- newQuestion conn'
let tmpDest = RemoteDest AnswerDest
{ conn
, answer = PromisedAnswer
{ answerId = qid
, transform = SnocList.empty
}
}
pState <- newTVar Pending { tmpDest }
sendPureMsg conn' $
R.Message'bootstrap def { R.questionId = qaWord qid }
M.insert
NewQA
{ onReturn = SnocList.singleton $
resolveClientReturn tmpDest (writeTVar pState) conn' []
, onFinish = SnocList.empty
}
qid
questions
exportMap <- ExportMap <$> M.new
pure $ Client $ Just PromiseClient
{ pState
, exportMap
, origTarget = tmpDest
}
resolveClientExn :: TmpDest -> (PromiseState -> STM ()) -> R.Exception -> STM ()
resolveClientExn tmpDest resolve exn = do
case tmpDest of
LocalDest LocalBuffer { callBuffer } -> do
calls <- flushTQueue callBuffer
traverse_
(\Server.CallInfo{response} ->
breakPromise response exn)
calls
RemoteDest AnswerDest {} ->
pure ()
RemoteDest (ImportDest _) ->
pure ()
resolve $ Error exn
resolveClientPtr :: TmpDest -> (PromiseState -> STM ()) -> MPtr -> STM ()
resolveClientPtr tmpDest resolve ptr = case ptr of
Nothing ->
resolveClientClient tmpDest resolve nullClient
Just (Untyped.PtrCap c) ->
resolveClientClient tmpDest resolve c
Just _ ->
resolveClientExn tmpDest resolve $
eFailed "Promise resolved to non-capability pointer"
resolveClientClient :: TmpDest -> (PromiseState -> STM ()) -> Client -> STM ()
resolveClientClient tmpDest resolve (Client client) =
case (client, tmpDest) of
( Just LocalClient{}, RemoteDest dest ) ->
disembargoAndResolve dest
( Just PromiseClient { origTarget=LocalDest _ }, RemoteDest dest) ->
disembargoAndResolve dest
( Nothing, RemoteDest dest ) ->
disembargoAndResolve dest
( _, LocalDest LocalBuffer { callBuffer } ) ->
flushAndResolve callBuffer
( Just PromiseClient { origTarget=RemoteDest newDest }, RemoteDest oldDest ) -> do
newConn <- destConn newDest
oldConn <- destConn oldDest
if newConn == oldConn
then releaseAndResolve
else disembargoAndResolve oldDest
( Just (ImportClient cell), RemoteDest oldDest ) -> do
ImportRef { conn=newConn } <- Fin.get cell
oldConn <- destConn oldDest
if newConn == oldConn
then releaseAndResolve
else disembargoAndResolve oldDest
where
destConn AnswerDest { conn } = pure conn
destConn (ImportDest cell) = do
ImportRef { conn } <- Fin.get cell
pure conn
destTarget AnswerDest { answer } = pure $ AnswerTgt answer
destTarget (ImportDest cell) = do
ImportRef { importId } <- Fin.get cell
pure $ ImportTgt importId
releaseAndResolve = do
releaseTmpDest tmpDest
resolve $ Ready (Client client)
flushAndResolve callBuffer = do
flushTQueue callBuffer >>= traverse_ (`call` Client client)
resolve $ Ready (Client client)
flushAndRaise callBuffer e =
flushTQueue callBuffer >>=
traverse_ (\Server.CallInfo{response} -> breakPromise response e)
disembargoAndResolve dest = do
Conn{liveState} <- destConn dest
readTVar liveState >>= \case
Live conn' -> do
callBuffer <- newTQueue
target <- destTarget dest
disembargo conn' target $ \case
Right () ->
flushAndResolve callBuffer
Left e ->
flushAndRaise callBuffer e
resolve $ Embargo { callBuffer }
Dead ->
resolveClientExn tmpDest resolve eDisconnected
disembargo :: Conn' -> MsgTarget -> (Either R.Exception () -> STM ()) -> STM ()
disembargo conn@Conn'{embargos} tgt onEcho = do
callback <- newCallback onEcho
eid <- newEmbargo conn
M.insert callback eid embargos
sendPureMsg conn $ R.Message'disembargo R.Disembargo
{ target = marshalMsgTarget tgt
, context = R.Disembargo'context'senderLoopback (embargoWord eid)
}
releaseTmpDest :: TmpDest -> STM ()
releaseTmpDest (LocalDest LocalBuffer{}) = pure ()
releaseTmpDest (RemoteDest AnswerDest { conn, answer=PromisedAnswer{ answerId } }) =
whenLive conn $ \conn' ->
finishQuestion conn' def
{ R.questionId = qaWord answerId
, R.releaseResultCaps = False
}
releaseTmpDest (RemoteDest (ImportDest _)) = pure ()
resolveClientReturn :: TmpDest -> (PromiseState -> STM ()) -> Conn' -> [Word16] -> R.Return -> STM ()
resolveClientReturn tmpDest resolve conn@Conn'{answers} transform R.Return { union' } = case union' of
R.Return'exception exn ->
resolveClientExn tmpDest resolve exn
R.Return'results R.Payload{ content } ->
case followPtrs transform content of
Right v ->
resolveClientPtr tmpDest resolve v
Left e ->
resolveClientExn tmpDest resolve e
R.Return'canceled ->
resolveClientExn tmpDest resolve $ eFailed "Canceled"
R.Return'resultsSentElsewhere ->
abortConn conn $ eFailed $ mconcat
[ "Received Return.resultsSentElsewhere for a call "
, "with sendResultsTo = caller."
]
R.Return'takeFromOtherQuestion (QAId -> qid) ->
subscribeReturn "answer" conn answers qid $
resolveClientReturn tmpDest resolve conn transform
R.Return'acceptFromThirdParty _ ->
abortConn conn $ eUnimplemented
"This vat does not support level 3."
R.Return'unknown' ordinal ->
abortConn conn $ eUnimplemented $
"Unknown return variant #" <> fromString (show ordinal)
getConnExport :: Conn -> Client' -> STM IEId
getConnExport conn client = getLive conn >>= \conn'@Conn'{exports} -> do
ExportMap m <- clientExportMap client
val <- M.lookup conn m
case val of
Just eid -> do
addBumpExport eid client exports
pure eid
Nothing -> do
eid <- newExport conn'
addBumpExport eid client exports
M.insert eid conn m
pure eid
dropConnExport :: Conn -> Client' -> STM ()
dropConnExport conn client' = do
ExportMap eMap <- clientExportMap client'
val <- M.lookup conn eMap
case val of
Just eid -> do
M.delete conn eMap
whenLive conn $ \conn'@Conn'{exports} -> do
M.delete eid exports
freeExport conn' eid
Nothing ->
error "BUG: tried to drop an export that doesn't exist."
clientExportMap :: Client' -> STM ExportMap
clientExportMap LocalClient{exportMap} = pure exportMap
clientExportMap PromiseClient{exportMap} = pure exportMap
clientExportMap (ImportClient cell) = do
ImportRef{proxies} <- Fin.get cell
pure proxies
addBumpExport :: IEId -> Client' -> M.Map IEId EntryE -> STM ()
addBumpExport exportId client =
M.focus (Focus.alter go) exportId
where
go Nothing = Just EntryE { client, refCount = 1 }
go (Just EntryE{ client = oldClient, refCount } )
| client /= oldClient =
error $
"BUG: addExportRef called with a client that is different " ++
"from what is already in our exports table."
| otherwise =
Just EntryE { client, refCount = refCount + 1 }
emitCap :: Conn -> Client -> STM R.CapDescriptor'
emitCap _targetConn (Client Nothing) =
pure R.CapDescriptor'none
emitCap targetConn (Client (Just client')) = case client' of
LocalClient{} ->
R.CapDescriptor'senderHosted . ieWord <$> getConnExport targetConn client'
PromiseClient{ pState } -> readTVar pState >>= \case
Pending { tmpDest = RemoteDest AnswerDest { conn, answer } }
| conn == targetConn ->
pure $ R.CapDescriptor'receiverAnswer (marshalPromisedAnswer answer)
Pending { tmpDest = RemoteDest (ImportDest cell) } -> do
ImportRef { conn, importId = IEId iid } <- Fin.get cell
if conn == targetConn
then pure (R.CapDescriptor'receiverHosted iid)
else newSenderPromise
_ ->
newSenderPromise
ImportClient cell -> do
ImportRef { conn=hostConn, importId } <- Fin.get cell
if hostConn == targetConn
then pure (R.CapDescriptor'receiverHosted (ieWord importId))
else R.CapDescriptor'senderHosted . ieWord <$> getConnExport targetConn client'
where
newSenderPromise = R.CapDescriptor'senderPromise . ieWord <$> getConnExport targetConn client'
acceptCap :: Conn -> R.CapDescriptor' -> STM Client
acceptCap conn cap = getLive conn >>= \conn' -> go conn' cap
where
go _ R.CapDescriptor'none = pure (Client Nothing)
go conn'@Conn'{imports} (R.CapDescriptor'senderHosted (IEId -> importId)) = do
entry <- M.lookup importId imports
case entry of
Just EntryI{ promiseState=Just _ } ->
let imp = fromString (show importId)
in abortConn conn' $ eFailed $
"received senderHosted capability #" <> imp <>
", but the imports table says #" <> imp <> " is senderPromise."
Just ent@EntryI{ localRc, remoteRc, proxies } -> do
Rc.incr localRc
M.insert ent { localRc, remoteRc = remoteRc + 1 } importId imports
cell <- Fin.newCell ImportRef
{ conn
, importId
, proxies
}
queueIO conn' $ Fin.addFinalizer cell $ Rc.decr localRc
pure $ Client $ Just $ ImportClient cell
Nothing ->
Client . Just . ImportClient <$> newImport importId conn Nothing
go conn'@Conn'{imports} (R.CapDescriptor'senderPromise (IEId -> importId)) = do
entry <- M.lookup importId imports
case entry of
Just EntryI { promiseState=Nothing } ->
let imp = fromString (show importId)
in abortConn conn' $ eFailed $
"received senderPromise capability #" <> imp <>
", but the imports table says #" <> imp <> " is senderHosted."
Just ent@EntryI { remoteRc, proxies, promiseState=Just (pState, origTarget) } -> do
M.insert ent { remoteRc = remoteRc + 1 } importId imports
pure $ Client $ Just PromiseClient
{ pState
, exportMap = proxies
, origTarget
}
Nothing -> do
rec imp <- newImport importId conn (Just (pState, tmpDest))
ImportRef{proxies} <- Fin.get imp
let tmpDest = RemoteDest (ImportDest imp)
pState <- newTVar Pending { tmpDest }
pure $ Client $ Just PromiseClient
{ pState
, exportMap = proxies
, origTarget = tmpDest
}
go conn'@Conn'{exports} (R.CapDescriptor'receiverHosted exportId) =
lookupAbort "export" conn' exports (IEId exportId) $
\EntryE{client} ->
pure $ Client $ Just client
go conn' (R.CapDescriptor'receiverAnswer pa) =
case unmarshalPromisedAnswer pa of
Left e ->
abortConn conn' e
Right pa ->
newLocalAnswerClient conn' pa
go conn' (R.CapDescriptor'thirdPartyHosted _) =
abortConn conn' $ eUnimplemented
"thirdPartyHosted unimplemented; level 3 is not supported."
go conn' (R.CapDescriptor'unknown' tag) =
abortConn conn' $ eUnimplemented $
"Unimplemented CapDescriptor variant #" <> fromString (show tag)
newImport :: IEId -> Conn -> Maybe (TVar PromiseState, TmpDest) -> STM (Fin.Cell ImportRef)
newImport importId conn promiseState = getLive conn >>= \conn'@Conn'{imports} -> do
localRc <- Rc.new () $ releaseImport importId conn'
proxies <- ExportMap <$> M.new
let importRef = ImportRef
{ conn
, importId
, proxies
}
M.insert EntryI
{ localRc
, remoteRc = 1
, proxies
, promiseState
}
importId
imports
cell <- Fin.newCell importRef
queueIO conn' $ Fin.addFinalizer cell $ Rc.decr localRc
pure cell
releaseImport :: IEId -> Conn' -> STM ()
releaseImport importId conn'@Conn'{imports} = do
lookupAbort "imports" conn' imports importId $ \EntryI { remoteRc } ->
sendPureMsg conn' $ R.Message'release
R.Release
{ id = ieWord importId
, referenceCount = remoteRc
}
M.delete importId imports
newLocalAnswerClient :: Conn' -> PromisedAnswer -> STM Client
newLocalAnswerClient conn@Conn'{answers} PromisedAnswer{ answerId, transform } = do
callBuffer <- newTQueue
let tmpDest = LocalDest $ LocalBuffer { callBuffer }
pState <- newTVar Pending { tmpDest }
subscribeReturn "answer" conn answers answerId $
resolveClientReturn
tmpDest
(writeTVar pState)
conn
(toList transform)
exportMap <- ExportMap <$> M.new
pure $ Client $ Just PromiseClient
{ pState
, exportMap
, origTarget = tmpDest
}