{-# LANGUAGE UnboxedTuples #-}

-- |
-- Module      : Streamly.Internal.Control.Concurrent
-- Copyright   : (c) 2017 Composewell Technologies
-- License     : BSD-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC

module Streamly.Internal.Control.Concurrent
    (
      MonadAsync
    , RunInIO(..)
    , doFork
    , fork
    , forkManaged
    )
where

import Control.Concurrent (ThreadId, forkIO, killThread)
import Control.Exception (SomeException(..), catch, mask)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Control
       (MonadBaseControl, control, StM, liftBaseDiscard)
import Data.Functor (void)
import GHC.Conc (ThreadId(..))
import GHC.Exts
import GHC.IO (IO(..))
import System.Mem.Weak (addFinalizer)

-- /Since: 0.8.0 ("Streamly.Prelude")/
--
-- | A monad that can perform concurrent or parallel IO operations. Streams
-- that can be composed concurrently require the underlying monad to be
-- 'MonadAsync'.
--
-- /Since: 0.1.0 ("Streamly")/
--
-- @since 0.8.0
type MonadAsync m = (MonadIO m, MonadBaseControl IO m, MonadThrow m)

newtype RunInIO m = RunInIO { RunInIO m -> forall b. m b -> IO (StM m b)
runInIO :: forall b. m b -> IO (StM m b) }

-- Stolen from the async package. The perf improvement is modest, 2% on a
-- thread heavy benchmark (parallel composition using noop computations).
-- A version of forkIO that does not include the outer exception
-- handler: saves a bit of time when we will be installing our own
-- exception handler.
{-# INLINE rawForkIO #-}
rawForkIO :: IO () -> IO ThreadId
rawForkIO :: IO () -> IO ThreadId
rawForkIO IO ()
action = (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, ThreadId #))
 -> IO ThreadId)
-> (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \ State# RealWorld
s ->
   case IO () -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
fork# IO ()
action State# RealWorld
s of (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)

-- | Fork a thread to run the given computation, installing the provided
-- exception handler. Lifted to any monad with 'MonadBaseControl IO m'
-- capability.
--
-- TODO: the RunInIO argument can be removed, we can directly pass the action
-- as "mrun action" instead.
{-# INLINE doFork #-}
doFork :: MonadBaseControl IO m
    => m ()
    -> RunInIO m
    -> (SomeException -> IO ())
    -> m ThreadId
doFork :: m () -> RunInIO m -> (SomeException -> IO ()) -> m ThreadId
doFork m ()
action (RunInIO forall b. m b -> IO (StM m b)
mrun) SomeException -> IO ()
exHandler =
    ((forall b. m b -> IO (StM m b)) -> IO (StM m ThreadId))
-> m ThreadId
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
(RunInBase m b -> b (StM m a)) -> m a
control (((forall b. m b -> IO (StM m b)) -> IO (StM m ThreadId))
 -> m ThreadId)
-> ((forall b. m b -> IO (StM m b)) -> IO (StM m ThreadId))
-> m ThreadId
forall a b. (a -> b) -> a -> b
$ \forall b. m b -> IO (StM m b)
run ->
        ((forall a. IO a -> IO a) -> IO (StM m ThreadId))
-> IO (StM m ThreadId)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (StM m ThreadId))
 -> IO (StM m ThreadId))
-> ((forall a. IO a -> IO a) -> IO (StM m ThreadId))
-> IO (StM m ThreadId)
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
                ThreadId
tid <- IO () -> IO ThreadId
rawForkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (IO () -> IO ()
forall a. IO a -> IO a
restore (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (StM m ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (StM m ()) -> IO ()) -> IO (StM m ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ m () -> IO (StM m ())
forall b. m b -> IO (StM m b)
mrun m ()
action)
                                         SomeException -> IO ()
exHandler
                m ThreadId -> IO (StM m ThreadId)
forall b. m b -> IO (StM m b)
run (ThreadId -> m ThreadId
forall (m :: * -> *) a. Monad m => a -> m a
return ThreadId
tid)

-- | 'fork' lifted to any monad with 'MonadBaseControl IO m' capability.
--
{-# INLINABLE fork #-}
fork :: MonadBaseControl IO m => m () -> m ThreadId
fork :: m () -> m ThreadId
fork = (IO () -> IO ThreadId) -> m () -> m ThreadId
forall (b :: * -> *) (m :: * -> *) a.
MonadBaseControl b m =>
(b () -> b a) -> m () -> m a
liftBaseDiscard IO () -> IO ThreadId
forkIO

-- | Fork a thread that is automatically killed as soon as the reference to the
-- returned threadId is garbage collected.
--
{-# INLINABLE forkManaged #-}
forkManaged :: (MonadIO m, MonadBaseControl IO m) => m () -> m ThreadId
forkManaged :: m () -> m ThreadId
forkManaged m ()
action = do
    ThreadId
tid <- m () -> m ThreadId
forall (m :: * -> *). MonadBaseControl IO m => m () -> m ThreadId
fork m ()
action
    IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> IO () -> IO ()
forall key. key -> IO () -> IO ()
addFinalizer ThreadId
tid (ThreadId -> IO ()
killThread ThreadId
tid)
    ThreadId -> m ThreadId
forall (m :: * -> *) a. Monad m => a -> m a
return ThreadId
tid