module Control.Distributed.Process.Backend.SimpleLocalnet
(
Backend(..)
, initializeBackend
, startSlave
, terminateSlave
, findSlaves
, terminateAllSlaves
, startMaster
) where
import System.IO (fixIO)
import Data.Maybe (catMaybes)
import Data.Binary (Binary(get, put), getWord8, putWord8)
import Data.Accessor (Accessor, accessor, (^:), (^.))
import Data.Set (Set)
import qualified Data.Set as Set (insert, empty, toList)
import Data.Foldable (forM_)
import Data.Typeable (Typeable)
import Control.Applicative ((<$>))
import Control.Exception (throw)
import Control.Monad (forever, replicateM, replicateM_)
import Control.Monad.IO.Class (liftIO)
import Control.Concurrent (forkIO, threadDelay, ThreadId)
import Control.Concurrent.MVar (MVar, newMVar, readMVar, modifyMVar_)
import Control.Distributed.Process
( RemoteTable
, NodeId
, Process
, ProcessId
, WhereIsReply(..)
, whereis
, whereisRemoteAsync
, getSelfPid
, register
, reregister
, expect
, nsendRemote
, receiveWait
, match
, processNodeId
, monitorNode
, monitor
, unmonitor
, NodeMonitorNotification(..)
, ProcessRegistrationException
, finally
, newChan
, receiveChan
, nsend
, SendPort
, bracket
, try
, send
)
import qualified Control.Distributed.Process.Node as Node
( LocalNode
, newLocalNode
, localNodeId
, runProcess
)
import qualified Network.Transport.TCP as NT
( createTransport
, defaultTCPParameters
)
import qualified Network.Transport as NT (Transport)
import qualified Network.Socket as N (HostName, ServiceName, SockAddr)
import Control.Distributed.Process.Backend.SimpleLocalnet.Internal.Multicast (initMulticast)
data Backend = Backend {
newLocalNode :: IO Node.LocalNode
, findPeers :: Int -> IO [NodeId]
, redirectLogsHere :: [ProcessId] -> Process ()
}
data BackendState = BackendState {
_localNodes :: [Node.LocalNode]
, _peers :: Set NodeId
, discoveryDaemon :: ThreadId
}
initializeBackend :: N.HostName -> N.ServiceName -> RemoteTable -> IO Backend
initializeBackend host port rtable = do
mTransport <- NT.createTransport host port NT.defaultTCPParameters
(recv, sendp) <- initMulticast "224.0.0.99" 9999 1024
(_, backendState) <- fixIO $ \ ~(tid, _) -> do
backendState <- newMVar BackendState
{ _localNodes = []
, _peers = Set.empty
, discoveryDaemon = tid
}
tid' <- forkIO $ peerDiscoveryDaemon backendState recv sendp
return (tid', backendState)
case mTransport of
Left err -> throw err
Right transport ->
let backend = Backend {
newLocalNode = apiNewLocalNode transport rtable backendState
, findPeers = apiFindPeers sendp backendState
, redirectLogsHere = apiRedirectLogsHere backend
}
in return backend
apiNewLocalNode :: NT.Transport
-> RemoteTable
-> MVar BackendState
-> IO Node.LocalNode
apiNewLocalNode transport rtable backendState = do
localNode <- Node.newLocalNode transport rtable
modifyMVar_ backendState $ return . (localNodes ^: (localNode :))
return localNode
apiFindPeers :: (PeerDiscoveryMsg -> IO ())
-> MVar BackendState
-> Int
-> IO [NodeId]
apiFindPeers sendfn backendState delay = do
sendfn PeerDiscoveryRequest
threadDelay delay
Set.toList . (^. peers) <$> readMVar backendState
data PeerDiscoveryMsg =
PeerDiscoveryRequest
| PeerDiscoveryReply NodeId
instance Binary PeerDiscoveryMsg where
put PeerDiscoveryRequest = putWord8 0
put (PeerDiscoveryReply nid) = putWord8 1 >> put nid
get = do
header <- getWord8
case header of
0 -> return PeerDiscoveryRequest
1 -> PeerDiscoveryReply <$> get
_ -> fail "PeerDiscoveryMsg.get: invalid"
peerDiscoveryDaemon :: MVar BackendState
-> IO (PeerDiscoveryMsg, N.SockAddr)
-> (PeerDiscoveryMsg -> IO ())
-> IO ()
peerDiscoveryDaemon backendState recv sendfn = forever go
where
go = do
(msg, _) <- recv
case msg of
PeerDiscoveryRequest -> do
nodes <- (^. localNodes) <$> readMVar backendState
forM_ nodes $ sendfn . PeerDiscoveryReply . Node.localNodeId
PeerDiscoveryReply nid ->
modifyMVar_ backendState $ return . (peers ^: Set.insert nid)
apiRedirectLogsHere :: Backend -> [ProcessId] -> Process ()
apiRedirectLogsHere _backend slavecontrollers = do
mLogger <- whereis "logger"
myPid <- getSelfPid
forM_ mLogger $ \logger -> do
bracket
(mapM monitor slavecontrollers)
(mapM unmonitor)
$ \_ -> do
forM_ slavecontrollers $ \pid -> send pid (RedirectLogsTo logger myPid)
replicateM_ (length slavecontrollers) $ do
receiveWait
[ match (\(RedirectLogsReply {}) -> return ())
, match (\(NodeMonitorNotification {}) -> return ())
]
data SlaveControllerMsg
= SlaveTerminate
| RedirectLogsTo ProcessId ProcessId
deriving (Typeable, Show)
instance Binary SlaveControllerMsg where
put SlaveTerminate = putWord8 0
put (RedirectLogsTo a b) = do putWord8 1; put (a,b)
get = do
header <- getWord8
case header of
0 -> return SlaveTerminate
1 -> do (a,b) <- get; return (RedirectLogsTo a b)
_ -> fail "SlaveControllerMsg.get: invalid"
data RedirectLogsReply
= RedirectLogsReply ProcessId Bool
deriving (Typeable, Show)
instance Binary RedirectLogsReply where
put (RedirectLogsReply from ok) = put (from,ok)
get = do
(from,ok) <- get
return (RedirectLogsReply from ok)
startSlave :: Backend -> IO ()
startSlave backend = do
node <- newLocalNode backend
Node.runProcess node slaveController
slaveController :: Process ()
slaveController = do
pid <- getSelfPid
register "slaveController" pid
go
where
go = do
msg <- expect
case msg of
SlaveTerminate -> return ()
RedirectLogsTo loggerPid from -> do
r <- try (reregister "logger" loggerPid)
ok <- case (r :: Either ProcessRegistrationException ()) of
Right _ -> return True
Left _ -> do
s <- try (register "logger" loggerPid)
case (s :: Either ProcessRegistrationException ()) of
Right _ -> return True
Left _ -> return False
pid <- getSelfPid
send from (RedirectLogsReply pid ok)
go
terminateSlave :: NodeId -> Process ()
terminateSlave nid = nsendRemote nid "slaveController" SlaveTerminate
findSlaves :: Backend -> Process [ProcessId]
findSlaves backend = do
nodes <- liftIO $ findPeers backend 1000000
bracket
(mapM monitorNode nodes)
(mapM unmonitor)
$ \_ -> do
forM_ nodes $ \nid -> whereisRemoteAsync nid "slaveController"
catMaybes <$> replicateM (length nodes) (
receiveWait
[ match (\(WhereIsReply "slaveController" mPid) -> return mPid)
, match (\(NodeMonitorNotification {}) -> return Nothing)
])
terminateAllSlaves :: Backend -> Process ()
terminateAllSlaves backend = do
slaves <- findSlaves backend
forM_ slaves $ \pid -> send pid SlaveTerminate
liftIO $ threadDelay 1000000
startMaster :: Backend -> ([NodeId] -> Process ()) -> IO ()
startMaster backend proc = do
node <- newLocalNode backend
Node.runProcess node $ do
slaves <- findSlaves backend
redirectLogsHere backend slaves
proc (map processNodeId slaves) `finally` shutdownLogger
shutdownLogger :: Process ()
shutdownLogger = do
(sport,rport) <- newChan
nsend "logger" (sport :: SendPort ())
receiveChan rport
localNodes :: Accessor BackendState [Node.LocalNode]
localNodes = accessor _localNodes (\ns st -> st { _localNodes = ns })
peers :: Accessor BackendState (Set NodeId)
peers = accessor _peers (\ps st -> st { _peers = ps })