{-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} module Raft.Handle where import Protolude import qualified Raft.Follower as Follower import qualified Raft.Candidate as Candidate import qualified Raft.Leader as Leader import Raft.Action import Raft.Event import Raft.Monad import Raft.NodeState import Raft.Persistent import Raft.RPC import Raft.Logging (LogMsg) -- | Main entry point for handling events handleEvent :: forall sm v. (RSMP sm v, Show v) => RaftNodeState -> TransitionEnv sm -> PersistentState -> Event v -> (RaftNodeState, PersistentState, [Action sm v], [LogMsg]) handleEvent raftNodeState@(RaftNodeState initNodeState) transitionEnv persistentState event = -- Rules for all servers: case handleNewerRPCTerm of ((RaftNodeState resNodeState, logMsgs), persistentState', outputs) -> case handleEvent' resNodeState transitionEnv persistentState' event of ((ResultState _ resultState, logMsgs'), persistentState'', outputs') -> (RaftNodeState resultState, persistentState'', outputs <> outputs', logMsgs <> logMsgs') where handleNewerRPCTerm :: ((RaftNodeState, [LogMsg]), PersistentState, [Action sm v]) handleNewerRPCTerm = case event of MessageEvent (RPCMessageEvent (RPCMessage _ rpc)) -> runTransitionM transitionEnv persistentState $ do -- If RPC request or response contains term T > currentTerm: set -- currentTerm = T, convert to follower currentTerm <- gets currentTerm if currentTerm < rpcTerm rpc then case convertToFollower initNodeState of ResultState _ nodeState -> do modify $ \pstate -> pstate { currentTerm = rpcTerm rpc , votedFor = Nothing } resetElectionTimeout pure (RaftNodeState nodeState) else pure raftNodeState _ -> ((raftNodeState, []), persistentState, mempty) convertToFollower :: forall s. NodeState s -> ResultState s convertToFollower nodeState = case nodeState of NodeFollowerState _ -> ResultState HigherTermFoundFollower nodeState NodeCandidateState cs -> ResultState HigherTermFoundCandidate $ NodeFollowerState FollowerState { fsCurrentLeader = NoLeader , fsCommitIndex = csCommitIndex cs , fsLastApplied = csLastApplied cs , fsLastLogEntryData = csLastLogEntryData cs , fsTermAtAEPrevIndex = Nothing } NodeLeaderState ls -> ResultState HigherTermFoundLeader $ NodeFollowerState FollowerState { fsCurrentLeader = NoLeader , fsCommitIndex = lsCommitIndex ls , fsLastApplied = lsLastApplied ls , fsLastLogEntryData = let (lastLogEntryIdx, lastLogEntryTerm, _) = lsLastLogEntryData ls in (lastLogEntryIdx, lastLogEntryTerm) , fsTermAtAEPrevIndex = Nothing } data RaftHandler ns sm v = RaftHandler { handleAppendEntries :: RPCHandler ns sm (AppendEntries v) v , handleAppendEntriesResponse :: RPCHandler ns sm AppendEntriesResponse v , handleRequestVote :: RPCHandler ns sm RequestVote v , handleRequestVoteResponse :: RPCHandler ns sm RequestVoteResponse v , handleTimeout :: TimeoutHandler ns sm v , handleClientRequest :: ClientReqHandler ns sm v } followerRaftHandler :: Show v => RaftHandler 'Follower sm v followerRaftHandler = RaftHandler { handleAppendEntries = Follower.handleAppendEntries , handleAppendEntriesResponse = Follower.handleAppendEntriesResponse , handleRequestVote = Follower.handleRequestVote , handleRequestVoteResponse = Follower.handleRequestVoteResponse , handleTimeout = Follower.handleTimeout , handleClientRequest = Follower.handleClientRequest } candidateRaftHandler :: Show v => RaftHandler 'Candidate sm v candidateRaftHandler = RaftHandler { handleAppendEntries = Candidate.handleAppendEntries , handleAppendEntriesResponse = Candidate.handleAppendEntriesResponse , handleRequestVote = Candidate.handleRequestVote , handleRequestVoteResponse = Candidate.handleRequestVoteResponse , handleTimeout = Candidate.handleTimeout , handleClientRequest = Candidate.handleClientRequest } leaderRaftHandler :: Show v => RaftHandler 'Leader sm v leaderRaftHandler = RaftHandler { handleAppendEntries = Leader.handleAppendEntries , handleAppendEntriesResponse = Leader.handleAppendEntriesResponse , handleRequestVote = Leader.handleRequestVote , handleRequestVoteResponse = Leader.handleRequestVoteResponse , handleTimeout = Leader.handleTimeout , handleClientRequest = Leader.handleClientRequest } mkRaftHandler :: forall ns sm v. Show v => NodeState ns -> RaftHandler ns sm v mkRaftHandler nodeState = case nodeState of NodeFollowerState _ -> followerRaftHandler NodeCandidateState _ -> candidateRaftHandler NodeLeaderState _ -> leaderRaftHandler handleEvent' :: forall ns sm v. (RSMP sm v, Show v) => NodeState ns -> TransitionEnv sm -> PersistentState -> Event v -> ((ResultState ns, [LogMsg]), PersistentState, [Action sm v]) handleEvent' initNodeState transitionEnv persistentState event = runTransitionM transitionEnv persistentState $ do case event of MessageEvent mev -> case mev of RPCMessageEvent rpcMsg -> handleRPCMessage rpcMsg ClientRequestEvent cr -> do handleClientRequest initNodeState cr TimeoutEvent tout -> do handleTimeout initNodeState tout where RaftHandler{..} = mkRaftHandler initNodeState handleRPCMessage :: RPCMessage v -> TransitionM sm v (ResultState ns) handleRPCMessage (RPCMessage sender rpc) = case rpc of AppendEntriesRPC appendEntries -> handleAppendEntries initNodeState sender appendEntries AppendEntriesResponseRPC appendEntriesResp -> handleAppendEntriesResponse initNodeState sender appendEntriesResp RequestVoteRPC requestVote -> handleRequestVote initNodeState sender requestVote RequestVoteResponseRPC requestVoteResp -> handleRequestVoteResponse initNodeState sender requestVoteResp