module Control.Concurrent.Hierarchical (
HIO
, runHIO
, newPrimGroup
, newGroup
, local
, close
, Group
, finished
) where
import Control.Monad
import Control.Applicative
import Control.Exception
import Control.Concurrent.MonadIO
import Control.Concurrent.STM.MonadIO
import System.IO.Unsafe
newtype HIO a = HIO {inGroup :: Group -> IO a}
instance Functor HIO where
fmap f (HIO hio) = HIO (fmap (fmap f) hio)
instance Monad HIO where
return x = HIO $ \_ -> return x
m >>= k = HIO $ \w -> do
x <- m `inGroup` w
k x `inGroup` w
instance Applicative HIO where
pure = return
f <*> x = ap f x
instance MonadIO HIO where
liftIO io = HIO $ const io
type Group = (TVar Int, TVar Inhabitants)
data Inhabitants = Closed | Open [Entry]
data Entry = Thread ThreadId
| Group Group
instance HasFork HIO where
fork hio = HIO $ \w -> block $ do
when countingThreads incrementThreadCount
increment w
fork (block (do tid <- myThreadId
register (Thread tid) w
unblock (hio `inGroup` w))
`finally`
decrement w)
newGroup :: HIO Group
newGroup = HIO $ \w -> do
w' <- newPrimGroup
register (Group w') w
return w'
local :: Group -> HIO a -> HIO a
local w p = liftIO (p `inGroup` w)
close :: Group -> HIO ()
close (c,t) = liftIO $ fork (kill (Group (c,t)) >> writeTVar c 0)
>> return ()
finished :: Group -> HIO ()
finished w = liftIO $ isZero w
runHIO :: HIO b -> IO b
runHIO hio = do
w <- newPrimGroup
r <- hio `inGroup` w
isZero w
when countingThreads printThreadReport
return r
newPrimGroup :: IO Group
newPrimGroup = do
count <- newTVar 0
threads <- newTVar (Open [])
return (count,threads)
register :: Entry -> Group -> IO ()
register tid (_,t) = join $ atomically $ do
ts <- readTVarSTM t
case ts of
Closed -> return (myThreadId >>= killThread)
Open tids -> writeTVarSTM t (Open (tid:tids)) >>
return (return ())
kill :: Entry -> IO ()
kill (Thread tid) = killThread tid
kill (Group (_,t)) = do
(ts,_) <- modifyTVar t (const Closed)
case ts of
Closed -> return ()
Open tids -> sequence_ (map kill tids)
increment, decrement, isZero :: Group -> IO ()
increment (c,_) = modifyTVar_ c (+1)
decrement (c,_) = modifyTVar_ c (\x->x1)
isZero (c,_) = atomically $ (readTVarSTM c >>= (check . (==0)))
countingThreads :: Bool
countingThreads = True
threadCount :: TVar Integer
threadCount = unsafePerformIO $ newTVar 0
incrementThreadCount :: IO ()
incrementThreadCount = modifyTVar_ threadCount (+1)
printThreadReport :: IO ()
printThreadReport = do
n <- readTVar threadCount
putStrLn "----------------------------"
putStrLn (show n ++ " HIO threads were forked")