module Control.Distributed.Process.Execution.Exchange.Broadcast
(
broadcastExchange
, broadcastExchangeT
, broadcastClient
, bindToBroadcaster
, BroadcastExchange
) where
import Control.Concurrent.STM (STM, atomically)
import Control.Concurrent.STM.TChan
( TChan
, newBroadcastTChanIO
, dupTChan
, readTChan
, writeTChan
)
import Control.DeepSeq (NFData)
import Control.Distributed.Process
( Process
, MonitorRef
, ProcessMonitorNotification(..)
, ProcessId
, SendPort
, processNodeId
, getSelfPid
, getSelfNode
, liftIO
, newChan
, sendChan
, unsafeSend
, unsafeSendChan
, receiveWait
, match
, matchIf
, die
, handleMessage
, Match
)
import qualified Control.Distributed.Process as P
import Control.Distributed.Process.Serializable()
import Control.Distributed.Process.Execution.Exchange.Internal
( startExchange
, configureExchange
, Message(..)
, Exchange(..)
, ExchangeType(..)
, applyHandlers
)
import Control.Distributed.Process.Extras.Internal.Types
( Channel
, ServerDisconnected(..)
)
import Control.Distributed.Process.Extras.Internal.Unsafe
( PCopy
, pCopy
, pUnwrap
, matchChanP
, InputStream(Null)
, newInputStream
)
import Control.Monad (forM_, void)
import Data.Accessor
( Accessor
, accessor
, (^:)
)
import Data.Binary
import qualified Data.Foldable as Foldable (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Typeable (Typeable)
import GHC.Generics
data BindPort = BindPort { portClient :: !ProcessId
, portSend :: !(SendPort Message)
} deriving (Typeable, Generic)
instance Binary BindPort where
instance NFData BindPort where
data BindSTM =
BindSTM { stmClient :: !ProcessId
, stmSend :: !(SendPort (PCopy (InputStream Message)))
} deriving (Typeable)
data OutputStream =
WriteChan (SendPort Message)
| WriteSTM (Message -> STM ())
| NoWrite
deriving (Typeable)
data Binding = Binding { outputStream :: !OutputStream
, inputStream :: !(InputStream Message)
}
| PidBinding !ProcessId
deriving (Typeable)
data BindOk = BindOk
deriving (Typeable, Generic)
instance Binary BindOk where
instance NFData BindOk where
data BindFail = BindFail !String
deriving (Typeable, Generic)
instance Binary BindFail where
instance NFData BindFail where
data BindPlease = BindPlease
deriving (Typeable, Generic)
instance Binary BindPlease where
instance NFData BindPlease where
type BroadcastClients = Map ProcessId Binding
data BroadcastEx =
BroadcastEx { _routingTable :: !BroadcastClients
, channel :: !(TChan Message)
}
type BroadcastExchange = ExchangeType BroadcastEx
broadcastExchange :: Process Exchange
broadcastExchange = broadcastExchangeT >>= startExchange
broadcastExchangeT :: Process BroadcastExchange
broadcastExchangeT = do
ch <- liftIO newBroadcastTChanIO
return $ ExchangeType { name = "BroadcastExchange"
, state = BroadcastEx Map.empty ch
, configureEx = apiConfigure
, routeEx = apiRoute
}
broadcastClient :: Exchange -> Process (InputStream Message)
broadcastClient ex@Exchange{..} = do
myNode <- getSelfNode
us <- getSelfPid
if processNodeId pid == myNode
then do (sp, rp) <- newChan
configureExchange ex $ pCopy (BindSTM us sp)
mRef <- P.monitor pid
P.finally (receiveWait [ matchChanP rp
, handleServerFailure mRef ])
(P.unmonitor mRef)
else do (sp, rp) <- newChan :: Process (Channel Message)
configureExchange ex $ BindPort us sp
mRef <- P.monitor pid
P.finally (receiveWait [
match (\(_ :: BindOk) -> return $ newInputStream $ Left rp)
, match (\(f :: BindFail) -> die f)
, handleServerFailure mRef
])
(P.unmonitor mRef)
bindToBroadcaster :: Exchange -> Process ()
bindToBroadcaster ex@Exchange{..} = do
us <- getSelfPid
configureExchange ex $ (BindPlease, us)
apiRoute :: BroadcastEx -> Message -> Process BroadcastEx
apiRoute ex@BroadcastEx{..} msg = do
liftIO $ atomically $ writeTChan channel msg
forM_ (Foldable.toList _routingTable) $ routeToClient msg
return ex
where
routeToClient m (PidBinding p) = P.forward (payload m) p
routeToClient m b@(Binding _ _) = writeToStream (outputStream b) m
apiConfigure :: BroadcastEx -> P.Message -> Process BroadcastEx
apiConfigure ex msg = do
applyHandlers ex msg $ [ \m -> handleMessage m (handleBindPort ex)
, \m -> handleBindSTM ex m
, \m -> handleMessage m (handleBindPlease ex)
, \m -> handleMessage m (handleMonitorSignal ex)
, (const $ return $ Just ex)
]
where
handleBindPlease ex' (BindPlease, p) = do
case lookupBinding ex' p of
Nothing -> return $ (routingTable ^: Map.insert p (PidBinding p)) ex'
Just _ -> return ex'
handleMonitorSignal bx (ProcessMonitorNotification _ p _) =
return $ (routingTable ^: Map.delete p) bx
handleBindSTM ex'@BroadcastEx{..} msg' = do
bind' <- pUnwrap msg' :: Process (Maybe BindSTM)
case bind' of
Nothing -> return Nothing
Just s -> do
let binding = lookupBinding ex' (stmClient s)
case binding of
Nothing -> createBinding ex' s >>= \ex'' -> handleBindSTM ex'' msg'
Just b -> sendBinding (stmSend s) b >> return (Just ex')
createBinding bEx'@BroadcastEx{..} BindSTM{..} = do
void $ P.monitor stmClient
nch <- liftIO $ atomically $ dupTChan channel
let istr = newInputStream $ Right (readTChan nch)
let ostr = NoWrite
let bnd = Binding ostr istr
return $ (routingTable ^: Map.insert stmClient bnd) bEx'
sendBinding sp' bs = unsafeSendChan sp' $ pCopy (inputStream bs)
handleBindPort :: BroadcastEx -> BindPort -> Process BroadcastEx
handleBindPort x@BroadcastEx{..} BindPort{..} = do
let binding = lookupBinding x portClient
case binding of
Just _ -> unsafeSend portClient (BindFail "DuplicateBinding") >> return x
Nothing -> do
let istr = Null
let ostr = WriteChan portSend
let bound = Binding ostr istr
void $ P.monitor portClient
unsafeSend portClient BindOk
return $ (routingTable ^: Map.insert portClient bound) x
lookupBinding BroadcastEx{..} k = Map.lookup k $ _routingTable
writeToStream :: OutputStream -> Message -> Process ()
writeToStream (WriteChan sp) = sendChan sp
writeToStream (WriteSTM stm) = liftIO . atomically . stm
writeToStream NoWrite = const $ return ()
handleServerFailure :: MonitorRef -> Match (InputStream Message)
handleServerFailure mRef =
matchIf (\(ProcessMonitorNotification r _ _) -> r == mRef)
(\(ProcessMonitorNotification _ _ d) -> die $ ServerDisconnected d)
routingTable :: Accessor BroadcastEx BroadcastClients
routingTable = accessor _routingTable (\r e -> e { _routingTable = r })