{-# LANGUAGE Trustworthy, GADTs, Rank2Types, ImplicitParams, Arrows, DeriveFunctor #-}
-- | An implementation of nested data parallelism (due to Simon Peyton Jones et al)
module Control.CUtils.DataParallel (Equal(Equal),
-- * Flattenable arrays
ArrC, newArray, inject, project,
-- * The arrows and associated operations
Structural, A, unA, mapA', liftA, countA, countA', splitOff, assoc, indexA, zipA, unzipA, concatA, dupA, fstA, sndA, eval,
-- * Examples
nQueens, sorting, permute, dotProduct, transpose') where

import qualified Data.Sequence as S
import Data.Array
import Data.List
import Data.Monoid (Any(Any))
import Data.Foldable (toList)
import Control.Parallel
import Control.Parallel.Strategies
import Control.Category
import Control.Arrow
import Control.Monad.Writer (Writer, tell, runWriter)
import Control.Monad.Identity
import Control.Monad
import Control.CUtils.Conc
import Control.CUtils.StrictArrow
import Prelude hiding (id, (.))

data Tree t = Node !t !(S.Seq (Tree t))

instance Functor Tree where
	-- 'fmap' on trees has the recurrence:
	-- U(n) = U(n/2) + n/c log^2(n) [based on the lemma about 'fastConcat'].
	-- assuming unlimited capabilities.
	-- It solves as O(n/c log^3 n).
	fmap f (Node x sq) =
		let sq' = fastConcat (return . fmap f) sq in
		(toList sq' `using` parList rseq) `pseq` Node (f x) sq'

data ArrC t = ArrC !(Array Int t) !(S.Seq (Tree Int)) deriving Functor

newArray ls = listArray (0, length ls - 1) ls

inject ar = ArrC (ixmap (0, uncurry subtract (bounds ar)) (subtract (fst (bounds ar))) ar) (S.fromList [Node 0 S.empty, Node (uncurry subtract (bounds ar) + 1) S.empty])

project (ArrC ar _) = ar

instance Show (t -> u) where
	showsPrec _ _ = ("<FUNCTION>"++)

data Structural a t u where
	Map :: Structural a t u -> Structural a (ArrC t) (ArrC u)
	Comp :: Structural a u v -> Structural a t u -> Structural a t v
	Id :: Structural a t t
	Product :: Structural a t u -> Structural a v w -> Structural a (t, v) (u, w)
	Lift :: a t u -> Structural a t u
	Count :: Structural a (t, [Int]) (ArrC (t, [Int]))
	Index :: Structural a (ArrC t, Int) t
	Split :: Structural a (ArrC t, Array Int Int) (ArrC t)
	{-Zip :: Structural a (ArrC t, ArrC u) (ArrC (t, u))
	Unzip :: Structural a (ArrC (t, u)) (ArrC t, ArrC u)-}
	ClearMarks :: Structural a (ArrC t) (ArrC t)
	Separate :: Structural a (Either t u) (ArrC t, ArrC u)
	Combine :: Structural a (ArrC t, ArrC u) (Either t u)
	Pack :: Structural a (ArrC (ArrC t)) (ArrC t)
	Unpack :: Structural a (ArrC t) (ArrC (ArrC t))
	Dup :: Structural a t (t, t)
	Fst :: Structural a (t, u) t
	Snd :: Structural a (t, u) u

-- | The 'A' arrow includes a set of primitives that may be executed concurrently.
--   Programs are incrementally optimized as they are put together. A program may be
--   optimized once, and the result saved for repeated use.
--
-- Notes:
--
--   * The exact output of the optimizer is subject to change.
--
--   * The program must be a finite data structure, or optimization may diverge.
--     Therefore recursive definitions do not work, unless something is done to
--     limit the depth.
data A a t u = A (forall v. Structural a v t -> Structural a v u)

sHead sq = case S.viewl sq of x S.:< _ -> x

sTail sq = case S.viewl sq of
	_ S.:< xs -> xs
	S.EmptyL -> S.empty

sLast sq = case S.viewr sq of _ S.:> x -> x

fromTo :: Int -> Int -> S.Seq t -> S.Seq t
fromTo n1 n2 sq =
	let (sq1, _) = S.splitAt n2 sq in
		snd$S.splitAt n1 sq1

pairUp sq = S.zip sq (sTail sq)

-- A concatenate function; it is described by the recurrence:
--
-- T(n, k) = 2T(n/2, k) + log(kn/2)
--
-- when running sequentially and
--
-- T(n, k) = T(n/2, k) + log(kn/2)
--
-- where k is the maximum length of a subentry,
--
-- when running in parallel. Consider splitting an array into c pieces
-- of roughly n/c each. The former recurrence solves as O(n/c log^2 (n/c));
-- the latter as O(log^2 c). Therefore the function runs in
-- O(n/c log^2 n) time [provided c <= n].
fastConcat :: (t -> S.Seq u) -> S.Seq t -> S.Seq u
fastConcat f sq = case S.length sq of
	0 -> S.empty
	1 -> f (sHead sq)
	n -> let
		(sq1, sq2) = S.splitAt (n `div` 2) sq
		cc1 = fastConcat f sq1
		cc2 = fastConcat f sq2 in
		(cc1 `par` cc2) `pseq` (cc1 S.>< cc2)

data Equal t u where
	Equal :: Equal t t

reassociate :: (Category a) => Structural a u v -> Either (Equal t u) (Structural a t u) -> Structural a t v
reassociate (Comp a Id) = reassociate a
reassociate (Comp a a2) = reassociate a . Right . reassociate a2
reassociate a = either (\Equal -> a) (a.)

-- | Obtain a 'Structural' program from an 'A' program.
unA (A f) = f id

-- | Obtain a 'Structural' program but postcompose with another program. 
unA' :: A a u v -> Structural a t u -> Structural a t v
unA' (A f) = f

mapA' :: (ArrowChoice a) => A a t u -> A a (ArrC t) (ArrC u)
mapA' (A f) = mapA (f id)

liftA :: (Category a) => a t u -> A a t u
liftA a = A (\a2 -> case a2 of
	Comp (Lift a2) a3 -> Comp (Lift (a . a2)) a3
	_ -> Lift a . a2)

pack :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
pack = A (\a -> case a of
	Comp (Map (Comp (Map a) a2)) a3 -> Map a . unA' pack (Map a2 . a3)
	Comp (Map (Map a)) a2 -> Map a . unA' pack a2
	Comp (Map (Comp Pack a)) a2 -> unA' pack (unA' pack (Map a . a2))
	Comp (Map Pack) a2 -> unA' pack (unA' pack a2)
	Comp Unpack a2 -> a2
	_ -> Pack . a)

flatten :: Structural a t u -> Bool
flatten (Comp a a2) = flatten a || flatten a2
flatten Id = False
flatten Unpack = False
flatten Pack = False
{-flatten Zip = False
flatten Unzip = False-}
flatten Separate = False
flatten Combine = False
flatten _ = True

{-flatCounts :: (ArrowChoice a) => A a ((t, [Int]), [Int]) (ArrC (ArrC (t, [Int]), [Int]))
flatCounts = zipA .
	(splitA
	. (mapA' (fstA . fstA &&& (arr (uncurry drop) . (sndA . fstA &&& sndA)))
		. countA
		. ((fstA . fstA &&& arr length . sndA) &&& arr (uncurry (flip (++))) . (sndA . fstA &&& sndA))
		&&& arr (\(ls, ls2) -> let n = product ls in newArray [0,n..n*product ls2]) . (sndA . fstA &&& sndA))
	&&& mapA' sndA . countA)-}

-- | Mapping is the primary way of constructing nested data parallel programs.
--   It applies an (arrow) transformation to each element of an array
--   uniformly. A form of flattening transformation is applied to nested
--   maps (following the NESL paper). The flattening transformation converts
--   two levels of 'Map' into one level.
mapA :: (ArrowChoice a) => Structural a t u -> A a (ArrC t) (ArrC u)
mapA (Map a) | flatten a = A (\a2 -> case a2 of
	Comp Unpack a3 -> Unpack . unA' (mapA a) a3
	Comp Split a3 -> Split . unA' (first (mapA' (mapA a))) a3
	Comp ClearMarks a3 -> ClearMarks . unA' (mapA' (mapA a)) a3
	_ -> Unpack . unA' (mapA a) (unA' pack a2))
