{-# LANGUAGE ScopedTypeVariables, DeriveDataTypeable, ImplicitParams #-}
-- | A module of concurrent higher order functions.
module Control.CUtils.Conc (ExceptionList(..), ConcException(..), assocFold, concF_, conc_, concF, conc, concP, oneOfF, oneOf) where

import Prelude hiding (catch)
import Control.Exception
import Data.Typeable
import Control.Concurrent.QSemN
import Control.Concurrent.Chan
import GHC.Conc
import Data.Array.IO (newArray_, readArray, writeArray, getElems, IOArray)
import Data.Array
-- import Data.Array.Unsafe
import Data.Array.MArray
import Data.IORef
import Control.Monad

-- | For exceptions caused by caller code.
data ExceptionList = ExceptionList [SomeException] deriving (Show, Typeable)

instance Exception ExceptionList

-- | For internal errors. If a procedure throws this, some threads it created may still be running. It is thrown separately from ExceptionList.
data ConcException = ConcException deriving (Show, Typeable)

instance Exception ConcException

simpleConc_ mnds = do
	sem <- newQSemN 0
	mapM_ (\m -> forkIO (do
			m
			signalQSemN sem 1))
		mnds
	waitQSemN sem (length mnds)

divideUp nPieces nVals = zip (0 : divisions) divisions where
	divisions = if nPieces >= nVals then
			[1..nVals]
		else
			map (`div` nPieces) $ take nPieces $ iterate (nVals +) nVals

getExceptions exs = do
	writeChan exs Nothing
	exslst <- let chanToList exslst = do
					may <- readChan exs
					case may of
						Just ex -> case fromException ex of
							Just (_ :: ConcException) -> throwIO ex
							Nothing -> chanToList (ex : exslst)
						Nothing -> return exslst in
			chanToList []
	unless (null exslst) $ throwIO (ExceptionList exslst)

-- | Runs an associative folding function on the given list. Note: this function only spawns enough threads to make effective use of the /capabilities/. Any two list elements may be processed sequentially or concurrently. To get parallelism, you have to set the numCapabilities value, e.g. using GHC's +RTS -N flag.
assocFold :: forall a. (a -> a -> IO a) -> Array Int a -> IO a
assocFold f parm = do
	let (lo, hi) = bounds parm
	when (lo > hi) $ error "Conc.assocFold: empty list"
	exs <- newChan
	ar <- (newArray_ (0, (rangeSize (bounds parm) `min` numCapabilities) - 1) :: IO (IOArray Int a))
	let
		rtnException ex = writeChan exs (Just ex) >> return undefined
		innerExHandler m = catch m rtnException
		outerExHandler m = catch m (\(_ :: SomeException) -> rtnException (toException ConcException)) in
		outerExHandler $ simpleConc_ $ map (\(i, (x, y)) ->
			innerExHandler $ foldM f (parm ! x) (map (parm !) [x+1..y]) >>= writeArray ar i) $ zip [0..] (divideUp numCapabilities (rangeSize (bounds parm)))
	getExceptions exs
	(x:xs) <- getElems ar
	foldM f x xs

-- |
concF_ :: (?seq :: Bool) => Int -> (Int -> IO ()) -> IO ()
concF_ n mnds = do
	exs <- newChan
	let
		rtnException ex = writeChan exs (Just ex) >> return undefined
		innerExHandler m = catch m rtnException
		outerExHandler m = catch m (\(_ :: SomeException) -> rtnException (toException ConcException)) in
		outerExHandler $ simpleConc_ $ map (\(x, y) -> outerExHandler $ (if ?seq then sequence_ else simpleConc_) $ map (innerExHandler . mnds) [x..y-1]) $ divideUp numCapabilities n
	getExceptions exs

partConc_ f mnds = concF_ (rangeSize (bounds mnds)) $ f . (+ fst (bounds mnds))

-- |
conc_ mnds = partConc_ (mnds !) mnds

unsafeFreeze' :: IOArray Int e -> IO (Array Int e)
unsafeFreeze' = freeze

partConcF bnds f mnds = do
	res <- newArray_ bnds
	f (\i -> do
		x <- mnds i
		writeArray res i x)
	unsafeFreeze' res

-- | The next three functions take an implicit parameter ?seq. Set it to True
-- if you want to only spawn threads for the capabilities (same as /assocFold/,
-- good for speed). if you need all the actions to be executed concurrently,
-- set it to False.
--
-- These functions promise O(m f(n)/c) time, provided:
--
--   * unsafeFreeze does a pointer cast (which it doesn't)
--
--   * green threads are created on the same OS thread as the creating
--     thread where possible
--
-- n is the number of computations which are indexed from 0 to n - 1.
concF n = partConcF (0, n - 1) (concF_ n)

-- | Runs several computations concurrently, and returns their results as an array. Waits for all threads to end before returning.
conc mnds = partConcF (bounds mnds) (\f -> partConc_ f mnds) (mnds !)

-- | Version of concF specialized for two computations.
concP m m2 = liftM ((\[Left x, Right y] -> (x, y)) . elems)
	$ concF 2 (\i -> if i == 0 then
				liftM Left m
			else
				liftM Right m2)

partOneOfF bnds mnds = do
	thds <- newIORef []
	chn <- newChan
	finally (do
		mapM_ (\n -> do
			thd <- forkIO (catch (mnds n >>= writeChan chn . Right) (\(ex :: SomeException) -> writeChan chn (Left ex) >> return undefined))
			modifyIORef thds (thd:))
			(range bnds)
		let chanToList n exs = if n == rangeSize bnds then
				throwIO (ExceptionList exs)
			else readChan chn >>=
				either
					(chanToList (n + 1) . (:exs))
					return in
			chanToList 0 [])
		(catch (readIORef thds >>= mapM_ killThread) (\(_ :: SomeException) -> throwIO ConcException))

-- |
oneOfF :: Int -> (Int -> IO a) -> IO a
oneOfF n = partOneOfF (0, n - 1)

-- | Runs several computations in parallel, and returns one of their results (terminating the other computations).
oneOf :: Array Int (IO a) -> IO a
oneOf mnds = partOneOfF (bounds mnds) (mnds !)