module Ki.Thread
  ( Thread,
    async,
    asyncWithUnmask,
    await,
    awaitSTM,
    awaitFor,
    fork,
    fork_,
    forkWithUnmask,
    forkWithUnmask_,
  )
where

import Control.Exception (Exception (fromException))
import Data.Bifunctor (first)
import qualified Ki.Context as Context
import Ki.Duration (Duration)
import Ki.Prelude
import Ki.Scope (Scope (Scope))
import qualified Ki.Scope as Scope
import Ki.ScopeClosing (ScopeClosing (ScopeClosing))
import Ki.ThreadFailed (ThreadFailed (ThreadFailed), ThreadFailedAsync (ThreadFailedAsync))
import Ki.Timeout (timeoutSTM)

-- | A running __thread__.
data Thread a = Thread
  { Thread a -> ThreadId
threadId :: !ThreadId,
    Thread a -> STM a
action :: !(STM a)
  }
  deriving stock (a -> Thread b -> Thread a
(a -> b) -> Thread a -> Thread b
(forall a b. (a -> b) -> Thread a -> Thread b)
-> (forall a b. a -> Thread b -> Thread a) -> Functor Thread
forall a b. a -> Thread b -> Thread a
forall a b. (a -> b) -> Thread a -> Thread b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Thread b -> Thread a
$c<$ :: forall a b. a -> Thread b -> Thread a
fmap :: (a -> b) -> Thread a -> Thread b
$cfmap :: forall a b. (a -> b) -> Thread a -> Thread b
Functor, (forall x. Thread a -> Rep (Thread a) x)
-> (forall x. Rep (Thread a) x -> Thread a) -> Generic (Thread a)
forall x. Rep (Thread a) x -> Thread a
forall x. Thread a -> Rep (Thread a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Thread a) x -> Thread a
forall a x. Thread a -> Rep (Thread a) x
$cto :: forall a x. Rep (Thread a) x -> Thread a
$cfrom :: forall a x. Thread a -> Rep (Thread a) x
Generic)

instance Eq (Thread a) where
  Thread ThreadId
id1 STM a
_ == :: Thread a -> Thread a -> Bool
== Thread ThreadId
id2 STM a
_ =
    ThreadId
id1 ThreadId -> ThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== ThreadId
id2

instance Ord (Thread a) where
  compare :: Thread a -> Thread a -> Ordering
compare (Thread ThreadId
id1 STM a
_) (Thread ThreadId
id2 STM a
_) =
    ThreadId -> ThreadId -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ThreadId
id1 ThreadId
id2

-- | Create a __thread__ within a __scope__ to compute a value concurrently.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
async :: Scope -> IO a -> IO (Thread (Either ThreadFailed a))
async :: Scope -> IO a -> IO (Thread (Either ThreadFailed a))
async Scope
scope IO a
action =
  Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either ThreadFailed a))
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either ThreadFailed a))
asyncWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore IO a
action

-- | Variant of 'async' that provides the __thread__ a function that unmasks asynchronous exceptions.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
asyncWithUnmask :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread (Either ThreadFailed a))
asyncWithUnmask :: Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either ThreadFailed a))
asyncWithUnmask Scope
scope (forall x. IO x -> IO x) -> IO a
action =
  Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either ThreadFailed a))
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either ThreadFailed a))
asyncWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore ((forall x. IO x -> IO x) -> IO a
action forall x. IO x -> IO x
unsafeUnmask)

asyncWithRestore :: forall a. Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread (Either ThreadFailed a))
asyncWithRestore :: Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> IO (Thread (Either ThreadFailed a))
asyncWithRestore Scope
scope (forall x. IO x -> IO x) -> IO a
action = do
  TMVar (Either ThreadFailed a)
resultVar <- IO (TMVar (Either ThreadFailed a))
forall a. IO (TMVar a)
newEmptyTMVarIO
  ThreadId
childThreadId <-
    Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (ThreadId -> Either SomeException a -> IO ())
-> IO ThreadId
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (ThreadId -> Either SomeException a -> IO ())
-> IO ThreadId
Scope.scopeFork Scope
scope (forall x. IO x -> IO x) -> IO a
action \ThreadId
childThreadId Either SomeException a
result ->
      TMVar (Either ThreadFailed a) -> Either ThreadFailed a -> IO ()
forall a. TMVar a -> a -> IO ()
putTMVarIO TMVar (Either ThreadFailed a)
resultVar ((SomeException -> ThreadFailed)
-> Either SomeException a -> Either ThreadFailed a
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ThreadId -> SomeException -> ThreadFailed
ThreadFailed ThreadId
childThreadId) Either SomeException a
result)
  Thread (Either ThreadFailed a)
