{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}

module Test.Hspec.Core.Runner.JobQueue (
  MonadIO
, Job
, Concurrency(..)
, JobQueue
, withJobQueue
, enqueueJob
) where

import           Prelude ()
import           Test.Hspec.Core.Compat hiding (Monad)
import qualified Test.Hspec.Core.Compat as M

import           Control.Concurrent
import           Control.Concurrent.Async (Async, AsyncCancelled(..), async, waitCatch, asyncThreadId)

import           Control.Monad.IO.Class (liftIO)
import qualified Control.Monad.IO.Class as M

-- for compatibility with GHC < 7.10.1
type Monad m = (Functor m, Applicative m, M.Monad m)
type MonadIO m = (Monad m, M.MonadIO m)

type Job m progress a = (progress -> m ()) -> m a

data Concurrency = Sequential | Concurrent

data JobQueue = JobQueue {
  JobQueue -> Semaphore
_semaphore :: Semaphore
, JobQueue -> CancelQueue
_cancelQueue :: CancelQueue
}

data Semaphore = Semaphore {
  Semaphore -> IO ()
_wait :: IO ()
, Semaphore -> IO ()
_signal :: IO ()
}

type CancelQueue = IORef [Async ()]

withJobQueue :: Int -> (JobQueue -> IO a) -> IO a
withJobQueue :: forall a. Int -> (JobQueue -> IO a) -> IO a
withJobQueue Int
concurrency = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO JobQueue
new JobQueue -> IO ()
cancelAll
  where
    new :: IO JobQueue
    new :: IO JobQueue
new = Semaphore -> CancelQueue -> JobQueue
JobQueue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO Semaphore
newSemaphore Int
concurrency forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef []

    cancelAll :: JobQueue -> IO ()
    cancelAll :: JobQueue -> IO ()
cancelAll (JobQueue Semaphore
_ CancelQueue
cancelQueue) = forall a. IORef a -> IO a
readIORef CancelQueue
cancelQueue forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. [Async a] -> IO ()
cancelMany

    cancelMany :: [Async a] -> IO ()
    cancelMany :: forall a. [Async a] -> IO ()
cancelMany [Async a]
jobs = do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a. Async a -> IO ()
notifyCancel [Async a]
jobs
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall a. Async a -> IO (Either SomeException a)
waitCatch [Async a]
jobs

    notifyCancel :: Async a -> IO ()
    notifyCancel :: forall a. Async a -> IO ()
notifyCancel = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall e. Exception e => ThreadId -> e -> IO ()
throwTo AsyncCancelled
AsyncCancelled forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Async a -> ThreadId
asyncThreadId

newSemaphore :: Int -> IO Semaphore
newSemaphore :: Int -> IO Semaphore
newSemaphore Int
capacity = do
  QSem
sem <- Int -> IO QSem
newQSem Int
capacity
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ IO () -> IO () -> Semaphore
Semaphore (QSem -> IO ()
waitQSem QSem
sem) (QSem -> IO ()
signalQSem QSem
sem)

enqueueJob :: MonadIO m => JobQueue -> Concurrency -> Job IO progress a -> IO (Job m progress (Either SomeException a))
enqueueJob :: forall (m :: * -> *) progress a.
MonadIO m =>
JobQueue
-> Concurrency
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
enqueueJob (JobQueue Semaphore
sem CancelQueue
cancelQueue) Concurrency
concurrency = case Concurrency
concurrency of
  Concurrency
Sequential -> forall (m :: * -> *) progress a.
MonadIO m =>
CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runSequentially CancelQueue
cancelQueue
  Concurrency
Concurrent -> forall (m :: * -> *) progress a.
MonadIO m =>
Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runConcurrently Semaphore
sem CancelQueue
cancelQueue

runSequentially :: forall m progress a. MonadIO m => CancelQueue -> Job IO progress a -> IO (Job m progress (Either SomeException a))
runSequentially :: forall (m :: * -> *) progress a.
MonadIO m =>
CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runSequentially CancelQueue
cancelQueue Job IO progress a
action = do
  MVar ()
