{-# LANGUAGE Trustworthy, ScopedTypeVariables, DeriveDataTypeable, ImplicitParams, FlexibleInstances, FlexibleContexts #-}
-- | A module of concurrent higher order functions.
module Control.CUtils.Conc (module Control.CUtils.ThreadPool, ExceptionList(..), ConcException(..), assocFold, Concurrent(..), concF_, concF, conc_, conc, concP, progressConcF, oneOfF, oneOf) where

import Prelude hiding (catch)
import Control.Exception
import Data.Typeable
import Control.Concurrent.QSemN
import Control.Concurrent.Chan
import Control.CUtils.ThreadPool
import GHC.Conc
import Data.Array.IO (newArray_, readArray, writeArray, getElems, IOArray)
import Data.Array
import Data.Array.Unsafe
import Data.Array.MArray
import Data.IORef
import Control.Monad
import Control.Arrow
import System.IO.Unsafe

-- | For exceptions caused by caller code.
data ExceptionList = ExceptionList [SomeException] deriving (Show, Typeable)

instance Exception ExceptionList

-- | For internal errors. If a procedure throws this, some threads it created may still be running. It is thrown separately from ExceptionList.
data ConcException = ConcException deriving (Show, Typeable)

instance Exception ConcException

simpleConc_ mnds = do
	sem <- newQSemN 0
	mapM_ (\m -> addToPool ?pool (do
			m
			signalQSemN sem 1))
		mnds
	waitQSemN sem (length mnds)

divideUp nPieces nVals = zip (0 : divisions) divisions where
	divisions = if nPieces >= nVals then
			[1..nVals]
		else
			map (`div` nPieces) $ take nPieces $ iterate (nVals +) nVals

getExceptions exs = do
	writeChan exs Nothing
	exslst <- let chanToList exslst = do
					may <- readChan exs
					case may of
						Just ex -> case fromException ex of
							Just (_ :: ConcException) -> throwIO ex
							Nothing -> chanToList (ex : exslst)
						Nothing -> return exslst in
			chanToList []
	unless (null exslst) $ throwIO (ExceptionList exslst)

-- | A type class of arrows that support some form of concurrency.
class Concurrent a where
	-- | Runs an associative folding function on the given array.
	--   Note: this function only spawns enough threads to make effective use of the /capabilities/.
	--   Any two list elements may be processed sequentially or concurrently. To get parallelism,
	--   you have to set the numCapabilities value, e.g. using GHC's +RTS -N flag.
	arr_assocFold :: (?pool :: BoxedThreadPool) => a (b, b) b -> (c -> b) -> a (b, Array Int c) b
	-- | The first parameter is the number of computations which are indexed from 0 to n - 1.
	arr_concF_ :: (?seq :: Bool, ?pool :: BoxedThreadPool) => a (t, Int) () -> a (t, Int) ()
	arr_concF :: (?seq :: Bool, ?pool :: BoxedThreadPool) => a (u, Int) t -> a (u, Int) (Array Int t)
	arr_oneOfF :: a (u, Int) b -> a (u, Int) b

instance Concurrent (Kleisli IO) where
	arr_assocFold f g = Kleisli $ \(init, parm) -> do
		let (lo, hi) = bounds parm
		when (lo > hi) $ error "Conc.arr_assocFold: empty list"
		exs <- newChan
		caps <- getNumCapabilities
		-- With unlimited caps, you can do sqrt(n) folds on each thread, then sqrt(n) to fold the results (O(sqrt(n)f(n)) time).
		let effectiveCaps = ceiling (sqrt (fromIntegral (rangeSize (bounds parm)))) `min` caps
		ar <- (newArray_ (0, effectiveCaps - 1) :: IO (IOArray Int b))
		let
			rtnException ex = writeChan exs (Just ex) >> return undefined
			innerExHandler m = catch m rtnException
			outerExHandler m = catch m (\(_ :: SomeException) -> rtnException (toException ConcException)) in
			outerExHandler $ simpleConc_ $ map (\(i, (x, y)) ->
				innerExHandler $ foldM (\x -> runKleisli f . (,) x . g . (parm !)) init [x..y] >>= writeArray ar i) $ zip [0..] (divideUp effectiveCaps (rangeSize $ bounds parm))
		getExceptions exs
		ls <- getElems ar
		foldM (curry (runKleisli f)) init ls
	arr_concF_ mnds = Kleisli $ \(parm, n) -> do
		exs <- newChan
		caps <- getNumCapabilities
		let
			rtnException ex = writeChan exs (Just ex) >> return undefined
			innerExHandler m = catch m rtnException
			outerExHandler m = catch m (\(_ :: SomeException) -> rtnException (toException ConcException)) in
			outerExHandler $
			simpleConc_ $ map (\(x, y) -> outerExHandler $ (if ?seq then sequence_ else simpleConc_) $ map (innerExHandler . runKleisli mnds . (,) parm) [x..y-1]) $ divideUp caps n
		getExceptions exs
	arr_concF mnds = Kleisli $ \(parm, n) -> partConcF (0, n - 1) (concF_ n) (runKleisli mnds . (,) parm)
	arr_oneOfF mnds = Kleisli $ \(parm, n) -> partOneOfF (0, n - 1) (runKleisli mnds . (,) parm)

-- '->' has no effects, but one can compute its results in parallel anyway (pointlessly,
-- in the case of 'arr_concF_').
instance Concurrent (->) where
	arr_assocFold f g x = unsafePerformIO $ assocFold (\x y -> return $! f (x, y)) g x
	arr_concF_ _ = arr (const ())
	arr_concF mnds (parm, n) = let ?seq = True in unsafePerformIO $ concF n ((return $!) . mnds . (,) parm)
	arr_oneOfF mnds (parm, n) = unsafePerformIO $ oneOfF n ((return $!) . mnds . (,) parm)

-- |
assocFold f g = runKleisli (arr_assocFold (Kleisli (uncurry f)) g)

partConc_ f mnds = concF_ (rangeSize (bounds mnds)) $ f . (+ fst (bounds mnds))

-- |
concF_ n mnds = runKleisli (arr_concF_ (Kleisli (mnds . snd))) ((), n)

-- |
concF n mnds = runKleisli (arr_concF (Kleisli (mnds . snd))) ((), n)

-- |
conc_ mnds = partConc_ (mnds !) mnds

unsafeFreeze' :: IOArray Int e -> IO (Array Int e)
unsafeFreeze' = unsafeFreeze

partConcF bnds f mnds = do
	res <- newArray_ bnds
	f (\i -> do
		x <- mnds i
		writeArray res i x)
	unsafeFreeze' res

-- | The next function takes an implicit parameter ?seq. Set it to True
-- if you want to only spawn threads for the capabilities (same as /assocFold/;
-- good for speed). If you need all the actions to be executed concurrently,
-- set it to False.

-- Runs several computations concurrently, and returns their results as an array. Waits for all threads to end before returning.
conc mnds = partConcF (bounds mnds) (\f -> partConc_ f mnds) (mnds !)

-- | Version of concF specialized for two computations.
concP m m2 = let ?seq = False in liftM ((\[Left x, Right y] -> (x, y)) . elems)
	$ concF 2 (\i -> if i == 0 then
			liftM Left m
		else
			liftM Right m2)

progressConcF n f = do
	res <- concF n (\i -> f i >>= \x -> when (i * 80 `mod` n == 0) (putChar '|') >> return x)
	putStrLn ""
	return res

partOneOfF bnds mnds = do
	thds <- newIORef []
	chn <- newChan
	finally (do
		mapM_ (\n -> do
			thd <- forkIO (catch (mnds n >>= writeChan chn . Right) (\(ex :: SomeException) -> writeChan chn (Left ex) >> return undefined))
			modifyIORef thds (thd:))
			(range bnds)
		let chanToList n exs = if n == rangeSize bnds then
				throwIO (ExceptionList exs)
			else readChan chn >>=
				either
					(chanToList (n + 1) . (:exs))
					return in
			chanToList 0 [])
		(catch (readIORef thds >>= mapM_ killThread) (\(_ :: SomeException) -> throwIO ConcException))

oneOfF n mnds = runKleisli (arr_oneOfF (Kleisli (mnds . snd))) ((), n)

-- | Runs several computations in parallel, and returns one of their results (terminating the other computations).
oneOf :: Array Int (IO a) -> IO a
oneOf mnds = partOneOfF (bounds mnds) (mnds !)