{-# LANGUAGE RankNTypes #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE GADTs #-} module Raft.Monad where import Protolude hiding (pass) import Control.Monad.RWS import qualified Data.Set as Set import Raft.Action import Raft.Client import Raft.Config import Raft.Event import Raft.Log import Raft.Persistent import Raft.NodeState import Raft.RPC import Raft.Types import Raft.Logging (RaftLogger, runRaftLoggerT, RaftLoggerT(..), LogMsg) import qualified Raft.Logging as Logging -------------------------------------------------------------------------------- -- State Machine -------------------------------------------------------------------------------- -- | Interface to handle commands in the underlying state machine. Functional --dependency permitting only a single state machine command to be defined to --update the state machine. class RSMP sm v | sm -> v where data RSMPError sm v type RSMPCtx sm v = ctx | ctx -> sm v applyCmdRSMP :: RSMPCtx sm v -> sm -> v -> Either (RSMPError sm v) sm class (Monad m, RSMP sm v) => RSM sm v m | m sm -> v where validateCmd :: v -> m (Either (RSMPError sm v) ()) askRSMPCtx :: m (RSMPCtx sm v) applyEntryRSM :: RSM sm v m => sm -> Entry v -> m (Either (RSMPError sm v) sm) applyEntryRSM sm e = case entryValue e of NoValue -> pure (Right sm) EntryValue v -> do res <- validateCmd v case res of Left err -> pure (Left err) Right () -> do ctx <- askRSMPCtx pure (applyCmdRSMP ctx sm v) -------------------------------------------------------------------------------- -- Raft Monad -------------------------------------------------------------------------------- tellAction :: Action sm v -> TransitionM sm v () tellAction a = tell [a] tellActions :: [Action sm v] -> TransitionM sm v () tellActions as = tell as data TransitionEnv sm = TransitionEnv { nodeConfig :: NodeConfig , stateMachine :: sm , nodeState :: RaftNodeState } newtype TransitionM sm v a = TransitionM { unTransitionM :: RaftLoggerT (RWS (TransitionEnv sm) [Action sm v] PersistentState) a } deriving (Functor, Applicative, Monad) instance MonadWriter [Action sm v] (TransitionM sm v) where tell = TransitionM . RaftLoggerT . tell listen = TransitionM . RaftLoggerT . listen . unRaftLoggerT . unTransitionM pass = TransitionM . RaftLoggerT . pass . unRaftLoggerT . unTransitionM instance MonadReader (TransitionEnv sm) (TransitionM sm v) where ask = TransitionM . RaftLoggerT $ ask local f = TransitionM . RaftLoggerT . local f . unRaftLoggerT . unTransitionM instance MonadState PersistentState (TransitionM sm v) where get = TransitionM . RaftLoggerT $ lift get put = TransitionM . RaftLoggerT . lift . put instance RaftLogger (RWS (TransitionEnv sm) [Action sm v] PersistentState) where loggerNodeId = configNodeId <$> asks nodeConfig loggerNodeState = asks nodeState runTransitionM :: TransitionEnv sm -> PersistentState -> TransitionM sm v a -> ((a, [LogMsg]), PersistentState, [Action sm v]) runTransitionM transEnv persistentState transitionM = runRWS (runRaftLoggerT (unTransitionM transitionM)) transEnv persistentState askNodeId :: TransitionM sm v NodeId askNodeId = asks (configNodeId . nodeConfig) -------------------------------------------------------------------------------- -- Handlers -------------------------------------------------------------------------------- type RPCHandler ns sm r v = RPCType r v => NodeState ns -> NodeId -> r -> TransitionM sm v (ResultState ns) type TimeoutHandler ns sm v = NodeState ns -> Timeout -> TransitionM sm v (ResultState ns) type ClientReqHandler ns sm v = NodeState ns -> ClientRequest v -> TransitionM sm v (ResultState ns) -------------------------------------------------------------------------------- -- RWS Helpers -------------------------------------------------------------------------------- broadcast :: SendRPCAction v -> TransitionM sm v () broadcast sendRPC = do selfNodeId <- askNodeId tellAction =<< flip BroadcastRPC sendRPC <$> asks (Set.filter (selfNodeId /=) . configNodeIds . nodeConfig) send :: NodeId -> SendRPCAction v -> TransitionM sm v () send nodeId sendRPC = tellAction (SendRPC nodeId sendRPC) -- | Resets the election timeout. resetElectionTimeout :: TransitionM sm v () resetElectionTimeout = tellAction (ResetTimeoutTimer ElectionTimeout) resetHeartbeatTimeout :: TransitionM sm v () resetHeartbeatTimeout = tellAction (ResetTimeoutTimer HeartbeatTimeout) redirectClientToLeader :: ClientId -> CurrentLeader -> TransitionM sm v () redirectClientToLeader clientId currentLeader = do let clientRedirResp = ClientRedirectResponse (ClientRedirResp currentLeader) tellAction (RespondToClient clientId clientRedirResp) respondClientRead :: ClientId -> TransitionM sm v () respondClientRead clientId = do clientReadResp <- ClientReadResponse . ClientReadResp <$> asks stateMachine tellAction (RespondToClient clientId clientReadResp) appendLogEntries :: Show v => Seq (Entry v) -> TransitionM sm v () appendLogEntries = tellAction . AppendLogEntries -------------------------------------------------------------------------------- startElection :: Index -> Index -> (Index, Term) -- ^ Last log entry data -> TransitionM sm v CandidateState startElection commitIndex lastApplied lastLogEntryData = do incrementTerm voteForSelf resetElectionTimeout broadcast =<< requestVoteMessage selfNodeId <- askNodeId -- Return new candidate state pure CandidateState { csCommitIndex = commitIndex , csLastApplied = lastApplied , csVotes = Set.singleton selfNodeId , csLastLogEntryData = lastLogEntryData } where requestVoteMessage = do term <- currentTerm <$> get selfNodeId <- askNodeId pure $ SendRequestVoteRPC RequestVote { rvTerm = term , rvCandidateId = selfNodeId , rvLastLogIndex = fst lastLogEntryData , rvLastLogTerm = snd lastLogEntryData } incrementTerm = do psNextTerm <- incrTerm . currentTerm <$> get modify $ \pstate -> pstate { currentTerm = psNextTerm , votedFor = Nothing } voteForSelf = do selfNodeId <- askNodeId modify $ \pstate -> pstate { votedFor = Just selfNodeId } -------------------------------------------------------------------------------- -- Logging -------------------------------------------------------------------------------- logInfo = TransitionM . Logging.logInfo logDebug = TransitionM . Logging.logDebug