{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module Control.Monad.Class.MonadFork
  ( MonadThread (..)
  , labelThisThread
  , MonadFork (..)
  ) where

import qualified Control.Concurrent as IO
import           Control.Exception (AsyncException (ThreadKilled), Exception)
import           Control.Monad.Reader (ReaderT (..), lift)
import           Data.Kind (Type)
import qualified GHC.Conc.Sync as IO (labelThread)


class (Monad m, Eq   (ThreadId m),
                Ord  (ThreadId m),
                Show (ThreadId m)) => MonadThread m where

  type ThreadId m :: Type

  myThreadId     :: m (ThreadId m)
  labelThread    :: ThreadId m -> String -> m ()

-- | Apply the label to the current thread
labelThisThread :: MonadThread m => String -> m ()
labelThisThread :: forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label = forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ThreadId m
tid -> forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label


class MonadThread m => MonadFork m where

  forkIO           :: m () -> m (ThreadId m)
  forkOn           :: Int -> m () -> m (ThreadId m)
  forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
  throwTo          :: Exception e => ThreadId m -> e -> m ()

  killThread       :: ThreadId m -> m ()
  killThread ThreadId m
tid = forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncException
ThreadKilled

  yield            :: m ()


instance MonadThread IO where
  type ThreadId IO = IO.ThreadId
  myThreadId :: IO (ThreadId IO)
myThreadId   = IO ThreadId
IO.myThreadId
  labelThread :: ThreadId IO -> String -> IO ()
labelThread  = ThreadId -> String -> IO ()
IO.labelThread

instance MonadFork IO where
  forkIO :: IO () -> IO (ThreadId IO)
forkIO           = IO () -> IO ThreadId
IO.forkIO
  forkOn :: Int -> IO () -> IO (ThreadId IO)
forkOn           = Int -> IO () -> IO ThreadId
IO.forkOn
  forkIOWithUnmask :: ((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
forkIOWithUnmask = ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
IO.forkIOWithUnmask
  throwTo :: forall e. Exception e => ThreadId IO -> e -> IO ()
throwTo          = forall e. Exception e => ThreadId -> e -> IO ()
IO.throwTo
  killThread :: ThreadId IO -> IO ()
killThread       = ThreadId -> IO ()
IO.killThread
  yield :: IO ()
yield            = IO ()
IO.yield

instance MonadThread m => MonadThread (ReaderT r m) where
  type ThreadId (ReaderT r m) = ThreadId m
  myThreadId :: ReaderT r m (ThreadId (ReaderT r m))
myThreadId      = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
  labelThread :: ThreadId (ReaderT r m) -> String -> ReaderT r m ()
labelThread ThreadId (ReaderT r m)
t String
l = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId (ReaderT r m)
t String
l)

instance MonadFork m => MonadFork (ReaderT e m) where
  forkIO :: ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkIO (ReaderT e -> m ()
f)   = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (e -> m ()
f e
e)
  forkOn :: Int -> ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkOn Int
n (ReaderT e -> m ()
f) = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *). MonadFork m => Int -> m () -> m (ThreadId m)
forkOn Int
n (e -> m ()
f e
e)
  forkIOWithUnmask :: ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkIOWithUnmask (forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k   = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *).
MonadFork m =>
((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkIOWithUnmask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
                         let restore' :: ReaderT e m a -> ReaderT e m a
                             restore' :: forall a. ReaderT e m a -> ReaderT e m a
restore' (ReaderT e -> m a
f) = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall a. m a -> m a
restore forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
f
                         in forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k forall a. ReaderT e m a -> ReaderT e m a
restore') e
e
  throwTo :: forall e.
Exception e =>
ThreadId (ReaderT e m) -> e -> ReaderT e m ()
throwTo ThreadId (ReaderT e m)
e e
t = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId (ReaderT e m)
e e
t)
  yield :: ReaderT e m ()
yield       = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadFork m => m ()
yield