{-# LANGUAGE Trustworthy, DeriveDataTypeable #-}

-- | Lists suitable for parallel execution (taken from Hackage's monad-par package). (For converting to regular lists, there is the toList function in Data.Foldable.)
module Control.CUtils.AList (AList(..), filterAList, assocFold, monoid, lenAList, findAList, concatAList) where

import Control.Parallel
import Control.Monad
import Control.Applicative
import Data.Monoid
import Data.Foldable (Foldable, foldMap)
import Data.Traversable (Traversable, traverse, foldMapDefault)
import Data.Data

data AList t = Append (AList t) (AList t) | List [t] deriving (Eq, Ord, Show, Typeable, Data)

instance Monad AList where
	return x = List [x]
	Append ls ls2 >>= f = (a1 `par` a2) `pseq` Append a1 a2 where
		a1 = ls >>= f
		a2 = ls2 >>= f
	List ls >>= f = foldr (mplus.f) mzero ls

instance MonadPlus AList where
	mzero = List []
	mplus m n = (m `par` n) `pseq` case (m, n) of
		(List [x], List xs) -> List (x:xs)
		(List [x], Append y z) -> Append (mplus m y) z
		(List [], n) -> n
		(m, List []) -> m
		_ -> Append m n

instance Applicative AList where
	pure = return	
	(<*>) = ap

instance Alternative AList where
	empty = mzero
	(<|>) = mplus

instance Functor AList where
	fmap f m = m >>= return . f

instance Traversable AList where
	traverse f (Append ls ls2) = Append <$> traverse f ls <*> traverse f ls2
	traverse f (List ls) = List <$> traverse f ls

instance Foldable AList where
	foldMap = foldMapDefault

-- | Filters the AList using a predicate.
filterAList f ls = ls >>= \x -> List $ if f x then [x] else []

noNils (Append m n) = noNils m `mplus` noNils n
noNils ls = ls

assocFold0 f (Append ls ls2) = (x `par` y) `pseq` f x y where
	x = assocFold0 f ls
	y = assocFold0 f ls2
assocFold0 f (List ls) = foldl1 f ls

-- | Folds the AList with a function, that must be associative. This allows parallelism to be introduced.
assocFold f = assocFold0 f . noNils

-- | Combine monoid elements to get a result.
monoid ls = if noNils ls == List [] then
		mempty
	else
		assocFold mappend ls

-- | Length of an AList.
lenAList ls = if noNils ls == List [] then
		0
	else
		assocFold (+) (fmap (const 1) ls)

-- | Find the first element satisfying a predicate.
findAList f = getFirst . monoid . fmap (\x -> First $ if f x then Just x else Nothing)

-- | Concatenate an AList of ALists.
concatAList ls = ls >>= id