{-# LANGUAGE CPP #-}

-- | Gang Primitives.
module Data.Array.Repa.Eval.Gang
        ( theGang
        , Gang, forkGang, gangSize, gangIO, gangST)     
where
import GHC.IO
import GHC.ST
import GHC.Conc                 (forkOn)
import Control.Concurrent.MVar
import Control.Exception        (assert)
import Control.Monad
import GHC.Conc                 (numCapabilities)
import System.IO


-- TheGang --------------------------------------------------------------------
-- | This globally shared gang is auto-initialised at startup and shared by all
--   Repa computations.
--
--   In a data parallel setting, it does not help to have multiple gangs
--   running at the same time. This is because a single data parallel
--   computation should already be able to keep all threads busy. If we had
--   multiple gangs running at the same time, then the system as a whole would
--   run slower as the gangs would contend for cache and thrash the scheduler.
--
--   If, due to laziness or otherwise, you try to start multiple parallel
--   Repa computations at the same time, then you will get the following
--   warning on stderr at runtime:
--
-- @Data.Array.Repa: Performing nested parallel computation sequentially.
--    You've probably called the 'compute' or 'copy' function while another
--    instance was already running. This can happen if the second version
--    was suspended due to lazy evaluation. Use 'deepSeqArray' to ensure that
--    each array is fully evaluated before you 'compute' the next one.
-- @
--
theGang :: Gang
{-# NOINLINE theGang #-}
theGang :: Gang
theGang 
 = IO Gang -> Gang
forall a. IO a -> a
unsafePerformIO 
 (IO Gang -> Gang) -> IO Gang -> Gang
forall a b. (a -> b) -> a -> b
$ do   let caps :: Int
caps        = Int
numCapabilities
        Int -> IO Gang
forkGang Int
caps


-- Requests -------------------------------------------------------------------
-- | The 'Req' type encapsulates work requests for individual members of a gang.
data Req
        -- | Instruct the worker to run the given action.
        = ReqDo        (Int -> IO ())

        -- | Tell the worker that we're shutting the gang down.
        --   The worker should signal that it's receieved the request by
        --   writing to its result var before returning to the caller (forkGang).
        | ReqShutdown


-- Gang -----------------------------------------------------------------------
-- | A 'Gang' is a group of threads that execute arbitrary work requests.
data Gang
        = Gang 
        { -- | Number of threads in the gang.
          Gang -> Int
_gangThreads           :: !Int           

          -- | Workers listen for requests on these vars.
        , Gang -> [MVar Req]
_gangRequestVars       :: [MVar Req]     

          -- | Workers put their results in these vars.
        , Gang -> [MVar ()]
_gangResultVars        :: [MVar ()] 

          -- | Indicates that the gang is busy.
        , Gang -> MVar Bool
_gangBusy              :: MVar Bool
        } 

instance Show Gang where
  showsPrec :: Int -> Gang -> ShowS
showsPrec Int
p (Gang Int
n [MVar Req]
_ [MVar ()]
_ MVar Bool
_)
        = String -> ShowS
showString String
"<<"
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
p Int
n
        ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" threads>>"


-- | O(1). Yield the number of threads in the 'Gang'.
gangSize :: Gang -> Int
gangSize :: Gang -> Int
gangSize (Gang Int
n [MVar Req]
_ [MVar ()]
_ MVar Bool
_) 
        = Int
n


-- | Fork a 'Gang' with the given number of threads (at least 1).
forkGang :: Int -> IO Gang
forkGang :: Int -> IO Gang
forkGang Int
n
 = Bool -> IO Gang -> IO Gang
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)
 (IO Gang -> IO Gang) -> IO Gang -> IO Gang
forall a b. (a -> b) -> a -> b
$ do
        -- Create the vars we'll use to issue work requests.
        [MVar Req]
mvsRequest     <- [IO (MVar Req)] -> IO [MVar Req]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([IO (MVar Req)] -> IO [MVar Req])
-> [IO (MVar Req)] -> IO [MVar Req]
forall a b. (a -> b) -> a -> b
$ Int -> IO (MVar Req) -> [IO (MVar Req)]
forall a. Int -> a -> [a]
replicate Int
n (IO (MVar Req) -> [IO (MVar Req)])
-> IO (MVar Req) -> [IO (MVar Req)]
forall a b. (a -> b) -> a -> b
$ IO (MVar Req)
forall a. IO (MVar a)
newEmptyMVar

        -- Create the vars we'll use to signal that threads are done.
        [MVar ()]
mvsDone        <- [IO (MVar ())] -> IO [MVar ()]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([IO (MVar ())] -> IO [MVar ()]) -> [IO (MVar ())] -> IO [MVar ()]
forall a b. (a -> b) -> a -> b
$ Int -> IO (MVar ()) -> [IO (MVar ())]
forall a. Int -> a -> [a]
replicate Int
n (IO (MVar ()) -> [IO (MVar ())]) -> IO (MVar ()) -> [IO (MVar ())]
forall a b. (a -> b) -> a -> b
$ IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar

        -- Add finalisers so we can shut the workers down cleanly if they
        -- become unreachable.
        (MVar Req -> MVar () -> IO (Weak (MVar Req)))
-> [MVar Req] -> [MVar ()] -> IO ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\MVar Req
varReq MVar ()
varDone 
                        -> MVar Req -> IO () -> IO (Weak (MVar Req))
forall a. MVar a -> IO () -> IO (Weak (MVar a))
mkWeakMVar MVar Req
varReq (MVar Req -> MVar () -> IO ()
finaliseWorker MVar Req
varReq MVar ()
varDone)) 
                [MVar Req]
