{-# OPTIONS_HADDOCK not-home #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
module Hedgehog.Internal.Queue (
    TaskIndex(..)
  , TasksRemaining(..)

  , runTasks
  , finalizeTask

  , runActiveFinalizers
  , dequeueMVar

  , updateNumCapabilities
  ) where

import           Control.Concurrent (rtsSupportsBoundThreads)
import           Control.Concurrent.Async (forConcurrently)
import           Control.Concurrent.MVar (MVar)
import qualified Control.Concurrent.MVar as MVar
import           Control.Monad (when)
import           Control.Monad.IO.Class (MonadIO(..))

import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map

import qualified GHC.Conc as Conc

import           Hedgehog.Internal.Config


newtype TaskIndex =
  TaskIndex Int
  deriving (TaskIndex -> TaskIndex -> Bool
(TaskIndex -> TaskIndex -> Bool)
-> (TaskIndex -> TaskIndex -> Bool) -> Eq TaskIndex
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TaskIndex -> TaskIndex -> Bool
$c/= :: TaskIndex -> TaskIndex -> Bool
== :: TaskIndex -> TaskIndex -> Bool
$c== :: TaskIndex -> TaskIndex -> Bool
Eq, Eq TaskIndex
Eq TaskIndex
-> (TaskIndex -> TaskIndex -> Ordering)
-> (TaskIndex -> TaskIndex -> Bool)
-> (TaskIndex -> TaskIndex -> Bool)
-> (TaskIndex -> TaskIndex -> Bool)
-> (TaskIndex -> TaskIndex -> Bool)
-> (TaskIndex -> TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex -> TaskIndex)
-> Ord TaskIndex
TaskIndex -> TaskIndex -> Bool
TaskIndex -> TaskIndex -> Ordering
TaskIndex -> TaskIndex -> TaskIndex
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TaskIndex -> TaskIndex -> TaskIndex
$cmin :: TaskIndex -> TaskIndex -> TaskIndex
max :: TaskIndex -> TaskIndex -> TaskIndex
$cmax :: TaskIndex -> TaskIndex -> TaskIndex
>= :: TaskIndex -> TaskIndex -> Bool
$c>= :: TaskIndex -> TaskIndex -> Bool
> :: TaskIndex -> TaskIndex -> Bool
$c> :: TaskIndex -> TaskIndex -> Bool
<= :: TaskIndex -> TaskIndex -> Bool
$c<= :: TaskIndex -> TaskIndex -> Bool
< :: TaskIndex -> TaskIndex -> Bool
$c< :: TaskIndex -> TaskIndex -> Bool
compare :: TaskIndex -> TaskIndex -> Ordering
$ccompare :: TaskIndex -> TaskIndex -> Ordering
$cp1Ord :: Eq TaskIndex
Ord, Int -> TaskIndex
TaskIndex -> Int
TaskIndex -> [TaskIndex]
TaskIndex -> TaskIndex
TaskIndex -> TaskIndex -> [TaskIndex]
TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex]
(TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex)
-> (Int -> TaskIndex)
-> (TaskIndex -> Int)
-> (TaskIndex -> [TaskIndex])
-> (TaskIndex -> TaskIndex -> [TaskIndex])
-> (TaskIndex -> TaskIndex -> [TaskIndex])
-> (TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex])
-> Enum TaskIndex
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex]
$cenumFromThenTo :: TaskIndex -> TaskIndex -> TaskIndex -> [TaskIndex]
enumFromTo :: TaskIndex -> TaskIndex -> [TaskIndex]
$cenumFromTo :: TaskIndex -> TaskIndex -> [TaskIndex]
enumFromThen :: TaskIndex -> TaskIndex -> [TaskIndex]
$cenumFromThen :: TaskIndex -> TaskIndex -> [TaskIndex]
enumFrom :: TaskIndex -> [TaskIndex]
$cenumFrom :: TaskIndex -> [TaskIndex]
fromEnum :: TaskIndex -> Int
$cfromEnum :: TaskIndex -> Int
toEnum :: Int -> TaskIndex
$ctoEnum :: Int -> TaskIndex
pred :: TaskIndex -> TaskIndex
$cpred :: TaskIndex -> TaskIndex
succ :: TaskIndex -> TaskIndex
$csucc :: TaskIndex -> TaskIndex
Enum, Integer -> TaskIndex
TaskIndex -> TaskIndex
TaskIndex -> TaskIndex -> TaskIndex
(TaskIndex -> TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex)
-> (TaskIndex -> TaskIndex)
-> (Integer -> TaskIndex)
-> Num TaskIndex
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> TaskIndex
$cfromInteger :: Integer -> TaskIndex
signum :: TaskIndex -> TaskIndex
$csignum :: TaskIndex -> TaskIndex
abs :: TaskIndex -> TaskIndex
$cabs :: TaskIndex -> TaskIndex
negate :: TaskIndex -> TaskIndex
$cnegate :: TaskIndex -> TaskIndex
* :: TaskIndex -> TaskIndex -> TaskIndex
$c* :: TaskIndex -> TaskIndex -> TaskIndex
- :: TaskIndex -> TaskIndex -> TaskIndex
$c- :: TaskIndex -> TaskIndex -> TaskIndex
+ :: TaskIndex -> TaskIndex -> TaskIndex
$c+ :: TaskIndex -> TaskIndex -> TaskIndex
Num)

