{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeApplications #-}

module Ki.Scope
  ( Scope (..),
    cancelScope,
    scopeCancelledSTM,
    scopeFork,
    scoped,
    wait,
    waitFor,
    waitSTM,
  )
where

import Control.Exception (fromException, pattern ErrorCall)
import qualified Data.Monoid as Monoid
import qualified Data.Set as Set
import Ki.Context (Context)
import qualified Ki.Context as Context
import Ki.Duration (Duration)
import Ki.Prelude
import Ki.ScopeClosing (ScopeClosing (..))
import Ki.ThreadFailed (ThreadFailedAsync (..))
import Ki.Timeout (timeoutSTM)

-- | A __scope__ delimits the lifetime of all __threads__ created within it.
data Scope = Scope
  { Scope -> Context
context :: Context,
    -- | Whether this scope is closed.
    -- Invariant: if closed, no threads are starting.
    Scope -> TVar Bool
closedVar :: TVar Bool,
    -- | The set of threads that are currently running.
    Scope -> TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId),
    -- | The number of threads that are *guaranteed* to be about to start, in the sense that only the GHC scheduler can
    -- continue to delay; no async exception can strike here and prevent one of these threads from starting.
    --
    -- If this number is non-zero, and that's problematic (e.g. because we're trying to cancel this scope), we always
    -- respect it and wait for it to drop to zero before proceeding.
    Scope -> TVar Int
startingVar :: TVar Int
  }

newScope :: Context -> IO Scope
newScope :: Context -> IO Scope
newScope Context
parentContext = do
  Context
context <- STM Context -> IO Context
forall a. STM a -> IO a
atomically (Context -> STM Context
Context.deriveContext Context
parentContext)
  TVar Bool
closedVar <- Bool -> IO (TVar Bool)
forall a. a -> IO (TVar a)
newTVarIO Bool
False
  TVar (Set ThreadId)
runningVar <- Set ThreadId -> IO (TVar (Set ThreadId))
forall a. a -> IO (TVar a)
newTVarIO Set ThreadId
forall a. Set a
Set.empty
  TVar Int
startingVar <- Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0
  Scope -> IO Scope
forall (f :: * -> *) a. Applicative f => a -> f a
pure Scope :: Context -> TVar Bool -> TVar (Set ThreadId) -> TVar Int -> Scope
Scope {Context
context :: Context
$sel:context:Scope :: Context
context, TVar Bool
closedVar :: TVar Bool
$sel:closedVar:Scope :: TVar Bool
closedVar, TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: TVar (Set ThreadId)
runningVar, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: TVar Int
startingVar}

-- | /Cancel/ all __contexts__ derived from a __scope__.
cancelScope :: Scope -> IO ()
cancelScope :: Scope -> IO ()
cancelScope Scope {Context
context :: Context
$sel:context:Scope :: Scope -> Context
context} =
  Context -> IO ()
Context.cancelContext Context
context

scopeCancelledSTM :: Scope -> STM (IO a)
scopeCancelledSTM :: Scope -> STM (IO a)
scopeCancelledSTM Scope {Context
context :: Context
$sel:context:Scope :: Scope -> Context
context} =
  CancelToken -> IO a
