{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeApplications #-}

module Ki.Scope
  ( Scope (..),
    cancel,
    scopeCancelledSTM,
    scopeFork,
    scoped,
    wait,
    waitFor,
    waitSTM,
    ScopeClosing (..),
    ThreadFailed (..),
  )
where

import Control.Exception
  ( Exception (fromException, toException),
    asyncExceptionFromException,
    asyncExceptionToException,
    pattern ErrorCall,
  )
import qualified Data.Monoid as Monoid
import qualified Data.Set as Set
import Ki.Context (Context)
import qualified Ki.Context as Context
import Ki.Duration (Duration)
import Ki.Prelude
import Ki.Timeout (timeoutSTM)

-- | A __scope__ delimits the lifetime of all __threads__ created within it.
data Scope = Scope
  { Scope -> Context
context :: Context,
    -- | Whether this scope is closed.
    -- Invariant: if closed, no threads are starting.
    Scope -> TVar Bool
closedVar :: TVar Bool,
    -- | The set of threads that are currently running.
    Scope -> TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId),
    -- | The number of threads that are *guaranteed* to be about to start, in the sense that only the GHC scheduler can
    -- continue to delay; no async exception can strike here and prevent one of these threads from starting.
    --
    -- If this number is non-zero, and that's problematic (e.g. because we're trying to cancel this scope), we always
    -- respect it and wait for it to drop to zero before proceeding.
    Scope -> TVar Int
startingVar :: TVar Int
  }

newScope :: Context -> IO Scope
newScope :: Context -> IO Scope
newScope Context
parentContext = do
  Context
context <- STM Context -> IO Context
forall a. STM a -> IO a
atomically (Context -> STM Context
Context.deriveContext Context
parentContext)
  TVar Bool
closedVar <- Bool -> IO (TVar Bool)
forall a. a -> IO (TVar a)
newTVarIO Bool
False
  TVar (Set ThreadId)
runningVar <- Set ThreadId -> IO (TVar (Set ThreadId))
forall a. a -> IO (TVar a)
newTVarIO Set ThreadId
forall a. Set a
Set.empty
  TVar Int
startingVar <- Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0
  Scope -> IO Scope
forall (f :: * -> *) a. Applicative f => a -> f a
pure Scope :: Context -> TVar Bool -> TVar (Set ThreadId) -> TVar Int -> Scope
Scope {Context
context :: Context
$sel:context:Scope :: Context
context, TVar Bool
closedVar :: TVar Bool
$sel:closedVar:Scope :: TVar Bool
closedVar, TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: TVar (Set ThreadId)
runningVar, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: TVar Int
startingVar}

-- | /Cancel/ all __contexts__ derived from a __scope__.
cancel :: Scope -> IO ()
cancel :: Scope -> IO ()
cancel Scope {Context
context :: Context
$sel:context:Scope :: Scope -> Context
context} =
  Context -> IO ()
Context.cancelContext Context
context

scopeCancelledSTM :: Scope -> STM (IO a)
scopeCancelledSTM :: Scope -> STM (IO a)
scopeCancelledSTM Scope {Context
context :: Context
$sel:context:Scope :: Scope -> Context
context} =
  CancelToken -> IO a
