{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

-- | A generalisation of
-- <https://hackage.haskell.org/package/base/docs/Control-Concurrent.html Control.Concurrent>
-- API to both 'IO' and <https://hackage.haskell.org/package/io-sim IOSim>.
--
module Control.Monad.Class.MonadFork
  ( MonadThread (..)
  , labelThisThread
  , MonadFork (..)
  ) where

import Control.Concurrent qualified as IO
import Control.Exception (AsyncException (ThreadKilled), Exception,
           SomeException)
import Control.Monad.Reader (ReaderT (..), lift)
import Data.Kind (Type)
import GHC.Conc.Sync qualified as IO (labelThread)


class (Monad m, Eq   (ThreadId m),
                Ord  (ThreadId m),
                Show (ThreadId m)) => MonadThread m where

  type ThreadId m :: Type

  myThreadId     :: m (ThreadId m)
  labelThread    :: ThreadId m -> String -> m ()

-- | Apply the label to the current thread
labelThisThread :: MonadThread m => String -> m ()
labelThisThread :: forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label = m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId m (ThreadId m) -> (ThreadId m -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ThreadId m
tid -> ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label


class MonadThread m => MonadFork m where

  forkIO           :: m () -> m (ThreadId m)
  forkOn           :: Int -> m () -> m (ThreadId m)
  forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
  forkFinally      :: m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
  throwTo          :: Exception e => ThreadId m -> e -> m ()

  killThread       :: ThreadId m -> m ()
  killThread ThreadId m
tid = ThreadId m -> AsyncException -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncException
ThreadKilled

  yield            :: m ()


instance MonadThread IO where
  type ThreadId IO = IO.ThreadId
  myThreadId :: IO (ThreadId IO)
myThreadId   = IO ThreadId
IO (ThreadId IO)
IO.myThreadId
  labelThread :: ThreadId IO -> String -> IO ()
labelThread  = ThreadId -> String -> IO ()
ThreadId IO -> String -> IO ()
IO.labelThread

instance MonadFork IO where
  forkIO :: IO () -> IO (ThreadId IO)
forkIO           = IO () -> IO ThreadId
IO () -> IO (ThreadId IO)
IO.forkIO
  forkOn :: Int -> IO () -> IO (ThreadId IO)
forkOn           = Int -> IO () -> IO ThreadId
Int -> IO () -> IO (ThreadId IO)
IO.forkOn
  forkIOWithUnmask :: ((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
forkIOWithUnmask = ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
IO.forkIOWithUnmask
  forkFinally :: forall a.
IO a -> (Either SomeException a -> IO ()) -> IO (ThreadId IO)
forkFinally      = IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
IO a -> (Either SomeException a -> IO ()) -> IO (ThreadId IO)
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
IO.forkFinally
  throwTo :: forall e. Exception e => ThreadId IO -> e -> IO ()
throwTo          = ThreadId -> e -> IO ()
ThreadId IO -> e -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
IO.throwTo
  killThread :: ThreadId IO -> IO ()
killThread       = ThreadId -> IO ()
ThreadId IO -> IO ()
IO.killThread
  yield :: IO ()
yield            = IO ()
IO.yield

instance MonadThread m => MonadThread (ReaderT r m) where
  type ThreadId (ReaderT r m) = ThreadId m
  myThreadId :: ReaderT r m (ThreadId (ReaderT r m))
myThreadId      = m (ThreadId m) -> ReaderT r m (ThreadId m)
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (ThreadId m)
forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
  labelThread :: ThreadId (ReaderT r m) -> String -> ReaderT r m ()
labelThread ThreadId (ReaderT r m)
t String
l = m () -> ReaderT r m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT r m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ThreadId m -> String -> m ()
forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
ThreadId (ReaderT r m)
t String
l)

instance MonadFork m => MonadFork (ReaderT e m) where
  forkIO :: ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkIO (ReaderT e -> m ()
f)   = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (e -> m ()
f e
e)
  forkOn :: Int -> ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkOn Int
n (ReaderT e -> m ()
f) = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> Int -> m () -> m (ThreadId m)
forall (m :: * -> *). MonadFork m => Int -> m () -> m (ThreadId m)
forkOn Int
n (e -> m ()
f e
e)
  forkIOWithUnmask :: ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkIOWithUnmask (forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k   = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall (m :: * -> *).
MonadFork m =>
((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkIOWithUnmask (((forall a. m a -> m a) -> m ()) -> m (ThreadId m))
-> ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
                         let restore' :: ReaderT e m a -> ReaderT e m a
                             restore' :: forall a. ReaderT e m a -> ReaderT e m a
restore' (ReaderT e -> m a
f) = (e -> m a) -> ReaderT e m a
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m a) -> ReaderT e m a) -> (e -> m a) -> ReaderT e m a
forall a b. (a -> b) -> a -> b
$ m a -> m a
forall a. m a -> m a
restore (m a -> m a) -> (e -> m a) -> e -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
f
                         in ReaderT e m () -> e -> m ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k ReaderT e m a -> ReaderT e m a
forall a. ReaderT e m a -> ReaderT e m a
restore') e
e
  forkFinally :: forall a.
ReaderT e m a
-> (Either SomeException a -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkFinally ReaderT e m a
f Either SomeException a -> ReaderT e m ()
k     = (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((e -> m (ThreadId (ReaderT e m)))
 -> ReaderT e m (ThreadId (ReaderT e m)))
-> (e -> m (ThreadId (ReaderT e m)))
-> ReaderT e m (ThreadId (ReaderT e m))
forall a b. (a -> b) -> a -> b
$ \e
e -> m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forall a. m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forall (m :: * -> *) a.
MonadFork m =>
m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forkFinally (ReaderT e m a -> e -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT e m a
f e
e)
                                      ((Either SomeException a -> m ()) -> m (ThreadId m))
-> (Either SomeException a -> m ()) -> m (ThreadId m)
forall a b. (a -> b) -> a -> b
$ \Either SomeException a
err -> ReaderT e m () -> e -> m ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Either SomeException a -> ReaderT e m ()
k Either SomeException a
err) e
e
  throwTo :: forall e.
Exception e =>
ThreadId (ReaderT e m) -> e -> ReaderT e m ()
throwTo ThreadId (ReaderT e m)
e e
t = m () -> ReaderT e m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ThreadId m -> e -> m ()
forall e. Exception e => ThreadId m -> e -> m ()
forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
ThreadId (ReaderT e m)
e e
t)
  yield :: ReaderT e m ()
yield       = m () -> ReaderT e m ()
forall (m :: * -> *) a. Monad m => m a -> ReaderT e m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m ()
forall (m :: * -> *). MonadFork m => m ()
yield