{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ConstraintKinds #-}

-- | Wrapper around 'parallel' for limiting the threads using a semaphore.

module Test.Sandwich.ParallelN (
  parallelN
  , parallelN'

  , parallelNFromArgs
  , parallelNFromArgs'

  , parallelSemaphore
  , HasParallelSemaphore

  , defaultParallelNodeOptions
  ) where

import Control.Concurrent.QSem
import Control.Exception.Safe
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Control (MonadBaseControl)
import Test.Sandwich.Contexts
import Test.Sandwich.Types.ArgParsing
import Test.Sandwich.Types.RunTree
import Test.Sandwich.Types.Spec



-- | Wrapper around 'parallel'. Introduces a semaphore to limit the parallelism to N threads.
parallelN :: (
  MonadBaseControl IO m, MonadIO m, MonadMask m
  ) => Int -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m () -> SpecFree context m ()
parallelN :: forall (m :: * -> *) context.
(MonadBaseControl IO m, MonadIO m, MonadMask m) =>
Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelN = NodeOptions
-> Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall (m :: * -> *) context.
(MonadBaseControl IO m, MonadIO m, MonadMask m) =>
NodeOptions
-> Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelN' NodeOptions
defaultParallelNodeOptions

parallelN' :: (
  MonadBaseControl IO m, MonadIO m, MonadMask m
  ) => NodeOptions -> Int -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m () -> SpecFree context m ()
