{-# LANGUAGE GADTs, Rank2Types, StandaloneDeriving, ImplicitParams #-}
-- | An implementation of nested data parallelism
module Control.CUtils.DataParallel (ArrC, inject, project, newArray, A(Count, Index, Zip, Unzip, Concat, Map, Comp, Arr, Prod, Sum), optimize, eval) where

import Data.Array
import Data.Tree
import Data.Monoid (Any(Any))
import Control.Category
import Control.Arrow
import Control.Monad.Writer (Writer, tell, runWriter)
import Control.Monad
import Control.CUtils.Conc
import System.IO.Unsafe
import Prelude hiding (id, (.))

data ArrC t = ArrC !(Array Int t) !(Forest Int)

inject ar = ArrC ar [Node 0 [], Node (uncurry subtract (bounds ar) + 1) []]

project (ArrC ar _) = ar

instance Functor ArrC where
	fmap f (ArrC ar ls) = ArrC (fmap f ar) ls

newArray ls = listArray (0, length ls - 1) ls

pairUp ls = zip ls (tail ls)

instance Show (t -> u) where
	showsPrec _ _ = ("<FUNCTION>"++)

-- | Constructors for caller's use
data A t u where
	Count :: A Int (ArrC Int)
	Index :: A (ArrC t, Int) t
	Zip :: A (ArrC t, ArrC u) (ArrC (t, u))
	Unzip :: A (ArrC (t, u)) (ArrC t, ArrC u)
	Concat :: A (ArrC (ArrC t)) (ArrC t)
	Map :: A t u -> A (ArrC t) (ArrC u)
	Comp :: A u v -> A t u -> A t v
	Arr :: (t -> u) -> A t u
	Prod :: A t u -> A v w -> A (t, v) (u, w)
	Sum :: A t u -> A v w -> A (Either t v) (Either u w)

	-- Internal constructors
	Id :: A t t
	Pack :: A (ArrC (ArrC t)) (ArrC t)
	Unpack :: A (ArrC t) (ArrC (ArrC t))
	PackProd :: A (t, u) (ArrC (Either t u))
	UnpackProd :: A (ArrC (Either t u)) (t, u)
	PackSum1 :: A (Either t (ArrC u)) (ArrC (Either t u))
	UnpackSum1 :: A (ArrC (Either t u)) (Either t (ArrC u))
	PackSum2 :: A (Either (ArrC t) u) (ArrC (Either t u))
	UnpackSum2 :: A (ArrC (Either t u)) (Either (ArrC t) u)

mirror ei = either Right Left ei

deriving instance Show (A t u)

instance Category A where
	id = arr id
	(.) = Comp

instance Arrow A where
	arr = Arr
	(***) = Prod
	first a = a *** arr id
	second a = arr id *** a

instance ArrowChoice A where
	(+++) = Sum
	left a = a +++ arr id
	right a = arr id +++ a

reassociate :: A u v -> A t u -> A t v
reassociate (Comp a a2) = reassociate a . reassociate a2
reassociate x = (x .)

-- Optimizer step 1. Pushes indexes and concats to the right and separates maps/products/sums.
-- Once this is done, the result should be internal layers of only Maps.
step :: A t u -> A t u
step (Comp (Map (Comp a a2)) a3) = step (Map (step a)) . (Map a2 . a3)
step (Comp (Map (Prod a a2)) a3) = Zip . ((Map a *** Map a2) . (Unzip . a3))
step (Comp (Map a) a2) = step (Map (step a)) . a2
step (Comp Index (Prod (Map a) a2)) = step a . (Index . second a2)
step (Comp Index (Prod Count a)) = arr (\(i, j) -> if inRange (0, i - 1) j then j else error $ "DataParallel.eval: bad index: " ++ show j) . second a
step (Comp Concat (Map (Map a))) = step (Map (step a)) . Concat
step (Comp Concat (Map Concat)) = Concat . Concat
step (Comp (Prod (Comp a a2) a3) a4) = step (Prod (step a) id) . (Prod a2 a3 . a4)
step (Comp (Prod a (Comp a2 a3)) a4) = step (Prod id (step a2)) . (Prod a a3 . a4)
step (Comp (Sum (Comp a a2) a3) a4) = step (Sum (step a) id) . (Sum a2 a3 . a4)
step (Comp (Sum a (Comp a2 a3)) a4) = step (Sum id (step a2)) . (Sum a a3 . a4)
step (Comp a (Comp a2 a3)) = case step (a . a2) of Comp a4 a5 -> a4 . step (a5 . a3)
step a = a

-- Optimizer step 2. Replaces nested arrays with the packed representation.
-- The first two steps will be repeated, until there is only one layer of Maps.
step2 :: A t u -> Writer Any (A t u)
step2 (Map (Map a)) = tell (Any True) >> liftM ((Unpack .) . (. Pack) . Map) (step2 a)
step2 (Prod a a2) = tell (Any True) >> liftM ((UnpackProd .) . (. PackProd)) (step2 (Map (Sum a a2)))
-- Sums create the possibility of recursion trees w/ variable depth.
step2 (Sum a (Map a2)) = tell (Any True) >> liftM2 (\x y -> UnpackSum1 . Map (Sum x y) . PackSum1) (step2 a) (step2 a2)
step2 (Sum (Map a) a2) = tell (Any True) >> liftM2 (\x y -> arr mirror . UnpackSum1 . Map (Sum y x) . PackSum1 . arr mirror) (step2 a) (step2 a2)
step2 (Sum a a2) = liftM2 (+++) (step2 a) (step2 a2)
step2 (Map a) = liftM Map (step2 a)
step2 (Comp a a2) = liftM2 (.) (step2 a) (step2 a2)
step2 a = return a

-- Optimizer step 3. Removes redundant packs and zips, combines maps/products/sums, pushes zips right.
step3 :: A t u -> Maybe (A t u)
step3 (Comp (Map a) (Comp (Map a2) a3)) = Just $ Map (repetition step3 (a . a2)) . a3
step3 (Comp Zip (Prod (Map a) (Map a2))) = Just $ Map (repetition step3 (a *** a2)) . Zip
step3 (Comp Zip (Prod Count Count)) = Just $ Map (arr (\x -> (x, x))) . (Count . arr (uncurry min))
step3 (Comp Zip (Comp Unzip a)) = Just a
step3 (Comp Pack (Comp Unpack a)) = Just a
step3 (Comp PackProd (Comp UnpackProd a)) = Just a
step3 (Comp PackSum1 (Comp UnpackSum1 a)) = Just a
step3 (Comp PackSum2 (Comp UnpackSum2 a)) = Just a
step3 (Comp (Sum a a2) (Sum a3 a4)) = Just $ repetition step3 (a . a3) +++ repetition step3 (a2 . a4)
step3 (Comp a (Comp a2 a3)) = liftM (a .) (step3 (a2 . a3))
step3 _ = Nothing

repetition f x = maybe x (repetition f) (f x)

repetition2 f x = if b then repetition2 f y else y where
	(y, Any b) = runWriter (f x)

-- | Optimizes an arrow for parallel execution. The arrow can be optimized once, and the result saved for multiple computations. (The exact output of the optimizer is subject to change.)
--
--   The arrow must be finitely examinable.
optimize = {-repetition step3 . -}repetition2 (liftM (`reassociate` arr id) . step2 . step) . (`reassociate` arr id)

eval0 :: (?seq :: Bool) => A t u -> t -> u
eval0 Count n = inject $ unsafePerformIO $ concF n (return $!)
eval0 Index (ArrC ar _, i) = ar ! i
eval0 Zip (ArrC ar _, ArrC ar2 _) = inject $ unsafePerformIO $ concF (snd (bounds ar) `min` snd (bounds ar2))
	(\i -> let x = ar ! i; y = ar2 ! i in x `seq` y `seq` return $! (x, y))
eval0 Unzip ar = (fmap fst ar, fmap snd ar)
eval0 Concat ar0 = ArrC ar [ Node (i + j) ls3 | Node i ls2 <- ls, Node j ls3 <- ls2 ] where ArrC ar ls = eval0 Pack ar0
eval0 (Map a) (ArrC ar ls) = ArrC (unsafePerformIO $ conc $ fmap ((return $!) . eval0 a) ar) ls
eval0 Pack (ArrC ar ls) = ArrC (newArray $ concatMap (elems . project) $ elems ar)
	(zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar)
		(map (\(ArrC _ ls) -> ls) (elems ar) ++ [[]]))