newtype TasksRemaining =
  TasksRemaining Int

dequeueMVar ::
     MVar [(TaskIndex, a)]
  -> (TasksRemaining -> TaskIndex -> a -> IO b)
  -> IO (Maybe (TaskIndex, b))
dequeueMVar :: MVar [(TaskIndex, a)]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> IO (Maybe (TaskIndex, b))
dequeueMVar MVar [(TaskIndex, a)]
mvar TasksRemaining -> TaskIndex -> a -> IO b
start =
  MVar [(TaskIndex, a)]
-> ([(TaskIndex, a)]
    -> IO ([(TaskIndex, a)], Maybe (TaskIndex, b)))
-> IO (Maybe (TaskIndex, b))
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
MVar.modifyMVar MVar [(TaskIndex, a)]
mvar (([(TaskIndex, a)] -> IO ([(TaskIndex, a)], Maybe (TaskIndex, b)))
 -> IO (Maybe (TaskIndex, b)))
-> ([(TaskIndex, a)]
    -> IO ([(TaskIndex, a)], Maybe (TaskIndex, b)))
-> IO (Maybe (TaskIndex, b))
forall a b. (a -> b) -> a -> b
$ \case
    [] ->
      ([(TaskIndex, a)], Maybe (TaskIndex, b))
-> IO ([(TaskIndex, a)], Maybe (TaskIndex, b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], Maybe (TaskIndex, b)
forall a. Maybe a
Nothing)
    (TaskIndex
ix, a
x) : [(TaskIndex, a)]
xs -> do
      b
y <- TasksRemaining -> TaskIndex -> a -> IO b
start (Int -> TasksRemaining
TasksRemaining (Int -> TasksRemaining) -> Int -> TasksRemaining
forall a b. (a -> b) -> a -> b
$ [(TaskIndex, a)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(TaskIndex, a)]
xs) TaskIndex
ix a
x
      ([(TaskIndex, a)], Maybe (TaskIndex, b))
-> IO ([(TaskIndex, a)], Maybe (TaskIndex, b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(TaskIndex, a)]
xs, (TaskIndex, b) -> Maybe (TaskIndex, b)
forall a. a -> Maybe a
Just (TaskIndex
ix, b
y))

runTasks ::
     WorkerCount
  -> [a]
  -> (TasksRemaining -> TaskIndex -> a -> IO b)
  -> (b -> IO ())
  -> (b -> IO ())
  -> (b -> IO c)
  -> IO [c]
runTasks :: WorkerCount
-> [a]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> (b -> IO ())
-> (b -> IO ())
-> (b -> IO c)
-> IO [c]
runTasks WorkerCount
n [a]
tasks TasksRemaining -> TaskIndex -> a -> IO b
start b -> IO ()
finish b -> IO ()
finalize b -> IO c
runTask = do
  MVar [(TaskIndex, a)]
