module General.Pool(
Pool, runPool,
addPool, PoolPriority(..),
increasePool, keepAlivePool
) where
import Control.Concurrent.Extra
import System.Time.Extra
import Control.Exception
import Control.Monad.Extra
import General.Timing
import General.Extra
import qualified Data.Heap as Heap
import qualified Data.HashSet as Set
import Data.IORef.Extra
import System.Random
data Pool = Pool
!(Var (Maybe S))
!(Barrier (Either SomeException S))
data S = S
{threads :: !(Set.HashSet ThreadId)
,threadsLimit :: {-# UNPACK #-} !Int
,threadsCount :: {-# UNPACK #-} !Int
,threadsMax :: {-# UNPACK #-} !Int
,threadsSum :: {-# UNPACK #-} !Int
,rand :: IO Int
,todo :: !(Heap.Heap (Heap.Entry (PoolPriority, Int) (IO ())))
}
emptyS :: Int -> Bool -> IO S
emptyS n deterministic = do
rand <- if not deterministic then return randomIO else do
ref <- newIORef 0
return $ do i <- readIORef ref; writeIORef' ref (i+1); return i
return $ S Set.empty n 0 0 0 rand Heap.empty
worker :: Pool -> IO ()
worker pool@(Pool var _) = do
let onVar act = modifyVar var $ maybe (return (Nothing, return ())) act
join $ onVar $ \s ->return $ case Heap.uncons $ todo s of
Nothing -> (Just s, return ())
Just (Heap.Entry _ now, todo2) -> (Just s{todo = todo2}, now >> worker pool)
step :: Pool -> (S -> IO S) -> IO ()
step pool@(Pool var done) op = do
let onVar act = modifyVar_ var $ maybe (return Nothing) act
onVar $ \s -> do
s <- op s
case Heap.uncons $ todo s of
Just (Heap.Entry _ now, todo2) | threadsCount s < threadsLimit s -> do
t <- forkFinallyUnmasked (now >> worker pool) $ \res -> case res of
Left e -> onVar $ \s -> do
t <- myThreadId
mapM_ killThread $ Set.toList $ Set.delete t $ threads s
signalBarrier done $ Left e
return Nothing
Right _ -> do
t <- myThreadId
step pool $ \s -> return s{threads = Set.delete t $ threads s, threadsCount = threadsCount s - 1}
return $ Just s{todo = todo2, threads = Set.insert t $ threads s, threadsCount = threadsCount s + 1
,threadsSum = threadsSum s + 1, threadsMax = threadsMax s `max` (threadsCount s + 1)}
Nothing | threadsCount s == 0 -> do
signalBarrier done $ Right s
return Nothing
_ -> return $ Just s
addPool :: PoolPriority -> Pool -> IO a -> IO ()
addPool priority pool act = step pool $ \s -> do
i <- rand s
return s{todo = Heap.insert (Heap.Entry (priority, i) $ void act) $ todo s}
data PoolPriority
= PoolException
| PoolResume
| PoolStart
| PoolBatch
| PoolDeprioritize Double
deriving (Eq,Ord)
increasePool :: Pool -> IO (IO ())
increasePool pool = do
step pool $ \s -> return s{threadsLimit = threadsLimit s + 1}
return $ step pool $ \s -> return s{threadsLimit = threadsLimit s - 1}
keepAlivePool :: Pool -> IO (IO ())
keepAlivePool pool = do
bar <- newBarrier
addPool PoolResume pool $ do
cancel <- increasePool pool
waitBarrier bar
cancel
return $ signalBarrier bar ()
runPool :: Bool -> Int -> (Pool -> IO ()) -> IO ()
runPool deterministic n act = do
s <- newVar . Just =<< emptyS n deterministic
done <- newBarrier
let cleanup = modifyVar_ s $ \s -> do
case s of
Just s -> mapM_ killThread $ Set.toList $ threads s
Nothing -> return ()
return Nothing
let ghc10793 = do
sleep 1
res <- waitBarrierMaybe done
case res of
Just (Left e) -> throwIO e
_ -> throwIO BlockedIndefinitelyOnMVar
handle (\BlockedIndefinitelyOnMVar -> ghc10793) $ flip onException cleanup $ do
let pool = Pool s done
addPool PoolStart pool $ act pool
res <- waitBarrier done
case res of
Left e -> throwIO e
Right s -> addTiming $ "Pool finished (" ++ show (threadsSum s) ++ " threads, " ++ show (threadsMax s) ++ " max)"