-- |
-- Module      : Streamly.Internal.Data.Stream.Channel.Dispatcher
-- Copyright   : (c) 2017 Composewell Technologies
-- License     : BSD-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
--
module Streamly.Internal.Data.Stream.Channel.Dispatcher
    (
    -- * Latency collection
      minThreadDelay
    , collectLatency

    -- * Thread accounting
    , addThread
    , delThread
    , modifyThread
    , allThreadsDone
    , recordMaxWorkers

    -- * Diagnostics
    , dumpSVarStats
    )
where

import Data.Set (Set)
import Control.Concurrent (MVar, ThreadId)
import Control.Concurrent.MVar (tryPutMVar)
import Control.Exception (assert)
import Control.Monad (when, void)
import Control.Monad.IO.Class (MonadIO(liftIO))
import Data.IORef (IORef, modifyIORef, readIORef, writeIORef)
import Streamly.Internal.Data.Atomics (atomicModifyIORefCAS, writeBarrier)
import Streamly.Internal.Data.Time.Clock (Clock(Monotonic), getTime)
import Streamly.Internal.Data.Time.Units
       ( AbsTime, NanoSecond64(..), diffAbsTime64, showNanoSecond64
       , showRelTime64)

import qualified Data.Set as S

import Streamly.Internal.Data.Stream.Channel.Types

-------------------------------------------------------------------------------
-- Worker latency data processing
-------------------------------------------------------------------------------

-- | This is a magic number and it is overloaded, and used at several places to
-- achieve batching:
--
-- 1. If we have to sleep to slowdown this is the minimum period that we
--    accumulate before we sleep. Also, workers do not stop until this much
--    sleep time is accumulated.
-- 3. Collected latencies are computed and transferred to measured latency
--    after a minimum of this period.
minThreadDelay :: NanoSecond64
minThreadDelay :: NanoSecond64
minThreadDelay = NanoSecond64
1000000

-- Every once in a while workers update the latencies and check the yield rate.
-- They return if we are above the expected yield rate. If we check too often
-- it may impact performance, if we check less often we may have a stale
-- picture. We update every minThreadDelay but we translate that into a yield
-- count based on latency so that the checking overhead is little.
--
-- XXX use a generation count to indicate that the value is updated. If the
-- value is updated an existing worker must check it again on the next yield.
-- Otherwise it is possible that we may keep updating it and because of the mod
-- worker keeps skipping it.
updateWorkerPollingInterval :: YieldRateInfo -> NanoSecond64 -> IO ()
updateWorkerPollingInterval :: YieldRateInfo -> NanoSecond64 -> IO ()
updateWorkerPollingInterval YieldRateInfo
yinfo NanoSecond64
latency = do
    let periodRef :: IORef Count
periodRef = YieldRateInfo -> IORef Count
workerPollingInterval YieldRateInfo
yinfo
        cnt :: NanoSecond64
cnt = forall a. Ord a => a -> a -> a
max NanoSecond64
1 forall a b. (a -> b) -> a -> b
$ NanoSecond64
minThreadDelay forall a. Integral a => a -> a -> a
`div` NanoSecond64
latency
        period :: NanoSecond64
period = forall a. Ord a => a -> a -> a
min NanoSecond64
cnt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
magicMaxBuffer)

    forall a. IORef a -> a -> IO ()
writeIORef IORef Count
periodRef (forall a b. (Integral a, Num b) => a -> b
fromIntegral NanoSecond64
period)

