module Data.Massiv.Core.Scheduler
( Scheduler
, numWorkers
, scheduleWork
, withScheduler
, withScheduler'
, withScheduler_
, divideWork
, divideWork_
) where
import Control.Concurrent (ThreadId, forkOnWithUnmask,
getNumCapabilities, killThread)
import Control.Concurrent.MVar
import Control.DeepSeq
import Control.Exception (SomeException, catch, mask,
mask_, throwIO, try,
uninterruptibleMask_)
import Control.Monad (forM)
import Control.Monad.Primitive (RealWorld)
import Data.IORef (IORef, atomicModifyIORef',
newIORef, readIORef)
import Data.Massiv.Core.Index.Class (Index (totalElem))
import Data.Massiv.Core.Iterator (loop)
import Data.Primitive.Array (Array, MutableArray, indexArray,
newArray, unsafeFreezeArray,
writeArray)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem.Weak
data Job = Job (IO ())
| Retire
data Scheduler a = Scheduler
{ jobsCountIORef :: !(IORef Int)
, jobQueueMVar :: !(MVar [Job])
, resultsMVar :: !(MVar (MutableArray RealWorld a))
, workers :: !Workers
, numWorkers :: !Int
}
data Workers = Workers { workerThreadIds :: ![ThreadId]
, workerJobDone :: !(MVar (Maybe SomeException))
, workerJobQueue :: !(MVar [Job])
}
scheduleWork :: Scheduler a
-> IO a
-> IO ()
scheduleWork Scheduler {..} jobAction = do
modifyMVar_ jobQueueMVar $ \jobs -> do
jix <- atomicModifyIORef' jobsCountIORef $ \jc -> (jc + 1, jc)
let job =
Job $ do
jobResult <- jobAction
withMVar resultsMVar $ \resArray -> do
writeArray resArray jix jobResult
putMVar (workerJobDone workers) Nothing
return (job : jobs)
uninitialized :: a
uninitialized = error "Data.Array.Massiv.Scheduler: uncomputed job result"
bracketWithException :: forall a b c d .
IO a
-> (a -> IO b)
-> (SomeException -> a -> IO c)
-> (a -> IO d)
-> IO d
bracketWithException before afterSuccess afterError thing = mask $ \restore -> do
x <- before
eRes <- try $ restore (thing x)
case eRes of
Left (exc :: SomeException) -> do
_ :: Either SomeException c <- try $ uninterruptibleMask_ $ afterError exc x
throwIO exc
Right y -> do
_ <- uninterruptibleMask_ $ afterSuccess x
return y
withScheduler :: [Int]
-> (Scheduler a -> IO b)
-> IO (Int, Array a)
withScheduler wss submitJobs = do
jobsCountIORef <- newIORef 0
jobQueueMVar <- newMVar []
resultsMVar <- newEmptyMVar
bracketWithException
(do mWeakWorkers <-
if null wss
then tryTakeMVar globalWorkersMVar
else return Nothing
mGlobalWorkers <- maybe (return Nothing) deRefWeak mWeakWorkers
let toWorkers w = return (mWeakWorkers, w)
maybe (hireWorkers wss >>= toWorkers) toWorkers mGlobalWorkers)
(\(mWeakWorkers, workers) -> do
case mWeakWorkers of
Nothing ->
putMVar (workerJobQueue workers) $
replicate (length (workerThreadIds workers)) Retire
Just weak -> putMVar globalWorkersMVar weak)
(\_ (mWeakWorkers, workers) -> do
case mWeakWorkers of
Nothing -> mapM_ killThread (workerThreadIds workers)
Just weakWorkers -> do
finalize weakWorkers
newWeakWorkers <- hireWeakWorkers globalWorkersMVar
putMVar globalWorkersMVar newWeakWorkers)
(\(_, workers) -> do
let scheduler =
Scheduler {numWorkers = length $ workerThreadIds workers, ..}
_ <- submitJobs scheduler
jobCount <- readIORef jobsCountIORef
marr <- newArray jobCount uninitialized
putMVar resultsMVar marr
jobQueue <- takeMVar jobQueueMVar
putMVar (workerJobQueue workers) $ reverse jobQueue
waitTillDone scheduler
arr <- unsafeFreezeArray marr
return (jobCount, arr))
withScheduler' :: [Int] -> (Scheduler a -> IO b) -> IO [a]
withScheduler' wss submitJobs = do
(jc, arr) <- withScheduler wss submitJobs
return $
loop (jc 1) (>= 0) (subtract 1) [] $ \i acc -> indexArray arr i : acc
withScheduler_ :: [Int] -> (Scheduler a -> IO b) -> IO ()
withScheduler_ wss submitJobs = withScheduler wss submitJobs >> return ()
divideWork_ :: Index ix
=> [Int] -> ix -> (Scheduler a -> Int -> Int -> Int -> IO b) -> IO ()
divideWork_ wss sz submit = divideWork wss sz submit >> return ()
divideWork :: Index ix
=> [Int]
-> ix
-> (Scheduler a -> Int -> Int -> Int -> IO b)
-> IO [a]
divideWork wss sz submit
| totalElem sz == 0 = return []
| otherwise = do
withScheduler' wss $ \scheduler -> do
let !totalLength = totalElem sz
!chunkLength = totalLength `quot` numWorkers scheduler
!slackStart = chunkLength * numWorkers scheduler
submit scheduler chunkLength totalLength slackStart
waitTillDone :: Scheduler a -> IO ()
waitTillDone (Scheduler {..}) = readIORef jobsCountIORef >>= waitTill 0
where
waitTill jobsDone jobsCount
| jobsDone == jobsCount = return ()
| otherwise = do
mExc <- takeMVar (workerJobDone workers)
case mExc of
Just exc -> throwIO exc
Nothing -> waitTill (jobsDone + 1) jobsCount
runWorker :: MVar [Job] -> IO ()
runWorker jobsMVar = do
jobs <- takeMVar jobsMVar
case jobs of
(Job job:rest) -> putMVar jobsMVar rest >> job >> runWorker jobsMVar
(Retire:rest) -> putMVar jobsMVar rest
[] -> runWorker jobsMVar
hireWorkers :: [Int] -> IO Workers
hireWorkers wss = do
wss' <-
if null wss
then do
wNum <- getNumCapabilities
return [0 .. wNum 1]
else return wss
workerJobQueue <- newEmptyMVar
workerJobDone <- newEmptyMVar
workerThreadIds <-
forM wss' $ \ws ->
mask_ $
forkOnWithUnmask ws $ \unmask -> do
catch
(unmask $ runWorker workerJobQueue)
(unmask . putMVar workerJobDone . Just)
workerThreadIds `deepseq` return Workers {..}
globalWorkersMVar :: MVar (Weak Workers)
globalWorkersMVar = unsafePerformIO $ do
workersMVar <- newEmptyMVar
weakWorkers <- hireWeakWorkers workersMVar
putMVar workersMVar weakWorkers
return workersMVar
hireWeakWorkers :: key -> IO (Weak Workers)
hireWeakWorkers k = do
workers <- hireWorkers []
mkWeak k workers (Just (mapM_ killThread (workerThreadIds workers)))