{-# LANGUAGE CPP
           , DeriveDataTypeable
           , NoImplicitPrelude
           , ImpredicativeTypes
           , RankNTypes #-}

#if __GLASGOW_HASKELL__ >= 701
{-# LANGUAGE Trustworthy #-}
#endif

--------------------------------------------------------------------------------
-- |
-- Module     : Control.Concurrent.Thread.Group
-- Copyright  : (c) 2010-2012 Bas van Dijk & Roel van Dijk
-- License    : BSD3 (see the file LICENSE)
-- Maintainer : Bas van Dijk <v.dijk.bas@gmail.com>
--            , Roel van Dijk <vandijk.roel@gmail.com>
--
-- This module extends @Control.Concurrent.Thread@ with the ability to wait for
-- a group of threads to terminate.
--
-- This module exports equivalently named functions from @Control.Concurrent@,
-- (@GHC.Conc@), and @Control.Concurrent.Thread@. Avoid ambiguities by importing
-- this module qualified. May we suggest:
--
-- @
-- import Control.Concurrent.Thread.Group ( ThreadGroup )
-- import qualified Control.Concurrent.Thread.Group as ThreadGroup ( ... )
-- @
--
--------------------------------------------------------------------------------

module Control.Concurrent.Thread.Group
    ( ThreadGroup
    , new
    , nrOfRunning
    , wait
    , waitN

      -- * Forking threads
    , forkIO
    , forkOS
    , forkOn
    , forkIOWithUnmask
    , forkOnWithUnmask
    ) where


--------------------------------------------------------------------------------
-- Imports
--------------------------------------------------------------------------------

-- from base:
import qualified Control.Concurrent     ( forkOS
                                        , forkIOWithUnmask
                                        , forkOnWithUnmask
                                        )
import Control.Concurrent               ( ThreadId )
import Control.Concurrent.MVar          ( newEmptyMVar, putMVar, readMVar )
import Control.Exception                ( try, mask )
import Control.Monad                    ( return, (>>=), when )
import Data.Function                    ( (.), ($) )
import Data.Functor                     ( fmap )
import Data.Eq                          ( Eq )
import Data.Ord                         ( (>=) )
import Data.Int                         ( Int )
import Data.Typeable                    ( Typeable )
import Prelude                          ( ($!), (+), subtract )
import System.IO                        ( IO )

-- from stm:
import Control.Concurrent.STM.TVar      ( TVar, newTVarIO, readTVar, writeTVar )
import Control.Concurrent.STM           ( STM, atomically, retry )

-- from threads:
import Control.Concurrent.Thread        ( Result )
import Control.Concurrent.Raw           ( rawForkIO, rawForkOn )
#ifdef __HADDOCK_VERSION__
import qualified Control.Concurrent.Thread as Thread ( forkIO
                                                     , forkOS
                                                     , forkOn
                                                     , forkIOWithUnmask
                                                     , forkOnWithUnmask
                                                     )
#endif


--------------------------------------------------------------------------------
-- * Thread groups
--------------------------------------------------------------------------------

{-| A @ThreadGroup@ can be understood as a counter which counts the number of
threads that were added to the group minus the ones that have terminated.

More formally a @ThreadGroup@ has the following semantics:

* 'new' initializes the counter to 0.

* Forking a thread increments the counter.

* When a forked thread terminates, whether normally or by raising an exception,
  the counter is decremented.

* 'nrOfRunning' yields a transaction that returns the counter.

* 'wait' blocks as long as the counter is greater than 0.

* 'waitN' blocks as long as the counter is greater or equal to the
   specified number.
-}
newtype ThreadGroup = ThreadGroup (TVar Int) deriving (ThreadGroup -> ThreadGroup -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ThreadGroup -> ThreadGroup -> Bool
$c/= :: ThreadGroup -> ThreadGroup -> Bool
== :: ThreadGroup -> ThreadGroup -> Bool
$c== :: ThreadGroup -> ThreadGroup -> Bool
Eq, Typeable)

-- | Create an empty group of threads.
new :: IO ThreadGroup
new :: IO ThreadGroup
new = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TVar Int -> ThreadGroup
ThreadGroup forall a b. (a -> b) -> a -> b
$ forall a. a -> IO (TVar a)
newTVarIO Int
0

{-| Yield a transaction that returns the number of running threads in the
group.

Note that because this function yields a 'STM' computation, the returned number
is guaranteed to be consistent inside the transaction.
-}
nrOfRunning :: ThreadGroup -> STM Int
nrOfRunning :: ThreadGroup -> STM Int
nrOfRunning (ThreadGroup TVar Int
numThreadsTV) = forall a. TVar a -> STM a
readTVar TVar Int
numThreadsTV

-- | Block until all threads in the group have terminated.
--
-- Note that: @wait = 'waitN' 1@.
wait :: ThreadGroup -> IO ()
wait :: ThreadGroup -> IO ()
wait = Int -> ThreadGroup -> IO ()
waitN Int
1

-- | Block until there are fewer than @N@ running threads in the group.
waitN :: Int -> ThreadGroup -> IO ()
waitN :: Int -> ThreadGroup -> IO ()
waitN Int
i ThreadGroup
tg = forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ ThreadGroup -> STM Int
nrOfRunning ThreadGroup
tg forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
n -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n forall a. Ord a => a -> a -> Bool
>= Int
i) forall a. STM a
retry


--------------------------------------------------------------------------------
-- * Forking threads
--------------------------------------------------------------------------------

-- | Same as @Control.Concurrent.Thread.'Thread.forkIO'@ but additionaly adds
-- the thread to the group.
forkIO :: ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkIO :: forall a. ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkIO = forall a.
(IO () -> IO ThreadId)
-> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
fork IO () -> IO ThreadId
rawForkIO

-- | Same as @Control.Concurrent.Thread.'Thread.forkOS'@ but additionaly adds
-- the thread to the group.
forkOS :: ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkOS :: forall a. ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkOS = forall a.
(IO () -> IO ThreadId)
-> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
fork IO () -> IO ThreadId
Control.Concurrent.forkOS

-- | Same as @Control.Concurrent.Thread.'Thread.forkOn'@ but
-- additionaly adds the thread to the group.
forkOn :: Int -> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkOn :: forall a.
Int -> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
forkOn = forall a.
(IO () -> IO ThreadId)
-> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
fork forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO () -> IO ThreadId
rawForkOn

-- | Same as @Control.Concurrent.Thread.'Thread.forkIOWithUnmask'@ but
-- additionaly adds the thread to the group.
forkIOWithUnmask
    :: ThreadGroup
    -> ((forall b. IO b -> IO b) -> IO a)
    -> IO (ThreadId, IO (Result a))
forkIOWithUnmask :: forall a.
ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkIOWithUnmask = forall a.
(((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId)
-> ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkWithUnmask ((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId
Control.Concurrent.forkIOWithUnmask

-- | Like @Control.Concurrent.Thread.'Thread.forkOnWithUnmask'@ but
-- additionaly adds the thread to the group.
forkOnWithUnmask
    :: Int
    -> ThreadGroup
    -> ((forall b. IO b -> IO b) -> IO a)
    -> IO (ThreadId, IO (Result a))
forkOnWithUnmask :: forall a.
Int
-> ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkOnWithUnmask = forall a.
(((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId)
-> ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkWithUnmask forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId
Control.Concurrent.forkOnWithUnmask


--------------------------------------------------------------------------------
-- Utils
--------------------------------------------------------------------------------

fork :: (IO () -> IO ThreadId)
     -> ThreadGroup
     -> IO a
     -> IO (ThreadId, IO (Result a))
fork :: forall a.
(IO () -> IO ThreadId)
-> ThreadGroup -> IO a -> IO (ThreadId, IO (Result a))
fork IO () -> IO ThreadId
doFork (ThreadGroup TVar Int
numThreadsTV) IO a
a = do
  MVar (Result a)
res <- forall a. IO (MVar a)
newEmptyMVar
  ThreadId
tid <- forall b. ((forall b. IO b -> IO b) -> IO b) -> IO b
mask forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
restore -> do
    forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Int
numThreadsTV (forall a. Num a => a -> a -> a
+ Int
1)
    IO () -> IO ThreadId
doFork forall a b. (a -> b) -> a -> b
$ do
      forall e a. Exception e => IO a -> IO (Either e a)
try (forall b. IO b -> IO b
restore IO a
a) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. MVar a -> a -> IO ()
putMVar MVar (Result a)
res
      forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Int
numThreadsTV (forall a. Num a => a -> a -> a
subtract Int
1)
  forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadId
tid, forall a. MVar a -> IO a
readMVar MVar (Result a)
res)

forkWithUnmask
    :: (((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId)
    -> ThreadGroup
    -> ((forall b. IO b -> IO b) -> IO a)
    -> IO (ThreadId, IO (Result a))
forkWithUnmask :: forall a.
(((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId)
-> ThreadGroup
-> ((forall b. IO b -> IO b) -> IO a)
-> IO (ThreadId, IO (Result a))
forkWithUnmask ((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId
doForkWithUnmask = \(ThreadGroup TVar Int
numThreadsTV) (forall b. IO b -> IO b) -> IO a
f -> do
  MVar (Result a)
res <- forall a. IO (MVar a)
newEmptyMVar
  ThreadId
tid <- forall b. ((forall b. IO b -> IO b) -> IO b) -> IO b
mask forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
restore -> do
    forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Int
numThreadsTV (forall a. Num a => a -> a -> a
+ Int
1)
    ((forall b. IO b -> IO b) -> IO ()) -> IO ThreadId
doForkWithUnmask forall a b. (a -> b) -> a -> b
$ \forall b. IO b -> IO b
unmask -> do
      forall e a. Exception e => IO a -> IO (Either e a)
try (forall b. IO b -> IO b
restore forall a b. (a -> b) -> a -> b
$ (forall b. IO b -> IO b) -> IO a
f forall b. IO b -> IO b
unmask) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. MVar a -> a -> IO ()
putMVar MVar (Result a)
res
      forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Int
numThreadsTV (forall a. Num a => a -> a -> a
subtract Int
1)
  forall (m :: * -> *) a. Monad m => a -> m a
return (ThreadId
tid, forall a. MVar a -> IO a
readMVar MVar (Result a)
res)

-- | Strictly modify the contents of a 'TVar'.
modifyTVar :: TVar a -> (a -> a) -> STM ()
modifyTVar :: forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar a
tv a -> a
f = forall a. TVar a -> STM a
readTVar TVar a
tv forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. TVar a -> a -> STM ()
writeTVar TVar a
tv forall b c a. (b -> c) -> (a -> b) -> a -> c
.! a -> a
f

-- | Strict function composition
(.!) :: (b -> c) -> (a -> b) -> (a -> c)
b -> c
f .! :: forall b c a. (b -> c) -> (a -> b) -> a -> c
.! a -> b
g = \a
x -> b -> c
f forall a b. (a -> b) -> a -> b
$! a -> b
g a
x