{-# LANGUAGE TupleSections #-}
module General.Pool(
    Pool, runPool,
    addPool, PoolPriority(..),
    increasePool, keepAlivePool
    ) where
import Control.Concurrent.Extra
import System.Time.Extra
import Control.Exception
import Control.Monad.Extra
import General.Timing
import General.Thread
import qualified Data.Heap as Heap
import qualified Data.HashSet as Set
import Data.IORef.Extra
import System.Random
data S = S
    {alive :: !Bool 
    ,threads :: !(Set.HashSet Thread) 
    ,threadsLimit :: {-# UNPACK #-} !Int 
    ,threadsCount :: {-# UNPACK #-} !Int 
    ,threadsMax :: {-# UNPACK #-} !Int 
    ,threadsSum :: {-# UNPACK #-} !Int 
    ,rand :: IO Int 
    ,todo :: !(Heap.Heap (Heap.Entry (PoolPriority, Int) (IO ()))) 
    }
emptyS :: Int -> Bool -> IO S
emptyS n deterministic = do
    rand <- if not deterministic then return randomIO else do
        ref <- newIORef 0
        
        return $ do i <- readIORef ref; writeIORef' ref (i+1); return i
    return $ S True Set.empty n 0 0 0 rand Heap.empty
data Pool = Pool
    !(Var S) 
    !(Barrier (Either SomeException S)) 
withPool :: Pool -> (S -> IO (S, IO ())) -> IO ()
withPool (Pool var _) f = join $ modifyVar var $ \s ->
    if alive s then f s else return (s, return ())
withPool_ :: Pool -> (S -> IO S) -> IO ()
withPool_ pool act = withPool pool $ fmap (, return()) . act
worker :: Pool -> IO ()
worker pool = withPool pool $ \s -> return $ case Heap.uncons $ todo s of
    Nothing -> (s, return ())
    Just (Heap.Entry _ now, todo2) -> (s{todo = todo2}, now >> worker pool)
step :: Pool -> (S -> IO S) -> IO ()
step pool@(Pool _ done) op = mask_ $ withPool_ pool $ \s -> do
    s <- op s
    case Heap.uncons $ todo s of
        Just (Heap.Entry _ now, todo2) | threadsCount s < threadsLimit s -> do
            
            t <- newThreadFinally (now >> worker pool) $ \t res -> case res of
                Left e -> withPool_ pool $ \s -> do
                    signalBarrier done $ Left e
                    return (remThread t s){alive = False}
                Right _ ->
                    step pool $ return . remThread t
            return (addThread t s){todo = todo2}
        Nothing | threadsCount s == 0 -> do
            signalBarrier done $ Right s
            return s{alive = False}
        _ -> return s
    where
        addThread t s = s{threads = Set.insert t $ threads s, threadsCount = threadsCount s + 1
                         ,threadsSum = threadsSum s + 1, threadsMax = threadsMax s `max` (threadsCount s + 1)}
        remThread t s = s{threads = Set.delete t $ threads s, threadsCount = threadsCount s - 1}
addPool :: PoolPriority -> Pool -> IO a -> IO ()
addPool priority pool act = step pool $ \s -> do
    i <- rand s
    return s{todo = Heap.insert (Heap.Entry (priority, i) $ void act) $ todo s}
data PoolPriority
    = PoolException
    | PoolResume
    | PoolStart
    | PoolBatch
    | PoolDeprioritize Double
      deriving (Eq,Ord)
increasePool :: Pool -> IO (IO ())
increasePool pool = do
    step pool $ \s -> return s{threadsLimit = threadsLimit s + 1}
    return $ step pool $ \s -> return s{threadsLimit = threadsLimit s - 1}
keepAlivePool :: Pool -> IO (IO ())
keepAlivePool pool = do
    bar <- newBarrier
    addPool PoolResume pool $ do
        cancel <- increasePool pool
        waitBarrier bar
        cancel
    return $ signalBarrier bar ()
runPool :: Bool -> Int -> (Pool -> IO ()) -> IO () 
runPool deterministic n act = do
    s <- newVar =<< emptyS n deterministic
    done <- newBarrier
    let pool = Pool s done
    
    let cleanup = join $ modifyVar s $ \s -> return (s{alive=False}, stopThreads $ Set.toList $ threads s)
    let ghc10793 = do
            
            
            
            sleep 1 
                    
            res <- waitBarrierMaybe done
            case res of
                Just (Left e) -> throwIO e
                _ -> throwIO BlockedIndefinitelyOnMVar
    flip finally cleanup $ handle (\BlockedIndefinitelyOnMVar -> ghc10793) $ do
        addPool PoolStart pool $ act pool
        res <- waitBarrier done
        case res of
            Left e -> throwIO e
            Right s -> addTiming $ "Pool finished (" ++ show (threadsSum s) ++ " threads, " ++ show (threadsMax s) ++ " max)"