forall e a. Exception e => e -> IO a
throwIO (CancelToken -> IO a) -> STM CancelToken -> STM (IO a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> STM CancelToken
Context.contextCancelTokenSTM Context
context

-- | Close a scope, kill all of the running threads, and return the first async exception delivered to us while doing
-- so, if any.
--
-- Preconditions:
--   * The set of threads doesn't include us
--   * We're uninterruptibly masked
closeScope :: Scope -> IO (Maybe SomeException)
closeScope :: Scope -> IO (Maybe SomeException)
closeScope scope :: Scope
scope@Scope {TVar Bool
closedVar :: TVar Bool
$sel:closedVar:Scope :: Scope -> TVar Bool
closedVar, TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: Scope -> TVar (Set ThreadId)
runningVar} = do
  Set ThreadId
threads <-
    STM (Set ThreadId) -> IO (Set ThreadId)
forall a. STM a -> IO a
atomically do
      Scope -> STM ()
blockUntilNoneStarting Scope
scope
      TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
closedVar Bool
True
      TVar (Set ThreadId) -> STM (Set ThreadId)
forall a. TVar a -> STM a
readTVar TVar (Set ThreadId)
runningVar
  Maybe SomeException
exception <- [ThreadId] -> IO (Maybe SomeException)
killThreads (Set ThreadId -> [ThreadId]
forall a. Set a -> [a]
Set.toList Set ThreadId
threads)
  STM () -> IO ()
forall a. STM a -> IO a
atomically (Scope -> STM ()
blockUntilNoneRunning Scope
scope)
  Maybe SomeException -> IO (Maybe SomeException)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe SomeException
exception

scopeFork :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> (ThreadId -> Either SomeException a -> IO ()) -> IO ThreadId
scopeFork :: Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (ThreadId -> Either SomeException a -> IO ())
-> IO ThreadId
scopeFork Scope {TVar Bool
closedVar :: TVar Bool
$sel:closedVar:Scope :: Scope -> TVar Bool
closedVar, TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: Scope -> TVar (Set ThreadId)
runningVar, TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} (forall x. IO x -> IO x) -> IO a
action ThreadId -> Either SomeException a -> IO ()
k =
  ((forall x. IO x -> IO x) -> IO ThreadId) -> IO ThreadId
forall b. ((forall x. IO x -> IO x) -> IO b) -> IO b
uninterruptibleMask \forall x. IO x -> IO x
restore -> do
    -- Record the thread as being about to start
    STM () -> IO ()
forall a. STM a -> IO a
atomically do
      TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
closedVar STM Bool -> (Bool -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Bool
False -> TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
startingVar (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        Bool
True -> ErrorCall -> STM ()
forall e a. Exception e => e -> STM a
throwSTM (String -> ErrorCall
ErrorCall String
"ki: scope closed")

    -- Fork the thread
    ThreadId
childThreadId <-
      IO () -> IO ThreadId
forkIO do
        ThreadId
childThreadId <- IO ThreadId
myThreadId
        Either SomeException a
result <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try ((forall x. IO x -> IO x) -> IO a
action forall x. IO x -> IO x
restore)
        ThreadId -> Either SomeException a -> IO ()
k ThreadId
childThreadId Either SomeException a
result
        STM () -> IO ()
forall a. STM a -> IO a
atomically do
          Set ThreadId
running <- TVar (Set ThreadId) -> STM (Set ThreadId)
forall a. TVar a -> STM a
readTVar TVar (Set ThreadId)
runningVar
          case ThreadId -> Set ThreadId -> (Set ThreadId, Bool, Set ThreadId)
forall a. Ord a => a -> Set a -> (Set a, Bool, Set a)
Set.splitMember ThreadId
childThreadId Set ThreadId
running of
            (Set ThreadId
xs, Bool
True, Set ThreadId
ys) -> TVar (Set ThreadId) -> Set ThreadId -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Set ThreadId)
runningVar (Set ThreadId -> STM ()) -> Set ThreadId -> STM ()
forall a b. (a -> b) -> a -> b
$! Set ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set ThreadId
xs Set ThreadId
ys
            (Set ThreadId, Bool, Set ThreadId)
_ -> STM ()
forall a. STM a
retry

    -- Record the thread as having started
    STM () -> IO ()
forall a. STM a -> IO a
atomically do
      TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
startingVar \Int
n -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1
      TVar (Set ThreadId) -> (Set ThreadId -> Set ThreadId) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Set ThreadId)
runningVar (ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => a -> Set a -> Set a
Set.insert ThreadId
childThreadId)

    ThreadId -> IO ThreadId
forall (f :: * -> *) a. Applicative f => a -> f a
pure ThreadId
childThreadId

scoped :: Context -> (Scope -> IO a) -> IO a
scoped :: Context -> (Scope -> IO a) -> IO a
scoped Context
context Scope -> IO a
f = do
  Scope
scope <- Context -> IO Scope
newScope Context
context
  ((forall x. IO x -> IO x) -> IO a) -> IO a
forall b. ((forall x. IO x -> IO x) -> IO b) -> IO b
uninterruptibleMask \forall x. IO x -> IO x
restore -> do
    Either SomeException a
result <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO a -> IO a
forall x. IO x -> IO x
restore (Scope -> IO a
f Scope
scope))
    Maybe SomeException
closeScopeException <- Scope -> IO (Maybe SomeException)
closeScope Scope
scope
    -- If the callback failed, we don't care if we were thrown an async exception while closing the scope.
    -- Otherwise, throw that exception (if it exists).
    case Either SomeException a
result of
      Left SomeException
exception -> SomeException -> IO a
forall a. SomeException -> IO a
throw SomeException
exception
      Right a
value -> do
        Maybe SomeException -> (SomeException -> IO ()) -> IO ()
forall a. Maybe a -> (a -> IO ()) -> IO ()
whenJust Maybe SomeException
closeScopeException SomeException -> IO ()
forall a. SomeException -> IO a
throw
        a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value
  where
    -- If applicable, unwrap the 'AsyncThreadFailed' (assumed to have come from one of our children).
    throw :: SomeException -> IO a
    throw :: SomeException -> IO a
throw SomeException
exception =
      case SomeException -> Maybe ThreadFailedAsync
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
        Just (ThreadFailedAsync ThreadFailed
threadFailedException) -> ThreadFailed -> IO a
forall e a. Exception e => e -> IO a
throwIO ThreadFailed
threadFailedException
        Maybe ThreadFailedAsync