qvar <- [(TaskIndex, a)] -> IO (MVar [(TaskIndex, a)])
forall a. a -> IO (MVar a)
MVar.newMVar ([TaskIndex] -> [a] -> [(TaskIndex, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TaskIndex
0..] [a]
tasks)
  MVar (TaskIndex, Map TaskIndex (IO ()))
fvar <- (TaskIndex, Map TaskIndex (IO ()))
-> IO (MVar (TaskIndex, Map TaskIndex (IO ())))
forall a. a -> IO (MVar a)
MVar.newMVar (-TaskIndex
1, Map TaskIndex (IO ())
forall k a. Map k a
Map.empty)

  let
    worker :: [c] -> IO [c]
worker [c]
rs = do
      Maybe (TaskIndex, b)
mx <- MVar [(TaskIndex, a)]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> IO (Maybe (TaskIndex, b))
forall a b.
MVar [(TaskIndex, a)]
-> (TasksRemaining -> TaskIndex -> a -> IO b)
-> IO (Maybe (TaskIndex, b))
dequeueMVar MVar [(TaskIndex, a)]
qvar TasksRemaining -> TaskIndex -> a -> IO b
start
      case Maybe (TaskIndex, b)
mx of
        Maybe (TaskIndex, b)
Nothing ->
          [c] -> IO [c]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [c]
rs
        Just (TaskIndex
ix, b
x) -> do
          c
r <- b -> IO c
runTask b
x
          b -> IO ()
finish b
x
          MVar (TaskIndex, Map TaskIndex (IO ()))
-> TaskIndex -> IO () -> IO ()
forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ()))
-> TaskIndex -> IO () -> m ()
finalizeTask MVar (TaskIndex, Map TaskIndex (IO ()))
fvar TaskIndex
ix (b -> IO ()
finalize b
x)
          [c] -> IO [c]
worker (c
r c -> [c] -> [c]
forall a. a -> [a] -> [a]
: [c]
rs)

  -- FIXME ensure all workers have finished running
  ([[c]] -> [c]) -> IO [[c]] -> IO [c]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[c]] -> [c]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (IO [[c]] -> IO [c])
-> ((WorkerCount -> IO [c]) -> IO [[c]])
-> (WorkerCount -> IO [c])
-> IO [c]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [WorkerCount] -> (WorkerCount -> IO [c]) -> IO [[c]]
forall (t :: * -> *) a b.
Traversable t =>
t a -> (a -> IO b) -> IO (t b)
forConcurrently [WorkerCount
1..WorkerCount -> WorkerCount -> WorkerCount
forall a. Ord a => a -> a -> a
max WorkerCount
1 WorkerCount
n] ((WorkerCount -> IO [c]) -> IO [c])
-> (WorkerCount -> IO [c]) -> IO [c]
forall a b. (a -> b) -> a -> b
$ \WorkerCount
_ix ->
    [c] -> IO [c]
worker []

runActiveFinalizers ::
     MonadIO m
  => MVar (TaskIndex, Map TaskIndex (IO ()))
  -> m ()
runActiveFinalizers :: MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
runActiveFinalizers MVar (TaskIndex, Map TaskIndex (IO ()))
mvar =
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Bool
again <-
      MVar (TaskIndex, Map TaskIndex (IO ()))
-> ((TaskIndex, Map TaskIndex (IO ()))
    -> IO ((TaskIndex, Map TaskIndex (IO ())), Bool))
-> IO Bool
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
MVar.modifyMVar MVar (TaskIndex, Map TaskIndex (IO ()))
mvar (((TaskIndex, Map TaskIndex (IO ()))
  -> IO ((TaskIndex, Map TaskIndex (IO ())), Bool))
 -> IO Bool)
-> ((TaskIndex, Map TaskIndex (IO ()))
    -> IO ((TaskIndex, Map TaskIndex (IO ())), Bool))
-> IO Bool
forall a b. (a -> b) -> a -> b
$ \original :: (TaskIndex, Map TaskIndex (IO ()))
original@(TaskIndex
minIx, Map TaskIndex (IO ())
finalizers0) ->
        case Map TaskIndex (IO ())