-> IO (Thread (Either ThreadFailed a))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ThreadId
-> STM (Either ThreadFailed a) -> Thread (Either ThreadFailed a)
forall a. ThreadId -> STM a -> Thread a
Thread ThreadId
childThreadId (TMVar (Either ThreadFailed a) -> STM (Either ThreadFailed a)
forall a. TMVar a -> STM a
readTMVar TMVar (Either ThreadFailed a)
resultVar))

await :: Thread a -> IO a
await :: Thread a -> IO a
await =
  STM a -> IO a
forall a. STM a -> IO a
atomically (STM a -> IO a) -> (Thread a -> STM a) -> Thread a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Thread a -> STM a
forall a. Thread a -> STM a
awaitSTM

-- | @STM@ variant of 'await'.
awaitSTM :: Thread a -> STM a
awaitSTM :: Thread a -> STM a
awaitSTM Thread {STM a
action :: STM a
$sel:action:Thread :: forall a. Thread a -> STM a
action} =
  STM a
action

-- | Variant of 'await' that gives up after the given duration.
--
-- @
-- 'awaitFor' thread duration =
--   'timeout' duration (pure . Just \<$\> 'awaitSTM' thread) (pure Nothing)
-- @
awaitFor :: Thread a -> Duration -> IO (Maybe a)
awaitFor :: Thread a -> Duration -> IO (Maybe a)
awaitFor Thread a
thread Duration
duration =
  Duration -> STM (IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall a. Duration -> STM (IO a) -> IO a -> IO a
timeoutSTM Duration
duration (Maybe a -> IO (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a -> IO (Maybe a)) -> (a -> Maybe a) -> a -> IO (Maybe a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Maybe a
forall a. a -> Maybe a
Just (a -> IO (Maybe a)) -> STM a -> STM (IO (Maybe a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Thread a -> STM a
forall a. Thread a -> STM a
awaitSTM Thread a
thread) (Maybe a -> IO (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing)

-- | Create a __thread__ within a __scope__ to compute a value concurrently.
--
-- If the __thread__ throws an exception, the exception is wrapped in 'ThreadFailed' and immediately propagated up the
-- call tree to the __thread__ that opened its __scope__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
fork :: Scope -> IO a -> IO (Thread a)
fork :: Scope -> IO a -> IO (Thread a)
fork Scope
scope IO a
action =
  Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forall a.
Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore IO a
action

-- | Variant of 'fork' that does not return a handle to the created __thread__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
fork_ :: Scope -> IO () -> IO ()
fork_ :: Scope -> IO () -> IO ()
fork_ Scope
scope IO ()
action =
  Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ Scope
scope \forall x. IO x -> IO x
restore -> IO () -> IO ()
forall x. IO x -> IO x
restore IO ()
action

-- | Variant of 'fork' that provides the __thread__ a function that unmasks asynchronous exceptions.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
forkWithUnmask :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithUnmask :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithUnmask Scope
scope (forall x. IO x -> IO x) -> IO a
action =
  Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forall a.
Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore Scope
scope \forall x. IO x -> IO x
restore -> IO a -> IO a
forall x. IO x -> IO x
restore ((forall x. IO x -> IO x) -> IO a
action forall x. IO x -> IO x
unsafeUnmask)

-- | Variant of 'forkWithUnmask' that does not return a handle to the created __thread__.
--
-- /Throws/:
--
--   * Calls 'error' if the __scope__ is /closed/.
forkWithUnmask_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithUnmask_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithUnmask_ Scope
scope (forall x. IO x -> IO x) -> IO ()
action =
  Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ Scope
scope \forall x. IO x -> IO x
restore -> IO () -> IO ()
forall x. IO x -> IO x
restore ((forall x. IO x -> IO x) -> IO ()
action forall x. IO x -> IO x
unsafeUnmask)

forkWithRestore :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore :: Scope -> ((forall x. IO x -> IO x) -> IO a) -> IO (Thread a)
forkWithRestore Scope
scope (forall x. IO x -> IO x) -> IO a
action = do
  ThreadId
parentThreadId <- IO ThreadId
myThreadId
  TMVar (Either ThreadFailed a)
resultVar <- IO (TMVar (Either ThreadFailed a))
forall a. IO (TMVar a)
newEmptyTMVarIO
  ThreadId
childThreadId <-
    Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (ThreadId -> Either SomeException a -> IO ())
-> IO ThreadId
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (ThreadId -> Either SomeException a -> IO ())
-> IO ThreadId
Scope.scopeFork Scope
scope (forall x. IO x -> IO x) -> IO a
action \ThreadId
childThreadId -> \case
      Left SomeException
exception -> do
        IO Bool -> IO () -> IO ()
whenM
          (Scope -> SomeException -> IO Bool
shouldPropagateException Scope
scope SomeException
exception)
          (ThreadId -> ThreadFailedAsync -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parentThreadId (ThreadFailed -> ThreadFailedAsync
ThreadFailedAsync ThreadFailed
threadFailedException))
        TMVar (Either ThreadFailed a) -> Either ThreadFailed a -> IO ()
forall a. TMVar a -> a -> IO ()
putTMVarIO TMVar (Either ThreadFailed a)
resultVar (ThreadFailed -> Either ThreadFailed a
forall a b. a -> Either a b
Left ThreadFailed
threadFailedException)
        where
          threadFailedException :: ThreadFailed
          threadFailedException :: ThreadFailed
threadFailedException =
            ThreadId -> SomeException -> ThreadFailed
ThreadFailed ThreadId
childThreadId SomeException
exception
      Right a
result -> TMVar (Either ThreadFailed a) -> Either ThreadFailed a -> IO ()
forall a. TMVar a -> a -> IO ()
putTMVarIO TMVar (Either ThreadFailed a)
resultVar (a -> Either ThreadFailed a
forall a b. b -> Either a b
Right a
result)
  Thread a -> IO (Thread a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ThreadId -> STM a -> Thread a
forall a. ThreadId -> STM a -> Thread a
Thread ThreadId
childThreadId (TMVar (Either ThreadFailed a) -> STM (Either ThreadFailed a)
forall a. TMVar a -> STM a
readTMVar TMVar (Either ThreadFailed a)
resultVar STM (Either ThreadFailed a)
-> (Either ThreadFailed a -> STM a) -> STM a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ThreadFailed -> STM a)
-> (a -> STM a) -> Either ThreadFailed a -> STM a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ThreadFailed -> STM a
forall e a. Exception e => e -> STM a
throwSTM a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure))

forkWithRestore_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ :: Scope -> ((forall x. IO x -> IO x) -> IO ()) -> IO ()
forkWithRestore_ Scope
scope (forall x. IO x -> IO x) -> IO ()
action = do
  ThreadId
parentThreadId <- IO ThreadId
myThreadId
  ThreadId
_childThreadId <-
    Scope
-> ((forall x. IO x -> IO x) -> IO ())
-> (ThreadId -> Either SomeException () -> IO ())
-> IO ThreadId
forall a.
Scope
-> ((forall x. IO x -> IO x) -> IO a)
-> (ThreadId -> Either SomeException a -> IO ())
-> IO ThreadId
Scope.scopeFork Scope
scope (forall x. IO x -> IO x) -> IO ()
action \ThreadId
childThreadId ->
      (SomeException -> IO ()) -> Either SomeException () -> IO ()
forall a b. (a -> IO b) -> Either a b -> IO b
onLeft \SomeException
exception -> do
        IO Bool -> IO () -> IO ()
whenM
          (Scope -> SomeException -> IO Bool
shouldPropagateException Scope
scope SomeException
exception)
          (ThreadId -> ThreadFailedAsync -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
parentThreadId (ThreadFailed -> ThreadFailedAsync
ThreadFailedAsync (ThreadId -> SomeException -> ThreadFailed
ThreadFailed ThreadId
childThreadId SomeException
exception)))
  () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

shouldPropagateException :: Scope -> SomeException -> IO Bool
shouldPropagateException :: Scope -> SomeException -> IO Bool
shouldPropagateException Scope {TVar Bool
$sel:closedVar:Scope :: Scope -> TVar Bool
closedVar :: TVar Bool
closedVar, Context
$sel:context:Scope :: Scope -> Context
context :: Context
context} SomeException
exception =
  case SomeException -> Maybe ScopeClosing
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
    -- Our scope is (presumably) closing, so don't propagate this exception that presumably just came from our parent.
    -- But if our scope's closedVar isn't True, that means this 'ScopeClosing' definitely came from somewhere else...
    Just ScopeClosing
ScopeClosing -> Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar Bool -> IO Bool
forall a. TVar a -> IO a
readTVarIO TVar Bool
closedVar
    Maybe ScopeClosing
Nothing ->
      case SomeException -> Maybe CancelToken
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exception of
        -- We (presumably) are honoring our own cancellation request, so don't propagate that either.
        -- It's a bit complicated looking because we *do* want to throw this token if we (somehow) threw it
        -- "inappropriately" in the sense that it wasn't ours to throw - it was smuggled from elsewhere.
        Just CancelToken
token -> STM Bool -> IO Bool
forall a. STM a -> IO a
atomically ((CancelToken -> CancelToken -> Bool
forall a. Eq a => a -> a -> Bool
/= CancelToken
token) (CancelToken -> Bool) -> STM CancelToken -> STM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> STM CancelToken
Context.contextCancelTokenSTM Context
context STM Bool -> STM Bool -> STM Bool
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Bool -> STM Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True)
        Maybe CancelToken
Nothing -> Bool -> IO Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True