{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Capnp.Rpc.Untyped
(
ConnConfig(..)
, handleConn
, Client
, call
, nullClient
, IsClient(..)
, export
, clientMethodHandler
, RpcError(..)
, R.Exception(..)
, R.Exception'Type(..)
) where
import Control.Concurrent.STM
import Data.Word
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async (concurrently_, race_)
import Control.Concurrent.MVar (MVar, newEmptyMVar)
import Control.Exception.Safe (Exception, bracket, throwIO, try)
import Control.Monad (forever, void, when)
import Data.Default (Default(def))
import Data.Foldable (for_, toList, traverse_)
import Data.Hashable (Hashable, hash, hashWithSalt)
import Data.Maybe (catMaybes)
import Data.String (fromString)
import Data.Text (Text)
import GHC.Generics (Generic)
import Supervisors (Supervisor, superviseSTM, withSupervisor)
import System.Mem.StableName (StableName, hashStableName, makeStableName)
import System.Timeout (timeout)
import qualified Data.Vector as V
import qualified Focus
import qualified ListT
import qualified StmContainers.Map as M
import Capnp.Classes (cerialize, decerialize)
import Capnp.Convert (msgToValue, valueToMsg)
import Capnp.Message (ConstMsg)
import Capnp.Rpc.Errors
( eDisconnected
, eFailed
, eMethodUnimplemented
, eUnimplemented
, wrapException
)
import Capnp.Rpc.Promise
(Fulfiller, breakPromiseSTM, fulfillSTM, newCallbackSTM)
import Capnp.Rpc.Transport (Transport(recvMsg, sendMsg))
import Capnp.TraversalLimit (defaultLimit, evalLimitT)
import Internal.BuildPure (createPure)
import Internal.Rc (Rc)
import Internal.SnocList (SnocList)
import qualified Capnp.Gen.Capnp.Rpc.Pure as R
import qualified Capnp.Message as Message
import qualified Capnp.Rpc.Server as Server
import qualified Capnp.Untyped as UntypedRaw
import qualified Capnp.Untyped.Pure as Untyped
import qualified Internal.Finalizer as Fin
import qualified Internal.Rc as Rc
import qualified Internal.SnocList as SnocList
import qualified Internal.TCloseQ as TCloseQ
type MPtr = Maybe Untyped.Ptr
type RawMPtr = Maybe (UntypedRaw.Ptr ConstMsg)
data RpcError
= ReceivedAbort R.Exception
| SentAbort R.Exception
deriving(Show, Eq, Generic)
instance Exception RpcError
newtype EmbargoId = EmbargoId { embargoWord :: Word32 } deriving(Eq, Hashable)
newtype QAId = QAId { qaWord :: Word32 } deriving(Eq, Hashable)
newtype IEId = IEId { ieWord :: Word32 } deriving(Eq, Hashable)
instance Show QAId where
show = show . qaWord
instance Show IEId where
show = show . ieWord
data Conn = Conn
{ stableName :: StableName (MVar ())
, debugMode :: !Bool
, liveState :: TVar LiveState
}
data LiveState
= Live Conn'
| Dead
data Conn' = Conn'
{ sendQ :: TBQueue ConstMsg
, recvQ :: TBQueue ConstMsg
, supervisor :: Supervisor
, questionIdPool :: IdPool
, exportIdPool :: IdPool
, questions :: M.Map QAId EntryQA
, answers :: M.Map QAId EntryQA
, exports :: M.Map IEId EntryE
, imports :: M.Map IEId EntryI
, embargos :: M.Map EmbargoId (Fulfiller ())
, pendingCallbacks :: TQueue (IO ())
, bootstrap :: Maybe Client
}
instance Eq Conn where
x == y = stableName x == stableName y
instance Hashable Conn where
hash Conn{stableName} = hashStableName stableName
hashWithSalt _ = hash
data ConnConfig = ConnConfig
{ maxQuestions :: !Word32
, maxExports :: !Word32
, debugMode :: !Bool
, getBootstrap :: Supervisor -> STM (Maybe Client)
, withBootstrap :: Maybe (Supervisor -> Client -> IO ())
}
instance Default ConnConfig where
def = ConnConfig
{ maxQuestions = 128
, maxExports = 8192
, debugMode = False
, getBootstrap = \_ -> pure Nothing
, withBootstrap = Nothing
}
queueIO :: Conn' -> IO () -> STM ()
queueIO Conn'{pendingCallbacks} = writeTQueue pendingCallbacks
queueSTM :: Conn' -> STM () -> STM ()
queueSTM conn = queueIO conn . atomically
mapQueueSTM :: Conn' -> SnocList (a -> STM ()) -> a -> STM ()
mapQueueSTM conn fs x = traverse_ (\f -> queueSTM conn (f x)) fs
newQuestion :: Conn' -> STM QAId
newQuestion = fmap QAId . newId . questionIdPool
freeQuestion :: Conn' -> QAId -> STM ()
freeQuestion conn = freeId (questionIdPool conn) . qaWord
newExport :: Conn' -> STM IEId
newExport = fmap IEId . newId . exportIdPool
freeExport :: Conn' -> IEId -> STM ()
freeExport conn = freeId (exportIdPool conn) . ieWord
newEmbargo :: Conn' -> STM EmbargoId
newEmbargo = fmap EmbargoId . newId . questionIdPool
freeEmbargo :: Conn' -> EmbargoId -> STM ()
freeEmbargo conn = freeId (exportIdPool conn) . embargoWord
handleConn :: Transport -> ConnConfig -> IO ()
handleConn
transport
cfg@ConnConfig
{ maxQuestions
, maxExports
, withBootstrap
, debugMode
}
= withSupervisor $ \sup ->
bracket
(newConn sup)
stopConn
runConn
where
newConn sup = do
stableName <- makeStableName =<< newEmptyMVar
atomically $ do
bootstrap <- getBootstrap cfg sup
questionIdPool <- newIdPool maxQuestions
exportIdPool <- newIdPool maxExports
sendQ <- newTBQueue $ fromIntegral maxQuestions
recvQ <- newTBQueue $ fromIntegral maxQuestions
questions <- M.new
answers <- M.new
exports <- M.new
imports <- M.new
embargos <- M.new
pendingCallbacks <- newTQueue
let conn' = Conn'
{ supervisor = sup
, questionIdPool
, exportIdPool
, recvQ
, sendQ
, questions
, answers
, exports
, imports
, embargos
, pendingCallbacks
, bootstrap
}
liveState <- newTVar (Live conn')
let conn = Conn
{ stableName
, debugMode
, liveState
}
pure (conn, conn')
runConn (conn, conn') = do
result <- try $
( coordinator conn
`concurrently_` sendLoop transport conn'
`concurrently_` recvLoop transport conn'
`concurrently_` callbacksLoop conn'
) `race_`
useBootstrap conn conn'
case result of
Left (SentAbort e) -> do
rawMsg <- createPure maxBound $ valueToMsg $ R.Message'abort e
void $ timeout 1000000 $ sendMsg transport rawMsg
throwIO $ SentAbort e
Left e ->
throwIO e
Right _ ->
pure ()
stopConn
( conn@Conn{liveState}
, conn'@Conn'{questions, exports, embargos}
) = do
atomically $ do
let walk table = flip ListT.traverse_ (M.listT table)
case bootstrap conn' of
Just (Client (Just client')) -> dropConnExport conn client'
_ -> pure ()
walk exports $ \(_, EntryE{client}) ->
dropConnExport conn client
walk questions $ \(QAId qid, entry) ->
let raiseDisconnected onReturn =
mapQueueSTM conn' onReturn $ R.Return
{ answerId = qid
, releaseParamCaps = False
, union' = R.Return'exception eDisconnected
}
in case entry of
NewQA{onReturn} -> raiseDisconnected onReturn
HaveFinish{onReturn} -> raiseDisconnected onReturn
_ -> pure ()
walk embargos $ \(_, fulfiller) ->
breakPromiseSTM fulfiller eDisconnected
writeTVar liveState Dead
flushCallbacks conn'
useBootstrap conn conn' = case withBootstrap of
Nothing ->
forever $ threadDelay maxBound
Just f ->
atomically (requestBootstrap conn) >>= f (supervisor conn')
newtype IdPool = IdPool (TVar [Word32])
newIdPool :: Word32 -> STM IdPool
newIdPool size = IdPool <$> newTVar [0..size-1]
newId :: IdPool -> STM Word32
newId (IdPool pool) = readTVar pool >>= \case
[] -> retry
(id:ids) -> do
writeTVar pool $! ids
pure id
freeId :: IdPool -> Word32 -> STM ()
freeId (IdPool pool) id = modifyTVar' pool (id:)
data EntryQA
= NewQA
{ onFinish :: SnocList (R.Finish -> STM ())
, onReturn :: SnocList (R.Return -> STM ())
}
| HaveReturn
{ returnMsg :: R.Return
, onFinish :: SnocList (R.Finish -> STM ())
}
| HaveFinish
{ finishMsg :: R.Finish
, onReturn :: SnocList (R.Return -> STM ())
}
data EntryI = EntryI
{ localRc :: Rc ()
, remoteRc :: !Word32
, proxies :: ExportMap
, promiseState :: Maybe
( TVar PromiseState
, TmpDest
)
}
data EntryE = EntryE
{ client :: Client'
, refCount :: !Word32
}
class IsClient a where
toClient :: a -> Client
fromClient :: Client -> a
instance Show Client where
show (Client Nothing) = "nullClient"
show _ = "({- capability; not statically representable -})"
newtype Client =
Client (Maybe Client')
deriving(Eq)
data Client'
= LocalClient
{ exportMap :: ExportMap
, qCall :: Rc (Server.CallInfo -> STM ())
, finalizerKey :: Fin.Cell ()
}
| PromiseClient
{ pState :: TVar PromiseState
, exportMap :: ExportMap
, origTarget :: TmpDest
}
| ImportClient (Fin.Cell ImportRef)
data PromiseState
= Ready
{ target :: Client
}
| Embargo
{ callBuffer :: TQueue Server.CallInfo
}
| Pending
{ tmpDest :: TmpDest
}
| Error R.Exception
data TmpDest
= LocalDest LocalDest
| RemoteDest RemoteDest
newtype LocalDest
= LocalBuffer { callBuffer :: TQueue Server.CallInfo }
data RemoteDest
= AnswerDest
{ conn :: Conn
, answer :: PromisedAnswer
}
| ImportDest (Fin.Cell ImportRef)
data ImportRef = ImportRef
{ conn :: Conn
, importId :: !IEId
, proxies :: ExportMap
}
instance Eq ImportRef where
ImportRef { conn=cx, importId=ix } == ImportRef { conn=cy, importId=iy } =
cx == cy && ix == iy
instance Eq Client' where
LocalClient { qCall = x } == LocalClient { qCall = y } =
x == y
PromiseClient { pState = x } == PromiseClient { pState = y } =
x == y
ImportClient x == ImportClient y =
x == y
_ == _ =
False
newtype ExportMap = ExportMap (M.Map Conn IEId)
data MsgTarget
= ImportTgt !IEId
| AnswerTgt PromisedAnswer
data PromisedAnswer = PromisedAnswer
{ answerId :: !QAId
, transform :: SnocList Word16
}
call :: Server.CallInfo -> Client -> STM ()
call Server.CallInfo { response } (Client Nothing) =
breakPromiseSTM response eMethodUnimplemented
call info@Server.CallInfo { response } (Client (Just client')) = case client' of
LocalClient { qCall } -> Rc.get qCall >>= \case
Just q ->
q info
Nothing ->
breakPromiseSTM response eDisconnected
PromiseClient { pState } -> readTVar pState >>= \case
Ready { target } ->
call info target
Embargo { callBuffer } ->
writeTQueue callBuffer info
Pending { tmpDest } -> case tmpDest of
LocalDest LocalBuffer { callBuffer } ->
writeTQueue callBuffer info
RemoteDest AnswerDest { conn, answer } ->
callRemote conn info $ AnswerTgt answer
RemoteDest (ImportDest (Fin.get -> ImportRef { conn, importId })) ->
callRemote conn info (ImportTgt importId)
Error exn ->
breakPromiseSTM response exn
ImportClient (Fin.get -> ImportRef { conn, importId }) ->
callRemote conn info (ImportTgt importId)
callRemote :: Conn -> Server.CallInfo -> MsgTarget -> STM ()
callRemote
conn
Server.CallInfo{ interfaceId, methodId, arguments, response }
target = do
conn'@Conn'{questions} <- getLive conn
qid <- newQuestion conn'
payload@R.Payload{capTable} <- makeOutgoingPayload conn arguments
sendPureMsg conn' $ R.Message'call def
{ R.questionId = qaWord qid
, R.target = marshalMsgTarget target
, R.params = payload
, R.interfaceId = interfaceId
, R.methodId = methodId
}
let paramCaps = catMaybes $ flip map (V.toList capTable) $ \case
R.CapDescriptor'senderHosted eid -> Just (IEId eid)
R.CapDescriptor'senderPromise eid -> Just (IEId eid)
_ -> Nothing
M.insert
NewQA
{ onReturn = SnocList.singleton $ cbCallReturn paramCaps conn response
, onFinish = SnocList.empty
}
qid
questions
cbCallReturn :: [IEId] -> Conn -> Fulfiller RawMPtr -> R.Return -> STM ()
cbCallReturn
paramCaps
conn
response
R.Return{ answerId, union', releaseParamCaps } = do
conn'@Conn'{answers} <- getLive conn
when releaseParamCaps $
traverse_ (releaseExport conn 1) paramCaps
case union' of
R.Return'exception exn ->
breakPromiseSTM response exn
R.Return'results R.Payload{ content } -> do
rawPtr <- createPure defaultLimit $ do
msg <- Message.newMessage Nothing
cerialize msg content
fulfillSTM response rawPtr
R.Return'canceled ->
breakPromiseSTM response $ eFailed "Canceled"
R.Return'resultsSentElsewhere ->
abortConn conn' $ eFailed $ mconcat
[ "Received Return.resultsSentElswhere for a call "
, "with sendResultsTo = caller."
]
R.Return'takeFromOtherQuestion (QAId -> qid) ->
subscribeReturn "answer" conn' answers qid $
cbCallReturn [] conn response
R.Return'acceptFromThirdParty _ ->
abortConn conn' $ eUnimplemented
"This vat does not support level 3."
R.Return'unknown' ordinal ->
abortConn conn' $ eUnimplemented $
"Unknown return variant #" <> fromString (show ordinal)
finishQuestion conn' def
{ R.questionId = answerId
, R.releaseResultCaps = False
}
marshalMsgTarget :: MsgTarget -> R.MessageTarget
marshalMsgTarget = \case
ImportTgt importId ->
R.MessageTarget'importedCap (ieWord importId)
AnswerTgt tgt ->
R.MessageTarget'promisedAnswer $ marshalPromisedAnswer tgt
marshalPromisedAnswer :: PromisedAnswer -> R.PromisedAnswer
marshalPromisedAnswer PromisedAnswer{ answerId, transform } =
R.PromisedAnswer
{ R.questionId = qaWord answerId
, R.transform =
V.fromList $
map R.PromisedAnswer'Op'getPointerField $
toList transform
}
unmarshalPromisedAnswer :: R.PromisedAnswer -> Either R.Exception PromisedAnswer
unmarshalPromisedAnswer R.PromisedAnswer { questionId, transform } = do
idxes <- unmarshalOps (toList transform)
pure PromisedAnswer
{ answerId = QAId questionId
, transform = SnocList.fromList idxes
}
unmarshalOps :: [R.PromisedAnswer'Op] -> Either R.Exception [Word16]
unmarshalOps [] = Right []
unmarshalOps (R.PromisedAnswer'Op'noop:ops) =
unmarshalOps ops
unmarshalOps (R.PromisedAnswer'Op'getPointerField i:ops) =
(i:) <$> unmarshalOps ops
unmarshalOps (R.PromisedAnswer'Op'unknown' tag:_) =
Left $ eFailed $ "Unknown PromisedAnswer.Op: " <> fromString (show tag)
nullClient :: Client
nullClient = Client Nothing
export :: Supervisor -> Server.ServerOps IO -> STM Client
export sup ops = do
q <- TCloseQ.new
qCall <- Rc.new (TCloseQ.write q) (TCloseQ.close q)
exportMap <- ExportMap <$> M.new
finalizerKey <- Fin.newCell ()
let client' = LocalClient
{ qCall
, exportMap
, finalizerKey
}
superviseSTM sup $ do
Fin.addFinalizer finalizerKey $ atomically $ Rc.release qCall
Server.runServer q ops
pure $ Client (Just client')
clientMethodHandler :: Word64 -> Word16 -> Client -> Server.MethodHandler IO p r
clientMethodHandler interfaceId methodId client =
Server.fromUntypedHandler $ Server.untypedHandler $
\arguments response -> atomically $ call Server.CallInfo{..} client
callbacksLoop :: Conn' -> IO ()
callbacksLoop Conn'{pendingCallbacks} = forever $ do
cbs <- atomically $ flushTQueue pendingCallbacks >>= \case
[] -> retry
cbs -> pure cbs
sequence_ cbs
flushCallbacks :: Conn' -> IO ()
flushCallbacks Conn'{pendingCallbacks} =
atomically (flushTQueue pendingCallbacks) >>= sequence_
sendLoop :: Transport -> Conn' -> IO ()
sendLoop transport Conn'{sendQ} =
forever $ atomically (readTBQueue sendQ) >>= sendMsg transport
recvLoop :: Transport -> Conn' -> IO ()
recvLoop transport Conn'{recvQ} =
forever $ recvMsg transport >>= atomically . writeTBQueue recvQ
coordinator :: Conn -> IO ()
coordinator conn@Conn{debugMode} = forever $ atomically $ do
conn'@Conn'{recvQ} <- getLive conn
msg <- (readTBQueue recvQ >>= parseWithCaps conn)
`catchSTM`
(abortConn conn' . wrapException debugMode)
case msg of
R.Message'abort exn ->
handleAbortMsg conn exn
R.Message'unimplemented oldMsg ->
handleUnimplementedMsg conn oldMsg
R.Message'bootstrap bs ->
handleBootstrapMsg conn bs
R.Message'call call ->
handleCallMsg conn call
R.Message'return ret ->
handleReturnMsg conn ret
R.Message'finish finish ->
handleFinishMsg conn finish
R.Message'resolve res ->
handleResolveMsg conn res
R.Message'release release ->
handleReleaseMsg conn release
R.Message'disembargo disembargo ->
handleDisembargoMsg conn disembargo
_ ->
sendPureMsg conn' $ R.Message'unimplemented msg
parseWithCaps :: Conn -> ConstMsg -> STM R.Message
parseWithCaps conn msg = do
pureMsg <- msgToValue msg
case pureMsg of
R.Message'call R.Call{params=R.Payload{capTable}} ->
fixCapTable capTable conn msg >>= msgToValue
R.Message'return R.Return{union'=R.Return'results R.Payload{capTable}} ->
fixCapTable capTable conn msg >>= msgToValue
_ ->
pure pureMsg
handleAbortMsg :: Conn -> R.Exception -> STM ()
handleAbortMsg _ exn =
throwSTM (ReceivedAbort exn)
handleUnimplementedMsg :: Conn -> R.Message -> STM ()
handleUnimplementedMsg conn msg = getLive conn >>= \conn' -> case msg of
R.Message'unimplemented _ ->
pure ()
R.Message'abort _ ->
abortConn conn' $ eFailed $
"Your vat sent an 'unimplemented' message for an abort message " <>
"that its remote peer never sent. This is likely a bug in your " <>
"capnproto library."
_ ->
abortConn conn' $
eFailed "Received unimplemented response for required message."
handleBootstrapMsg :: Conn -> R.Bootstrap -> STM ()
handleBootstrapMsg conn R.Bootstrap{ questionId } = getLive conn >>= \conn' -> do
ret <- case bootstrap conn' of
Nothing ->
pure $ R.Return
{ R.answerId = questionId
, R.releaseParamCaps = True
, R.union' =
R.Return'exception $
eFailed "No bootstrap interface for this connection."
}
Just client -> do
capDesc <- emitCap conn client
pure $ R.Return
{ R.answerId = questionId
, R.releaseParamCaps = True
, R.union' =
R.Return'results R.Payload
{ content = Just (Untyped.PtrCap client)
, capTable = V.singleton capDesc
}
}
M.focus
(Focus.alterM $ insertBootstrap conn' ret)
(QAId questionId)
(answers conn')
sendPureMsg conn' $ R.Message'return ret
where
insertBootstrap _ ret Nothing =
pure $ Just HaveReturn
{ returnMsg = ret
, onFinish = SnocList.fromList
[ \R.Finish{releaseResultCaps} ->
case ret of
R.Return
{ union' = R.Return'results R.Payload
{ capTable = (V.toList -> [ R.CapDescriptor'receiverHosted (IEId -> eid)])
}
} ->
when releaseResultCaps $
releaseExport conn 1 eid
_ ->
pure ()
]
}
insertBootstrap conn' _ (Just _) =
abortConn conn' $ eFailed "Duplicate question ID"
handleCallMsg :: Conn -> R.Call -> STM ()
handleCallMsg
conn
R.Call
{ questionId
, target
, interfaceId
, methodId
, params=R.Payload{content, capTable}
}
= getLive conn >>= \conn'@Conn'{exports, answers} -> do
insertNewAbort
"answer"
conn'
(QAId questionId)
NewQA
{ onReturn = SnocList.empty
, onFinish = SnocList.fromList
[ \R.Finish{releaseResultCaps} ->
when releaseResultCaps $
for_ capTable $ \case
R.CapDescriptor'receiverHosted (IEId -> importId) ->
releaseExport conn 1 importId
_ ->
pure ()
]
}
answers
callParams <- createPure defaultLimit $ do
msg <- Message.newMessage Nothing
cerialize msg content
fulfiller <- newCallbackSTM $ \case
Left e ->
returnAnswer conn' def
{ R.answerId = questionId
, R.releaseParamCaps = False
, R.union' = R.Return'exception e
}
Right v -> do
content <- evalLimitT defaultLimit (decerialize v)
capTable <- genSendableCapTable conn content
returnAnswer conn' def
{ R.answerId = questionId
, R.releaseParamCaps = False
, R.union' = R.Return'results def
{ R.content = content
, R.capTable = capTable
}
}
let callInfo = Server.CallInfo
{ interfaceId
, methodId
, arguments = callParams
, response = fulfiller
}
case target of
R.MessageTarget'importedCap exportId ->
lookupAbort "export" conn' exports (IEId exportId) $
\EntryE{client} -> call callInfo $ Client $ Just client
R.MessageTarget'promisedAnswer R.PromisedAnswer { questionId = targetQid, transform } ->
let onReturn ret@R.Return{union'} =
case union' of
R.Return'exception _ ->
returnAnswer conn' ret { R.answerId = questionId }
R.Return'canceled ->
returnAnswer conn' ret { R.answerId = questionId }
R.Return'results R.Payload{content} ->
transformClient transform content conn' >>= call callInfo
R.Return'resultsSentElsewhere ->
abortConn conn' $ eFailed $
"Tried to call a method on a promised answer that " <>
"returned resultsSentElsewhere"
R.Return'takeFromOtherQuestion otherQid ->
subscribeReturn "answer" conn' answers (QAId otherQid) onReturn
R.Return'acceptFromThirdParty _ ->
error "BUG: our implementation unexpectedly used a level 3 feature"
R.Return'unknown' tag ->
error $
"BUG: our implemented unexpectedly returned unknown " ++
"result variant #" ++ show tag
in
subscribeReturn "answer" conn' answers (QAId targetQid) onReturn
R.MessageTarget'unknown' ordinal ->
abortConn conn' $ eUnimplemented $
"Unknown MessageTarget ordinal #" <> fromString (show ordinal)
transformClient :: V.Vector R.PromisedAnswer'Op -> MPtr -> Conn' -> STM Client
transformClient transform ptr conn =
case unmarshalOps (V.toList transform) >>= flip followPtrs ptr of
Left e ->
abortConn conn e
Right Nothing ->
pure nullClient
Right (Just (Untyped.PtrCap client)) ->
pure client
Right (Just _) ->
abortConn conn $ eFailed "Tried to call method on non-capability."
followPtrs :: [Word16] -> MPtr -> Either R.Exception MPtr
followPtrs [] ptr =
Right ptr
followPtrs (_:_) Nothing =
Right Nothing
followPtrs (i:is) (Just (Untyped.PtrStruct (Untyped.Struct _ ptrs))) =
followPtrs is (Untyped.sliceIndex (fromIntegral i) ptrs)
followPtrs (_:_) (Just _) =
Left (eFailed "Tried to access pointer field of non-struct.")
handleReturnMsg :: Conn -> R.Return -> STM ()
handleReturnMsg conn ret = getLive conn >>= \conn'@Conn'{questions} ->
updateQAReturn conn' questions "question" ret
handleFinishMsg :: Conn -> R.Finish -> STM ()
handleFinishMsg conn finish = getLive conn >>= \conn'@Conn'{answers} ->
updateQAFinish conn' answers "answer" finish
handleResolveMsg :: Conn -> R.Resolve -> STM ()
handleResolveMsg conn R.Resolve{promiseId, union'} =
getLive conn >>= \conn'@Conn'{imports} -> do
entry <- M.lookup (IEId promiseId) imports
case entry of
Nothing ->
case union' of
R.Resolve'cap (R.CapDescriptor'receiverHosted importId) ->
sendPureMsg conn' $ R.Message'release def
{ R.id = importId
, R.referenceCount = 1
}
_ -> pure ()
Just EntryI{ promiseState = Nothing } ->
abortConn conn' $ eFailed $ mconcat
[ "Received a resolve message for export id #", fromString (show promiseId)
, ", but that capability is not a promise!"
]
Just EntryI { promiseState = Just (tvar, tmpDest) } ->
case union' of
R.Resolve'cap cap -> do
client <- acceptCap conn cap
resolveClientClient tmpDest (writeTVar tvar) client
R.Resolve'exception exn ->
resolveClientExn tmpDest (writeTVar tvar) exn
R.Resolve'unknown' tag ->
abortConn conn' $ eUnimplemented $ mconcat
[ "Resolve variant #"
, fromString (show tag)
, " not understood"
]
handleReleaseMsg :: Conn -> R.Release -> STM ()
handleReleaseMsg
conn
R.Release
{ id=(IEId -> eid)
, referenceCount=refCountDiff
} =
releaseExport conn refCountDiff eid
releaseExport :: Conn -> Word32 -> IEId -> STM ()
releaseExport conn refCountDiff eid =
getLive conn >>= \conn'@Conn'{exports} ->
lookupAbort "export" conn' exports eid $
\EntryE{client, refCount=oldRefCount} ->
case compare oldRefCount refCountDiff of
LT ->
abortConn conn' $ eFailed $
"Received release for export with referenceCount " <>
"greater than our recorded total ref count."
EQ ->
dropConnExport conn client
GT ->
M.insert
EntryE
{ client
, refCount = oldRefCount - refCountDiff
}
eid
exports
handleDisembargoMsg :: Conn -> R.Disembargo -> STM ()
handleDisembargoMsg conn d = getLive conn >>= go d
where
go
R.Disembargo { context=R.Disembargo'context'receiverLoopback (EmbargoId -> eid) }
conn'@Conn'{embargos}
= do
result <- M.lookup eid embargos
case result of
Nothing ->
abortConn conn' $ eFailed $
"No such embargo: " <> fromString (show $ embargoWord eid)
Just fulfiller -> do
queueSTM conn' (fulfillSTM fulfiller ())
M.delete eid embargos
freeEmbargo conn' eid
go
R.Disembargo{ target, context=R.Disembargo'context'senderLoopback embargoId }
conn'@Conn'{exports, answers}
= case target of
R.MessageTarget'importedCap exportId ->
lookupAbort "export" conn' exports (IEId exportId) $ \EntryE{ client } ->
disembargoPromise client
R.MessageTarget'promisedAnswer R.PromisedAnswer{ questionId, transform } ->
lookupAbort "answer" conn' answers (QAId questionId) $ \case
HaveReturn { returnMsg=R.Return{union'=R.Return'results R.Payload{content} } } ->
transformClient transform content conn' >>= \case
Client (Just client') -> disembargoClient client'
Client Nothing -> abortDisembargo "targets a null capability"
_ ->
abortDisembargo $
"does not target an answer which has resolved to a value hosted by"
<> " the sender."
R.MessageTarget'unknown' ordinal ->
abortConn conn' $ eUnimplemented $
"Unknown MessageTarget ordinal #" <> fromString (show ordinal)
where
disembargoPromise PromiseClient{ pState } = readTVar pState >>= \case
Ready (Client (Just client)) ->
disembargoClient client
Ready (Client Nothing) ->
abortDisembargo "targets a promise which resolved to null."
_ ->
abortDisembargo "targets a promise which has not resolved."
disembargoPromise _ =
abortDisembargo "targets something that is not a promise."
disembargoClient (ImportClient (Fin.get -> ImportRef {conn=targetConn, importId}))
| conn == targetConn =
sendPureMsg conn' $ R.Message'disembargo R.Disembargo
{ context = R.Disembargo'context'receiverLoopback embargoId
, target = R.MessageTarget'importedCap (ieWord importId)
}
disembargoClient _ =
abortDisembargo $
"targets a promise which has not resolved to a capability"
<> " hosted by the sender."
abortDisembargo info =
abortConn conn' $ eFailed $ mconcat
[ "Disembargo #"
, fromString (show embargoId)
, " with context = senderLoopback "
, info
]
go d conn' =
sendPureMsg conn' $ R.Message'unimplemented $ R.Message'disembargo d
fixCapTable :: V.Vector R.CapDescriptor -> Conn -> ConstMsg -> STM ConstMsg
fixCapTable capDescs conn msg = do
clients <- traverse (acceptCap conn) capDescs
pure $ Message.withCapTable clients msg
lookupAbort
:: (Eq k, Hashable k, Show k)
=> Text -> Conn' -> M.Map k v -> k -> (v -> STM a) -> STM a
lookupAbort keyTypeName conn m key f = do
result <- M.lookup key m
case result of
Just val ->
f val
Nothing ->
abortConn conn $ eFailed $ mconcat
[ "No such "
, keyTypeName
, ": "
, fromString (show key)
]
insertNewAbort :: (Eq k, Hashable k) => Text -> Conn' -> k -> v -> M.Map k v -> STM ()
insertNewAbort keyTypeName conn key value =
M.focus
(Focus.alterM $ \case
Just _ ->
abortConn conn $ eFailed $
"duplicate entry in " <> keyTypeName <> " table."
Nothing ->
pure (Just value)
)
key
genSendableCapTable :: Conn -> MPtr -> STM (V.Vector R.CapDescriptor)
genSendableCapTable conn ptr = do
rawPtr <- createPure defaultLimit $ do
msg <- Message.newMessage Nothing
cerialize msg ptr
genSendableCapTableRaw conn rawPtr
genSendableCapTableRaw
:: Conn
-> Maybe (UntypedRaw.Ptr ConstMsg)
-> STM (V.Vector R.CapDescriptor)
genSendableCapTableRaw _ Nothing = pure V.empty
genSendableCapTableRaw conn (Just ptr) =
traverse
(emitCap conn)
(Message.getCapTable (UntypedRaw.message ptr))
makeOutgoingPayload :: Conn -> RawMPtr -> STM R.Payload
makeOutgoingPayload conn rawContent = do
capTable <- genSendableCapTableRaw conn rawContent
content <- evalLimitT defaultLimit (decerialize rawContent)
pure R.Payload { content, capTable }
sendPureMsg :: Conn' -> R.Message -> STM ()
sendPureMsg Conn'{sendQ} msg =
createPure maxBound (valueToMsg msg) >>= writeTBQueue sendQ
finishQuestion :: Conn' -> R.Finish -> STM ()
finishQuestion conn@Conn'{questions} finish@R.Finish{questionId} = do
subscribeReturn "question" conn questions (QAId questionId) $ \_ ->
freeQuestion conn (QAId questionId)
sendPureMsg conn $ R.Message'finish finish
updateQAFinish conn questions "question" finish
returnAnswer :: Conn' -> R.Return -> STM ()
returnAnswer conn@Conn'{answers} ret = do
sendPureMsg conn $ R.Message'return ret
updateQAReturn conn answers "answer" ret
updateQAReturn :: Conn' -> M.Map QAId EntryQA -> Text -> R.Return -> STM ()
updateQAReturn conn table tableName ret@R.Return{answerId} =
lookupAbort tableName conn table (QAId answerId) $ \case
NewQA{onFinish, onReturn} -> do
mapQueueSTM conn onReturn ret
M.insert
HaveReturn
{ returnMsg = ret
, onFinish
}
(QAId answerId)
table
HaveFinish{onReturn} -> do
mapQueueSTM conn onReturn ret
M.delete (QAId answerId) table
HaveReturn{} ->
abortConn conn $ eFailed $
"Duplicate return message for " <> tableName <> " #"
<> fromString (show answerId)
updateQAFinish :: Conn' -> M.Map QAId EntryQA -> Text -> R.Finish -> STM ()
updateQAFinish conn table tableName finish@R.Finish{questionId} =
lookupAbort tableName conn table (QAId questionId) $ \case
NewQA{onFinish, onReturn} -> do
mapQueueSTM conn onFinish finish
M.insert
HaveFinish
{ finishMsg = finish
, onReturn
}
(QAId questionId)
table
HaveReturn{onFinish} -> do
mapQueueSTM conn onFinish finish
M.delete (QAId questionId) table
HaveFinish{} ->
abortConn conn $ eFailed $
"Duplicate finish message for " <> tableName <> " #"
<> fromString (show questionId)
subscribeReturn :: Text -> Conn' -> M.Map QAId EntryQA -> QAId -> (R.Return -> STM ()) -> STM ()
subscribeReturn tableName conn table qaId onRet =
lookupAbort tableName conn table qaId $ \qa -> do
new <- go qa
M.insert new qaId table
where
go = \case
NewQA{onFinish, onReturn} ->
pure NewQA
{ onFinish
, onReturn = SnocList.snoc onReturn onRet
}
HaveFinish{finishMsg, onReturn} ->
pure HaveFinish
{ finishMsg
, onReturn = SnocList.snoc onReturn onRet
}
val@HaveReturn{returnMsg} -> do
queueSTM conn (onRet returnMsg)
pure val
abortConn :: Conn' -> R.Exception -> STM a
abortConn _ e = throwSTM (SentAbort e)
getLive :: Conn -> STM Conn'
getLive Conn{liveState} = readTVar liveState >>= \case
Live conn' -> pure conn'
Dead -> throwSTM eDisconnected
whenLive :: Conn -> (Conn' -> STM ()) -> STM ()
whenLive Conn{liveState} f = readTVar liveState >>= \case
Live conn' -> f conn'
Dead -> pure ()
requestBootstrap :: Conn -> STM Client
requestBootstrap conn@Conn{liveState} = readTVar liveState >>= \case
Dead ->
pure nullClient
Live conn'@Conn'{questions} -> do
qid <- newQuestion conn'
let tmpDest = RemoteDest AnswerDest
{ conn
, answer = PromisedAnswer
{ answerId = qid
, transform = SnocList.empty
}
}
pState <- newTVar Pending { tmpDest }
sendPureMsg conn' $
R.Message'bootstrap def { R.questionId = qaWord qid }
M.insert
NewQA
{ onReturn = SnocList.singleton $
resolveClientReturn tmpDest (writeTVar pState) conn' []
, onFinish = SnocList.empty
}
qid
questions
exportMap <- ExportMap <$> M.new
pure $ Client $ Just PromiseClient
{ pState
, exportMap
, origTarget = tmpDest
}
resolveClientExn :: TmpDest -> (PromiseState -> STM ()) -> R.Exception -> STM ()
resolveClientExn tmpDest resolve exn = do
case tmpDest of
LocalDest LocalBuffer { callBuffer } -> do
calls <- flushTQueue callBuffer
traverse_
(\Server.CallInfo{response} ->
breakPromiseSTM response exn)
calls
RemoteDest AnswerDest {} ->
pure ()
RemoteDest (ImportDest _) ->
pure ()
resolve $ Error exn
resolveClientPtr :: TmpDest -> (PromiseState -> STM ()) -> MPtr -> STM ()
resolveClientPtr tmpDest resolve ptr = case ptr of
Nothing ->
resolveClientClient tmpDest resolve nullClient
Just (Untyped.PtrCap c) ->
resolveClientClient tmpDest resolve c
Just _ ->
resolveClientExn tmpDest resolve $
eFailed "Promise resolved to non-capability pointer"
resolveClientClient :: TmpDest -> (PromiseState -> STM ()) -> Client -> STM ()
resolveClientClient tmpDest resolve (Client client) =
case (client, tmpDest) of
( Just LocalClient{}, RemoteDest dest ) ->
disembargoAndResolve dest
( Just PromiseClient { origTarget=LocalDest _ }, RemoteDest dest) ->
disembargoAndResolve dest
( Nothing, RemoteDest dest ) ->
disembargoAndResolve dest
( Just PromiseClient { origTarget=RemoteDest newDest }, RemoteDest oldDest )
| destConn newDest /= destConn oldDest ->
disembargoAndResolve oldDest
| otherwise ->
releaseAndResolve
( Just (ImportClient (Fin.get -> ImportRef { conn=newConn })), RemoteDest oldDest )
| newConn /= destConn oldDest ->
disembargoAndResolve oldDest
| otherwise ->
releaseAndResolve
( _, LocalDest LocalBuffer { callBuffer } ) ->
flushAndResolve callBuffer
where
destConn AnswerDest { conn } = conn
destConn (ImportDest (Fin.get -> ImportRef { conn })) = conn
destTarget AnswerDest { answer } = AnswerTgt answer
destTarget (ImportDest (Fin.get -> ImportRef { importId })) = ImportTgt importId
releaseAndResolve = do
releaseTmpDest tmpDest
resolve $ Ready (Client client)
flushAndResolve callBuffer = do
flushTQueue callBuffer >>= traverse_ (`call` Client client)
resolve $ Ready (Client client)
flushAndRaise callBuffer e =
flushTQueue callBuffer >>=
traverse_ (\Server.CallInfo{response} -> breakPromiseSTM response e)
disembargoAndResolve dest@(destConn -> Conn{liveState}) =
readTVar liveState >>= \case
Live conn' -> do
callBuffer <- newTQueue
disembargo conn' (destTarget dest) $ \case
Right () ->
flushAndResolve callBuffer
Left e ->
flushAndRaise callBuffer e
resolve $ Embargo { callBuffer }
Dead ->
resolveClientExn tmpDest resolve eDisconnected
disembargo :: Conn' -> MsgTarget -> (Either R.Exception () -> STM ()) -> STM ()
disembargo conn@Conn'{embargos} tgt onEcho = do
callback <- newCallbackSTM onEcho
eid <- newEmbargo conn
M.insert callback eid embargos
sendPureMsg conn $ R.Message'disembargo R.Disembargo
{ target = marshalMsgTarget tgt
, context = R.Disembargo'context'senderLoopback (embargoWord eid)
}
releaseTmpDest :: TmpDest -> STM ()
releaseTmpDest (LocalDest LocalBuffer{}) = pure ()
releaseTmpDest (RemoteDest AnswerDest { conn, answer=PromisedAnswer{ answerId } }) =
whenLive conn $ \conn' ->
finishQuestion conn' def
{ R.questionId = qaWord answerId
, R.releaseResultCaps = False
}
releaseTmpDest (RemoteDest (ImportDest _)) = pure ()
resolveClientReturn :: TmpDest -> (PromiseState -> STM ()) -> Conn' -> [Word16] -> R.Return -> STM ()
resolveClientReturn tmpDest resolve conn@Conn'{answers} transform R.Return { union' } = case union' of
R.Return'exception exn ->
resolveClientExn tmpDest resolve exn
R.Return'results R.Payload{ content } ->
case followPtrs transform content of
Right v ->
resolveClientPtr tmpDest resolve v
Left e ->
resolveClientExn tmpDest resolve e
R.Return'canceled ->
resolveClientExn tmpDest resolve $ eFailed "Canceled"
R.Return'resultsSentElsewhere ->
abortConn conn $ eFailed $ mconcat
[ "Received Return.resultsSentElsewhere for a call "
, "with sendResultsTo = caller."
]
R.Return'takeFromOtherQuestion (QAId -> qid) ->
subscribeReturn "answer" conn answers qid $
resolveClientReturn tmpDest resolve conn transform
R.Return'acceptFromThirdParty _ ->
abortConn conn $ eUnimplemented
"This vat does not support level 3."
R.Return'unknown' ordinal ->
abortConn conn $ eUnimplemented $
"Unknown return variant #" <> fromString (show ordinal)
getConnExport :: Conn -> Client' -> STM IEId
getConnExport conn client = getLive conn >>= \conn'@Conn'{exports} -> do
let ExportMap m = clientExportMap client
val <- M.lookup conn m
case val of
Just eid -> do
addBumpExport eid client exports
pure eid
Nothing -> do
eid <- newExport conn'
addBumpExport eid client exports
M.insert eid conn m
pure eid
dropConnExport :: Conn -> Client' -> STM ()
dropConnExport conn client' = do
let ExportMap eMap = clientExportMap client'
val <- M.lookup conn eMap
case val of
Just eid -> do
M.delete conn eMap
whenLive conn $ \conn'@Conn'{exports} -> do
M.delete eid exports
freeExport conn' eid
Nothing ->
error "BUG: tried to drop an export that doesn't exist."
clientExportMap :: Client' -> ExportMap
clientExportMap LocalClient{exportMap} = exportMap
clientExportMap PromiseClient{exportMap} = exportMap
clientExportMap (ImportClient (Fin.get -> ImportRef{proxies})) = proxies
addBumpExport :: IEId -> Client' -> M.Map IEId EntryE -> STM ()
addBumpExport exportId client =
M.focus (Focus.alter go) exportId
where
go Nothing = Just EntryE { client, refCount = 1 }
go (Just EntryE{ client = oldClient, refCount } )
| client /= oldClient =
error $
"BUG: addExportRef called with a client that is different " ++
"from what is already in our exports table."
| otherwise =
Just EntryE { client, refCount = refCount + 1 }
emitCap :: Conn -> Client -> STM R.CapDescriptor
emitCap _targetConn (Client Nothing) =
pure R.CapDescriptor'none
emitCap targetConn (Client (Just client')) = case client' of
LocalClient{} ->
R.CapDescriptor'senderHosted . ieWord <$> getConnExport targetConn client'
PromiseClient{ pState } -> readTVar pState >>= \case
Pending { tmpDest = RemoteDest AnswerDest { conn, answer } }
| conn == targetConn ->
pure $ R.CapDescriptor'receiverAnswer (marshalPromisedAnswer answer)
Pending { tmpDest = RemoteDest (ImportDest (Fin.get -> ImportRef { conn, importId = IEId iid })) }
| conn == targetConn ->
pure $ R.CapDescriptor'receiverHosted iid
_ ->
R.CapDescriptor'senderPromise . ieWord <$> getConnExport targetConn client'
ImportClient (Fin.get -> ImportRef { conn=hostConn, importId })
| hostConn == targetConn ->
pure (R.CapDescriptor'receiverHosted (ieWord importId))
| otherwise ->
R.CapDescriptor'senderHosted . ieWord <$> getConnExport targetConn client'
acceptCap :: Conn -> R.CapDescriptor -> STM Client
acceptCap conn cap = getLive conn >>= \conn' -> go conn' cap
where
go _ R.CapDescriptor'none = pure (Client Nothing)
go conn'@Conn'{imports} (R.CapDescriptor'senderHosted (IEId -> importId)) = do
entry <- M.lookup importId imports
case entry of
Just EntryI{ promiseState=Just _ } ->
let imp = fromString (show importId)
in abortConn conn' $ eFailed $
"received senderHosted capability #" <> imp <>
", but the imports table says #" <> imp <> " is senderPromise."
Just ent@EntryI{ localRc, remoteRc, proxies } -> do
Rc.incr localRc
M.insert ent { localRc, remoteRc = remoteRc + 1 } importId imports
cell <- Fin.newCell ImportRef
{ conn
, importId
, proxies
}
queueIO conn' $ Fin.addFinalizer cell $ atomically (Rc.decr localRc)
pure $ Client $ Just $ ImportClient cell
Nothing ->
Client . Just . ImportClient <$> newImport importId conn Nothing
go conn'@Conn'{imports} (R.CapDescriptor'senderPromise (IEId -> importId)) = do
entry <- M.lookup importId imports
case entry of
Just EntryI { promiseState=Nothing } ->
let imp = fromString (show importId)
in abortConn conn' $ eFailed $
"received senderPromise capability #" <> imp <>
", but the imports table says #" <> imp <> " is senderHosted."
Just ent@EntryI { remoteRc, proxies, promiseState=Just (pState, origTarget) } -> do
M.insert ent { remoteRc = remoteRc + 1 } importId imports
pure $ Client $ Just PromiseClient
{ pState
, exportMap = proxies
, origTarget
}
Nothing -> do
rec imp@(Fin.get -> ImportRef{proxies}) <- newImport importId conn (Just (pState, tmpDest))
let tmpDest = RemoteDest (ImportDest imp)
pState <- newTVar Pending { tmpDest }
pure $ Client $ Just PromiseClient
{ pState
, exportMap = proxies
, origTarget = tmpDest
}
go conn'@Conn'{exports} (R.CapDescriptor'receiverHosted exportId) =
lookupAbort "export" conn' exports (IEId exportId) $
\EntryE{client} ->
pure $ Client $ Just client
go conn' (R.CapDescriptor'receiverAnswer pa) =
case unmarshalPromisedAnswer pa of
Left e ->
abortConn conn' e
Right pa ->
newLocalAnswerClient conn' pa
go conn' (R.CapDescriptor'thirdPartyHosted _) =
abortConn conn' $ eUnimplemented
"thirdPartyHosted unimplemented; level 3 is not supported."
go conn' (R.CapDescriptor'unknown' tag) =
abortConn conn' $ eUnimplemented $
"Unimplemented CapDescriptor variant #" <> fromString (show tag)
newImport :: IEId -> Conn -> Maybe (TVar PromiseState, TmpDest) -> STM (Fin.Cell ImportRef)
newImport importId conn promiseState = getLive conn >>= \conn'@Conn'{imports} -> do
localRc <- Rc.new () $ releaseImport importId conn'
proxies <- ExportMap <$> M.new
let importRef = ImportRef
{ conn
, importId
, proxies
}
M.insert EntryI
{ localRc
, remoteRc = 1
, proxies
, promiseState
}
importId
imports
cell <- Fin.newCell importRef
queueIO conn' $ Fin.addFinalizer cell $ atomically (Rc.decr localRc)
pure cell
releaseImport :: IEId -> Conn' -> STM ()
releaseImport importId conn'@Conn'{imports} = do
lookupAbort "imports" conn' imports importId $ \EntryI { remoteRc } ->
sendPureMsg conn' $ R.Message'release
R.Release
{ id = ieWord importId
, referenceCount = remoteRc
}
M.delete importId imports
newLocalAnswerClient :: Conn' -> PromisedAnswer -> STM Client
newLocalAnswerClient conn@Conn'{answers} PromisedAnswer{ answerId, transform } = do
callBuffer <- newTQueue
let tmpDest = LocalDest $ LocalBuffer { callBuffer }
pState <- newTVar Pending { tmpDest }
subscribeReturn "answer" conn answers answerId $
resolveClientReturn
tmpDest
(writeTVar pState)
conn
(toList transform)
exportMap <- ExportMap <$> M.new
pure $ Client $ Just PromiseClient
{ pState
, exportMap
, origTarget = tmpDest
}