{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}

-- |
-- Module      : Control.Concurrent.NQE.Supervisor
-- Copyright   : No rights reserved
-- License     : UNLICENSE
-- Maintainer  : xenog@protonmail.com
-- Stability   : experimental
-- Portability : POSIX
--
-- Supervisors run and monitor processes, including other supervisors. A supervisor
-- has a corresponding 'Strategy' that controls its behaviour if a child stops.
-- Supervisors deal with exceptions in concurrent processes so that their code does
-- not need to be written in an overly-defensive style. They help prevent problems
-- caused by processes dying quietly in the background, potentially locking an
-- entire application.
module Control.Concurrent.NQE.Supervisor
  ( ChildAction,
    Child,
    SupervisorMessage,
    Supervisor,
    Strategy (..),
    withSupervisor,
    supervisor,
    supervisorProcess,
    addChild,
    removeChild,
  )
where

import Control.Applicative
import Control.Concurrent.NQE.Process
import Control.Concurrent.STM (retry)
import Control.Monad
import Data.List
import UnliftIO

-- | Alias for child action to be executed asynchronously by supervisor.
type ChildAction = IO ()

-- | Thread handler for child.
type Child = Async ()

-- | Send this message to a supervisor to add or remove a child.
data SupervisorMessage
  = AddChild
      !ChildAction
      !(Listen Child)
  | RemoveChild
      !Child
      !(Listen ())

-- | Alias for supervisor process.
type Supervisor = Process SupervisorMessage

-- | Supervisor strategies to decide what to do when a child stops.
data Strategy
  = -- | send a 'SupervisorNotif' to 'Mailbox' when child dies
    Notify (Listen (Child, Maybe SomeException))
  | -- | kill all processes and propagate exception upstream
    KillAll
  | -- | ignore processes that stop without raising an exception
    IgnoreGraceful
  | -- | keep running if a child dies and ignore it
    IgnoreAll

-- | Run a supervisor asynchronously and pass its mailbox to a function.
-- Supervisor will be stopped along with all its children when the function
-- ends.
withSupervisor ::
  (MonadUnliftIO m) =>
  Strategy ->
  (Supervisor -> m a) ->
  m a
withSupervisor :: forall (m :: * -> *) a.
MonadUnliftIO m =>
Strategy -> (Supervisor -> m a) -> m a
withSupervisor = (Inbox SupervisorMessage -> m ()) -> (Supervisor -> m a) -> m a
forall (m :: * -> *) msg a.
MonadUnliftIO m =>
(Inbox msg -> m ()) -> (Process msg -> m a) -> m a
withProcess ((Inbox SupervisorMessage -> m ()) -> (Supervisor -> m a) -> m a)
-> (Strategy -> Inbox SupervisorMessage -> m ())
-> Strategy
-> (Supervisor -> m a)
-> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Strategy -> Inbox SupervisorMessage -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
Strategy -> Inbox SupervisorMessage -> m ()
supervisorProcess

-- | Run a supervisor as an asynchronous process.
supervisor :: (MonadUnliftIO m) => Strategy -> m Supervisor
supervisor :: forall (m :: * -> *). MonadUnliftIO m => Strategy -> m Supervisor
supervisor Strategy
strat = (Inbox SupervisorMessage -> m ()) -> m Supervisor
forall (m :: * -> *) msg.
MonadUnliftIO m =>
(Inbox msg -> m ()) -> m (Process msg)
process (Strategy -> Inbox SupervisorMessage -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
Strategy -> Inbox SupervisorMessage -> m ()
supervisorProcess Strategy
strat)

-- | Run a supervisor in the current thread.
supervisorProcess ::
  (MonadUnliftIO m) =>
  Strategy ->
  Inbox SupervisorMessage ->
  m ()
supervisorProcess :: forall (m :: * -> *).
MonadUnliftIO m =>
Strategy -> Inbox SupervisorMessage -> m ()
supervisorProcess Strategy
strat Inbox SupervisorMessage
i = do
  TVar [Child]
