{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Scheduler
(
Scheduler
, SchedulerWS
, Results(..)
, withScheduler
, withScheduler_
, withSchedulerR
, withSchedulerWS
, withSchedulerWS_
, withSchedulerWSR
, unwrapSchedulerWS
, trivialScheduler_
, withTrivialScheduler
, withTrivialSchedulerR
, scheduleWork
, scheduleWork_
, scheduleWorkId
, scheduleWorkId_
, scheduleWorkState
, scheduleWorkState_
, replicateWork
, terminate
, terminate_
, terminateWith
, WorkerId(..)
, WorkerStates
, numWorkers
, workerStatesComp
, initWorkerStates
, Comp(..)
, getCompWorkers
, replicateConcurrently
, replicateConcurrently_
, traverseConcurrently
, traverseConcurrently_
, traverse_
, MutexException(..)
) where
import Control.Concurrent
import Control.Exception
import Control.Monad
import Control.Monad.IO.Unlift
import Control.Monad.Primitive (PrimMonad)
import Control.Scheduler.Computation
import Control.Scheduler.Internal
import Control.Scheduler.Queue
import Data.Atomics (atomicModifyIORefCAS, atomicModifyIORefCAS_)
import qualified Data.Foldable as F (foldl', traverse_, toList)
import Data.IORef
import Data.Maybe (catMaybes)
import Data.Primitive.Array
import Data.Primitive.MutVar
import Data.Traversable
#if !MIN_VERSION_primitive(0,6,2)
import Control.Monad.ST
#endif
unwrapSchedulerWS :: SchedulerWS s m a -> Scheduler m a
unwrapSchedulerWS = _getScheduler
initWorkerStates :: MonadIO m => Comp -> (WorkerId -> m s) -> m (WorkerStates s)
initWorkerStates comp initState = do
nWorkers <- getCompWorkers comp
workerStates <- mapM (initState . WorkerId) [0 .. nWorkers - 1]
mutex <- liftIO $ newIORef False
pure
WorkerStates
{ _workerStatesComp = comp
, _workerStatesArray = arrayFromListN nWorkers workerStates
, _workerStatesMutex = mutex
}
arrayFromListN :: Int -> [a] -> Array a
#if MIN_VERSION_primitive(0,6,2)
arrayFromListN = fromListN
#else
arrayFromListN n l =
runST $ do
ma <- newArray n (error "initWorkerStates: uninitialized element")
let go !ix [] =
if ix == n
then return ()
else error "initWorkerStates: list length less than specified size"
go !ix (x:xs) =
if ix < n
then do
writeArray ma ix x
go (ix + 1) xs
else error "initWorkerStates: list length greater than specified size"
go 0 l
unsafeFreezeArray ma
#endif
workerStatesComp :: WorkerStates s -> Comp
workerStatesComp = _workerStatesComp
withSchedulerWSInternal ::
MonadUnliftIO m
=> (Comp -> (Scheduler m a -> t) -> m b)
-> WorkerStates s
-> (SchedulerWS s m a -> t)
-> m b
withSchedulerWSInternal withScheduler' states action =
withRunInIO $ \run -> bracket lockState unlockState (run . runSchedulerWS)
where
mutex = _workerStatesMutex states
lockState = atomicModifyIORef' mutex ((,) True)
unlockState wasLocked
| wasLocked = pure ()
| otherwise = writeIORef mutex False
runSchedulerWS isLocked
| isLocked = liftIO $ throwIO MutexException
| otherwise =
withScheduler' (_workerStatesComp states) $ \scheduler ->
action (SchedulerWS states scheduler)
withSchedulerWS :: MonadUnliftIO m => WorkerStates s -> (SchedulerWS s m a -> m b) -> m [a]
withSchedulerWS = withSchedulerWSInternal withScheduler
withSchedulerWS_ :: MonadUnliftIO m => WorkerStates s -> (SchedulerWS s m () -> m b) -> m ()
withSchedulerWS_ = withSchedulerWSInternal withScheduler_
withSchedulerWSR :: MonadUnliftIO m => WorkerStates s -> (SchedulerWS s m a -> m b) -> m (Results a)
withSchedulerWSR = withSchedulerWSInternal withSchedulerR
scheduleWorkState :: SchedulerWS s m a -> (s -> m a) -> m ()
scheduleWorkState schedulerS withState =
scheduleWorkId (_getScheduler schedulerS) $ \(WorkerId i) ->
withState (indexArray (_workerStatesArray (_workerStates schedulerS)) i)
scheduleWorkState_ :: SchedulerWS s m () -> (s -> m ()) -> m ()
scheduleWorkState_ schedulerS withState =
scheduleWorkId_ (_getScheduler schedulerS) $ \(WorkerId i) ->
withState (indexArray (_workerStatesArray (_workerStates schedulerS)) i)
numWorkers :: Scheduler m a -> Int
numWorkers = _numWorkers
scheduleWorkId :: Scheduler m a -> (WorkerId -> m a) -> m ()
scheduleWorkId =_scheduleWorkId
terminate :: Scheduler m a -> a -> m a
terminate = _terminate
terminateWith :: Scheduler m a -> a -> m a
terminateWith = _terminateWith
scheduleWork :: Scheduler m a -> m a -> m ()
scheduleWork scheduler f = _scheduleWorkId scheduler (const f)
scheduleWork_ :: Scheduler m () -> m () -> m ()
scheduleWork_ = scheduleWork
scheduleWorkId_ :: Scheduler m () -> (WorkerId -> m ()) -> m ()
scheduleWorkId_ = _scheduleWorkId
replicateWork :: Applicative m => Int -> Scheduler m a -> m a -> m ()
replicateWork !n scheduler f = go n
where
go !k
| k <= 0 = pure ()
| otherwise = scheduleWork scheduler f *> go (k - 1)
terminate_ :: Scheduler m () -> m ()
terminate_ = (`_terminateWith` ())
trivialScheduler_ :: Applicative f => Scheduler f ()
trivialScheduler_ = Scheduler
{ _numWorkers = 1
, _scheduleWorkId = \f -> f (WorkerId 0)
, _terminate = const $ pure ()
, _terminateWith = const $ pure ()
}
withTrivialScheduler :: PrimMonad m => (Scheduler m a -> m b) -> m [a]
withTrivialScheduler action = F.toList <$> withTrivialSchedulerR action
withTrivialSchedulerR :: PrimMonad m => (Scheduler m a -> m b) -> m (Results a)
withTrivialSchedulerR action = do
resVar <- newMutVar []
finResVar <- newMutVar Nothing
_ <- action $ Scheduler
{ _numWorkers = 1
, _scheduleWorkId = \f -> do
r <- f (WorkerId 0)
modifyMutVar' resVar (r:)
, _terminate = \r -> do
rs <- readMutVar resVar
writeMutVar finResVar (Just (FinishedEarly rs r))
pure r
, _terminateWith = \r -> do
writeMutVar finResVar (Just (FinishedEarlyWith r))
pure r
}
readMutVar finResVar >>= \case
Just rs -> pure $ reverseResults rs
Nothing -> Finished . Prelude.reverse <$> readMutVar resVar
traverse_ :: (Applicative f, Foldable t) => (a -> f ()) -> t a -> f ()
traverse_ f = F.foldl' (\c a -> c *> f a) (pure ())
traverseConcurrently :: (MonadUnliftIO m, Traversable t) => Comp -> (a -> m b) -> t a -> m (t b)
traverseConcurrently comp f xs = do
ys <- withScheduler comp $ \s -> traverse_ (scheduleWork s . f) xs
pure $ transList ys xs
transList :: Traversable t => [a] -> t b -> t a
transList xs' = snd . mapAccumL withR xs'
where
withR (x:xs) _ = (xs, x)
withR _ _ = errorWithoutStackTrace "Impossible<traverseConcurrently> - Mismatched sizes"
traverseConcurrently_ :: (MonadUnliftIO m, Foldable t) => Comp -> (a -> m b) -> t a -> m ()
traverseConcurrently_ comp f xs =
withScheduler_ comp $ \s -> scheduleWork s $ F.traverse_ (scheduleWork s . void . f) xs
replicateConcurrently :: MonadUnliftIO m => Comp -> Int -> m a -> m [a]
replicateConcurrently comp n f =
withScheduler comp $ \s -> replicateM_ n $ scheduleWork s f
replicateConcurrently_ :: MonadUnliftIO m => Comp -> Int -> m a -> m ()
replicateConcurrently_ comp n f =
withScheduler_ comp $ \s -> scheduleWork s $ replicateM_ n (scheduleWork s $ void f)
scheduleJobs :: MonadIO m => Jobs m a -> (WorkerId -> m a) -> m ()
scheduleJobs = scheduleJobsWith mkJob
scheduleJobs_ :: MonadIO m => Jobs m a -> (WorkerId -> m b) -> m ()
scheduleJobs_ = scheduleJobsWith (\job -> pure (Job_ (void . job)))
scheduleJobsWith ::
MonadIO m => ((WorkerId -> m b) -> m (Job m a)) -> Jobs m a -> (WorkerId -> m b) -> m ()
scheduleJobsWith mkJob' jobs action = do
liftIO $ atomicModifyIORefCAS_ (jobsCountRef jobs) (+ 1)
job <-
mkJob' $ \ i -> do
res <- action i
res `seq`
dropCounterOnZero (jobsCountRef jobs) $
retireWorkersN (jobsQueue jobs) (jobsNumWorkers jobs)
return res
pushJQueue (jobsQueue jobs) job
retireWorkersN :: MonadIO m => JQueue m a -> Int -> m ()
retireWorkersN jobsQueue n = traverse_ (pushJQueue jobsQueue) $ replicate n Retire
dropCounterOnZero :: MonadIO m => IORef Int -> m () -> m ()
dropCounterOnZero counterRef onZero = do
jc <-
liftIO $
atomicModifyIORefCAS
counterRef
(\ !i' ->
let !i = i' - 1
in (i, i))
when (jc == 0) onZero
runWorker :: MonadIO m =>
WorkerId
-> JQueue m a
-> m ()
-> m ()
runWorker wId jQueue onRetire = go
where
go =
popJQueue jQueue >>= \case
Just job -> job wId >> go
Nothing -> onRetire
withScheduler ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m [a]
withScheduler comp = withSchedulerInternal comp scheduleJobs readResults (reverse . resultsToList)
withSchedulerR ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m (Results a)
withSchedulerR comp = withSchedulerInternal comp scheduleJobs readResults reverseResults
withScheduler_ ::
MonadUnliftIO m
=> Comp
-> (Scheduler m a -> m b)
-> m ()
withScheduler_ comp = void . withSchedulerInternal comp scheduleJobs_ (const (pure [])) (const ())
withSchedulerInternal ::
MonadUnliftIO m
=> Comp
-> (Jobs m a -> (WorkerId -> m a) -> m ())
-> (JQueue m a -> m [Maybe a])
-> (Results a -> c)
-> (Scheduler m a -> m b)
-> m c
withSchedulerInternal comp submitWork collect adjust onScheduler = do
jobsNumWorkers <- getCompWorkers comp
sWorkersCounterRef <- liftIO $ newIORef jobsNumWorkers
jobsQueue <- newJQueue
jobsCountRef <- liftIO $ newIORef 0
workDoneMVar <- liftIO newEmptyMVar
let jobs = Jobs {..}
scheduler =
Scheduler
{ _numWorkers = jobsNumWorkers
, _scheduleWorkId = submitWork jobs
, _terminate =
\a -> do
mas <- collect jobsQueue
let as = catMaybes mas
liftIO $
void $ tryPutMVar workDoneMVar $ SchedulerTerminatedEarly (FinishedEarly as a)
pure a
, _terminateWith =
\a -> do
liftIO $
void $ tryPutMVar workDoneMVar $ SchedulerTerminatedEarly (FinishedEarlyWith a)
pure a
}
onRetire =
dropCounterOnZero sWorkersCounterRef $
void $ liftIO (tryPutMVar workDoneMVar SchedulerFinished)
_ <- onScheduler scheduler
jc <- liftIO $ readIORef jobsCountRef
when (jc == 0) $ scheduleJobs_ jobs (\_ -> pure ())
let spawnWorkersWith fork ws =
withRunInIO $ \run ->
forM (zip [0 ..] ws) $ \(wId, on) ->
fork on $ \unmask ->
catch
(unmask $ run $ runWorker wId jobsQueue onRetire)
(run . handleWorkerException jobsQueue workDoneMVar jobsNumWorkers)
spawnWorkers =
case comp of
Seq -> return []
Par -> spawnWorkersWith forkOnWithUnmask [1 .. jobsNumWorkers]
ParOn ws -> spawnWorkersWith forkOnWithUnmask ws
ParN _ -> spawnWorkersWith (\_ -> forkIOWithUnmask) [1 .. jobsNumWorkers]
terminateWorkers = liftIO . traverse_ (`throwTo` SomeAsyncException WorkerTerminateException)
doWork tids = do
when (comp == Seq) $ runWorker 0 jobsQueue onRetire
mExc <- liftIO (readMVar workDoneMVar)
case mExc of
SchedulerFinished -> adjust . Finished . catMaybes <$> collect jobsQueue
SchedulerTerminatedEarly as -> terminateWorkers tids >> pure (adjust as)
SchedulerWorkerException (WorkerException exc) -> liftIO $ throwIO exc
safeBracketOnError spawnWorkers terminateWorkers doWork
resultsToList :: Results a -> [a]
resultsToList = \case
Finished rs -> rs
FinishedEarly rs r -> r:rs
FinishedEarlyWith r -> [r]
reverseResults :: Results a -> Results a
reverseResults = \case
Finished rs -> Finished (reverse rs)
FinishedEarly rs r -> FinishedEarly (reverse rs) r
res -> res
handleWorkerException ::
MonadIO m => JQueue m a -> MVar (SchedulerOutcome a) -> Int -> SomeException -> m ()
handleWorkerException jQueue workDoneMVar nWorkers exc =
case asyncExceptionFromException exc of
Just WorkerTerminateException -> return ()
_ -> do
_ <- liftIO $ tryPutMVar workDoneMVar $ SchedulerWorkerException $ WorkerException exc
retireWorkersN jQueue (nWorkers - 1)
safeBracketOnError :: MonadUnliftIO m => m a -> (a -> m b) -> (a -> m c) -> m c
safeBracketOnError before after thing = withRunInIO $ \run -> mask $ \restore -> do
x <- run before
res1 <- try $ restore $ run $ thing x
case res1 of
Left (e1 :: SomeException) -> do
_ :: Either SomeException b <-
try $ uninterruptibleMask_ $ run $ after x
throwIO e1
Right y -> return y