{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
module Supervisors
( Supervisor
, withSupervisor
, supervise
, superviseSTM
) where
import Control.Concurrent.STM
import Control.Concurrent (ThreadId, forkIO, myThreadId, throwTo)
import Control.Concurrent.Async (withAsync)
import Control.Exception.Safe
(Exception, SomeException, bracket, bracket_, toException, withException)
import Control.Monad (forever, void)
import Data.Foldable (traverse_)
import qualified Data.Set as S
data Supervisor = Supervisor
{ stateVar :: TVar (Either SomeException (S.Set ThreadId))
, runQ :: TQueue (IO ())
}
newSupervisor :: IO Supervisor
newSupervisor = do
stateVar <- newTVarIO $ Right S.empty
runQ <- newTQueueIO
let sup = Supervisor
{ stateVar = stateVar
, runQ = runQ
}
pure sup
runSupervisor :: Supervisor -> IO ()
runSupervisor sup@Supervisor{runQ=q} =
forever (atomically (readTQueue q) >>= supervise sup)
`withException`
\e -> throwKids sup (e :: SomeException)
withSupervisor :: (Supervisor -> IO a) -> IO a
withSupervisor f = do
sup <- newSupervisor
withAsync (runSupervisor sup) $ const (f sup)
throwKids :: Exception e => Supervisor -> e -> IO ()
throwKids Supervisor{stateVar=stateVar} exn =
bracket
(atomically $ readTVar stateVar >>= \case
Left _ ->
pure S.empty
Right kids -> do
writeTVar stateVar $ Left (toException exn)
pure kids)
(traverse_ (`throwTo` exn))
(\_ -> pure ())
supervise :: Supervisor -> IO () -> IO ()
supervise Supervisor{stateVar=stateVar} task =
void $ forkIO $ bracket_ addMe removeMe task
where
addMe = do
me <- myThreadId
atomically $ do
supState <- readTVar stateVar
case supState of
Left e ->
throwSTM e
Right kids -> do
let !newKids = S.insert me kids
writeTVar stateVar $ Right newKids
removeMe = do
me <- myThreadId
atomically $ modifyTVar' stateVar $ \case
state@(Left _) ->
state
Right kids ->
Right $! S.delete me kids
superviseSTM :: Supervisor -> IO () -> STM ()
superviseSTM Supervisor{runQ=q} = writeTQueue q