{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Control.Concurrent.Async.Pool.Internal where
import Control.Applicative (Applicative((<*>), pure), (<$>))
import Control.Arrow (first)
import Control.Concurrent (ThreadId)
import qualified Control.Concurrent.Async as Async (withAsync)
import Control.Concurrent.Async.Pool.Async
import Control.Concurrent.STM
import Control.Exception (SomeException, throwIO, finally)
import Control.Monad hiding (forM, forM_)
import Control.Monad.Base
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Control
import Data.Foldable (Foldable(foldMap), toList, forM_, all)
import Data.Graph.Inductive.Graph as Gr (Graph(empty))
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.List (delete)
import Data.Monoid (Monoid(mempty), (<>))
import Data.Traversable (Traversable(sequenceA), forM)
import Prelude hiding (mapM_, mapM, foldr, all, any, concatMap, foldl1)
import Unsafe.Coerce
getReadyNodes :: TaskGroup -> TaskGraph -> STM (IntMap (IO ThreadId, TMVar a))
getReadyNodes p g = do
availSlots <- readTVar (avail p)
check (availSlots > 0)
taskQueue <- readTVar (pending p)
check (not (IntMap.null taskQueue))
let readyNodes = IntMap.fromList
. take availSlots
. IntMap.toAscList
. IntMap.filterWithKey (const . isReady)
$ taskQueue
check (not (IntMap.null readyNodes))
writeTVar (avail p) (availSlots - IntMap.size readyNodes)
writeTVar (pending p) (taskQueue IntMap.\\ readyNodes)
return readyNodes
where
isReady = all isCompleted . inn g
isCompleted (_, _, Completed) = True
isCompleted (_, _, _) = False
getReadyTasks :: TaskGroup -> STM [(TVar State, (IO ThreadId, TMVar a))]
getReadyTasks p = do
g <- readTVar (tasks (pool p))
map (first (getTaskVar g)) . IntMap.toList <$> getReadyNodes p g
createPool :: IO Pool
createPool = Pool <$> newTVarIO Gr.empty
<*> newTVarIO 0
withPool :: (Pool -> IO a) -> IO a
withPool f = do
p <- createPool
x <- f p
atomically $ syncPool p
return x
createTaskGroup :: Pool -> Int -> IO TaskGroup
createTaskGroup p cnt = do
c <- newTVarIO cnt
m <- newTVarIO IntMap.empty
return $ TaskGroup p c (unsafeCoerce m)
runTaskGroup :: TaskGroup -> IO ()
runTaskGroup p = forever $ do
ready <- atomically $ do
cnt <- readTVar (avail p)
check (cnt > 0)
ready <- getReadyTasks p
check (not (null ready))
forM_ ready $ \(tv, _) -> writeTVar tv Starting
return ready
forM_ ready $ \(tv, (go, var)) -> do
t <- go
atomically $ swapTVar tv $ Started t var
withTaskGroupIn :: Pool -> Int -> (TaskGroup -> IO b) -> IO b
withTaskGroupIn p n f = createTaskGroup p n >>= \g ->
Async.withAsync (runTaskGroup g) $ const $ f g `finally` cancelAll g
withTaskGroup :: Int -> (TaskGroup -> IO b) -> IO b
withTaskGroup n f = createPool >>= \p -> withTaskGroupIn p n f
makeDependent :: Pool
-> Handle
-> Handle
-> STM ()
makeDependent p child parent = do
g <- readTVar (tasks p)
when (gelem parent g) $
case esp child parent g of
[] -> modifyTVar (tasks p) (insEdge (parent, child, Pending))
_ -> error "makeDependent: Cycle in task graph"
unsafeMakeDependent :: Pool
-> Handle
-> Handle
-> STM ()
unsafeMakeDependent p child parent = do
g <- readTVar (tasks p)
when (gelem parent g) $
modifyTVar (tasks p) (insEdge (parent, child, Pending))
asyncSTM :: TaskGroup -> IO a -> STM (Async a)
asyncSTM p = asyncUsing p rawForkIO
asyncAfterAll :: TaskGroup -> [Handle] -> IO a -> IO (Async a)
asyncAfterAll p parents t = atomically $ do
child <- asyncUsing p rawForkIO t
forM_ parents $ makeDependent (pool p) (taskHandle child)
return child
asyncAfter :: TaskGroup -> Async b -> IO a -> IO (Async a)
asyncAfter p parent = asyncAfterAll p [taskHandle parent]
mapTasksWorker :: Traversable t
=> TaskGroup
-> t (IO a)
-> (IO (t b) -> IO (t c))
-> (Async a -> IO b)
-> IO (t c)
mapTasksWorker p fs f g = do
hs <- forM fs $ atomically . asyncUsing p rawForkIO
f $ forM hs g
mapTasks :: Traversable t => TaskGroup -> t (IO a) -> IO (t a)
mapTasks p fs = mapTasksWorker p fs id wait
mapTasksE :: Traversable t => TaskGroup -> t (IO a) -> IO (t (Either SomeException a))
mapTasksE p fs = mapTasksWorker p fs id waitCatch
mapTasks_ :: Foldable t => TaskGroup -> t (IO a) -> IO ()
mapTasks_ p fs = forM_ fs $ atomically . asyncUsing p rawForkIO
mapTasksE_ :: Traversable t => TaskGroup -> t (IO a) -> IO (t (Maybe SomeException))
mapTasksE_ p fs = mapTasksWorker p fs (fmap (fmap leftToMaybe)) waitCatch
where
leftToMaybe :: Either a b -> Maybe a
leftToMaybe = either Just (const Nothing)
mapRace :: Foldable t
=> TaskGroup -> t (IO a) -> IO (Async a, Either SomeException a)
mapRace p fs = do
hs <- atomically $ sequenceA $ foldMap ((:[]) <$> asyncUsing p rawForkIO) fs
waitAnyCatchCancel hs
mapReduce :: (Foldable t, Monoid a)
=> TaskGroup
-> t (IO a)
-> STM (Async a)
mapReduce p fs = do
hs <- sequenceA $ foldMap ((:[]) <$> asyncUsing p rawForkIO) fs
loopM hs
where
loopM hs = do
hs' <- squeeze hs
case hs' of
[] -> error "mapReduce: impossible"
[x] -> return x
xs -> loopM xs
squeeze [] = (:[]) <$> asyncUsing p rawForkIO (return mempty)
squeeze [x] = return [x]
squeeze (x:y:xs) = do
t <- asyncUsing p rawForkIO $ do
meres <- atomically $ do
eres1 <- pollSTM x
eres2 <- pollSTM y
case liftM2 (<>) <$> eres1 <*> eres2 of
Nothing -> retry
Just a -> return a
case meres of
Left e -> throwIO e
Right a -> return a
forM_ [x, y] (unsafeMakeDependent (pool p) (taskHandle t) . taskHandle)
case xs of
[] -> return [t]
_ -> (t :) <$> squeeze xs
scatterFoldMapM :: (Foldable t, Monoid b, MonadBaseControl IO m)
=> TaskGroup -> t (IO a) -> (Either SomeException a -> m b) -> m b
scatterFoldMapM p fs f = do
hs <- liftBase $ atomically
$ sequenceA
$ foldMap ((:[]) <$> asyncUsing p rawForkIO) fs
control $ \(run :: m b -> IO (StM m b)) -> loop run (run $ return mempty) (toList hs)
where
loop _ z [] = z
loop run z hs = do
(h, eres) <- atomically $ do
mres <- foldM go Nothing hs
maybe retry return mres
r' <- z
r <- run $ do
s <- restoreM r'
r <- f eres
return $ s <> r
loop run (return r) (delete h hs)
go acc@(Just _) _ = return acc
go acc h = do
eres <- pollSTM h
return $ case eres of
Nothing -> acc
Just (Left e) -> Just (h, Left e)
Just (Right x) -> Just (h, Right x)
mapConcurrently :: Traversable t => TaskGroup -> (a -> IO b) -> t a -> IO (t b)
mapConcurrently tg f = mapTasks tg . fmap f
newtype Task a = Task { runTask' :: TaskGroup -> IO (IO a) }
runTask :: TaskGroup -> Task a -> IO a
runTask group ts = join $ runTask' ts group
task :: IO a -> Task a
task action = Task $ \_ -> return action
instance Functor Task where
fmap f (Task k) = Task $ fmap (fmap (liftM f)) k
instance Applicative Task where
pure x = Task $ \_ -> return (return x)
Task f <*> Task x = Task $ \tg -> do
xa <- x tg
x' <- wait <$> async tg xa
fa <- f tg
return $ fa <*> x'
instance Monad Task where
return = pure
Task m >>= f = Task $ \tg -> join (m tg) >>= flip runTask' tg . f
instance MonadIO Task where
liftIO = task