mapA (Product a a2) = A (\a3 -> case a3 of
	Comp Count a4 -> Comp (Map (Product Id a2)) (unA' (countA . first (A (Comp a))) a4)
	Comp ClearMarks a4 -> ClearMarks . unA' (mapA (Product a a2)) a4
	Comp (Map (Product a4 a5)) a6 -> unA' (mapA (Product (a . a4) (a2 . a5))) a6
	Comp (Map (Comp (Product a4 a5) a6)) a7 -> unA' (mapA (Product (a . a4) (a2 . a5) . a6)) a7
	_ -> Map (Product a a2) . a3)
mapA (Comp a a2) = mapA a . mapA a2
mapA Id = id
-- mapA (Product a a2) = zipA . (mapA a *** mapA a2) . unzipA
mapA Unpack = A (\a -> case a of
	Comp Unpack a -> Unpack . (Unpack . a)
	Comp (Map (Comp Pack a)) a2 -> Map a . a2
	Comp (Map Pack) a -> a
	Comp ClearMarks a3 -> ClearMarks . unA' (mapA Unpack) a3
	_ -> Map Unpack . a)
mapA a = A (\a2 -> case a2 of
	Comp (Map a2) a3 -> Comp (Map (reassociate a (Right a2))) a3
	Comp ClearMarks a3 -> ClearMarks . unA' (mapA a) a3
	_ -> Comp (Map a) a2)

scrubIds (Comp Id x) = scrubIds x
scrubIds x = x

instance (Category a) => Category (A a) where
	id = A (\a -> a)
	A f . A g = A (f . scrubIds . g)

instance (ArrowChoice a) => Arrow (A a) where
	arr = liftA . arr
	A f *** A g = A (\a -> case a of
		Comp (Product a2 a3) a4 -> Product (f a2) (g a3) . a4
		_ -> Product (f id) (g id) . a)
	first a = a *** id
	second a = id *** a
	a &&& a2 = (a *** a2) . dupA

instance (ArrowChoice a) => ArrowChoice (A a) where
	a +++ a2 = A (\a3 -> case a3 of
		Comp Combine a3 -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) a3
		_ -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) (Separate . a3))
	left a = a +++ id
	right a = id +++ a