forall e a. Exception e => e -> IO a
throwIO (CancelToken -> IO a) -> STM CancelToken -> STM (IO a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> STM CancelToken
Context.contextCancelTokenSTM Context
context

-- | Close a scope, kill all of the running threads, and return the first async exception delivered to us while doing
-- so, if any.
--
-- Preconditions:
--   * The set of threads doesn't include us
--   * We're uninterruptibly masked
closeScope :: Scope -> IO (Maybe SomeException)
closeScope :: Scope -> IO (Maybe SomeException)
closeScope scope :: Scope
scope@Scope {TVar Bool
closedVar :: TVar Bool
$sel:closedVar:Scope :: Scope -> TVar Bool
closedVar, TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: Scope -> TVar (Set ThreadId)
runningVar} = do
  Set ThreadId
threads <-
    STM (Set ThreadId) -> IO (Set ThreadId)
forall a. STM a -> IO a
atomically do
      Scope -> STM ()
blockUntilNoneStarting Scope
scope
      TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
closedVar Bool
True
      TVar (Set ThreadId) -> STM (Set ThreadId)
forall a. TVar a -> STM a
readTVar TVar (Set ThreadId)
runningVar
  Maybe SomeException
exception <- [ThreadId] -> IO (Maybe SomeException)
killThreads (Set ThreadId -> [ThreadId]
forall a. Set a -> [a]
Set.toList Set ThreadId
threads)
  STM () -> IO ()
forall a. STM a -> IO a
atomically (Scope -> STM ()
blockUntilNoneRunning Scope
scope)
  Maybe SomeException -> IO (Maybe SomeException)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe SomeException
exception

scopeFork :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> (Either SomeException a -> IO ()) -> IO ThreadId
scopeFork :: Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (Either SomeException a -> IO ())
-> IO ThreadId
scopeFork Scope {TVar Bool
closedVar :: TVar Bool
$sel:closedVar:Scope :: Scope -> TVar Bool
closedVar, TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: Scope -> TVar (Set ThreadId)
runningVar, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} (forall x. IO x -> IO x) -> IO a
action Either SomeException a -> IO ()
k =
  ((forall x. IO x -> IO x) -> IO ThreadId) -> IO ThreadId
forall b. ((forall x. IO x -> IO x) -> IO b) -> IO b
uninterruptibleMask \forall x. IO x -> IO x
restore -> do
    -- Record the thread as being about to start
    STM () -> IO ()
forall a. STM a -> IO a
atomically do
      TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
closedVar STM Bool -> (Bool -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
startingVar (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        Bool
True -> ErrorCall -> STM ()
forall e a. Exception e => e -> STM a
throwSTM (String -> ErrorCall
ErrorCall String
"ki: scope closed")

    -- Fork the thread
    ThreadId
childThreadId <-
      IO () -> IO ThreadId
forkIO do
        ThreadId
childThreadId <- IO ThreadId
myThreadId
        Either SomeException a
result <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try ((forall x. IO x -> IO x) -> IO a
action forall x. IO x -> IO x
restore)
        Either SomeException a -> IO ()
k Either SomeException a
result
        STM () -> IO ()
forall a. STM a -> IO a
atomically do
          Set ThreadId
running <- TVar (Set ThreadId) -> STM (Set ThreadId)
forall a. TVar a -> STM a
readTVar TVar (Set ThreadId)
runningVar
          case ThreadId -> Set ThreadId -> (Set ThreadId, Bool, Set ThreadId)
forall a. Ord a => a -> Set a -> (Set a, Bool, Set a)
Set.splitMember ThreadId
childThreadId Set ThreadId
running of
            (Set ThreadId
xs, Bool
True, Set ThreadId
ys) -> TVar (Set ThreadId) -> Set ThreadId -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Set ThreadId)
runningVar (Set ThreadId -> STM ()) -> Set ThreadId -> STM ()
forall a b. (a -> b) -> a -> b
$! Set ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set ThreadId
xs Set ThreadId
ys
            (Set ThreadId, Bool, Set ThreadId)
_ -> STM ()
forall a. STM a
retry

    -- Record the thread as having started
    STM () -> IO ()
forall a. STM a -> IO a
atomically do
      TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
startingVar \Int
n -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1
      TVar (Set ThreadId) -> (Set ThreadId -> Set ThreadId) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Set ThreadId)
runningVar (ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => a -> Set a -> Set a
Set.insert ThreadId
childThreadId)

    ThreadId -> IO ThreadId
forall (f :: * -> *) a. Applicative f => a -> f a
pure ThreadId
childThreadId

scoped :: Context -> (Scope -> IO a) -> IO a
scoped :: Context -> (Scope -> IO a) -> IO a
scoped Context
context Scope -> IO a
f = do
  Scope
scope <- Context -> IO Scope
newScope Context
context
  ((forall x. IO x -> IO x) -> IO a) -> IO a
forall b. ((forall x. IO x -> IO x) -> IO b) -> IO b
uninterruptibleMask \forall x. IO x -> IO x
restore -> do
    Either SomeException a
result <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO a
forall x. IO x -> IO x
restore (Scope -> IO a
f Scope
scope))
    Maybe SomeException
