{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE NamedFieldPuns #-}
module MagicWormhole.Internal.Rendezvous
(
ping
, list
, allocate
, claim
, release
, open
, close
, runClient
, Session
, ServerError(..)
, ClientError(..)
) where
import Protolude hiding (list, phase)
import Control.Concurrent.STM
( TVar
, newTVar
, modifyTVar'
, readTVar
, writeTVar
, TMVar
, newEmptyTMVar
, putTMVar
, takeTMVar
, tryPutTMVar
, TQueue
, newTQueue
, readTQueue
, writeTQueue
)
import Data.Aeson (eitherDecode, encode)
import Data.Hashable (Hashable)
import Data.HashMap.Lazy (HashMap)
import qualified Data.HashMap.Lazy as HashMap
import Data.String (String)
import qualified Network.Socket as Socket
import qualified Network.WebSockets as WS
import qualified MagicWormhole.Internal.ClientProtocol as ClientProtocol
import qualified MagicWormhole.Internal.Messages as Messages
import MagicWormhole.Internal.WebSockets (WebSocketEndpoint(..))
data Session
= Session
{ connection :: WS.Connection
, sessionAppID :: Messages.AppID
, sessionSide :: Messages.Side
, pendingVar :: TVar (HashMap ResponseType (TMVar Messages.ServerMessage))
, messageChan :: TQueue Messages.MailboxMessage
, motd :: TMVar (Maybe Text)
}
new :: WS.Connection
-> Messages.AppID
-> Messages.Side
-> STM Session
new connection appID side
= Session connection appID side
<$> newTVar mempty
<*> newTQueue
<*> newEmptyTMVar
send :: Session
-> Messages.ClientMessage
-> IO ()
send session msg = WS.sendBinaryData (connection session) (encode msg)
receive :: Session
-> IO Messages.ServerMessage
receive session = do
msg <- WS.receiveData (connection session)
either (throwIO . ParseError) pure (eitherDecode msg)
runClient
:: HasCallStack
=> WebSocketEndpoint
-> Messages.AppID
-> Messages.Side
-> Maybe Socket.Socket
-> (Session -> IO a)
-> IO a
runClient (WebSocketEndpoint host port path) appID side maybeSock app =
case maybeSock of
Nothing -> Socket.withSocketsDo . WS.runClient host port path $ runAction
Just sock -> Socket.withSocketsDo . WS.runClientWithSocket sock host path WS.defaultConnectionOptions [] $ runAction
where
action ws session = do
bind session appID side
app session `finally` WS.sendClose ws ("Connection closed connection" :: Text)
runAction ws = do
session <- atomically $ new ws appID side
(_, result) <- concurrently (readMessages session) (action ws session)
pure result
readMessages session = do
msg <- try $ receive session
case msg of
Left (WS.CloseRequest _ _) -> pass
Left err -> throwIO err
Right msg' -> do
result <- atomically $ gotMessage session msg'
case result of
Just err -> throwIO err
Nothing -> readMessages session
rpc :: HasCallStack
=> Session
-> Messages.ClientMessage
-> IO Messages.ServerMessage
rpc session req =
case expectedResponse req of
Nothing ->
throwIO (NotAnRPC req)
Just responseType -> do
box <- atomically $ expectResponse responseType
send session req
response <- atomically $ waitForResponse session responseType box
case response of
Messages.Error reason original -> throwIO (BadRequest reason original)
response' -> pure response'
where
expectResponse :: ResponseType -> STM (TMVar Messages.ServerMessage)
expectResponse responseType = do
pending <- readTVar (pendingVar session)
case HashMap.lookup responseType pending of
Nothing -> do
box <- newEmptyTMVar
writeTVar (pendingVar session) (HashMap.insert responseType box pending)
pure box
Just _ -> throwSTM (AlreadySent req)
bind :: HasCallStack => Session -> Messages.AppID -> Messages.Side -> IO ()
bind session appID side' = send session (Messages.Bind appID side')
ping :: HasCallStack => Session -> Int -> IO Int
ping session n = do
response <- rpc session (Messages.Ping n)
case response of
Messages.Pong n' -> pure n'
unexpected -> unexpectedMessage (Messages.Ping n) unexpected
list :: HasCallStack => Session -> IO [Messages.Nameplate]
list session = do
response <- rpc session Messages.List
case response of
Messages.Nameplates nameplates -> pure nameplates
unexpected -> unexpectedMessage Messages.List unexpected
allocate :: HasCallStack => Session -> IO Messages.Nameplate
allocate session = do
response <- rpc session Messages.Allocate
case response of
Messages.Allocated nameplate -> pure nameplate
unexpected -> unexpectedMessage Messages.Allocate unexpected
claim :: HasCallStack => Session -> Messages.Nameplate -> IO Messages.Mailbox
claim session nameplate = do
response <- rpc session (Messages.Claim nameplate)
case response of
Messages.Claimed mailbox -> pure mailbox
unexpected -> unexpectedMessage (Messages.Claim nameplate) unexpected
release :: HasCallStack => Session -> Maybe Messages.Nameplate -> IO ()
release session nameplate' = do
response <- rpc session (Messages.Release nameplate')
case response of
Messages.Released -> pure ()
unexpected -> unexpectedMessage (Messages.Release nameplate') unexpected
open :: HasCallStack => Session -> Messages.Mailbox -> IO ClientProtocol.Connection
open session mailbox = do
send session (Messages.Open mailbox)
pure ClientProtocol.Connection { ClientProtocol.appID = sessionAppID session
, ClientProtocol.ourSide = sessionSide session
, ClientProtocol.send = add session
, ClientProtocol.receive = readFromMailbox session
}
close :: HasCallStack => Session -> Maybe Messages.Mailbox -> Maybe Messages.Mood -> IO ()
close session mailbox' mood' = do
response <- rpc session (Messages.Close mailbox' mood')
case response of
Messages.Closed -> pure ()
unexpected -> unexpectedMessage (Messages.Close mailbox' mood') unexpected
add :: HasCallStack => Session -> Messages.Phase -> Messages.Body -> IO ()
add session phase body = send session (Messages.Add phase body)
readFromMailbox :: HasCallStack => Session -> STM Messages.MailboxMessage
readFromMailbox session = do
msg <- readFromMailbox' session
if Messages.side msg == sessionSide session
then readFromMailbox session
else pure msg
readFromMailbox' :: HasCallStack => Session -> STM Messages.MailboxMessage
readFromMailbox' session = readTQueue (messageChan session)
unexpectedMessage :: HasCallStack => Messages.ClientMessage -> Messages.ServerMessage -> a
unexpectedMessage request response = panic $ "Unexpected message: " <> show response <> ", in response to: " <> show request
waitForResponse :: Session -> ResponseType -> TMVar Messages.ServerMessage -> STM Messages.ServerMessage
waitForResponse session responseType box = do
response <- takeTMVar box
modifyTVar' (pendingVar session) (HashMap.delete responseType)
pure response
gotResponse :: Session -> ResponseType -> Messages.ServerMessage -> STM (Maybe ServerError)
gotResponse session responseType message = do
pending <- readTVar (pendingVar session)
case HashMap.lookup responseType pending of
Nothing -> pure (Just (ResponseWithoutRequest message))
Just box -> do
putTMVar box message
pure Nothing
gotMessage :: Session -> Messages.ServerMessage -> STM (Maybe ServerError)
gotMessage session msg =
case msg of
Messages.Ack -> pure Nothing
Messages.Welcome welcome -> handleWelcome welcome
Messages.Error{Messages.errorMessage, Messages.original} ->
case expectedResponse original of
Nothing -> pure (Just (ErrorForNonRequest errorMessage original))
Just responseType ->
gotResponse session responseType msg
Messages.Message mailboxMsg -> do
writeTQueue (messageChan session) mailboxMsg
pure Nothing
Messages.Nameplates{} -> gotResponse session NameplatesResponse msg
Messages.Allocated{} -> gotResponse session AllocatedResponse msg
Messages.Claimed{} -> gotResponse session ClaimedResponse msg
Messages.Released -> gotResponse session ReleasedResponse msg
Messages.Closed -> gotResponse session ClosedResponse msg
Messages.Pong{} -> gotResponse session PongResponse msg
where
handleWelcome welcome =
case Messages.welcomeErrorMessage welcome of
Just err -> pure (Just (Unwelcome err))
Nothing -> do
notYet <- tryPutTMVar (motd session) (Messages.motd welcome)
if notYet
then pure Nothing
else pure (Just (UnexpectedMessage (Messages.Welcome welcome)))
data ResponseType
= NameplatesResponse
| AllocatedResponse
| ClaimedResponse
| ReleasedResponse
| ClosedResponse
| PongResponse
deriving (Eq, Show, Generic, Hashable)
expectedResponse :: Messages.ClientMessage -> Maybe ResponseType
expectedResponse Messages.Bind{} = Nothing
expectedResponse Messages.List = Just NameplatesResponse
expectedResponse Messages.Allocate = Just AllocatedResponse
expectedResponse Messages.Claim{} = Just ClaimedResponse
expectedResponse Messages.Release{} = Just ReleasedResponse
expectedResponse Messages.Open{} = Nothing
expectedResponse Messages.Add{} = Nothing
expectedResponse Messages.Close{} = Just ClosedResponse
expectedResponse Messages.Ping{} = Just PongResponse
data ServerError
=
ResponseWithoutRequest Messages.ServerMessage
| UnexpectedMessage Messages.ServerMessage
| ErrorForNonRequest Text Messages.ClientMessage
| Unwelcome Text
| ParseError String
deriving (Eq, Show, Typeable)
instance Exception ServerError
data ClientError
=
AlreadySent Messages.ClientMessage
| NotAnRPC Messages.ClientMessage
| BadRequest Text Messages.ClientMessage
deriving (Eq, Show, Typeable)
instance Exception ClientError