{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Scheduler
(
Scheduler
, SchedulerWS
, trivialScheduler_
, withScheduler
, withScheduler_
, withSchedulerWS
, withSchedulerWS_
, unwrapSchedulerWS
, scheduleWork
, scheduleWork_
, scheduleWorkId
, scheduleWorkId_
, scheduleWorkState
, scheduleWorkState_
, terminate
, terminate_
, terminateWith
, WorkerId(..)
, WorkerStates
, numWorkers
, workerStatesComp
, initWorkerStates
, Comp(..)
, getCompWorkers
, replicateConcurrently
, replicateConcurrently_
, traverseConcurrently
, traverseConcurrently_
, traverse_
, MutexException(..)
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import Control.Monad.IO.Unlift
import Control.Scheduler.Computation
import Control.Scheduler.Internal
import Control.Scheduler.Queue
import Data.Atomics (atomicModifyIORefCAS, atomicModifyIORefCAS_)
import qualified Data.Foldable as F (foldl', traverse_)
import Data.IORef
import Data.Maybe (catMaybes)
import Data.Primitive.Array
import Data.Traversable
#if !MIN_VERSION_primitive(0,6,2)
import Control.Monad.ST
#endif
unwrapSchedulerWS :: SchedulerWS s m a -> Scheduler m a
unwrapSchedulerWS = _getScheduler
initWorkerStates :: MonadIO m => Comp -> (WorkerId -> m s) -> m (WorkerStates s)
initWorkerStates comp initState = do
nWorkers <- getCompWorkers comp
workerStates <- mapM (initState . WorkerId) [0 .. nWorkers - 1]
mutex <- liftIO $ newIORef False
pure
WorkerStates
{ _workerStatesComp = comp
, _workerStatesArray = arrayFromListN nWorkers workerStates
, _workerStatesMutex = mutex
}
arrayFromListN :: Int -> [a] -> Array a
#if MIN_VERSION_primitive(0,6,2)
arrayFromListN = fromListN
#else
arrayFromListN n l =
runST $ do
ma <- newArray n (error "initWorkerStates: uninitialized element")
let go !ix [] =
if ix == n
then return ()
else error "initWorkerStates: list length less than specified size"
go !ix (x:xs) =
if ix < n
then do
writeArray ma ix x
go (ix + 1) xs
else error "initWorkerStates: list length greater than specified size"
go 0 l
unsafeFreezeArray ma
#endif
workerStatesComp :: WorkerStates s -> Comp
workerStatesComp = _workerStatesComp
withSchedulerWS :: MonadUnliftIO m => WorkerStates s -> (SchedulerWS s m a -> m b) -> m [a]
withSchedulerWS states action =
withRunInIO $ \run -> bracket lockState unlockState (run . runSchedulerWS)
where
mutex = _workerStatesMutex states
lockState = atomicModifyIORef' mutex ((,) True)
unlockState wasLocked
| wasLocked = pure ()
| otherwise = writeIORef mutex False
runSchedulerWS isLocked
| isLocked = liftIO $ throwIO MutexException
| otherwise =
withScheduler (_workerStatesComp states) $ \scheduler ->
action (SchedulerWS states scheduler)
withSchedulerWS_ :: MonadUnliftIO m => WorkerStates s -> (SchedulerWS s m () -> m b) -> m ()
withSchedulerWS_ states = void . withSchedulerWS states
scheduleWorkState :: SchedulerWS s m a -> (s -> m a) -> m ()
scheduleWorkState schedulerS withState =
scheduleWorkId (_getScheduler schedulerS) $ \(WorkerId i) ->
withState (indexArray (_workerStatesArray (_workerStates schedulerS)) i)
scheduleWorkState_ :: SchedulerWS s m () -> (s -> m ()) -> m ()
scheduleWorkState_ schedulerS withState =
scheduleWorkId_ (_getScheduler schedulerS) $ \(WorkerId i) ->
withState (indexArray (_workerStatesArray (_workerStates schedulerS)) i)
numWorkers :: Scheduler m a -> Int
numWorkers = _numWorkers
scheduleWorkId :: Scheduler m a -> (WorkerId -> m a) -> m ()
scheduleWorkId =_scheduleWorkId
terminate :: Scheduler m a -> a -> m a
terminate = _terminate
terminateWith :: Scheduler m a -> a -> m a
terminateWith = _terminateWith
scheduleWork :: Scheduler m a -> m a -> m ()
scheduleWork scheduler f = _scheduleWorkId scheduler (const f)
scheduleWork_ :: Scheduler m () -> m () -> m ()
scheduleWork_ = scheduleWork
scheduleWorkId_ :: Scheduler m () -> (WorkerId -> m ()) -> m ()
scheduleWorkId_ = _scheduleWorkId
terminate_ :: Scheduler m () -> m ()
terminate_ = (`_terminateWith` ())
trivialScheduler_ :: Applicative f => Scheduler f ()
trivialScheduler_ = Scheduler
{ _numWorkers = 1
, _scheduleWorkId = \f -> f (WorkerId 0)
, _terminate = const $ pure ()
, _terminateWith = const $ pure ()
}
traverse_ :: (Applicative f, Foldable t) => (a -> f ()) -> t a -> f ()
traverse_ f = F.foldl' (\c a -> c *> f a) (pure ())
traverseConcurrently :: (MonadUnliftIO m, Traversable t) => Comp -> (a -> m b) -> t a -> m (t b)
traverseConcurrently comp f xs = do
ys <- withScheduler comp $ \s -> traverse_ (scheduleWork s . f) xs
pure $ transList ys xs
transList :: Traversable t => [a] -> t b -> t a
transList xs' = snd . mapAccumL withR xs'
where
withR (x:xs) _ = (xs, x)
withR _ _ = error "Impossible<traverseConcurrently> - Mismatched sizes"
traverseConcurrently_ :: (MonadUnliftIO m, Foldable t) => Comp -> (a -> m b) -> t a -> m ()
traverseConcurrently_ comp f xs =
withScheduler_ comp $ \s -> scheduleWork s $ F.traverse_ (scheduleWork s . void . f) xs
replicateConcurrently :: MonadUnliftIO m => Comp -> Int -> m a -> m [a]
replicateConcurrently comp n f =
withScheduler comp $ \s -> replicateM_ n $ scheduleWork s f
replicateConcurrently_ :: MonadUnliftIO m => Comp -> Int -> m a -> m ()
replicateConcurrently_ comp n f =
withScheduler_ comp $ \s -> scheduleWork s $ replicateM_ n (scheduleWork s $ void f)
scheduleJobs :: MonadIO m => Jobs m a -> (WorkerId -> m a) -> m ()
scheduleJobs = scheduleJobsWith mkJob
scheduleJobs_ :: MonadIO m => Jobs m a -> (WorkerId -> m b) -> m ()
scheduleJobs_ = scheduleJobsWith (\job -> pure (Job_ (void . job)))
scheduleJobsWith ::
MonadIO m => ((WorkerId -> m b) -> m (Job m a)) -> Jobs m a -> (WorkerId -> m b) -> m ()
scheduleJobsWith mkJob' jobs action = do
liftIO $ atomicModifyIORefCAS_ (jobsCountRef jobs) (+ 1)
job <-
mkJob' $ \ i -> do
res <- action i
res `seq`
dropCounterOnZero (jobsCountRef jobs) $
retireWorkersN (jobsQueue jobs) (jobsNumWorkers jobs)
return res
pushJQueue (jobsQueue jobs) job
retireWorkersN :: MonadIO m => JQueue m a -> Int -> m ()
retireWorkersN jobsQueue n = traverse_ (pushJQueue jobsQueue) $ replicate n Retire
dropCounterOnZero :: MonadIO m => IORef Int -> m () -> m ()
dropCounterOnZero counterRef onZero = do
jc <-
liftIO $
atomicModifyIORefCAS
counterRef
(\ !i' ->
let !i = i' - 1
in (i, i))
when (jc == 0) onZero
runWorker :: MonadIO m =>
WorkerId
-> JQueue m a
-> m ()
-> m ()
runWorker wId jQueue onRetire = go
where
go =
popJQueue jQueue >>= \case
Just job -> job wId >> go
Nothing -> onRetire
withScheduler ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m [a]
withScheduler comp = withSchedulerInternal comp scheduleJobs readResults reverse
withScheduler_ ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m ()
withScheduler_ comp = void . withSchedulerInternal comp scheduleJobs_ (const (pure [])) id
withSchedulerInternal ::
MonadUnliftIO m
=> Comp
-> (Jobs m a -> (WorkerId -> m a) -> m ())
-> (JQueue m a -> m [Maybe a])
-> ([a] -> [a])
-> (Scheduler m a -> m b)
-> m [a]
withSchedulerInternal comp submitWork collect adjust onScheduler = do
jobsNumWorkers <- getCompWorkers comp
sWorkersCounterRef <- liftIO $ newIORef jobsNumWorkers
jobsQueue <- newJQueue
jobsCountRef <- liftIO $ newIORef 0
workDoneMVar <- liftIO newEmptyMVar
let jobs = Jobs {..}
scheduler =
Scheduler
{ _numWorkers = jobsNumWorkers
, _scheduleWorkId = submitWork jobs
, _terminate =
\a -> do
mas <- collect jobsQueue
let as = adjust (a : catMaybes mas)
liftIO $ void $ tryPutMVar workDoneMVar $ SchedulerTerminatedEarly as
pure a
, _terminateWith =
\a -> do
liftIO $ void $ tryPutMVar workDoneMVar $ SchedulerTerminatedEarly [a]
pure a
}
onRetire =
dropCounterOnZero sWorkersCounterRef $
void $ liftIO (tryPutMVar workDoneMVar SchedulerFinished)
_ <- onScheduler scheduler
jc <- liftIO $ readIORef jobsCountRef
when (jc == 0) $ scheduleJobs_ jobs (\_ -> pure ())
let spawnWorkersWith fork ws =
withRunInIO $ \run ->
forM (zip [0 ..] ws) $ \(wId, on) ->
fork on $ \unmask ->
catch
(unmask $ run $ runWorker wId jobsQueue onRetire)
(run . handleWorkerException jobsQueue workDoneMVar jobsNumWorkers)
spawnWorkers =
case comp of
Seq -> return []
Par -> spawnWorkersWith forkOnWithUnmask [1 .. jobsNumWorkers]
ParOn ws -> spawnWorkersWith forkOnWithUnmask ws
ParN _ -> spawnWorkersWith (\_ -> forkIOWithUnmask) [1 .. jobsNumWorkers]
terminateWorkers = liftIO . traverse_ (`throwTo` SomeAsyncException WorkerTerminateException)
doWork tids = do
when (comp == Seq) $ runWorker 0 jobsQueue onRetire
mExc <- liftIO $ readMVar workDoneMVar
case mExc of
SchedulerFinished -> adjust . catMaybes <$> collect jobsQueue
SchedulerTerminatedEarly as -> terminateWorkers tids >> pure as
SchedulerWorkerException (WorkerException exc) -> liftIO $ throwIO exc
safeBracketOnError spawnWorkers terminateWorkers doWork
handleWorkerException ::
MonadIO m => JQueue m a -> MVar (SchedulerOutcome a) -> Int -> SomeException -> m ()
handleWorkerException jQueue workDoneMVar nWorkers exc =
case asyncExceptionFromException exc of
Just WorkerTerminateException -> return ()
_ -> do
_ <- liftIO $ tryPutMVar workDoneMVar $ SchedulerWorkerException $ WorkerException exc
retireWorkersN jQueue (nWorkers - 1)
safeBracketOnError :: MonadUnliftIO m => m a -> (a -> m b) -> (a -> m c) -> m c
safeBracketOnError before after thing = withRunInIO $ \run -> mask $ \restore -> do
x <- run before
res1 <- try $ restore $ run $ thing x
case res1 of
Left (e1 :: SomeException) -> do
_ :: Either SomeException b <-
try $ uninterruptibleMask_ $ run $ after x
throwIO e1
Right y -> return y