instance Show (Structural a t u) where
	showsPrec prec (Map a) = ("Map " ++) . showParen (prec==11) (showsPrec 11 a)
	showsPrec _ (Comp a a2) = showsPrec 11 a . (" . "++) . showsPrec 11 a2
	showsPrec prec (Product a a2) = showParen (prec>=3) (showsPrec 3 a . (" *** "++) . showsPrec 3 a2)
	showsPrec _ Count = ("Count"++)
	showsPrec _ Index = ("Index"++)
	showsPrec _ Split = ("Split"++)
	showsPrec _ ClearMarks = ("Clr"++)
	showsPrec _ Pack = ("Pk"++)
	showsPrec _ Unpack = ("Unpk"++)
	showsPrec _ Separate = ("Sep"++)
	showsPrec _ Combine = ("Comb"++)
	showsPrec _ Dup = ("Dup"++)
	showsPrec _ Fst = ("Fst"++)
	showsPrec _ Snd = ("Snd"++)
	showsPrec _ Id = ("Id"++)
	showsPrec _ _ = ("_"++)

instance (Category a) => Category (Structural a) where
	id = Id
	(.) = Comp

mirror ei = either Right Left ei

splitOff :: (ArrowChoice a) => A a ((t1, t2), u) ((t1, u), (t2, u))
splitOff = first fstA &&& first sndA