closeScopeException <- Scope -> IO (Maybe SomeException)
closeScope Scope
scope
    -- If the callback failed, we don't care if we were thrown an async exception while closing the scope.
    -- Otherwise, throw that exception (if it exists).
    case Either SomeException a
result of
      Left SomeException
exception -> SomeException -> IO a
forall a. SomeException -> IO a
throw SomeException
exception
      Right a
value -> do
        Maybe SomeException -> (SomeException -> IO ()) -> IO ()
forall a. Maybe a -> (a -> IO ()) -> IO ()
whenJust Maybe SomeException
closeScopeException SomeException -> IO ()
forall a. SomeException -> IO a
throw
        a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value
  where
    -- If applicable, unwrap the 'ThreadFailed' (assumed to have come from one of our children).
    throw :: SomeException -> IO a
    throw :: SomeException -> IO a
throw SomeException
exception =
      case SomeException -> Maybe ThreadFailed
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
        Just (ThreadFailed SomeException
threadFailedException) -> SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
threadFailedException
        Maybe ThreadFailed
Nothing -> SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
exception

-- | Wait until all __threads__ created within a __scope__ finish.
wait :: Scope -> IO ()
wait :: Scope -> IO ()
wait =
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (Scope -> STM ()) -> Scope -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> STM ()
waitSTM

-- | Variant of 'wait' that waits for up to the given duration.
waitFor :: Scope -> Duration -> IO ()
waitFor :: Scope -> Duration -> IO ()
waitFor Scope
scope Duration
duration =
  Duration -> STM (IO ()) -> IO () -> IO ()
forall a. Duration -> STM (IO a) -> IO a -> IO a
timeoutSTM Duration
duration (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> IO ()) -> STM () -> STM (IO ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope -> STM ()
waitSTM Scope
scope) (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- | @STM@ variant of 'wait'.
waitSTM :: Scope -> STM ()
waitSTM :: Scope -> STM ()
waitSTM Scope
scope = do
  Scope -> STM ()
blockUntilNoneRunning Scope
scope
  Scope -> STM ()
blockUntilNoneStarting Scope
scope

--------------------------------------------------------------------------------
-- Scope helpers

blockUntilNoneRunning :: Scope -> STM ()
blockUntilNoneRunning :: Scope -> STM ()
blockUntilNoneRunning Scope {TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: Scope -> TVar (Set ThreadId)
runningVar} =
  TVar (Set ThreadId) -> (Set ThreadId -> Bool) -> STM ()
forall a. TVar a -> (a -> Bool) -> STM ()
blockUntilTVar TVar (Set ThreadId)
runningVar Set ThreadId -> Bool
forall a. Set a -> Bool
Set.null

blockUntilNoneStarting :: Scope -> STM ()
blockUntilNoneStarting :: Scope -> STM ()
blockUntilNoneStarting Scope {TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} =
  TVar Int -> (Int -> Bool) -> STM ()
forall a. TVar a -> (a -> Bool) -> STM ()
blockUntilTVar TVar Int
startingVar (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)

--------------------------------------------------------------------------------
-- Internal exception types

-- | Exception thrown by a parent __thread__ to its children when the __scope__ is closing.
data ScopeClosing
  = ScopeClosing
  deriving stock (ScopeClosing -> ScopeClosing -> Bool
(ScopeClosing -> ScopeClosing -> Bool)
-> (ScopeClosing -> ScopeClosing -> Bool) -> Eq ScopeClosing
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScopeClosing -> ScopeClosing -> Bool
$c/= :: ScopeClosing -> ScopeClosing -> Bool
== :: ScopeClosing -> ScopeClosing -> Bool
$c== :: ScopeClosing -> ScopeClosing -> Bool
Eq, Int -> ScopeClosing -> ShowS
[ScopeClosing] -> ShowS
ScopeClosing -> String
(Int -> ScopeClosing -> ShowS)
-> (ScopeClosing -> String)
-> ([ScopeClosing] -> ShowS)
-> Show ScopeClosing
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScopeClosing] -> ShowS
$cshowList :: [ScopeClosing] -> ShowS
show :: ScopeClosing -> String
$cshow :: ScopeClosing -> String
showsPrec :: Int -> ScopeClosing -> ShowS
$cshowsPrec :: Int -> ScopeClosing -> ShowS
Show)