eval0 Unpack (ArrC ar ls) = inject $ newArray $ map
	(\(Node i ls, Node j _) -> ArrC (ixmap (0, j-i-1) (+i) ar) ls)
	(pairUp ls)
eval0 PackProd (x, y) = inject $ newArray [Left x, Right y]
eval0 UnpackProd ar = (let Left x = project ar ! 0 in x, let Right x = project ar ! 1 in x)
eval0 PackSum1 (Left x) = inject (newArray [Left x])
eval0 PackSum1 (Right ar) = fmap Right ar
eval0 UnpackSum1 ar = either Left (\_ -> Right (fmap (\(Right x) -> x) ar)) (project ar ! 0)
eval0 PackSum2 ei = fmap mirror $ eval0 PackSum1 $ mirror ei
eval0 UnpackSum2 ar = mirror $ eval0 UnpackSum1 $ fmap mirror ar
eval0 (Comp a a2) x = eval0 a $ eval0 a2 x
eval0 (Arr f) x = f x
eval0 (Prod a a2) (x, y) = b `seq` c `seq` (b, c) where b = eval0 a x; c = eval0 a2 y
eval0 (Sum a a2) ei = either (Left . eval0 a) (Right . eval0 a2) ei

-- | Evaluates arrows.
eval a = let ?seq = True in eval0 a