Nothing -> SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
exception

-- | Wait until all __threads__ created within a __scope__ finish.
wait :: Scope -> IO ()
wait :: Scope -> IO ()
wait =
  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> (Scope -> STM ()) -> Scope -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope -> STM ()
waitSTM

-- | Variant of 'wait' that waits for up to the given duration.
waitFor :: Scope -> Duration -> IO ()
waitFor :: Scope -> Duration -> IO ()
waitFor Scope
scope Duration
duration =
  Duration -> STM (IO ()) -> IO () -> IO ()
forall a. Duration -> STM (IO a) -> IO a -> IO a
timeoutSTM Duration
duration (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> IO ()) -> STM () -> STM (IO ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope -> STM ()
waitSTM Scope
scope) (() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- | @STM@ variant of 'wait'.
waitSTM :: Scope -> STM ()
waitSTM :: Scope -> STM ()
waitSTM Scope
scope = do
  Scope -> STM ()
blockUntilNoneRunning Scope
scope
  Scope -> STM ()
blockUntilNoneStarting Scope
scope

--------------------------------------------------------------------------------
-- Scope helpers

blockUntilNoneRunning :: Scope -> STM ()
blockUntilNoneRunning :: Scope -> STM ()
blockUntilNoneRunning Scope {TVar (Set ThreadId)
runningVar :: TVar (Set ThreadId)
$sel:runningVar:Scope :: Scope -> TVar (Set ThreadId)
runningVar} =
  TVar (Set ThreadId) -> (Set ThreadId -> Bool) -> STM ()
forall a. TVar a -> (a -> Bool) -> STM ()
blockUntilTVar TVar (Set ThreadId)
runningVar Set ThreadId -> Bool
forall a. Set a -> Bool
Set.null

blockUntilNoneStarting :: Scope -> STM ()
blockUntilNoneStarting :: Scope -> STM ()
blockUntilNoneStarting Scope {TVar Int
startingVar :: TVar Int
$sel:startingVar:Scope :: Scope -> TVar Int
startingVar} =
  TVar Int -> (Int -> Bool) -> STM ()
forall a. TVar a -> (a -> Bool) -> STM ()
blockUntilTVar TVar Int
startingVar (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)

--------------------------------------------------------------------------------
-- Misc. utils

killThreads :: [ThreadId] -> IO (Maybe SomeException)
killThreads :: [ThreadId] -> IO (Maybe SomeException)
killThreads =
  (((First SomeException -> [ThreadId] -> IO (Maybe SomeException))
 -> First SomeException -> [ThreadId] -> IO (Maybe SomeException))
-> First SomeException -> [ThreadId] -> IO (Maybe SomeException)
forall a. (a -> a) -> a
`fix` First SomeException
forall a. Monoid a => a
mempty) \First SomeException -> [ThreadId] -> IO (Maybe SomeException)
loop First SomeException
acc -> \case
    [] -> Maybe SomeException -> IO (Maybe SomeException)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (First SomeException -> Maybe SomeException
forall a. First a -> Maybe a
Monoid.getFirst First SomeException
acc)
    ThreadId
threadId : [ThreadId]
threadIds ->
      -- We unmask because we don't want to deadlock with a thread
      -- that is concurrently trying to throw an exception to us with
      -- exceptions masked.
      IO () -> IO (Either SomeException ())
forall e a. Exception e => IO a -> IO (Either e a)
try (IO () -> IO ()
forall x. IO x -> IO x
unsafeUnmask (ThreadId -> ScopeClosing -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
threadId ScopeClosing
ScopeClosing)) IO (Either SomeException ())
-> (Either SomeException () -> IO (Maybe SomeException))
-> IO (Maybe SomeException)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        -- don't drop thread we didn't (necessarily) deliver the exception to
        Left SomeException
exception -> First SomeException -> [ThreadId] -> IO (Maybe SomeException)
loop (First SomeException
acc First SomeException -> First SomeException -> First SomeException
forall a. Semigroup a => a -> a -> a
<> Maybe SomeException -> First SomeException
forall a. Maybe a -> First a
Monoid.First (SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
exception)) (ThreadId
threadId ThreadId -> [ThreadId] -> [ThreadId]
forall a. a -> [a] -> [a]
: [ThreadId]
threadIds)
        Right () -> First SomeException -> [ThreadId] -> IO (Maybe SomeException)
loop First SomeException
acc [ThreadId]
threadIds

blockUntilTVar :: TVar a -> (a -> Bool) -> STM ()
blockUntilTVar :: TVar a -> (a -> Bool) -> STM ()
blockUntilTVar TVar a
var a -> Bool
f = do
  a
value <- TVar a -> STM a
forall a. TVar a -> STM a
readTVar TVar a
var
  Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (a -> Bool
f a
value) STM ()
forall a. STM a
retry