{-# LANGUAGE ScopedTypeVariables #-}
--
-- Copyright (c) 2013   Stefan Wehr - http://www.stefanwehr.de
--
-- This library is free software; you can redistribute it and/or
-- modify it under the terms of the GNU Lesser General Public
-- License as published by the Free Software Foundation; either
-- version 2.1 of the License, or (at your option) any later version.
--
-- This library is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-- Lesser General Public License for more details.
--
-- You should have received a copy of the GNU Lesser General Public
-- License along with this library; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA
--

module Test.Framework.ThreadPool (

    ThreadPoolEntry, ThreadPool(..), StopFlag(..), sequentialThreadPool, parallelThreadPool
  , threadPoolTest

) where

import qualified Control.Exception as Ex
import Control.Monad
import Control.Monad.Trans
import Control.Concurrent

-- for tests
import System.Random

data StopFlag
    = DoStop
    | DoNotStop
      deriving (Eq, Show, Read)

type ThreadPoolEntry m a b = ( m a        -- pre-action, must not throw exceptions
                             , a -> IO b  -- action
                             , Either Ex.SomeException b -> m StopFlag
                               -- post-action, must not throw exceptions. If the result is
                               -- DoStop, the thread pool is terminated asap.
                             )

data ThreadPool m a b
    = ThreadPool
      { tp_run :: [ThreadPoolEntry m a b] -> m () }

sequentialThreadPool :: MonadIO m => ThreadPool m a b
sequentialThreadPool = ThreadPool runSequentially

parallelThreadPool :: MonadIO m => Int -> m (ThreadPool m a b)
parallelThreadPool n =
    do when (n < 1) $ liftIO (fail ("invalid number of workers: " ++ show n))
       return (ThreadPool (runParallel n))

runSequentially :: MonadIO m => [ThreadPoolEntry m a b] -> m ()
runSequentially entries =
    loop entries
    where
      loop [] = return ()
      loop (e:es) =
          do b <- run e
             if b == DoStop then return () else loop es
      run (pre, action, post) =
          do a <- pre
             b <- liftIO $ Ex.try (action a)
             post b

data WorkItem m b = Work (IO b) (Either Ex.SomeException b -> m StopFlag) | Done

instance Show (WorkItem m b) where
    show (Work _ _) = "Work"
    show Done = "Done"

type NamedMVar a = (String, MVar a)
type NamedChan a = (String, Chan a)

type ToWorker m b = NamedMVar (WorkItem m b)

data WorkResult m b = WorkResult (m StopFlag) (ToWorker m b)

instance Show (WorkResult m b) where
    show _ = "WorkResult"

type FromWorker m b = NamedChan (WorkResult m b)

runParallel :: forall m a b . MonadIO m => Int -> [ThreadPoolEntry m a b] -> m ()
runParallel _ [] = return ()
runParallel n entries =
    do when (n < 1) $ liftIO (fail ("invalid number of workers: " ++ show n))
       fromWorker <- liftIO $ newNamedChan "fromWorker"
       let nWorkers = min n (length entries)
       toWorkers <- mapM (\i -> liftIO $ mkWorker i fromWorker) [1..nWorkers]
       let (initEntries, restEntries) = splitAt nWorkers entries
       mapM_ (\(mvar, entry) -> runEntry entry mvar) (zip toWorkers initEntries)
       loop fromWorker nWorkers restEntries
    where
      loop :: FromWorker m b -> Int -> [ThreadPoolEntry m a b] -> m ()
      loop fromWorker nWorkers [] =
          cleanup fromWorker nWorkers
      loop fromWorker nWorkers (x:xs) =
          do (toWorker, stop) <- waitForWorkerResult fromWorker
             if stop == DoStop
             then return ()
             else do runEntry x toWorker
                     loop fromWorker nWorkers xs
      cleanup :: FromWorker m b -> Int -> m ()
      -- n is the number of workers that will still write to fromWorker
      cleanup fromWorker n =
          do debug ("cleanup, n=" ++ show n)
             (toWorker, _) <- waitForWorkerResult fromWorker
             liftIO $ putNamedMVar toWorker Done
             when (n > 1) $ cleanup fromWorker (n - 1)
      waitForWorkerResult :: FromWorker m b -> m (ToWorker m b, StopFlag)
      waitForWorkerResult fromWorker =
          do WorkResult postAction toWorker <- liftIO $ readNamedChan fromWorker
             b <- postAction
             return (toWorker, b)
      runEntry :: ThreadPoolEntry m a b -> ToWorker m b -> m ()
      runEntry (pre, action, post) toWorker =
          do a <- pre
             liftIO $ putNamedMVar toWorker (Work (action a) post)
      mkWorker :: Int -> FromWorker m b -> IO (ToWorker m b)
      mkWorker i fromWorker =
          do toWorker <- newEmptyNamedMVar ("worker" ++ show i)
             let loop = do workItem <- takeNamedMVar toWorker
                           case workItem of
                             Done ->
                                 do debug ("worker" ++ show i ++ " exiting!")
                                    return ()
                             Work action post ->
                                 do res <- Ex.try action
                                    _ <- Ex.evaluate res
                                    writeNamedChan fromWorker (WorkResult (post res) toWorker)
                                    loop
             _ <- forkIO (loop `Ex.catch` (\(e::Ex.BlockedIndefinitelyOnMVar) ->
                                          fail ("worker " ++ show i ++ ": " ++ show e)))
             return toWorker

--
-- Debugging and testing
--

_DEBUG_ = False

newNamedChan :: String -> IO (NamedChan a)
newNamedChan name =
    do chan <- newChan
       return (name, chan)

readNamedChan :: Show a => NamedChan a -> IO a
readNamedChan (name, chan) =
    do debug ("readChan[" ++ name ++ "]...")
       x <- readChan chan
       debug ("DONE readChan[" ++ name ++"]=" ++ show x)
       return x

writeNamedChan :: Show a => NamedChan a -> a -> IO ()
writeNamedChan (name, chan) x =
    do debug ("writeChan[" ++ name ++ "]=" ++ show x)
       writeChan chan x

newEmptyNamedMVar :: String -> IO (NamedMVar a)
newEmptyNamedMVar name =
    do mvar <- newEmptyMVar
       return (name, mvar)

putNamedMVar :: Show a => NamedMVar a -> a -> IO ()
putNamedMVar (name, mvar) x =
    do debug ("putMVar[" ++ name ++ "]=" ++ show x ++ "...")
       putMVar mvar x
       debug ("DONE putMVar[" ++ name ++ "]=" ++ show x)

takeNamedMVar :: Show a => NamedMVar a -> IO a
takeNamedMVar (name, mvar) =
    do debug ("takeMVar[" ++ name ++ "]...")
       x <- takeMVar mvar
       debug ("DONE takeMVar[" ++ name ++ "]=" ++ show x)
       return x

debug :: MonadIO m => String -> m ()
debug s = if _DEBUG_ then liftIO $ putStrLn s else return ()

runTestParallel :: Int -> Int -> IO ()
runTestParallel nEntries n =
    do debug ("Running test " ++ show n)
       boxes <- mapM (\i -> do mvar <- newEmptyNamedMVar ("testbox" ++ show i)
                               return (mvar, i))
                      [1..nEntries]
       let entries = map mkEntry boxes
       runParallel n entries
       debug ("Checking boxes in test " ++ show n)
       --runSequentially entries
       mapM_ assertBox boxes
       debug ("Test " ++ show n ++ " successful")
    where
      mkEntry (mvar, i) =
          let pre = myThreadId
              post x = case x of
                         Left err -> fail ("Exception in worker thread: " ++ show err)
                         Right y -> do tid <- myThreadId
                                       putNamedMVar mvar (y, tid)
                                       return DoNotStop
              action x = do tid <- myThreadId
                            j <- randomIO
                            let micros = (j `mod` 50)
                            threadDelay micros
                            return (x, tid, i)
          in (pre, action, post)
      assertBox (mvar, i) =
         do ((preTid, actionTid, i'), postTid) <- takeNamedMVar mvar
            tid <- myThreadId
            assertEq "pre-tid" tid preTid
            assertEq "post-tid" tid postTid
            assertNeq "action-tid" tid actionTid
            assertEq "i" i i'
      assertEq what exp act =
          when (exp /= act) $ fail (what ++ " wrong, expected=" ++ show exp ++ ", actual=" ++
                                    show act)
      assertNeq what exp act =
          when (exp == act) $ fail (what ++ " wrong, did not expected " ++ show exp)

threadPoolTest (i, j) nEntries =
    mapM (runTestParallel nEntries) [i..j] `Ex.catch`
             (\(e::Ex.BlockedIndefinitelyOnMVar) ->
                  fail ("main-thread blocked " ++ show e))