instance Exception ScopeClosing where
  toException :: ScopeClosing -> SomeException
toException = ScopeClosing -> SomeException
forall e. Exception e => e -> SomeException
asyncExceptionToException
  fromException :: SomeException -> Maybe ScopeClosing
fromException = SomeException -> Maybe ScopeClosing
forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException

-- | Exception thrown by a child __thread__ to its parent, if it fails unexpectedly.
newtype ThreadFailed
  = ThreadFailed SomeException
  deriving stock (Int -> ThreadFailed -> ShowS
[ThreadFailed] -> ShowS
ThreadFailed -> String
(Int -> ThreadFailed -> ShowS)
-> (ThreadFailed -> String)
-> ([ThreadFailed] -> ShowS)
-> Show ThreadFailed
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ThreadFailed] -> ShowS
$cshowList :: [ThreadFailed] -> ShowS
show :: ThreadFailed -> String
$cshow :: ThreadFailed -> String
showsPrec :: Int -> ThreadFailed -> ShowS
$cshowsPrec :: Int -> ThreadFailed -> ShowS
Show)

instance Exception ThreadFailed where
  toException :: ThreadFailed -> SomeException
toException = ThreadFailed -> SomeException
forall e. Exception e => e -> SomeException
asyncExceptionToException
  fromException :: SomeException -> Maybe ThreadFailed
fromException = SomeException -> Maybe ThreadFailed
forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException

--------------------------------------------------------------------------------
-- Misc. utils

killThreads :: [ThreadId] -> IO (Maybe SomeException)
killThreads :: [ThreadId] -> IO (Maybe SomeException)
killThreads =
  (((First SomeException -> [ThreadId] -> IO (Maybe SomeException))
 -> First SomeException -> [ThreadId] -> IO (Maybe SomeException))
-> First SomeException -> [ThreadId] -> IO (Maybe SomeException)
forall a. (a -> a) -> a
`fix` First SomeException
forall a. Monoid a => a
mempty) \First SomeException -> [ThreadId] -> IO (Maybe SomeException)
loop First SomeException
acc -> \case
    [] -> Maybe SomeException -> IO (Maybe SomeException)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (First SomeException -> Maybe SomeException
forall a. First a -> Maybe a
Monoid.getFirst First SomeException
acc)
    ThreadId
threadId : [ThreadId]
threadIds ->
      -- We unmask because we don't want to deadlock with a thread
      -- that is concurrently trying to throw an exception to us with
      -- exceptions masked.
      IO () -> IO (Either SomeException ())
forall e a. Exception e => IO a -> IO (Either e a)
try (IO () -> IO ()
forall x. IO x -> IO x
unsafeUnmask (ThreadId -> ScopeClosing -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
threadId ScopeClosing
ScopeClosing)) IO (Either SomeException ())
-> (Either SomeException () -> IO (Maybe SomeException))
-> IO (Maybe SomeException)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        -- don't drop thread we didn't (necessarily) deliver the exception to
        Left SomeException
exception -> First SomeException -> [ThreadId] -> IO (Maybe SomeException)
loop (First SomeException
acc First SomeException -> First SomeException -> First SomeException
forall a. Semigroup a => a -> a -> a
<> Maybe SomeException -> First SomeException
forall a. Maybe a -> First a
Monoid.First (SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
exception)) (ThreadId
threadId ThreadId -> [ThreadId] -> [ThreadId]
forall a. a -> [a] -> [a]
: [ThreadId]
threadIds)
        Right () -> First SomeException -> [ThreadId] -> IO (Maybe SomeException)
loop First SomeException
acc [ThreadId]
threadIds

blockUntilTVar :: TVar a -> (a -> Bool) -> STM ()
blockUntilTVar :: TVar a -> (a -> Bool) -> STM ()
blockUntilTVar TVar a
var a -> Bool
f = do
  a
value <- TVar a -> STM a
forall a. TVar a -> STM a
readTVar TVar a
var
  Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (a -> Bool
f a
value) STM ()
forall a. STM a
retry