mvsRequest
                [MVar ()]
mvsDone

        -- Create all the worker threads
        (Int -> IO () -> IO ThreadId) -> [Int] -> [IO ()] -> IO ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> IO () -> IO ThreadId
forkOn [Int
0..]
                ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ (Int -> MVar Req -> MVar () -> IO ())
-> [Int] -> [MVar Req] -> [MVar ()] -> [IO ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Int -> MVar Req -> MVar () -> IO ()
gangWorker 
                        [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [MVar Req]
mvsRequest [MVar ()]
mvsDone

        -- The gang is currently idle.
        MVar Bool
busy   <- Bool -> IO (MVar Bool)
forall a. a -> IO (MVar a)
newMVar Bool
False

        Gang -> IO Gang
forall (m :: * -> *) a. Monad m => a -> m a
return (Gang -> IO Gang) -> Gang -> IO Gang
forall a b. (a -> b) -> a -> b
$ Int -> [MVar Req] -> [MVar ()] -> MVar Bool -> Gang
Gang Int
n [MVar Req]
mvsRequest [MVar ()]
mvsDone MVar Bool
busy



-- | The worker thread of a 'Gang'.
--   The threads blocks on the MVar waiting for a work request.
gangWorker :: Int -> MVar Req -> MVar () -> IO ()
gangWorker :: Int -> MVar Req -> MVar () -> IO ()
gangWorker Int
threadId MVar Req
varRequest MVar ()
varDone
 = do   
        -- Wait for a request 
        Req
req     <- MVar Req -> IO Req
forall a. MVar a -> IO a
takeMVar MVar Req
varRequest

        case Req
req of
         ReqDo Int -> IO ()
action
          -> do -- Run the action we were given.
                Int -> IO ()
action Int
threadId

                -- Signal that the action is complete.
                MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
varDone ()

                -- Wait for more requests.
                Int -> MVar Req -> MVar () -> IO ()
gangWorker Int
threadId MVar Req
varRequest MVar ()
varDone

         Req
ReqShutdown
          ->    MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
varDone ()


-- | Finaliser for worker threads.
--   We want to shutdown the corresponding thread when it's MVar becomes
--   unreachable.
--   Without this Repa programs can complain about "Blocked indefinitely
--   on an MVar" because worker threads are still blocked on the request
--   MVars when the program ends. Whether the finalizer is called or not
--   is very racey. It happens about 1 in 10 runs when for the
--   repa-edgedetect benchmark, and less often with the others.
--
--   We're relying on the comment in System.Mem.Weak that says
--    "If there are no other threads to run, the runtime system will
--     check for runnablefinalizers before declaring the system to be
--     deadlocked."
--
--   If we were creating and destroying the gang cleanly we wouldn't need
--     this, but theGang is created with a top-level unsafePerformIO.
--     Hacks beget hacks beget hacks...
--
finaliseWorker :: MVar Req -> MVar () -> IO ()
finaliseWorker :: MVar Req -> MVar () -> IO ()
finaliseWorker MVar Req
varReq MVar ()
varDone 
 = do   MVar Req -> Req -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Req
varReq Req
ReqShutdown
        MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
varDone
        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()


-- | Issue work requests for the 'Gang' and wait until they complete.
--
--   If the gang is already busy then print a warning to `stderr` and just
--   run the actions sequentially in the requesting thread.
gangIO  :: Gang
        -> (Int -> IO ())
        -> IO ()

{-# NOINLINE gangIO #-}
gangIO :: Gang -> (Int -> IO ()) -> IO ()
gangIO gang :: Gang
gang@(Gang Int
_ [MVar Req]
_ [MVar ()]
_ MVar Bool
busy) Int -> IO ()
action
 = do   Bool
b <- MVar Bool -> Bool -> IO Bool
forall a. MVar a -> a -> IO a
swapMVar MVar Bool
busy Bool
True
        if Bool
b
         then do
                Gang -> (Int -> IO ()) -> IO ()
seqIO Gang
gang Int -> IO ()
action

         else do
                Gang -> (Int -> IO ()) -> IO ()
parIO Gang
gang Int -> IO ()
action
                Bool
_ <- MVar Bool -> Bool -> IO Bool
forall a. MVar a -> a -> IO a
swapMVar MVar Bool
busy Bool
False
                () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()


-- | Run an action on the gang sequentially.
seqIO   :: Gang -> (Int -> IO ()) -> IO ()
seqIO :: Gang -> (Int -> IO ()) -> IO ()
seqIO (Gang Int
n [MVar Req]
_ [MVar ()]
_ MVar Bool
_) Int -> IO ()
action
 = do   Handle -> String -> IO ()
hPutStr Handle
stderr
         (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines
         [ String
"Data.Array.Repa: Performing nested parallel computation sequentially."
         , String
"  You've probably called the 'compute' or 'copy' function while another"
         , String
"  instance was already running. This can happen if the second version"
         , String
"  was suspended due to lazy evaluation. Use 'deepSeqArray' to ensure"
         , String
"  that each array is fully evaluated before you 'compute' the next one."
         , String
"" ]

        (Int -> IO ()) -> [Int] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> IO ()
action [Int
0 .. Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Run an action on the gang in parallel.
parIO   :: Gang -> (Int -> IO ()) -> IO ()
parIO :: Gang -> (Int -> IO ()) -> IO ()
parIO (Gang Int
_ [MVar Req]
mvsRequest [MVar ()]
mvsResult MVar Bool
_) Int -> IO ()
action
 = do   
        -- Send requests to all the threads.
        (MVar Req -> IO ()) -> [MVar Req] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\MVar Req
v -> MVar Req -> Req -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Req
v ((Int -> IO ()) -> Req
ReqDo Int -> IO ()
action)) [MVar Req]
mvsRequest

        -- Wait for all the requests to complete.
        (MVar () -> IO ()) -> [MVar ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar [MVar ()]
mvsResult


-- | Same as 'gangIO' but in the 'ST' monad.
gangST :: Gang -> (Int -> ST s ()) -> ST s ()
gangST :: Gang -> (Int -> ST s ()) -> ST s ()
gangST Gang
g Int -> ST s ()
p = IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ())
-> ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
g ((Int -> IO ()) -> ST s ()) -> (Int -> IO ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ ST s () -> IO ()
forall s a. ST s a -> IO a
unsafeSTToIO (ST s () -> IO ()) -> (Int -> ST s ()) -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ST s ()
p