barrier <- forall a. IO (MVar a)
newEmptyMVar
  let
    wait :: IO ()
    wait :: IO ()
wait = forall a. MVar a -> IO a
takeMVar MVar ()
barrier

    signal :: m ()
    signal :: m ()
signal = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO ()
putMVar MVar ()
barrier ()

  Job m progress (Either SomeException a)
job <- forall (m :: * -> *) progress a.
MonadIO m =>
Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runConcurrently (IO () -> IO () -> Semaphore
Semaphore IO ()
wait forall (m :: * -> *). Applicative m => m ()
pass) CancelQueue
cancelQueue Job IO progress a
action
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ \ progress -> m ()
notifyPartial -> m ()
signal forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Job m progress (Either SomeException a)
job progress -> m ()
notifyPartial

data Partial progress a = Partial progress | Done

runConcurrently :: forall m progress a. MonadIO m => Semaphore -> CancelQueue -> Job IO progress a -> IO (Job m progress (Either SomeException a))
runConcurrently :: forall (m :: * -> *) progress a.
MonadIO m =>
Semaphore
-> CancelQueue
-> Job IO progress a
-> IO (Job m progress (Either SomeException a))
runConcurrently (Semaphore IO ()
wait IO ()
signal) CancelQueue
cancelQueue Job IO progress a
action = do
  MVar (Partial progress a)
result :: MVar (Partial progress a) <- forall a. IO (MVar a)
newEmptyMVar
  let
    worker :: IO a
    worker :: IO a
worker = forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ IO ()
wait IO ()
signal forall a b. (a -> b) -> a -> b
$ do
      forall a. IO a -> IO a
interruptible (Job IO progress a
action progress -> IO ()
partialResult) forall a b. IO a -> IO b -> IO a
`finally` IO ()
done
      where
        partialResult :: progress -> IO ()
        partialResult :: progress -> IO ()
partialResult = forall a. MVar a -> a -> IO ()
replaceMVar MVar (Partial progress a)
result forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall progress a. progress -> Partial progress a
Partial

        done :: IO ()
        done :: IO ()
done = forall a. MVar a -> a -> IO ()
replaceMVar MVar (Partial progress a)
result forall progress a. Partial progress a
Done

    pushOnCancelQueue :: Async a -> IO ()
    pushOnCancelQueue :: Async a -> IO ()
pushOnCancelQueue = (forall a. IORef a -> (a -> a) -> IO ()
modifyIORef CancelQueue
cancelQueue forall b c a. (b -> c) -> (a -> b) -> a -> c
. (:) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Functor f => f a -> f ()
void)

  Async a
job <- forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. IO a -> IO (Async a)
async IO a
worker) Async a -> IO ()
pushOnCancelQueue forall (m :: * -> *) a. Monad m => a -> m a
return

  let
    waitForResult :: (progress -> m ()) -> m (Either SomeException a)
    waitForResult :: Job m progress (Either SomeException a)
waitForResult progress -> m ()
notifyPartial = do
      Partial progress a
r <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. MVar a -> IO a
takeMVar MVar (Partial progress a)
result)
      case Partial progress a
r of
        Partial progress
progress -> progress -> m ()
notifyPartial progress
progress forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Job m progress (Either SomeException a)
waitForResult progress -> m ()
notifyPartial
        Partial progress a
Done -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Async a -> IO (Either SomeException a)
waitCatch Async a
job

  forall (m :: * -> *) a. Monad m => a -> m a
return Job m progress (Either SomeException a)
waitForResult

replaceMVar :: MVar a -> a -> IO ()
replaceMVar :: forall a. MVar a -> a -> IO ()
replaceMVar MVar a
mvar a
p = forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar a
mvar forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. MVar a -> a -> IO ()
putMVar MVar a
mvar a
p