{-# LANGUAGE LambdaCase, ScopedTypeVariables, GADTs, DeriveDataTypeable, Trustworthy #-}
module Control.CUtils.ThreadPool (Pool, addToPoolMulti, newPool, stopPool_,
globalPool,
ThreadPool(..), Interruptible(..), NoPool(..), BoxedThreadPool(..)) where
import Control.Exception
import Control.Concurrent
import Data.Data
import Data.Array
import Data.Foldable
import Data.IORef
import Data.List.Extras.Argmax
import Data.Maybe
import Control.Monad
import Control.Monad.Identity
import Control.Monad.Loops
import Control.CUtils.BoundedQueue
import System.IO.Unsafe
import System.IO
import Prelude hiding (mapM_)
data Instruction = NextTask(IO()) | S deriving (Typeable)
data Worker= Worker{ instructions :: !(BoundedQueue Instruction), counter :: !(IORef Int) }
deriving (Typeable)
instance Data Instruction
instance Data Worker
newtype Pool = Pool { workers_ :: Array Int Worker } deriving (Typeable, Data)
addToWorker :: Worker -> IO t -> IO()
{-# INLINE addToWorker #-}
addToWorker mv mnd = mask_$ do
atomicModifyIORef'(counter mv) (flip(,) ().succ)
writeRB(instructions mv) $!NextTask(void mnd)
addToPoolMulti :: Pool -> IO t -> IO()
{-# INLINABLE addToPoolMulti #-}
addToPoolMulti (Pool ls) _ | rangeSize(bounds ls)<=0 = throwIO$ErrorCall"addToPoolMulti: pool is empty"
addToPoolMulti (Pool ls) mnd = do
let ls' = toList ls
ls2 <- mapM(readIORef.counter) ls'
let worker = fst.argmin snd.zip ls'$ls2
addToWorker worker mnd
newWorker :: IO Worker
newWorker = do
rb <- newRB 10000
ref <- newIORef 0
let worker = Worker rb ref
_ <- forkIO(loop worker)
return$!worker
where
loop worker = whileM_(readRB(instructions worker) >>= \ case
NextTask mnd -> do
atomicModifyIORef'(counter worker) (flip(,) ().pred)
catch mnd(\(ex::SomeException) -> hPrint stderr ex)
return$!True
S -> return$!False)
$return()
newPool :: Int -> IO Pool
newPool n= liftM(Pool. listArray(0,n-1)) (replicateM n newWorker)
stopWorker :: Worker -> IO()
stopWorker mv = writeRB(instructions mv)$!S
stopPool_ :: Pool -> IO()
stopPool_ = mapM_ stopWorker.workers_
globalPool :: Pool
{-# NOINLINE globalPool #-}
globalPool = unsafePerformIO(getNumCapabilities >>= newPool)
class ThreadPool pool where
addToPool :: pool -> IO t -> IO()
class Interruptible pool where
stopPool :: pool -> IO()
instance ThreadPool Pool where
addToPool = addToPoolMulti
instance Interruptible Pool where
stopPool = stopPool_
data NoPool = NoPool deriving (Typeable, Data)
instance ThreadPool NoPool where
addToPool _ = void.forkIO.void
data BoxedThreadPool where
BoxedThreadPool :: (ThreadPool pool) => pool -> BoxedThreadPool
instance ThreadPool BoxedThreadPool where
addToPool(BoxedThreadPool pool) = addToPool pool
{-# SPECIALIZE addToPool :: Pool -> IO t -> IO() #-}
{-# SPECIALIZE addToPool :: NoPool -> IO t -> IO() #-}
{-# SPECIALIZE addToPool :: BoxedThreadPool -> IO t -> IO() #-}