{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module Control.Monad.Class.MonadFork
  ( MonadThread (..)
  , labelThisThread
  , MonadFork (..)
  ) where

import qualified Control.Concurrent as IO
import           Control.Exception (AsyncException (ThreadKilled), Exception, SomeException)
import           Control.Monad.Reader (ReaderT (..), lift)
import           Data.Kind (Type)
import qualified GHC.Conc.Sync as IO (labelThread)


class (Monad m, Eq   (ThreadId m),
                Ord  (ThreadId m),
                Show (ThreadId m)) => MonadThread m where

  type ThreadId m :: Type

  myThreadId     :: m (ThreadId m)
  labelThread    :: ThreadId m -> String -> m ()

-- | Apply the label to the current thread
labelThisThread :: MonadThread m => String -> m ()
labelThisThread :: forall (m :: * -> *). MonadThread m => String -> m ()
labelThisThread String
label = forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ThreadId m
tid -> forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId m
tid String
label


class MonadThread m => MonadFork m where

  forkIO           :: m () -> m (ThreadId m)
  forkOn           :: Int -> m () -> m (ThreadId m)
  forkIOWithUnmask :: ((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
  forkFinally      :: m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
  throwTo          :: Exception e => ThreadId m -> e -> m ()

  killThread       :: ThreadId m -> m ()
  killThread ThreadId m
tid = forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId m
tid AsyncException
ThreadKilled

  yield            :: m ()


instance MonadThread IO where
  type ThreadId IO = IO.ThreadId
  myThreadId :: IO (ThreadId IO)
myThreadId   = IO ThreadId
IO.myThreadId
  labelThread :: ThreadId IO -> String -> IO ()
labelThread  = ThreadId -> String -> IO ()
IO.labelThread

instance MonadFork IO where
  forkIO :: IO () -> IO (ThreadId IO)
forkIO           = IO () -> IO ThreadId
IO.forkIO
  forkOn :: Int -> IO () -> IO (ThreadId IO)
forkOn           = Int -> IO () -> IO ThreadId
IO.forkOn
  forkIOWithUnmask :: ((forall a. IO a -> IO a) -> IO ()) -> IO (ThreadId IO)
forkIOWithUnmask = ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
IO.forkIOWithUnmask
  forkFinally :: forall a.
IO a -> (Either SomeException a -> IO ()) -> IO (ThreadId IO)
forkFinally      = forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
IO.forkFinally
  throwTo :: forall e. Exception e => ThreadId IO -> e -> IO ()
throwTo          = forall e. Exception e => ThreadId -> e -> IO ()
IO.throwTo
  killThread :: ThreadId IO -> IO ()
killThread       = ThreadId -> IO ()
IO.killThread
  yield :: IO ()
yield            = IO ()
IO.yield

instance MonadThread m => MonadThread (ReaderT r m) where
  type ThreadId (ReaderT r m) = ThreadId m
  myThreadId :: ReaderT r m (ThreadId (ReaderT r m))
myThreadId      = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadThread m => m (ThreadId m)
myThreadId
  labelThread :: ThreadId (ReaderT r m) -> String -> ReaderT r m ()
labelThread ThreadId (ReaderT r m)
t String
l = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadThread m => ThreadId m -> String -> m ()
labelThread ThreadId (ReaderT r m)
t String
l)

instance MonadFork m => MonadFork (ReaderT e m) where
  forkIO :: ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkIO (ReaderT e -> m ()
f)   = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *). MonadFork m => m () -> m (ThreadId m)
forkIO (e -> m ()
f e
e)
  forkOn :: Int -> ReaderT e m () -> ReaderT e m (ThreadId (ReaderT e m))
forkOn Int
n (ReaderT e -> m ()
f) = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *). MonadFork m => Int -> m () -> m (ThreadId m)
forkOn Int
n (e -> m ()
f e
e)
  forkIOWithUnmask :: ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkIOWithUnmask (forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k   = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *).
MonadFork m =>
((forall a. m a -> m a) -> m ()) -> m (ThreadId m)
forkIOWithUnmask forall a b. (a -> b) -> a -> b
$ \forall a. m a -> m a
restore ->
                         let restore' :: ReaderT e m a -> ReaderT e m a
                             restore' :: forall a. ReaderT e m a -> ReaderT e m a
restore' (ReaderT e -> m a
f) = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ forall a. m a -> m a
restore forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
f
                         in forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((forall a. ReaderT e m a -> ReaderT e m a) -> ReaderT e m ()
k forall a. ReaderT e m a -> ReaderT e m a
restore') e
e
  forkFinally :: forall a.
ReaderT e m a
-> (Either SomeException a -> ReaderT e m ())
-> ReaderT e m (ThreadId (ReaderT e m))
forkFinally ReaderT e m a
f Either SomeException a -> ReaderT e m ()
k     = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \e
e -> forall (m :: * -> *) a.
MonadFork m =>
m a -> (Either SomeException a -> m ()) -> m (ThreadId m)
forkFinally (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT e m a
f e
e)
                                      forall a b. (a -> b) -> a -> b
$ \Either SomeException a
err -> forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Either SomeException a -> ReaderT e m ()
k Either SomeException a
err) e
e
  throwTo :: forall e.
Exception e =>
ThreadId (ReaderT e m) -> e -> ReaderT e m ()
throwTo ThreadId (ReaderT e m)
e e
t = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) e.
(MonadFork m, Exception e) =>
ThreadId m -> e -> m ()
throwTo ThreadId (ReaderT e m)
e e
t)
  yield :: ReaderT e m ()
yield       = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *). MonadFork m => m ()
yield