{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# Language ConstraintKinds #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Raft.Monad where
import Protolude hiding (STM, TChan, readTChan, writeTChan, newTChan, atomically)
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Trans.Class
import qualified Control.Monad.Conc.Class as Conc
import Control.Concurrent.Classy.STM.TChan
import Raft.Config
import Raft.Event
import Raft.Logging
import Raft.NodeState
import Test.DejaFu.Conc (ConcIO)
import qualified Test.DejaFu.Types as TDT
type MonadRaft v m = (MonadRaftChan v m, MonadRaftFork m)
class Monad m => MonadRaftChan v m where
type RaftEventChan v m
readRaftChan :: RaftEventChan v m -> m (Event v)
writeRaftChan :: RaftEventChan v m -> Event v -> m ()
newRaftChan :: m (RaftEventChan v m)
instance MonadRaftChan v IO where
type RaftEventChan v IO = TChan (Conc.STM IO) (Event v)
readRaftChan = Conc.atomically . readTChan
writeRaftChan chan = Conc.atomically . writeTChan chan
newRaftChan = Conc.atomically newTChan
instance MonadRaftChan v ConcIO where
type RaftEventChan v ConcIO = TChan (Conc.STM ConcIO) (Event v)
readRaftChan = Conc.atomically . readTChan
writeRaftChan chan = Conc.atomically . writeTChan chan
newRaftChan = Conc.atomically newTChan
data RaftThreadRole
= RPCHandler
| ClientRequestHandler
| CustomThreadRole Text
deriving Show
class Monad m => MonadRaftFork m where
type RaftThreadId m
raftFork
:: RaftThreadRole
-> m ()
-> m (RaftThreadId m)
instance MonadRaftFork IO where
type RaftThreadId IO = Protolude.ThreadId
raftFork _ = forkIO
instance MonadRaftFork ConcIO where
type RaftThreadId ConcIO = TDT.ThreadId
raftFork r = Conc.forkN (show r)
data RaftEnv v m = RaftEnv
{ eventChan :: RaftEventChan v m
, resetElectionTimer :: m ()
, resetHeartbeatTimer :: m ()
, raftNodeConfig :: RaftNodeConfig
, raftNodeLogCtx :: LogCtx (RaftT v m)
}
newtype RaftT v m a = RaftT
{ unRaftT :: ReaderT (RaftEnv v m) (StateT (RaftNodeState v) m) a
} deriving (Functor, Applicative, Monad, MonadReader (RaftEnv v m), MonadState (RaftNodeState v), MonadFail, Alternative, MonadPlus)
instance MonadTrans (RaftT v) where
lift = RaftT . lift . lift
deriving instance MonadIO m => MonadIO (RaftT v m)
deriving instance MonadThrow m => MonadThrow (RaftT v m)
deriving instance MonadCatch m => MonadCatch (RaftT v m)
instance MonadRaftFork m => MonadRaftFork (RaftT v m) where
type RaftThreadId (RaftT v m) = RaftThreadId m
raftFork s m = do
raftEnv <- ask
raftState <- get
lift $ raftFork s (runRaftT raftState raftEnv m)
instance Monad m => RaftLogger v (RaftT v m) where
loggerCtx = (,) <$> asks (configNodeId . raftNodeConfig) <*> get
runRaftT
:: Monad m
=> RaftNodeState v
-> RaftEnv v m
-> RaftT v m a
-> m a
runRaftT raftNodeState raftEnv =
flip evalStateT raftNodeState . flip runReaderT raftEnv . unRaftT
logDebug :: MonadIO m => Text -> RaftT v m ()
logDebug msg = flip logDebugIO msg =<< asks raftNodeLogCtx
logCritical :: MonadIO m => Text -> RaftT v m ()
logCritical msg = flip logCriticalIO msg =<< asks raftNodeLogCtx
logAndPanic :: MonadIO m => Text -> RaftT v m a
logAndPanic msg = flip logAndPanicIO msg =<< asks raftNodeLogCtx