state <- [Child] -> m (TVar [Child])
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO []
  m () -> m () -> m ()
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
finally (TVar [Child] -> m ()
loop TVar [Child]
state) (TVar [Child] -> m ()
forall (m :: * -> *). MonadUnliftIO m => TVar [Child] -> m ()
stopAll TVar [Child]
state)
  where
    loop :: TVar [Child] -> m ()
loop TVar [Child]
state = do
      Either (Child, Either SomeException ()) SupervisorMessage
e <- STM (Either (Child, Either SomeException ()) SupervisorMessage)
-> m (Either (Child, Either SomeException ()) SupervisorMessage)
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM (Either (Child, Either SomeException ()) SupervisorMessage)
 -> m (Either (Child, Either SomeException ()) SupervisorMessage))
-> STM (Either (Child, Either SomeException ()) SupervisorMessage)
-> m (Either (Child, Either SomeException ()) SupervisorMessage)
forall a b. (a -> b) -> a -> b
$ SupervisorMessage
-> Either (Child, Either SomeException ()) SupervisorMessage
forall a b. b -> Either a b
Right (SupervisorMessage
 -> Either (Child, Either SomeException ()) SupervisorMessage)
-> STM SupervisorMessage
-> STM (Either (Child, Either SomeException ()) SupervisorMessage)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Inbox SupervisorMessage -> STM SupervisorMessage
forall msg. Inbox msg -> STM msg
forall (mbox :: * -> *) msg. InChan mbox => mbox msg -> STM msg
receiveSTM Inbox SupervisorMessage
i STM (Either (Child, Either SomeException ()) SupervisorMessage)
-> STM (Either (Child, Either SomeException ()) SupervisorMessage)
-> STM (Either (Child, Either SomeException ()) SupervisorMessage)
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Child, Either SomeException ())
-> Either (Child, Either SomeException ()) SupervisorMessage
forall a b. a -> Either a b
Left ((Child, Either SomeException ())
 -> Either (Child, Either SomeException ()) SupervisorMessage)
-> STM (Child, Either SomeException ())
-> STM (Either (Child, Either SomeException ()) SupervisorMessage)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar [Child] -> STM (Child, Either SomeException ())
waitForChild TVar [Child]
state
      Bool
again <-
        case Either (Child, Either SomeException ()) SupervisorMessage
e of
          Right SupervisorMessage
m -> TVar [Child] -> SupervisorMessage -> m Bool
forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child] -> SupervisorMessage -> m Bool
processMessage TVar [Child]
state SupervisorMessage
m
          Left (Child, Either SomeException ())
x -> TVar [Child]
-> Strategy -> (Child, Either SomeException ()) -> m Bool
forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child]
-> Strategy -> (Child, Either SomeException ()) -> m Bool
processDead TVar [Child]
state Strategy
strat (Child, Either SomeException ())
x
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
again (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar [Child] -> m ()
loop TVar [Child]
state

-- | Add a new 'ChildAction' to the supervisor. Will return the 'Child' that was
-- just started. This function will not block or raise an exception if the child
-- dies.
addChild :: (MonadIO m) => Supervisor -> ChildAction -> m Child
addChild :: forall (m :: * -> *).
MonadIO m =>
Supervisor -> ChildAction -> m Child
addChild Supervisor
sup ChildAction
action = ChildAction -> Listen Child -> SupervisorMessage
AddChild ChildAction
action (Listen Child -> SupervisorMessage) -> Supervisor -> m Child
forall (m :: * -> *) (mbox :: * -> *) response request.
(MonadIO m, OutChan mbox) =>
(Listen response -> request) -> mbox request -> m response
`query` Supervisor
sup

-- | Stop a 'Child' controlled by this supervisor. Will block until the child
-- dies.
removeChild :: (MonadIO m) => Supervisor -> Child -> m ()
removeChild :: forall (m :: * -> *). MonadIO m => Supervisor -> Child -> m ()
removeChild Supervisor
sup Child
c = Child -> Listen () -> SupervisorMessage
RemoveChild Child
c (Listen () -> SupervisorMessage) -> Supervisor -> m ()
forall (m :: * -> *) (mbox :: * -> *) response request.
(MonadIO m, OutChan mbox) =>
(Listen response -> request) -> mbox request -> m response
`query` Supervisor
sup

-- | Internal function to stop all children.
stopAll :: (MonadUnliftIO m) => TVar [Child] -> m ()
stopAll :: forall (m :: * -> *). MonadUnliftIO m => TVar [Child] -> m ()
stopAll TVar [Child]
state = m () -> m ()
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m a
mask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  [Child]
as <- TVar [Child] -> m [Child]
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar [Child]
state
  (Child -> m ()) -> [Child] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Child -> m ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
cancel [Child]
as

-- | Internal function to wait for a child process to finish running.
waitForChild :: TVar [Child] -> STM (Child, Either SomeException ())
waitForChild :: TVar [Child] -> STM (Child, Either SomeException ())
waitForChild TVar [Child]
state = do
  [Child]
as <- TVar [Child] -> STM [Child]
forall a. TVar a -> STM a
readTVar TVar [Child]
state
  Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Child] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Child]
