{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Array.Accelerate.LLVM.Native.Execute.Scheduler (
Action, Job(..), Workers,
schedule,
hireWorkers, hireWorkersOn, retireWorkers, fireWorkers, numWorkers,
) where
import qualified Data.Array.Accelerate.LLVM.Native.Debug as D
import Control.Concurrent
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Data.Concurrent.Queue.MichaelScott
import Data.IORef
import Data.Int
import Data.Sequence ( Seq )
import Text.Printf
import qualified Data.Sequence as Seq
import GHC.Base
#include "MachDeps.h"
type Action = IO ()
data Task
= Work Action
| Retire
data Job = Job
{ Job -> Seq Action
jobTasks :: !(Seq Action)
, Job -> Maybe Action
jobDone :: !(Maybe Action)
}
data Workers = Workers
{ Workers -> Int
workerCount :: {-# UNPACK #-} !Int
, Workers -> IORef (MVar ())
workerActive :: {-# UNPACK #-} !(IORef (MVar ()))
, Workers -> LinkedQueue Task
workerTaskQueue :: {-# UNPACK #-} !(LinkedQueue Task)
, Workers -> [ThreadId]
workerThreadIds :: ![ThreadId]
, Workers -> MVar (Seq (ThreadId, SomeException))
workerException :: !(MVar (Seq (ThreadId, SomeException)))
}
{-# INLINEABLE schedule #-}
schedule :: Workers -> Job -> IO ()
schedule :: Workers -> Job -> Action
schedule Workers
workers Job{Maybe Action
Seq Action
jobDone :: Maybe Action
jobTasks :: Seq Action
jobDone :: Job -> Maybe Action
jobTasks :: Job -> Seq Action
..} = do
Seq Task
tasks <- case Maybe Action
jobDone of
Maybe Action
Nothing -> Seq Task -> IO (Seq Task)
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq Task -> IO (Seq Task)) -> Seq Task -> IO (Seq Task)
forall a b. (a -> b) -> a -> b
$ (Action -> Task) -> Seq Action -> Seq Task
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Action -> Task
Work Seq Action
jobTasks
Just Action
done -> do
Atomic
count <- Int -> IO Atomic
newAtomic (Seq Action -> Int
forall a. Seq a -> Int
Seq.length Seq Action
jobTasks)
Seq Task -> IO (Seq Task)
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq Task -> IO (Seq Task)) -> Seq Task -> IO (Seq Task)
forall a b. (a -> b) -> a -> b
$ ((Action -> Task) -> Seq Action -> Seq Task)
-> Seq Action -> (Action -> Task) -> Seq Task
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Action -> Task) -> Seq Action -> Seq Task
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Seq Action
jobTasks ((Action -> Task) -> Seq Task) -> (Action -> Task) -> Seq Task
forall a b. (a -> b) -> a -> b
$ \Action
io -> Action -> Task
Work (Action -> Task) -> Action -> Task
forall a b. (a -> b) -> a -> b
$ do
()
_result <- Action
io
Int
remaining <- Atomic -> IO Int
fetchSubAtomic Atomic
count
Bool -> Action -> Action
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
remaining Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) Action
done
Workers -> Seq Task -> Action
pushTasks Workers
workers Seq Task
tasks
runWorker :: ThreadId -> IORef (MVar ()) -> LinkedQueue Task -> IO ()
runWorker :: ThreadId -> IORef (MVar ()) -> LinkedQueue Task -> Action
runWorker ThreadId
tid IORef (MVar ())
ref LinkedQueue Task
queue = Int16 -> Action
loop Int16
0
where
loop :: Int16 -> IO ()
loop :: Int16 -> Action
loop !Int16
retries = do
MVar ()
var <- IORef (MVar ()) -> IO (MVar ())
forall a. IORef a -> IO a
readIORef IORef (MVar ())
ref
Maybe Task
req <- LinkedQueue Task -> IO (Maybe Task)
forall a. LinkedQueue a -> IO (Maybe a)
tryPopR LinkedQueue Task
queue
case Maybe Task
req of
Maybe Task
Nothing -> if Int16
retries Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
< Int16
16
then Int16 -> Action
loop (Int16
retriesInt16 -> Int16 -> Int16
forall a. Num a => a -> a -> a
+Int16
1)
else do
String -> Action
message (String -> Action) -> String -> Action
forall a b. (a -> b) -> a -> b
$ String -> String -> String
forall r. PrintfType r => String -> r
printf String
"sched: %s sleeping" (ThreadId -> String
forall a. Show a => a -> String
show ThreadId
tid)
() <- MVar () -> Action
forall a. MVar a -> IO a
readMVar MVar ()
var
Int16 -> Action
loop Int16
0
Just Task
task -> case Task
task of
Work Action
io -> Action
io Action -> Action -> Action
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int16 -> Action
loop Int16
0
Task
Retire -> String -> Action
message (String -> Action) -> String -> Action
forall a b. (a -> b) -> a -> b
$ String -> String -> String
forall r. PrintfType r => String -> r
printf String
"sched: %s shutting down" (ThreadId -> String
forall a. Show a => a -> String
show ThreadId
tid)
hireWorkers :: IO Workers
hireWorkers :: IO Workers
hireWorkers = do
Int
ncpu <- IO Int
getNumCapabilities
Workers
workers <- [Int] -> IO Workers
hireWorkersOn [Int
0 .. Int
ncpuInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
Workers -> IO Workers
forall (m :: * -> *) a. Monad m => a -> m a
return Workers
workers
hireWorkersOn :: [Int] -> IO Workers
hireWorkersOn :: [Int] -> IO Workers
hireWorkersOn [Int]
caps = do
MVar ()
active <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
IORef (MVar ())
workerActive <- MVar () -> IO (IORef (MVar ()))
forall a. a -> IO (IORef a)
newIORef MVar ()
active
MVar (Seq (ThreadId, SomeException))
workerException <- IO (MVar (Seq (ThreadId, SomeException)))
forall a. IO (MVar a)
newEmptyMVar
LinkedQueue Task
workerTaskQueue <- IO (LinkedQueue Task)
forall a. IO (LinkedQueue a)
newQ
[ThreadId]
workerThreadIds <- [Int] -> (Int -> IO ThreadId) -> IO [ThreadId]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int]
caps ((Int -> IO ThreadId) -> IO [ThreadId])
-> (Int -> IO ThreadId) -> IO [ThreadId]
forall a b. (a -> b) -> a -> b
$ \Int
cpu -> do
ThreadId
tid <- IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Int -> ((forall a. IO a -> IO a) -> Action) -> IO ThreadId
forkOnWithUnmask Int
cpu (((forall a. IO a -> IO a) -> Action) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> Action) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
ThreadId
tid <- IO ThreadId
myThreadId
Action -> (SomeException -> Action) -> Action
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
(Action -> Action
forall a. IO a -> IO a
restore (Action -> Action) -> Action -> Action
forall a b. (a -> b) -> a -> b
$ ThreadId -> IORef (MVar ()) -> LinkedQueue Task -> Action
runWorker ThreadId
tid IORef (MVar ())
workerActive LinkedQueue Task
workerTaskQueue)
(MVar (Seq (ThreadId, SomeException))
-> (ThreadId, SomeException) -> Action
forall a. MVar (Seq a) -> a -> Action
appendMVar MVar (Seq (ThreadId, SomeException))
workerException ((ThreadId, SomeException) -> Action)
-> (SomeException -> (ThreadId, SomeException))
-> SomeException
-> Action
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ThreadId
tid,))
String -> Action
message (String -> Action) -> String -> Action
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"sched: fork %s on capability %d" (ThreadId -> String
forall a. Show a => a -> String
show ThreadId
tid) Int
cpu
ThreadId -> IO ThreadId
forall (m :: * -> *) a. Monad m => a -> m a
return ThreadId
tid
[ThreadId]
workerThreadIds [ThreadId] -> IO Workers -> IO Workers
forall a b. NFData a => a -> b -> b
`deepseq` Workers -> IO Workers
forall (m :: * -> *) a. Monad m => a -> m a
return Workers :: Int
-> IORef (MVar ())
-> LinkedQueue Task
-> [ThreadId]
-> MVar (Seq (ThreadId, SomeException))
-> Workers
Workers { workerCount :: Int
workerCount = [ThreadId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ThreadId]
workerThreadIds, [ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerActive :: IORef (MVar ())
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
..}
retireWorkers :: Workers -> IO ()
retireWorkers :: Workers -> Action
retireWorkers workers :: Workers
workers@Workers{Int
[ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
workerCount :: Int
workerException :: Workers -> MVar (Seq (ThreadId, SomeException))
workerThreadIds :: Workers -> [ThreadId]
workerTaskQueue :: Workers -> LinkedQueue Task
workerActive :: Workers -> IORef (MVar ())
workerCount :: Workers -> Int
..} =
Workers -> Seq Task -> Action
pushTasks Workers
workers (Int -> Task -> Seq Task
forall a. Int -> a -> Seq a
Seq.replicate Int
workerCount Task
Retire)
pushTasks :: Workers -> Seq Task -> IO ()
pushTasks :: Workers -> Seq Task -> Action
pushTasks Workers{Int
[ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
workerCount :: Int
workerException :: Workers -> MVar (Seq (ThreadId, SomeException))
workerThreadIds :: Workers -> [ThreadId]
workerTaskQueue :: Workers -> LinkedQueue Task
workerActive :: Workers -> IORef (MVar ())
workerCount :: Workers -> Int
..} Seq Task
tasks = do
(Task -> Action) -> Seq Task -> Action
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (LinkedQueue Task -> Task -> Action
forall a. LinkedQueue a -> a -> Action
pushL LinkedQueue Task
workerTaskQueue) Seq Task
tasks
MVar ()
new <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
MVar ()
old <- IORef (MVar ()) -> (MVar () -> (MVar (), MVar ())) -> IO (MVar ())
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (MVar ())
workerActive (MVar ()
new,)
MVar () -> () -> Action
forall a. MVar a -> a -> Action
putMVar MVar ()
old ()
fireWorkers :: Workers -> IO ()
fireWorkers :: Workers -> Action
fireWorkers Workers{Int
[ThreadId]
IORef (MVar ())
MVar (Seq (ThreadId, SomeException))
LinkedQueue Task
workerException :: MVar (Seq (ThreadId, SomeException))
workerThreadIds :: [ThreadId]
workerTaskQueue :: LinkedQueue Task
workerActive :: IORef (MVar ())
workerCount :: Int
workerException :: Workers -> MVar (Seq (ThreadId, SomeException))
workerThreadIds :: Workers -> [ThreadId]
workerTaskQueue :: Workers -> LinkedQueue Task
workerActive :: Workers -> IORef (MVar ())
workerCount :: Workers -> Int
..} =
(ThreadId -> Action) -> [ThreadId] -> Action
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ThreadId -> Action
killThread [ThreadId]
workerThreadIds
numWorkers :: Workers -> Int
numWorkers :: Workers -> Int
numWorkers = Workers -> Int
workerCount
data Atomic = Atomic !(MutableByteArray# RealWorld)
{-# INLINE newAtomic #-}
newAtomic :: Int -> IO Atomic
newAtomic :: Int -> IO Atomic
newAtomic (I# Int#
n#) = (State# RealWorld -> (# State# RealWorld, Atomic #)) -> IO Atomic
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Atomic #)) -> IO Atomic)
-> (State# RealWorld -> (# State# RealWorld, Atomic #))
-> IO Atomic
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
case SIZEOF_HSINT of { I# size# ->
case newByteArray# size# s0 of { (# s1, mba# #) ->
case writeIntArray# mba# 0# n# s1 of { s2 ->
(# s2, Atomic mba# #) }}}
{-# INLINE fetchSubAtomic #-}
fetchSubAtomic :: Atomic -> IO Int
fetchSubAtomic :: Atomic -> IO Int
fetchSubAtomic (Atomic MutableByteArray# RealWorld
mba#) = (State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int)
-> (State# RealWorld -> (# State# RealWorld, Int #)) -> IO Int
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s0 ->
case MutableByteArray# RealWorld
-> Int# -> Int# -> State# RealWorld -> (# State# RealWorld, Int# #)
forall d.
MutableByteArray# d
-> Int# -> Int# -> State# d -> (# State# d, Int# #)
fetchSubIntArray# MutableByteArray# RealWorld
mba# Int#
0# Int#
1# State# RealWorld
s0 of { (# State# RealWorld
s1, Int#
old# #) ->
(# State# RealWorld
s1, Int# -> Int
I# Int#
old# #) }
{-# INLINE appendMVar #-}
appendMVar :: MVar (Seq a) -> a -> IO ()
appendMVar :: MVar (Seq a) -> a -> Action
appendMVar MVar (Seq a)
mvar a
a =
Action -> Action
forall a. IO a -> IO a
mask_ (Action -> Action) -> Action -> Action
forall a b. (a -> b) -> a -> b
$ do
Maybe (Seq a)
ma <- MVar (Seq a) -> IO (Maybe (Seq a))
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar (Seq a)
mvar
case Maybe (Seq a)
ma of
Maybe (Seq a)
Nothing -> MVar (Seq a) -> Seq a -> Action
forall a. MVar a -> a -> Action
putMVar MVar (Seq a)
mvar (a -> Seq a
forall a. a -> Seq a
Seq.singleton a
a)
Just Seq a
as -> MVar (Seq a) -> Seq a -> Action
forall a. MVar a -> a -> Action
putMVar MVar (Seq a)
mvar (Seq a
as Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> a
a)
message :: String -> IO ()
message :: String -> Action
message = Flag -> String -> Action
D.traceIO Flag
D.dump_sched