{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE UnboxedTuples       #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.Execute.Scheduler
-- Copyright   : [2018..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.Execute.Scheduler (

  Action, Job(..), Workers,

  schedule,
  hireWorkers, hireWorkersOn, retireWorkers, fireWorkers, numWorkers,

) where

import qualified Data.Array.Accelerate.LLVM.Native.Debug            as D

import Control.Concurrent
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Data.Concurrent.Queue.MichaelScott
import Data.IORef
import Data.Int
import Data.Sequence                                                ( Seq )
import Text.Printf
import qualified Data.Sequence                                      as Seq

import GHC.Base

#include "MachDeps.h"


-- An individual computation is a job consisting of a sequence of actions to be
-- executed by the worker threads in parallel.
--
type Action = IO ()

data Task
  = Work Action
  | Retire

data Job = Job
  { Job -> Seq Action
jobTasks  :: !(Seq Action)    -- actions required to complete this job
  , Job -> Maybe Action
jobDone   :: !(Maybe Action)  -- execute after the last action is completed
  }

data Workers = Workers
  { Workers -> Int
workerCount       :: {-# UNPACK #-} !Int                      -- number of worker threads (length workerThreadIds)
  , Workers -> IORef (MVar ())
workerActive      :: {-# UNPACK #-} !(IORef (MVar ()))        -- fill to signal to the threads to wake up
  , Workers -> LinkedQueue Task
workerTaskQueue   :: {-# UNPACK #-} !(LinkedQueue Task)       -- tasks currently being executed; may be from different jobs
  , Workers -> [ThreadId]
workerThreadIds   :: ![ThreadId]                              -- to send signals to / kill
  , Workers -> MVar (Seq (ThreadId, SomeException))
workerException   :: !(MVar (Seq (ThreadId, SomeException)))  -- XXX: what should we do with these?
  }


-- Schedule a job to be executed by the worker threads. May be called
-- concurrently.
--
{-# INLINEABLE schedule #-}
schedule :: Workers -> Job -> IO ()
schedule :: Workers -> Job -> Action
schedule Workers
workers Job{Maybe Action
Seq Action
jobDone :: Maybe Action
jobTasks :: Seq Action
jobDone :: Job -> Maybe Action
jobTasks :: Job -> Seq Action
..} = do
  -- Generate the work list. If there is a finalisation action, there is a bit
  -- of extra work to do at each step.
  --
  Seq Task
tasks <- case Maybe Action
jobDone of
             Maybe Action
Nothing    -> Seq Task -> IO (Seq Task)
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq Task -> IO (Seq Task)) -> Seq Task -> IO (Seq Task)
forall a b. (a -> b) -> a -> b
$ (Action -> Task) -> Seq Action -> Seq Task
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Action -> Task
Work Seq Action
jobTasks
             Just Action
done  -> do
                -- The thread which finishes the last task runs the finalisation
                -- action, so keep track of how many items have been completed.
                --
                Atomic
count <- Int -> IO Atomic
newAtomic (Seq Action -> Int
forall a. Seq a -> Int
Seq.length Seq Action
jobTasks)
                Seq Task -> IO (Seq Task)
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq Task -> IO (Seq Task)) -> Seq Task -> IO (Seq Task)
forall a b. (a -> b) -> a -> b
$ ((Action -> Task) -> Seq Action -> Seq Task)
-> Seq Action -> (Action -> Task) -> Seq Task
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Action -> Task) -> Seq Action -> Seq Task
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Seq Action
jobTasks ((Action -> Task) -> Seq Task) -> (Action -> Task) -> Seq Task
forall a b. (a -> b) -> a -> b
$ \Action
io -> Action -> Task
Work (Action -> Task) -> Action -> Task
forall a b. (a -> b) -> a -> b
$ do
                  ()
_result   <- Action
io
                  Int
remaining <- Atomic -> IO Int
fetchSubAtomic Atomic
count -- returns old value
                  Bool -> Action -> Action
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
remaining Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) Action
done

  -- Submit the tasks to the queue to be executed by the worker threads.
  --
  Workers -> Seq Task -> Action
pushTasks Workers
workers Seq Task
tasks


-- Workers can either be executing tasks (real work), waiting for work, or
-- going into retirement (whence the thread will exit).
--
-- So that threads don't spin endlessly on an empty queue waiting for work,
-- they will automatically sleep waiting on the signal MVar after several
-- failed retries. Note that 'readMVar' is multiple wake up, so all threads
-- will be efficiently woken up when new work is added via 'submit'.
--
-- The MVar is stored in an IORef. When scheduling new work, we resolve the
-- old MVar by putting a value in it, and we put a new, at that moment
-- unresolved, MVar in the IORef. If the queue is empty in runWorker, then
-- we will after some attempts wait on an MVar. It is essential that we
-- first get the MVar out of the IORef, before trying to read from the
-- queue. If this would have been done the other way around, we could have
-- a race condition, where new work is pushed after we tried to dequeue
-- work and before we wait on an MVar. We then wait on the new MVar, which
-- may cause that this thread stalls forever.
--
runWorker :: ThreadId -> IORef (MVar ()) -> LinkedQueue Task -> IO ()
runWorker :: ThreadId -> IORef (MVar ()) -> LinkedQueue Task -> Action
runWorker ThreadId
tid IORef (MVar ())
ref LinkedQueue Task
queue = Int16 -> Action
loop Int16
0
  where
    loop :: Int16 -> IO ()
    loop :: Int16 -> Action
loop !Int16
retries = do
      -- Extract the activation MVar from the IORef, before trying to claim
      -- an item from the work queue
      MVar ()
var <- IORef (MVar ()) -> IO (MVar ())
forall a. IORef a -> IO a
readIORef IORef (MVar ())
ref
      Maybe Task
req <- LinkedQueue Task -> IO (Maybe Task)
forall a. LinkedQueue a -> IO (Maybe a)
tryPopR LinkedQueue Task
queue
      case Maybe Task
req of
        -- The number of retries and thread delay on failure are knobs which can
        -- be tuned. Having these values too high results in busy work which
        -- will detract from time spent adding new work thus reducing
        -- productivity, whereas low values will reduce responsiveness and thus
        -- also reduce productivity.
        --
        -- TODO: Tune these values a bit
        --
        Maybe Task
Nothing   -> if Int16
retries Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
< Int16
16
                       then Int16 -> Action
loop (Int16
retriesInt16 -> Int16 -> Int16
forall a. Num a => a -> a -> a
+Int16
1)
                       else do
                         -- This thread will sleep, by waiting on the MVar (var) we previously
                         -- extracted from the IORef (ref)
                         --
                         -- When some other thread pushes new work, it will also write to that MVar
                         -- and this thread will wake up.
                         String -> Action
message (String -> Action) -> String -> Action
forall a b. (a -> b) -> a -> b
$ String -> String -> String
forall r. PrintfType r => String -> r
printf String
"sched: %s sleeping" (ThreadId -> String
forall a. Show a => a -> String
show ThreadId
tid)

                         -- blocking, wake-up when new work is available
                         () <- MVar () -> Action
forall a. MVar a -> IO a
readMVar MVar ()
var
                         Int16 -> Action
loop Int16
0
        --
        Just Task
task -> case Task
task of
                       Work Action
io -> Action
io Action -> Action -> Action
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int16 -> Action
loop Int16
0
                       Task
Retire  -> String -> Action
message (String -> Action) -> String -> Action
forall a b. (a -> b) -> a -> b
$ String -> String -> String
forall r. PrintfType r => String -> r
printf String
"sched: %s shutting down" (ThreadId -> String
forall a. Show a => a -> String
show ThreadId
tid)


-- Spawn a new worker thread for each capability
--
hireWorkers :: IO Workers
hireWorkers :: IO Workers
hireWorkers = do
  Int
ncpu    <- IO Int
getNumCapabilities
  Workers
workers <- [Int] -> IO Workers
hireWorkersOn [Int
0 .. Int
ncpuInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  Workers -> IO Workers
forall (m :: * -> *) a. Monad m => a -> m a
return Workers
workers

-- Spawn worker threads on the specified capabilities
--
hireWorkersOn :: [Int] -> IO Workers
hireWorkersOn :: [Int] -> IO Workers
hireWorkersOn [Int]
caps = do
  MVar ()
active          <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  IORef (MVar ())
workerActive    <- MVar () -> IO (IORef (MVar ()))
forall a. a -> IO (IORef a)
newIORef MVar ()
active
  MVar (Seq (ThreadId, SomeException))
workerException <- IO (MVar (Seq (ThreadId, SomeException)))
forall a. IO (MVar a)
newEmptyMVar
  LinkedQueue Task
workerTaskQueue <- IO (LinkedQueue Task)
forall a. IO (LinkedQueue a)
newQ
  [ThreadId]
workerThreadIds <- [Int] -> (Int -> IO ThreadId) -> IO [ThreadId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
caps ((Int -> IO ThreadId) -> IO [ThreadId])
-> (Int -> IO ThreadId) -> IO [ThreadId]
forall a b. (a -> b) -> a -> b
$ \Int
cpu -> do
                       ThreadId
tid <- IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Int -> ((forall a. IO a -> IO a) -> Action) -> IO ThreadId
forkOnWithUnmask Int
cpu (((forall a. IO a -> IO a) -> Action) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> Action) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
                                ThreadId
tid <- IO ThreadId
myThreadId
                                Action -> (SomeException -> Action) -> Action
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
                                  (Action -> Action
forall a. IO a -> IO a
restore (Action -> Action) -> Action -> Action
forall a b. (a -> b) -> a -> b
$ ThreadId -> IORef (MVar ()) -> LinkedQueue Task -> Action
runWorker ThreadId
tid IORef (MVar ())
workerActive LinkedQueue Task
workerTaskQueue)
                                  (MVar (Seq (ThreadId, SomeException))
-> (ThreadId, SomeException) -> Action
forall a. MVar (Seq a) -> a -> Action
appendMVar MVar (Seq (ThreadId, SomeException))
workerException ((ThreadId, SomeException) -> Action)
-> (SomeException -> (ThreadId, SomeException))
-> SomeException
-> Action
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ThreadId
tid,))
                       --
                       String -> Action
message (String -> Action) -> String -> Action
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"sched: fork %s on capability %d" (ThreadId -> String
forall a. Show a => a -> String
show ThreadId
tid) Int
cpu
                       ThreadId -> IO ThreadId
forall (m :: * -> *) a. Monad m => a -> m a
return ThreadId
tid
  --
  [ThreadId]
workerThreadIds [ThreadId] -> IO Workers -> IO Workers
forall a b. NFData a => a -> b -> b
`deepseq` Workers -> IO Workers
forall (m :: * -> *) a. Monad m => a -> m a
return Workers :: Int
-> IORef (MVar ())
-> LinkedQueue Task
-> [ThreadId]
-> MVar (Seq (ThreadId, SomeException))
-> Workers
Workers { workerCount :: Int
workerCount = [ThreadId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ThreadId]
workerThreadIds, [ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerActive :: IORef (MVar ())
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
..}


-- Submit a job telling every worker to retire. Currently pending tasks will be
-- completed first.
--
retireWorkers :: Workers -> IO ()
retireWorkers :: Workers -> Action
retireWorkers workers :: Workers
workers@Workers{Int
[ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
workerCount :: Int
workerException :: Workers -> MVar (Seq (ThreadId, SomeException))
workerThreadIds :: Workers -> [ThreadId]
workerTaskQueue :: Workers -> LinkedQueue Task
workerActive :: Workers -> IORef (MVar ())
workerCount :: Workers -> Int
..} =
  Workers -> Seq Task -> Action
pushTasks Workers
workers (Int -> Task -> Seq Task
forall a. Int -> a -> Seq a
Seq.replicate Int
workerCount Task
Retire)


-- Pushes work to the task queue
--
-- Wakes up the worker threads if needed, by writing to the old MVar in
-- workerActive. We replace workerActive with a new, empty MVar, such that
-- we can wake them up later when we again have new work.
--
pushTasks :: Workers -> Seq Task -> IO ()
pushTasks :: Workers -> Seq Task -> Action
pushTasks Workers{Int
[ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
workerCount :: Int
workerException :: Workers -> MVar (Seq (ThreadId, SomeException))
workerThreadIds :: Workers -> [ThreadId]
workerTaskQueue :: Workers -> LinkedQueue Task
workerActive :: Workers -> IORef (MVar ())
workerCount :: Workers -> Int
..} Seq Task
tasks = do
  -- Push work to the queue
  (Task -> Action) -> Seq Task -> Action
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (LinkedQueue Task -> Task -> Action
forall a. LinkedQueue a -> a -> Action
pushL LinkedQueue Task
workerTaskQueue) Seq Task
tasks

  -- Create a new MVar, which we use in a later call to pushTasks to wake
  -- up the threads, then swap the MVar in the IORef workerActive, with the
  -- new MVar.
  --
  -- This must be atomic, to prevent race conditions when two threads are
  -- adding new work. Without the atomic, it may occur that some MVar is
  -- never resolved, causing that a worker thread which waits on that MVar
  -- to stall.
  MVar ()
new <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
  MVar ()
old <- IORef (MVar ()) -> (MVar () -> (MVar (), MVar ())) -> IO (MVar ())
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (MVar ())
workerActive (MVar ()
new,)

  -- Resolve the old MVar to wake up all waiting threads
  MVar () -> () -> Action
forall a. MVar a -> a -> Action
putMVar MVar ()
old ()


-- Kill worker threads immediately.
--
fireWorkers :: Workers -> IO ()
fireWorkers :: Workers -> Action
fireWorkers Workers{Int
[ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
workerCount :: Int
workerException :: Workers -> MVar (Seq (ThreadId, SomeException))
workerThreadIds :: Workers -> [ThreadId]
workerTaskQueue :: Workers -> LinkedQueue Task
workerActive :: Workers -> IORef (MVar ())
workerCount :: Workers -> Int
..} =
  (ThreadId -> Action) -> [ThreadId] -> Action
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ThreadId -> Action
killThread [ThreadId]
workerThreadIds

-- Number of workers
--
numWorkers :: Workers -> Int
numWorkers :: Workers -> Int
numWorkers = Workers -> Int
workerCount


-- Utility
-- -------

data Atomic = Atomic !(MutableByteArray# RealWorld)

{-# INLINE newAtomic #-}
newAtomic :: Int -> IO Atomic
newAtomic :: Int -> IO Atomic
newAtomic (I# Int#
n#) = (State# RealWorld -> (# State# RealWorld, Atomic #)) -> IO Atomic
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Atomic #)) -> IO Atomic)
-> (State# RealWorld -> (# State# RealWorld, Atomic #))
-> IO Atomic
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
  case SIZEOF_HSINT                 of { I# size#       ->
  case newByteArray# size# s0       of { (# s1, mba# #) ->
  case writeIntArray# mba# 0# n# s1 of { s2             ->  -- non-atomic is ok
    (# s2, Atomic mba# #) }}}

{-# INLINE fetchSubAtomic #-}
fetchSubAtomic :: Atomic -> IO Int
fetchSubAtomic :: Atomic -> IO Int
fetchSubAtomic (Atomic MutableByteArray# RealWorld
mba#) = (State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int)
-> (State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
  case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> (# State# RealWorld, Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchSubIntArray# MutableByteArray# RealWorld
mba# Int#
0# Int#
1# State# RealWorld
s0 of { (# State# RealWorld
s1, Int#
old# #) ->
    (# State# RealWorld
s1, Int# -> Int
I# Int#
old# #) }

{-# INLINE appendMVar #-}
appendMVar :: MVar (Seq a) -> a -> IO ()
appendMVar :: MVar (Seq a) -> a -> Action
appendMVar MVar (Seq a)
mvar a
a =
  Action -> Action
forall a. IO a -> IO a
mask_ (Action -> Action) -> Action -> Action
forall a b. (a -> b) -> a -> b
$ do
    Maybe (Seq a)
ma <- MVar (Seq a) -> IO (Maybe (Seq a))
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar (Seq a)
mvar
    case Maybe (Seq a)
ma of
      Maybe (Seq a)
Nothing -> MVar (Seq a) -> Seq a -> Action
forall a. MVar a -> a -> Action
putMVar MVar (Seq a)
mvar (a -> Seq a
forall a. a -> Seq a
Seq.singleton a
a)
      Just Seq a
as -> MVar (Seq a) -> Seq a -> Action
forall a. MVar a -> a -> Action
putMVar MVar (Seq a)
mvar (Seq a
as Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> a
a)


-- Debug
-- -----

message :: String -> IO ()
message :: String -> Action
message = Flag -> String -> Action
D.traceIO Flag
D.dump_sched