{-# INLINE recordMinMaxLatency #-}
recordMinMaxLatency :: SVarStats -> NanoSecond64 -> IO ()
recordMinMaxLatency :: SVarStats -> NanoSecond64 -> IO ()
recordMinMaxLatency SVarStats
ss NanoSecond64
new = do
    NanoSecond64
minLat <- forall a. IORef a -> IO a
readIORef (SVarStats -> IORef NanoSecond64
minWorkerLatency SVarStats
ss)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (NanoSecond64
new forall a. Ord a => a -> a -> Bool
< NanoSecond64
minLat Bool -> Bool -> Bool
|| NanoSecond64
minLat forall a. Eq a => a -> a -> Bool
== NanoSecond64
0) forall a b. (a -> b) -> a -> b
$
        forall a. IORef a -> a -> IO ()
writeIORef (SVarStats -> IORef NanoSecond64
minWorkerLatency SVarStats
ss) NanoSecond64
new

    NanoSecond64
maxLat <- forall a. IORef a -> IO a
readIORef (SVarStats -> IORef NanoSecond64
maxWorkerLatency SVarStats
ss)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (NanoSecond64
new forall a. Ord a => a -> a -> Bool
> NanoSecond64
maxLat) forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (SVarStats -> IORef NanoSecond64
maxWorkerLatency SVarStats
ss) NanoSecond64
new

recordAvgLatency :: SVarStats -> (Count, NanoSecond64) -> IO ()
recordAvgLatency :: SVarStats -> (Count, NanoSecond64) -> IO ()
recordAvgLatency SVarStats
ss (Count
count, NanoSecond64
time) = do
    forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (SVarStats -> IORef (Count, NanoSecond64)
avgWorkerLatency SVarStats
ss) forall a b. (a -> b) -> a -> b
$
        \(Count
cnt, NanoSecond64
t) -> (Count
cnt forall a. Num a => a -> a -> a
+ Count
count, NanoSecond64
t forall a. Num a => a -> a -> a
+ NanoSecond64
time)

-- Pour the pending latency stats into a collection bucket
{-# INLINE collectWorkerPendingLatency #-}
collectWorkerPendingLatency
    :: IORef (Count, Count, NanoSecond64)
    -> IORef (Count, Count, NanoSecond64)
    -> IO (Count, Maybe (Count, NanoSecond64))
collectWorkerPendingLatency :: IORef (Count, Count, NanoSecond64)
-> IORef (Count, Count, NanoSecond64)
-> IO (Count, Maybe (Count, NanoSecond64))
collectWorkerPendingLatency IORef (Count, Count, NanoSecond64)
cur IORef (Count, Count, NanoSecond64)
col = do
    (Count
fcount, Count
count, NanoSecond64
time) <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORefCAS IORef (Count, Count, NanoSecond64)
cur forall a b. (a -> b) -> a -> b
$ \(Count, Count, NanoSecond64)
v -> ((Count
0,Count
0,NanoSecond64
0), (Count, Count, NanoSecond64)
v)

    (Count
fcnt, Count
cnt, NanoSecond64
t) <- forall a. IORef a -> IO a
readIORef IORef (Count, Count, NanoSecond64)
col
    let totalCount :: Count
totalCount = Count
fcnt forall a. Num a => a -> a -> a
+ Count
fcount
        latCount :: Count
latCount   = Count
cnt forall a. Num a => a -> a -> a
+ Count
count
        latTime :: NanoSecond64
latTime    = NanoSecond64
t forall a. Num a => a -> a -> a
+ NanoSecond64
time
    forall a. IORef a -> a -> IO ()
writeIORef IORef (Count, Count, NanoSecond64)
col (Count
totalCount, Count
latCount, NanoSecond64
latTime)

    forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Count
latCount forall a. Eq a => a -> a -> Bool
== Count
0 Bool -> Bool -> Bool
|| NanoSecond64
latTime forall a. Eq a => a -> a -> Bool
/= NanoSecond64
0) (forall (m :: * -> *) a. Monad m => a -> m a
return ())
    let latPair :: Maybe (Count, NanoSecond64)
latPair =
            if Count
latCount forall a. Ord a => a -> a -> Bool
> Count
0 Bool -> Bool -> Bool
&& NanoSecond64
latTime forall a. Ord a => a -> a -> Bool
> NanoSecond64
0
            then forall a. a -> Maybe a
Just (Count
latCount, NanoSecond64
latTime)
            else forall a. Maybe a
Nothing
    forall (m :: * -> *) a. Monad m => a -> m a
return (Count
totalCount, Maybe (Count, NanoSecond64)
latPair)

{-# INLINE shouldUseCollectedBatch #-}
shouldUseCollectedBatch
    :: Count
    -> NanoSecond64
    -> NanoSecond64
    -> NanoSecond64
    -> Bool
shouldUseCollectedBatch :: Count -> NanoSecond64 -> NanoSecond64 -> NanoSecond64 -> Bool
shouldUseCollectedBatch Count
collectedYields NanoSecond64
collectedTime NanoSecond64
newLat NanoSecond64
prevLat =
    let r :: Double
r = forall a b. (Integral a, Num b) => a -> b
fromIntegral NanoSecond64
newLat forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral NanoSecond64
prevLat :: Double
    in     (Count
collectedYields forall a. Ord a => a -> a -> Bool
> forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
magicMaxBuffer)
        Bool -> Bool -> Bool
|| (NanoSecond64
collectedTime forall a. Ord a => a -> a -> Bool
> NanoSecond64
minThreadDelay)
        Bool -> Bool -> Bool
|| (NanoSecond64
prevLat forall a. Ord a => a -> a -> Bool
> NanoSecond64
0 Bool -> Bool -> Bool
&& (Double
r forall a. Ord a => a -> a -> Bool
> Double
2 Bool -> Bool -> Bool
|| Double
r forall a. Ord a => a -> a -> Bool
< Double
0.5))
        Bool -> Bool -> Bool
|| (NanoSecond64
prevLat forall a. Eq a => a -> a -> Bool
== NanoSecond64
0)

-- Returns a triple, (1) yield count since last collection, (2) the base time
-- when we started counting, (3) average latency in the last measurement
-- period. The former two are used for accurate measurement of the going rate
-- whereas the average is used for future estimates e.g. how many workers
-- should be maintained to maintain the rate.
-- CAUTION! keep it in sync with getWorkerLatency
collectLatency ::
       Bool
    -> SVarStats
    -> YieldRateInfo
    -> Bool
    -> IO (Count, AbsTime, NanoSecond64)
collectLatency :: Bool
-> SVarStats
-> YieldRateInfo
-> Bool
-> IO (Count, AbsTime, NanoSecond64)
collectLatency Bool
inspecting SVarStats
ss YieldRateInfo
yinfo Bool
drain = do
    let cur :: IORef (Count, Count, NanoSecond64)
cur      = YieldRateInfo -> IORef (Count, Count, NanoSecond64)
workerPendingLatency YieldRateInfo
yinfo
        col :: IORef (Count, Count, NanoSecond64)
col      = YieldRateInfo -> IORef (Count, Count, NanoSecond64)
workerCollectedLatency YieldRateInfo
yinfo
        longTerm :: IORef (Count, AbsTime)
longTerm = YieldRateInfo -> IORef (Count, AbsTime)
svarAllTimeLatency YieldRateInfo
yinfo
        measured :: IORef NanoSecond64
measured = YieldRateInfo -> IORef NanoSecond64
workerMeasuredLatency YieldRateInfo
yinfo

    (Count
newCount, Maybe (Count, NanoSecond64)
newLatPair) <- IORef (Count, Count, NanoSecond64)
-> IORef (Count, Count, NanoSecond64)
-> IO (Count, Maybe (Count, NanoSecond64))
collectWorkerPendingLatency IORef (Count, Count, NanoSecond64)
cur IORef (Count, Count, NanoSecond64)
col
    (Count
lcount, AbsTime
ltime) <- forall a. IORef a -> IO a
readIORef IORef (Count, AbsTime)
longTerm
    NanoSecond64
prevLat <- forall a. IORef a -> IO a
readIORef IORef NanoSecond64
measured

    let newLcount :: Count
newLcount = Count
lcount forall a. Num a => a -> a -> a
+ Count
newCount
        retWith :: c -> m (Count, AbsTime, c)
retWith c
lat = forall (m :: * -> *) a. Monad m => a -> m a
return (Count
newLcount, AbsTime
ltime, c
lat)

    case Maybe (Count, NanoSecond64)
newLatPair of
        Maybe (Count, NanoSecond64)
Nothing -> forall {m :: * -> *} {c}. Monad m => c -> m (Count, AbsTime, c)
retWith NanoSecond64
prevLat
        Just (Count
count, NanoSecond64
time) -> do
            let newLat :: NanoSecond64
newLat = NanoSecond64
time forall a. Integral a => a -> a -> a
`div` forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
count
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
inspecting forall a b. (a -> b) -> a -> b
$ SVarStats -> NanoSecond64 -> IO ()
recordMinMaxLatency SVarStats
ss NanoSecond64
newLat
            -- When we have collected a significant sized batch we compute the
            -- new latency using that batch and return the new latency,
            -- otherwise we return the previous latency derived from the
            -- previous batch.
            if Count -> NanoSecond64 -> NanoSecond64 -> NanoSecond64 -> Bool
shouldUseCollectedBatch Count
newCount NanoSecond64
time NanoSecond64
newLat NanoSecond64
prevLat Bool -> Bool -> Bool
|| Bool
drain
            then do
                -- XXX make this NOINLINE?
                YieldRateInfo -> NanoSecond64 -> IO ()
updateWorkerPollingInterval YieldRateInfo
yinfo (forall a. Ord a => a -> a -> a
max NanoSecond64
newLat NanoSecond64
prevLat)
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
inspecting forall a b. (a -> b) -> a -> b
$ SVarStats -> (Count, NanoSecond64) -> IO ()
recordAvgLatency SVarStats
ss (Count
count, NanoSecond64
time)
                forall a. IORef a -> a -> IO ()
writeIORef IORef (Count, Count, NanoSecond64)
col (Count
0, Count
0, NanoSecond64
0)
                forall a. IORef a -> a -> IO ()
writeIORef IORef NanoSecond64
measured ((NanoSecond64
prevLat forall a. Num a => a -> a -> a
+ NanoSecond64
newLat) forall a. Integral a => a -> a -> a
`div` NanoSecond64
2)
                forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Count, AbsTime)
longTerm forall a b. (a -> b) -> a -> b
$ \(Count
_, AbsTime
t) -> (Count
newLcount, AbsTime
t)
                forall {m :: * -> *} {c}. Monad m => c -> m (Count, AbsTime, c)
retWith NanoSecond64
newLat
            else forall {m :: * -> *} {c}. Monad m => c -> m (Count, AbsTime, c)
retWith NanoSecond64
prevLat

-------------------------------------------------------------------------------
-- Dumping the SVar for debug/diag
-------------------------------------------------------------------------------

dumpSVarStats :: Bool -> Maybe YieldRateInfo -> SVarStats -> IO String
dumpSVarStats :: Bool -> Maybe YieldRateInfo -> SVarStats -> IO String
dumpSVarStats Bool
inspecting Maybe YieldRateInfo
rateInfo SVarStats
ss = do
    case Maybe YieldRateInfo
rateInfo of
        Maybe YieldRateInfo
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just YieldRateInfo
yinfo -> do
            (Count, AbsTime, NanoSecond64)
_ <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Bool
-> SVarStats
-> YieldRateInfo
-> Bool
-> IO (Count, AbsTime, NanoSecond64)
collectLatency Bool
inspecting SVarStats
ss YieldRateInfo
yinfo Bool
True
            forall (m :: * -> *) a. Monad m => a -> m a
return ()

    Int
dispatches <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ SVarStats -> IORef Int
totalDispatches SVarStats
ss
    Int
maxWrk <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ SVarStats -> IORef Int
maxWorkers SVarStats
ss
    Int
maxOq <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ SVarStats -> IORef Int
maxOutQSize SVarStats
ss
    -- maxHp <- readIORef $ maxHeapSize ss
    NanoSecond64
minLat <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ SVarStats -> IORef NanoSecond64
minWorkerLatency SVarStats
ss
    NanoSecond64
maxLat <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ SVarStats -> IORef NanoSecond64
maxWorkerLatency SVarStats
ss
    (Count
avgCnt, NanoSecond64
avgTime) <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ SVarStats -> IORef (Count, NanoSecond64)
avgWorkerLatency SVarStats
ss
    (Count
svarCnt, Count
svarGainLossCnt, RelTime64
svarLat) <- case Maybe YieldRateInfo
rateInfo of
        Maybe YieldRateInfo
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return (Count
0, Count
0, RelTime64
0)
        Just YieldRateInfo
yinfo -> do
            (Count
cnt, AbsTime
startTime) <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ YieldRateInfo -> IORef (Count, AbsTime)
svarAllTimeLatency YieldRateInfo
yinfo
            if Count
cnt forall a. Ord a => a -> a -> Bool
> Count
0
            then do
                Maybe AbsTime
t <- forall a. IORef a -> IO a
readIORef (SVarStats -> IORef (Maybe AbsTime)
svarStopTime SVarStats
ss)
                Count
gl <- forall a. IORef a -> IO a
readIORef (YieldRateInfo -> IORef Count
svarGainedLostYields YieldRateInfo
yinfo)
                case Maybe AbsTime
t of
                    Maybe AbsTime
Nothing -> do
                        AbsTime
now <- Clock -> IO AbsTime
getTime Clock
Monotonic
                        let interval :: RelTime64
interval = AbsTime -> AbsTime -> RelTime64
diffAbsTime64 AbsTime
now AbsTime
startTime
                        forall (m :: * -> *) a. Monad m => a -> m a
return (Count
cnt, Count
gl, RelTime64
interval forall a. Integral a => a -> a -> a
`div` forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
cnt)
                    Just AbsTime
stopTime -> do
                        let interval :: RelTime64
interval = AbsTime -> AbsTime -> RelTime64
diffAbsTime64 AbsTime
stopTime AbsTime
startTime
                        forall (m :: * -> *) a. Monad m => a -> m a
return (Count
cnt, Count
gl, RelTime64
interval forall a. Integral a => a -> a -> a
`div` forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
cnt)
            else forall (m :: * -> *) a. Monad m => a -> m a
return (Count
0, Count
0, RelTime64
0)

    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [String] -> String
unlines
        [ String
"total dispatches = " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
dispatches
        , String
"max workers = " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
maxWrk
        , String
"max outQSize = " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
maxOq
            forall a. Semigroup a => a -> a -> a
<> (if NanoSecond64
minLat forall a. Ord a => a -> a -> Bool
> NanoSecond64
0
               then String
"\nmin worker latency = " forall a. Semigroup a => a -> a -> a
<> NanoSecond64 -> String
showNanoSecond64 NanoSecond64
minLat
               else String
"")
            forall a. Semigroup a => a -> a -> a
<> (if NanoSecond64
maxLat forall a. Ord a => a -> a -> Bool
> NanoSecond64
0
               then String
"\nmax worker latency = " forall a. Semigroup a => a -> a -> a
<> NanoSecond64 -> String
showNanoSecond64 NanoSecond64
maxLat
               else String
"")
            forall a. Semigroup a => a -> a -> a
<> (if Count
avgCnt forall a. Ord a => a -> a -> Bool
> Count
0
                then let lat :: NanoSecond64
lat = NanoSecond64
avgTime forall a. Integral a => a -> a -> a
`div` forall a b. (Integral a, Num b) => a -> b
fromIntegral Count
avgCnt
                     in String
"\navg worker latency = " forall a. Semigroup a => a -> a -> a
<> NanoSecond64 -> String
showNanoSecond64 NanoSecond64
lat
                else String
"")
            forall a. Semigroup a => a -> a -> a
<> (if RelTime64
svarLat forall a. Ord a => a -> a -> Bool
> RelTime64
0
               then String
"\nSVar latency = " forall a. Semigroup a => a -> a -> a
<> RelTime64 -> String
showRelTime64 RelTime64
svarLat
               else String
"")
            forall a. Semigroup a => a -> a -> a
<> (if Count
svarCnt forall a. Ord a => a -> a -> Bool
> Count
0
               then String
"\nSVar yield count = " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Count
svarCnt
               else String
"")
            forall a. Semigroup a => a -> a -> a
<> (if Count
svarGainLossCnt forall a. Ord a => a -> a -> Bool
> Count
0
               then String
"\nSVar gain/loss yield count = " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Count
svarGainLossCnt
               else String
"")
        ]

-------------------------------------------------------------------------------
-- Thread accounting
-------------------------------------------------------------------------------

-- Thread tracking is needed for two reasons:
--
-- 1) Killing threads on exceptions. Threads may not be left to go away by
-- themselves because they may run for significant times before going away or
-- worse they may be stuck in IO and never go away.
--
-- 2) To know when all threads are done and the stream has ended.