assoc :: (ArrowChoice a) => A a ((t, u), v) (t, (u, v))
assoc = fstA . fstA &&& (sndA . fstA &&& sndA)

-- | Access one index of an array.
indexA :: (ArrowChoice a) => A a (ArrC t, Int) t
indexA = A (\a -> case a of
	Comp (Product (Map a) a2) a3 -> a . unA' indexA (Product Id a2 . a3)
	-- Comp (Product Zip a) a2 -> unA' ((indexA *** indexA) . splitOff) (second a . a2)
	Comp (Product Count a) a2 -> unA' (fstA . fstA &&& arr (\(ns, i) -> snd (mapAccumL divMod i ns)) . (sndA . fstA &&& sndA)) (Product Id a . a2)
	_ -> Index . a)

-- | An operation analogous to 'zip', 'zipA' combines two packed arrays into a single array
-- element by element.
zipA :: (ArrowChoice a) => A a (ArrC t, ArrC u) (ArrC (t, u))
zipA = id &&& arr (\(ar, ar2) -> (uncurry subtract (bounds (project ar)) `min` uncurry subtract (bounds (project ar2))) + 1)
	>>> countA'
	>>> mapA' (splitOff >>> indexA *** indexA)

-- | 'unzipA' and 'zipA' are inverses.
unzipA :: (ArrowChoice a) => A a (ArrC (t, u)) (ArrC t, ArrC u)
unzipA = mapA' fstA &&& mapA' sndA

-- | Concatenation flattens out nested layers of arrays. The key operation used to implement
-- is erasing marks; erasing marks throws away the structure that would delineate the
-- edges of arrays; effectively flattening them into one array. The operation is divided
-- into packing and erasing marks, in the hope that the packing stage will fuse with an adjacent 'unpack'.
concatA :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
concatA = A (\a -> case a of
	Comp Split a2 -> unA' fstA a2
	_ -> Comp ClearMarks a) . pack

forcePair (x, y) = x `seq` y `seq` (x, y)

-- | Supplies an array of a repeated value paired with the index of each element.
-- Arguably adjacent 'countA's should fuse; however this is hard to implement, so I
-- have opted to provide a more powerful 'countA' that works on arrays of indices;
-- it generates arrays of indices lexicographically ordered.
countA :: (ArrowChoice a) => A a (t, [Int]) (ArrC (t, [Int]))
countA = A(Comp Count)
{- (\a -> case a of
	Comp (Product Count a2) a3 -> unA' flatCounts (unA' (second (A (Comp a2))) a3)
	Comp (Product (Comp a2 Count) a3) a4 -> unA' (mapA' (first (A (Comp a2))) . flatCounts . second (A (Comp a3))) a4
	Comp (Product (Comp Count a2) a3) a4 -> unA' flatCounts (unA' (A (Comp a2) *** A (Comp a3)) a4)
	Comp (Product (Comp a2 (Comp Count a3)) a4) a5 -> unA' (mapA' (first (A (Comp a2))) . flatCounts . (A (Comp a3) *** A (Comp a4))) a5
	_ -> Comp Count a)-}

countA' :: (ArrowChoice a) => A a (t, Int) (ArrC (t, Int))
countA' = second (arr return) >>> countA >>> mapA' (second (arr head))

-- | Replacements for common arrow functions make fusing work better.
dupA :: (Category a) => A a t (t, t)
dupA = A (Dup .)