parallelN' :: forall (m :: * -> *) context.
(MonadBaseControl IO m, MonadIO m, MonadMask m) =>
NodeOptions
-> Int
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelN' NodeOptions
nodeOptions Int
n SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
children = String
-> Label "parallelSemaphore" QSem
-> ExampleT context m QSem
-> (QSem -> ExampleT context m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall intro (l :: Symbol) context (m :: * -> *).
(HasCallStack, Typeable intro) =>
String
-> Label l intro
-> ExampleT context m intro
-> (intro -> ExampleT context m ())
-> SpecFree (LabelValue l intro :> context) m ()
-> SpecFree context m ()
introduce String
"Introduce parallel semaphore" Label "parallelSemaphore" QSem
parallelSemaphore (IO QSem -> ExampleT context m QSem
forall a. IO a -> ExampleT context m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO QSem -> ExampleT context m QSem)
-> IO QSem -> ExampleT context m QSem
forall a b. (a -> b) -> a -> b
$ Int -> IO QSem
newQSem Int
n) (ExampleT context m () -> QSem -> ExampleT context m ()
forall a b. a -> b -> a
const (ExampleT context m () -> QSem -> ExampleT context m ())
-> ExampleT context m () -> QSem -> ExampleT context m ()
forall a b. (a -> b) -> a -> b
$ () -> ExampleT context m ()
forall a. a -> ExampleT context m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
 -> SpecFree context m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall a b. (a -> b) -> a -> b
$
  NodeOptions
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall context (m :: * -> *).
HasCallStack =>
NodeOptions -> SpecFree context m () -> SpecFree context m ()
parallel' NodeOptions
nodeOptions (SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
 -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall a b. (a -> b) -> a -> b
$ String
-> (ExampleT
      (LabelValue "parallelSemaphore" QSem :> context) m [Result]
    -> ExampleT (LabelValue "parallelSemaphore" QSem :> context) m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall (m :: * -> *) context.
(Monad m, HasCallStack) =>
String
-> (ExampleT context m [Result] -> ExampleT context m ())
-> SpecFree context m ()
-> SpecFree context m ()
aroundEach String
"Take parallel semaphore" ExampleT
  (LabelValue "parallelSemaphore" QSem :> context) m [Result]
-> ExampleT (LabelValue "parallelSemaphore" QSem :> context) m ()
forall {m :: * -> *} {context} {a}.
(HasLabel context "parallelSemaphore" QSem, MonadReader context m,
 MonadMask m, MonadIO m) =>
m a -> m ()
claimRunSlot SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
children
  where claimRunSlot :: m a -> m ()
claimRunSlot m a
f = do
          QSem
s <- Label "parallelSemaphore" QSem -> m QSem
forall (m :: * -> *) context (l :: Symbol) a.
(Monad m, HasLabel context l a, HasCallStack,
 MonadReader context m) =>
Label l a -> m a
getContext Label "parallelSemaphore" QSem
parallelSemaphore
          m () -> m () -> m () -> m ()
forall (m :: * -> *) a b c.
(HasCallStack, MonadMask m) =>
m a -> m b -> m c -> m c
bracket_ (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
waitQSem QSem
s) (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
signalQSem QSem
s) (m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void m a
f)

-- | Same as 'parallelN', but extracts the semaphore size from the command line options.
parallelNFromArgs :: forall context a m. (
  MonadBaseControl IO m, MonadIO m, MonadMask m, HasCommandLineOptions context a
  ) => (CommandLineOptions a -> Int) -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m () -> SpecFree context m ()
parallelNFromArgs :: forall context a (m :: * -> *).
(MonadBaseControl IO m, MonadIO m, MonadMask m,
 HasCommandLineOptions context a) =>
(CommandLineOptions a -> Int)
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelNFromArgs = forall context a (m :: * -> *).
(MonadBaseControl IO m, MonadIO m, MonadMask m,
 HasCommandLineOptions context a) =>
NodeOptions
-> (CommandLineOptions a -> Int)
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelNFromArgs' @context @a NodeOptions
defaultParallelNodeOptions

parallelNFromArgs' :: forall context a m. (
  MonadBaseControl IO m, MonadIO m, MonadMask m, HasCommandLineOptions context a
  ) => NodeOptions -> (CommandLineOptions a -> Int) -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m () -> SpecFree context m ()
parallelNFromArgs' :: forall context a (m :: * -> *).
(MonadBaseControl IO m, MonadIO m, MonadMask m,
 HasCommandLineOptions context a) =>
NodeOptions
-> (CommandLineOptions a -> Int)
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
parallelNFromArgs' NodeOptions
nodeOptions CommandLineOptions a -> Int
getParallelism SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
children = String
-> Label "parallelSemaphore" QSem
-> ExampleT context m QSem
-> (QSem -> ExampleT context m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall intro (l :: Symbol) context (m :: * -> *).
(HasCallStack, Typeable intro) =>
String
-> Label l intro
-> ExampleT context m intro
-> (intro -> ExampleT context m ())
-> SpecFree (LabelValue l intro :> context) m ()
-> SpecFree context m ()
introduce String
"Introduce parallel semaphore" Label "parallelSemaphore" QSem
parallelSemaphore ExampleT context m QSem
getQSem (ExampleT context m () -> QSem -> ExampleT context m ()
forall a b. a -> b -> a
const (ExampleT context m () -> QSem -> ExampleT context m ())
-> ExampleT context m () -> QSem -> ExampleT context m ()
forall a b. (a -> b) -> a -> b
$ () -> ExampleT context m ()
forall a. a -> ExampleT context m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
 -> SpecFree context m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree context m ()
forall a b. (a -> b) -> a -> b
$
  NodeOptions
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall context (m :: * -> *).
HasCallStack =>
NodeOptions -> SpecFree context m () -> SpecFree context m ()
parallel' NodeOptions
nodeOptions (SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
 -> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall a b. (a -> b) -> a -> b
$ String
-> (ExampleT
      (LabelValue "parallelSemaphore" QSem :> context) m [Result]
    -> ExampleT (LabelValue "parallelSemaphore" QSem :> context) m ())
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
-> SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
forall (m :: * -> *) context.
(Monad m, HasCallStack) =>
String
-> (ExampleT context m [Result] -> ExampleT context m ())
-> SpecFree context m ()
-> SpecFree context m ()
aroundEach String
"Take parallel semaphore" ExampleT
  (LabelValue "parallelSemaphore" QSem :> context) m [Result]
-> ExampleT (LabelValue "parallelSemaphore" QSem :> context) m ()
forall {m :: * -> *} {context} {a}.
(HasLabel context "parallelSemaphore" QSem, MonadReader context m,
 MonadMask m, MonadIO m) =>
m a -> m ()
claimRunSlot SpecFree (LabelValue "parallelSemaphore" QSem :> context) m ()
children
  where
    getQSem :: ExampleT context m QSem
getQSem = do
      Int
n <- CommandLineOptions a -> Int
getParallelism (CommandLineOptions a -> Int)
-> ExampleT context m (CommandLineOptions a)
-> ExampleT context m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Label "commandLineOptions" (CommandLineOptions a)
-> ExampleT context m (CommandLineOptions a)
forall (m :: * -> *) context (l :: Symbol) a.
(Monad m, HasLabel context l a, HasCallStack,
 MonadReader context m) =>
Label l a -> m a
getContext Label "commandLineOptions" (CommandLineOptions a)
forall {a}. Label "commandLineOptions" (CommandLineOptions a)
commandLineOptions
      IO QSem -> ExampleT context m QSem
forall a. IO a -> ExampleT context m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO QSem -> ExampleT context m QSem)
-> IO QSem -> ExampleT context m QSem
forall a b. (a -> b) -> a -> b
$ Int -> IO QSem
newQSem Int
n

    claimRunSlot :: m a -> m ()
claimRunSlot m a
f = do
      QSem
s <- Label "parallelSemaphore" QSem -> m QSem
forall (m :: * -> *) context (l :: Symbol) a.
(Monad m, HasLabel context l a, HasCallStack,
 MonadReader context m) =>
Label l a -> m a
getContext Label "parallelSemaphore" QSem
parallelSemaphore
      m () -> m () -> m () -> m ()
forall (m :: * -> *) a b c.
(HasCallStack, MonadMask m) =>
m a -> m b -> m c -> m c
bracket_ (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
waitQSem QSem
s) (IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ QSem -> IO ()
signalQSem QSem
s) (m a -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void m a
f)

parallelSemaphore :: Label "parallelSemaphore" QSem
parallelSemaphore :: Label "parallelSemaphore" QSem
parallelSemaphore = Label "parallelSemaphore" QSem
forall {k} (l :: Symbol) (a :: k). Label l a
Label

type HasParallelSemaphore context = HasLabel context "parallelSemaphore" QSem

defaultParallelNodeOptions :: NodeOptions
defaultParallelNodeOptions :: NodeOptions
defaultParallelNodeOptions = NodeOptions
defaultNodeOptions { nodeOptionsVisibilityThreshold = 70 }