-> Maybe ((TaskIndex, IO ()), Map TaskIndex (IO ()))
forall k a. Map k a -> Maybe ((k, a), Map k a)
Map.minViewWithKey Map TaskIndex (IO ())
finalizers0 of
          Maybe ((TaskIndex, IO ()), Map TaskIndex (IO ()))
Nothing ->
            ((TaskIndex, Map TaskIndex (IO ())), Bool)
-> IO ((TaskIndex, Map TaskIndex (IO ())), Bool)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TaskIndex, Map TaskIndex (IO ()))
original, Bool
False)

          Just ((TaskIndex
ix, IO ()
finalize), Map TaskIndex (IO ())
finalizers) ->
            if TaskIndex
ix TaskIndex -> TaskIndex -> Bool
forall a. Eq a => a -> a -> Bool
== TaskIndex
minIx TaskIndex -> TaskIndex -> TaskIndex
forall a. Num a => a -> a -> a
+ TaskIndex
1 then do
              IO ()
finalize
              ((TaskIndex, Map TaskIndex (IO ())), Bool)
-> IO ((TaskIndex, Map TaskIndex (IO ())), Bool)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TaskIndex
ix, Map TaskIndex (IO ())
finalizers), Bool
True)
            else
              ((TaskIndex, Map TaskIndex (IO ())), Bool)
-> IO ((TaskIndex, Map TaskIndex (IO ())), Bool)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TaskIndex, Map TaskIndex (IO ()))
original, Bool
False)

    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
again (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      MVar (TaskIndex, Map TaskIndex (IO ())) -> IO ()
forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
runActiveFinalizers MVar (TaskIndex, Map TaskIndex (IO ()))
mvar

finalizeTask ::
     MonadIO m
  => MVar (TaskIndex, Map TaskIndex (IO ()))
  -> TaskIndex
  -> IO ()
  -> m ()
finalizeTask :: MVar (TaskIndex, Map TaskIndex (IO ()))
-> TaskIndex -> IO () -> m ()
finalizeTask MVar (TaskIndex, Map TaskIndex (IO ()))
mvar TaskIndex
ix IO ()
finalize = do
  IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ())
-> (((TaskIndex, Map TaskIndex (IO ()))
     -> IO (TaskIndex, Map TaskIndex (IO ())))
    -> IO ())
-> ((TaskIndex, Map TaskIndex (IO ()))
    -> IO (TaskIndex, Map TaskIndex (IO ())))
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar (TaskIndex, Map TaskIndex (IO ()))
-> ((TaskIndex, Map TaskIndex (IO ()))
    -> IO (TaskIndex, Map TaskIndex (IO ())))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
MVar.modifyMVar_ MVar (TaskIndex, Map TaskIndex (IO ()))
mvar (((TaskIndex, Map TaskIndex (IO ()))
  -> IO (TaskIndex, Map TaskIndex (IO ())))
 -> m ())
-> ((TaskIndex, Map TaskIndex (IO ()))
    -> IO (TaskIndex, Map TaskIndex (IO ())))
-> m ()
forall a b. (a -> b) -> a -> b
$ \(TaskIndex
minIx, Map TaskIndex (IO ())
finalizers) ->
    (TaskIndex, Map TaskIndex (IO ()))
-> IO (TaskIndex, Map TaskIndex (IO ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TaskIndex
minIx, TaskIndex
-> IO () -> Map TaskIndex (IO ()) -> Map TaskIndex (IO ())
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert TaskIndex
ix IO ()
finalize Map TaskIndex (IO ())
finalizers)
  MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
forall (m :: * -> *).
MonadIO m =>
MVar (TaskIndex, Map TaskIndex (IO ())) -> m ()
runActiveFinalizers MVar (TaskIndex, Map TaskIndex (IO ()))
mvar

-- | Update the number of capabilities but never set it lower than it already
--   is.
--
updateNumCapabilities :: WorkerCount -> IO ()
updateNumCapabilities :: WorkerCount -> IO ()
updateNumCapabilities (WorkerCount Int
n) = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
rtsSupportsBoundThreads (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  Int
ncaps <- IO Int
Conc.getNumCapabilities
  Int -> IO ()
Conc.setNumCapabilities (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
n Int
ncaps)