fstA :: (Category a) => A a (t, u) t
fstA = A (\a -> case a of
	Comp Dup a -> a
	Comp (Product Id a) a2 -> Fst . (Product Id a . a2)
	Comp (Product a Id) a2 -> a . unA' fstA a2
	Comp (Product a a2) a3 -> a . unA' fstA (Product Id a2 . a3) -- Due to effects, cannot omit to do any operations
	_ -> Fst . a)

sndA :: (Category a) => A a (t, u) u
sndA = A (\a -> case a of
	Comp Dup a -> a
	Comp (Product a Id) a2 -> Snd . (Product a Id . a2)
	Comp (Product Id a) a2 -> a . unA' sndA a2
	Comp (Product a a2) a3 -> a2 . unA' sndA (Product a Id . a3)
	_ -> Snd . a)

-- Runs in O(log^2(n)) time in the number of elements.
binarySearch :: (Ord t) => t -> S.Seq t -> Int
binarySearch x sq = recurse 0 (S.length sq) sq where
	recurse off sz sq = if sz <= 1 then
			off
		else let
			sz' = sz `div` 2
			(sq1, sq2) = S.splitAt sz' sq
			y S.:< _ = S.viewl sq2 in
			if x < y then
				recurse off sz' sq1
			else
				recurse (off + sz') (sz - sz') sq2