{-# NOINLINE addThread #-}
addThread :: MonadIO m => IORef (Set ThreadId) -> ThreadId -> m ()
addThread :: forall (m :: * -> *).
MonadIO m =>
IORef (Set ThreadId) -> ThreadId -> m ()
addThread IORef (Set ThreadId)
workerSet ThreadId
tid =
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Set ThreadId)
workerSet (forall a. Ord a => a -> Set a -> Set a
S.insert ThreadId
tid)

-- This is cheaper than modifyThread because we do not have to send a
-- outputDoorBell This can make a difference when more workers are being
-- dispatched.
{-# INLINE delThread #-}
delThread :: MonadIO m => IORef (Set ThreadId) -> ThreadId -> m ()
delThread :: forall (m :: * -> *).
MonadIO m =>
IORef (Set ThreadId) -> ThreadId -> m ()
delThread IORef (Set ThreadId)
workerSet ThreadId
tid =
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Set ThreadId)
workerSet (forall a. Ord a => a -> Set a -> Set a
S.delete ThreadId
tid)

-- If present then delete else add. This takes care of out of order add and
-- delete i.e. a delete arriving before we even added a thread.
-- This occurs when the forked thread is done even before the 'addThread' right
-- after the fork gets a chance to run.
{-# INLINE modifyThread #-}
modifyThread :: MonadIO m => IORef (Set ThreadId) -> MVar () -> ThreadId -> m ()
modifyThread :: forall (m :: * -> *).
MonadIO m =>
IORef (Set ThreadId) -> MVar () -> ThreadId -> m ()
modifyThread IORef (Set ThreadId)
workerSet MVar ()
bell ThreadId
tid = do
    Set ThreadId
changed <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORefCAS IORef (Set ThreadId)
workerSet forall a b. (a -> b) -> a -> b
$ \Set ThreadId
old ->
        if forall a. Ord a => a -> Set a -> Bool
S.member ThreadId
tid Set ThreadId
old
        then let new :: Set ThreadId
new = forall a. Ord a => a -> Set a -> Set a
S.delete ThreadId
tid Set ThreadId
old in (Set ThreadId
new, Set ThreadId
new)
        else let new :: Set ThreadId
new = forall a. Ord a => a -> Set a -> Set a
S.insert ThreadId
tid Set ThreadId
old in (Set ThreadId
new, Set ThreadId
old)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null Set ThreadId
changed) forall a b. (a -> b) -> a -> b
$
         forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
            IO ()