as) STM ()
forall a. STM a
retry
  [Child] -> STM (Child, Either SomeException ())
forall a. [Async a] -> STM (Async a, Either SomeException a)
waitAnyCatchSTM [Child]
as

-- | Internal function to process incoming supervisor message.
processMessage ::
  (MonadUnliftIO m) => TVar [Child] -> SupervisorMessage -> m Bool
processMessage :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child] -> SupervisorMessage -> m Bool
processMessage TVar [Child]
state (AddChild ChildAction
ch Listen Child
r) = do
  Child
a <- TVar [Child] -> ChildAction -> m Child
forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child] -> ChildAction -> m Child
startChild TVar [Child]
state ChildAction
ch
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ Listen Child
r Child
a
  Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
processMessage TVar [Child]
state (RemoveChild Child
a Listen ()
r) = do
  TVar [Child] -> Child -> m ()
forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child] -> Child -> m ()
stopChild TVar [Child]
state Child
a
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ Listen ()
r ()
  Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

-- | Internal function to run when a child process dies.
processDead ::
  (MonadUnliftIO m) =>
  TVar [Child] ->
  Strategy ->
  (Child, Either SomeException ()) ->
  m Bool
processDead :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child]
-> Strategy -> (Child, Either SomeException ()) -> m Bool
processDead TVar [Child]
state Strategy
IgnoreAll (Child
a, Either SomeException ()
_) = do
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ())
-> (([Child] -> [Child]) -> STM ()) -> ([Child] -> [Child]) -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar [Child] -> ([Child] -> [Child]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [Child]
state (([Child] -> [Child]) -> m ()) -> ([Child] -> [Child]) -> m ()
forall a b. (a -> b) -> a -> b
$ (Child -> Bool) -> [Child] -> [Child]
forall a. (a -> Bool) -> [a] -> [a]
filter (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
/= Child
a)
  Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
processDead TVar [Child]
state Strategy
KillAll (Child
a, Either SomeException ()
e) = do
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar [Child] -> ([Child] -> [Child]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [Child]
state (([Child] -> [Child]) -> STM ())
-> ((Child -> Bool) -> [Child] -> [Child])
-> (Child -> Bool)
-> STM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Child -> Bool) -> [Child] -> [Child]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Child -> Bool) -> STM ()) -> (Child -> Bool) -> STM ()
forall a b. (a -> b) -> a -> b
$ (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
/= Child
a)
  TVar [Child] -> m ()
forall (m :: * -> *). MonadUnliftIO m => TVar [Child] -> m ()
stopAll TVar [Child]
state
  case Either SomeException ()
e of
    Left SomeException
x -> SomeException -> m Bool
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
x
    Right () -> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