packImpl (ArrC ar fr) = ArrC
	(arr_concF (\(_, i) -> let
		j = binarySearch i fr''
		i2 = S.index fr'' j
		ArrC ar' _ = ar ! j in
		ar' ! (i-i2))
		((), sz))
	fr'
	where
	fr' = S.fromList $ snd $ mapAccumL (\i (ArrC ar fr) -> let j = i + rangeSize (bounds ar) in (j, Node i (fastConcat ((return $!) . fmap (+i)) fr))) 0 $ elems ar ++ [ArrC (newArray []) S.empty]
	fr'' = fmap (\(Node i _) -> i) fr'
	_ S.:> sz = S.viewr fr''

unpackImpl (ArrC ar fr) = fastConcat
	(\(Node j fr2, Node k _) -> return $! ArrC (ixmap (0, k-j-1) (+j) ar) (fastConcat ((return $!) . fmap (subtract j)) fr2))
	(pairUp fr)

-- | An evaluator for 'Structural' arrows. A structural arrow may be obtained from an 'A' arrow
-- by either 'unA' or 'unA''.
--
-- Discussion of complexity bounds for various operations [these are provided c <= k]:
--
-- * Cost of 'ClearMarks' is O(k/c log^3(k)) in the number of subelements k.
--
-- * Cost of 'Pack' and 'Unpack' is O(k/c log^3(k) + k) in the number of subelements k.
-- 'Pack' is O(n) in the worst case in the number of spine elements n.
--
-- * 'Map' costs O(f) assuming unlimited capabilities where 'a' runs in O(f) time.
--
-- * 'Combine' and 'Separate' are both O(1).
eval0 :: (Concurrent a, Strict a, ArrowChoice a, ?seq :: Bool, ?pool :: BoxedThreadPool) => Structural a t u -> a t u
eval0 Count = id &&& arr(snd>>>product) >>> arr_concF (arr (\((x, ns), i) -> (x, snd (mapAccumL divMod i ns)))) >>> arr inject
eval0 Index = arr (\(ArrC ar _, i) -> ar ! i)
eval0 ClearMarks =
	arr (\(ArrC ar fr) ->
		ArrC ar (fastConcat id (fmap (\(Node _ fr) -> fr) fr)))
eval0 (Map a) = (arr (\(ArrC ar _) -> (ar, uncurry subtract (bounds ar) + 1)) >>> arr_concF (arr (uncurry (!)) >>> eval0 a)) &&& arr (\(ArrC _ fr) -> fr) >>> arr (uncurry ArrC)
eval0 Split = undefined
eval0 Pack = arr packImpl
eval0 Unpack = arr (inject . newArray . toList . unpackImpl)
eval0 Separate = arr (\ei -> ((,) $! either (\x -> inject $ newArray [x]) (\_ -> inject $ newArray []) ei) $! either (\_ -> inject $ newArray []) (\x -> inject $ newArray [x]) ei)
eval0 Combine = arr (\(ar, ar2) -> let
	a1 = project ar
	a2 = project ar2 in
	if uncurry subtract (bounds (project ar)) == 0 then Left $! a1 ! 0 else Right $! a2 ! 0)
eval0 (Comp a a2) = force (eval0 a) . eval0 a2
eval0 Id = id
eval0 (Lift a) = a
eval0 (Product a a2) = arr forcePair . force (second (eval0 a2)) . arr forcePair . first (eval0 a)
eval0 Dup = arr (\x -> forcePair (x, x))
eval0 Fst = arr fst
eval0 Snd = arr snd

-- | Evaluates arrows.
eval a = let ?seq = True in eval0 a

instance (Concurrent a, Strict a, ArrowChoice a, ArrowApply a) => ArrowApply (A a) where
	app = first (arr (eval . unA)) >>> liftA app where
		?pool = BoxedThreadPool NoPool

--------------------------------
-- Examples using NDP techniques

checkThreats n positions = n `elem` positions -- Check if there is a piece on the row
	|| n `elem` zipWith (-) positions [1..] -- ... the diagonal
	|| n `elem` zipWith (+) positions [1..] -- ... or the other diagonal

checkThreats2 positions = or [ checkThreats n tl | n:tl <- tails positions ]

nQueensImpl :: A (->) ((), [Int]) (ArrC [Int])
nQueensImpl = countA >>> mapA' (arr (\(_, soln) -> if checkThreats2 soln then inject (newArray []) else inject (newArray [soln])))
	>>> concatA

nQueens n = arr (\() -> ((), replicate n n)) >>> nQueensImpl

-------------------------------

sorting :: (Ord t) => Int -> A (->) (ArrC t) (ArrC t)
sorting depth | depth <= 0 = arr (inject . newArray . sort . elems . project)
sorting depth = arr (\x -> if uncurry subtract (bounds (project x)) <= 0 then Left x else Right x)
	>>> id
		||| (arr (\ar -> let
			x:xs = elems (project ar)
			(bef, aft) = partition (<x) xs in
			((inject (newArray bef), inject (newArray aft)), x))
			>>> first (s *** s)
			>>> arr (\((bef, aft), x) -> inject (newArray (elems (project bef) ++ x : elems (project aft)))))
	where s = sorting (pred depth) -- Memoize the answer
-- In order to make this recursive function a finite data structure, there is a depth limit
-- parameter, beyond which the standard 'sort' takes over.

-------------------------------

permute :: A (->) (ArrC Int) (ArrC Int)
permute = arr (\ar -> (ar, [uncurry subtract (bounds (project ar)) + 1])) >>> countA >>> mapA' (second (arr head) >>> indexA)

-------------------------------

dotProduct :: (Num t) => A (->) (ArrC t, ArrC t) t
dotProduct = proc (v1, v2) -> do
	vzip <- zipA -< (v1, v2)
	vdots <- mapA' (arr (uncurry (*))) -< vzip
	returnA -< sum $ elems $ project vdots

transpose' :: A (->) (ArrC (ArrC t)) (ArrC (ArrC t))
transpose' = proc m -> do
	firstrow <- indexA -< (m, 0)
	rows <- countA -< (m, [uncurry subtract (bounds (project firstrow)) + 1])
	-- Build skeleton of result array
	rowcols <- mapA' (proc (m, [ii]) -> do
		v <- countA -< ((m, ii), [uncurry subtract (bounds (project m)) + 1])
		mapA' (proc ((m, ii), [jj]) -> returnA -< (m, (ii, jj))) -< v) -< rows
	-- Build result
	mapA' (mapA' (proc (m, (ii, jj)) -> do
		v <- indexA -< (m, jj)
		indexA -< (v, ii)))
		-< rowcols