writeBarrier
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO Bool
tryPutMVar MVar ()
bell ()

-- | This is safe even if we are adding more threads concurrently because if
-- a child thread is adding another thread then anyway 'workerThreads' will
-- not be empty.
{-# INLINE allThreadsDone #-}
allThreadsDone :: MonadIO m => IORef (Set ThreadId) -> m Bool
allThreadsDone :: forall (m :: * -> *). MonadIO m => IORef (Set ThreadId) -> m Bool
allThreadsDone IORef (Set ThreadId)
ref = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Set a -> Bool
S.null forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef (Set ThreadId)
ref

-------------------------------------------------------------------------------
-- Dispatching workers
-------------------------------------------------------------------------------

{-# NOINLINE recordMaxWorkers #-}
recordMaxWorkers :: MonadIO m => IORef Int -> SVarStats -> m ()
recordMaxWorkers :: forall (m :: * -> *). MonadIO m => IORef Int -> SVarStats -> m ()
recordMaxWorkers IORef Int
countRef SVarStats
ss = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
    Int
active <- forall a. IORef a -> IO a
readIORef IORef Int
countRef
    Int
maxWrk <- forall a. IORef a -> IO a
readIORef (SVarStats -> IORef Int
maxWorkers SVarStats
ss)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
active forall a. Ord a => a -> a -> Bool
> Int
maxWrk) forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (SVarStats -> IORef Int
maxWorkers SVarStats
ss) Int
active
    forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (SVarStats -> IORef Int
totalDispatches SVarStats
ss) (forall a. Num a => a -> a -> a
+Int
1)