processDead TVar [Child]
state Strategy
IgnoreGraceful (Child
a, Right ()) = do
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TVar [Child] -> ([Child] -> [Child]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [Child]
state ((Child -> Bool) -> [Child] -> [Child]
forall a. (a -> Bool) -> [a] -> [a]
filter (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
/= Child
a)))
  Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
processDead TVar [Child]
state Strategy
IgnoreGraceful (Child
a, Left SomeException
e) = do
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar [Child] -> ([Child] -> [Child]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [Child]
state ((Child -> Bool) -> [Child] -> [Child]
forall a. (a -> Bool) -> [a] -> [a]
filter (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
/= Child
a))
  TVar [Child] -> m ()
forall (m :: * -> *). MonadUnliftIO m => TVar [Child] -> m ()
stopAll TVar [Child]
state
  SomeException -> m Bool
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
throwIO SomeException
e
processDead TVar [Child]
state (Notify Listen (Child, Maybe SomeException)
notif) (Child
a, Either SomeException ()
ee) = do
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    [Child]
as <- TVar [Child] -> STM [Child]
forall a. TVar a -> STM a
readTVar TVar [Child]
state
    case (Child -> Bool) -> [Child] -> Maybe Child
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
== Child
a) [Child]
as of
      Just Child
p -> Listen (Child, Maybe SomeException)
notif (Child
p, Maybe SomeException
me)
      Maybe Child
Nothing -> Listen ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    TVar [Child] -> ([Child] -> [Child]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar [Child]
state ((Child -> Bool) -> [Child] -> [Child]
forall a. (a -> Bool) -> [a] -> [a]
filter (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
/= Child
a))
  Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  where
    me :: Maybe SomeException
me =
      case Either SomeException ()
ee of
        Left SomeException
e -> SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
e
        Right () -> Maybe SomeException
forall a. Maybe a
Nothing

-- | Internal function to start a child process.
startChild :: (MonadUnliftIO m) => TVar [Child] -> ChildAction -> m Child
startChild :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child] -> ChildAction -> m Child
startChild TVar [Child]
state ChildAction
ch = m Child -> m Child
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m a
mask_ (m Child -> m Child) -> m Child -> m Child
forall a b. (a -> b) -> a -> b
$ do
  Child
a <- IO Child -> m Child
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Child -> m Child) -> IO Child -> m Child
forall a b. (a -> b) -> a -> b
$ ChildAction -> IO Child
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m (Async a)
async ChildAction
ch
  STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ()) -> STM () -> m ()
forall a b. (a -> b) -> a -> b
$ TVar [Child] -> ([Child] -> [Child]) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar [Child]
state (Child
a Child -> [Child] -> [Child]
forall a. a -> [a] -> [a]
:)
  Child -> m Child
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Child
a

-- | Internal fuction to stop a child process.
stopChild :: (MonadUnliftIO m) => TVar [Child] -> Child -> m ()
stopChild :: forall (m :: * -> *).
MonadUnliftIO m =>
TVar [Child] -> Child -> m ()
stopChild TVar [Child]
state Child
a = m () -> m ()
forall (m :: * -> *) a. MonadUnliftIO m => m a -> m a
mask_ (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Bool
isChild <-
    STM Bool -> m Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Bool -> m Bool) -> STM Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
      [Child]
cur <- TVar [Child] -> STM [Child]
forall a. TVar a -> STM a
readTVar TVar [Child]
state
      let new :: [Child]
new = (Child -> Bool) -> [Child] -> [Child]
forall a. (a -> Bool) -> [a] -> [a]
filter (Child -> Child -> Bool
forall a. Eq a => a -> a -> Bool
/= Child
a) [Child]
cur
      TVar [Child] -> [Child] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [Child]
state [Child]
new
      Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Child]
cur [Child] -> [Child] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Child]
new)
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
isChild (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Child -> m ()
forall (m :: * -> *) a. MonadIO m => Async a -> m ()
cancel Child
a