{-# LANGUAGE Safe, GADTs, Rank2Types, ImplicitParams, Arrows #-} -- | An implementation of nested data parallelism (due to Simon Peyton Jones et al) module Control.CUtils.DataParallel ( -- * Flattenable arrays ArrC, newArray, inject, project, -- * The arrows and associated operations Structural, A, unA, mapA', liftA, countA, indexA, zipA, unzipA, concatA, eval, -- * Examples nQueens, sorting, permute) where import Data.Array import Data.List import Data.Monoid (Any(Any)) import Control.Category import Control.Arrow import Control.Monad.Writer (Writer, tell, runWriter) import Control.Monad.Identity import Control.Monad import Control.CUtils.Conc import Control.CUtils.StrictArrow import Prelude hiding (id, (.)) data Tree t = Node !t !(Array Int (Tree t)) data ArrC t = ArrC !(Array Int t) !(Array Int (Tree Int)) newArray ls = listArray (0, length ls - 1) ls inject ar = ArrC (ixmap (0, uncurry subtract (bounds ar)) (subtract (fst (bounds ar))) ar) (newArray [Node 0 (newArray []), Node (uncurry subtract (bounds ar) + 1) (newArray [])]) project (ArrC ar _) = ar instance Functor ArrC where fmap f (ArrC ar fr) = ArrC (fmap f ar) fr instance Show (t -> u) where showsPrec _ _ = (""++) data Structural a t u where Map :: Structural a t u -> Structural a (ArrC t) (ArrC u) Comp :: Structural a u v -> Structural a t u -> Structural a t v Id :: Structural a t t Product :: Structural a t u -> Structural a v w -> Structural a (t, v) (u, w) Lift :: a t u -> Structural a t u Count :: Structural a (t, Int) (ArrC (t, Int)) Index :: Structural a (ArrC t, Int) t Zip :: Structural a (ArrC t, ArrC u) (ArrC (t, u)) Unzip :: Structural a (ArrC (t, u)) (ArrC t, ArrC u) ClearMarks :: Structural a (ArrC t) (ArrC t) Separate :: Structural a (Either t u) (ArrC t, ArrC u) Combine :: Structural a (ArrC t, ArrC u) (Either t u) Pack :: Structural a (ArrC (ArrC t)) (ArrC t) Unpack :: Structural a (ArrC t) (ArrC (ArrC t)) -- | The 'A' arrow includes a set of primitives that may be executed concurrently. -- Programs are incrementally optimized as they are put together. A program may be -- optimized once, and the result saved for repeated use. -- -- Notes: -- -- * The exact output of the optimizer is subject to change. -- -- * The program must be a finite data structure, or optimization will diverge. data A a t u = A (forall v. Structural a v t -> Structural a v u) data Equal t u = (t ~ u) => Equal reassociate :: (Category a) => Structural a u v -> Either (Equal t u) (Structural a t u) -> Structural a t v reassociate (Comp a Id) = reassociate a reassociate (Comp a a2) = reassociate a . Right . reassociate a2 reassociate a = either (\Equal -> a) (a.) -- | Obtain a 'Structural' program from an 'A' program. unA (A f) = f id -- | Obtain a 'Structural' program but postcompose with another program. unA' :: A a u v -> Structural a t u -> Structural a t v unA' (A f) = f mapA' :: (ArrowChoice a) => A a t u -> A a (ArrC t) (ArrC u) mapA' (A f) = mapA (f id) liftA :: (Category a) => a t u -> A a t u liftA a = A (\a2 -> case a2 of Comp (Lift a2) a3 -> Comp (Lift (a . a2)) a3 _ -> Lift a . a2) pack :: (Category a) => A a (ArrC (ArrC t)) (ArrC t) pack = A (\a -> case a of Comp (Map (Comp (Map a) a2)) a3 -> Map a . unA' pack (Map a2 . a3) Comp (Map (Map a)) a2 -> Map a . unA' pack a2 Comp (Map (Comp Pack a)) a2 -> unA' pack (unA' pack (Map a . a2)) Comp (Map Pack) a2 -> unA' pack (unA' pack a2) Comp Unpack a2 -> a2 _ -> Pack . a) flatten :: Structural a t u -> Bool flatten (Comp a a2) = flatten a || flatten a2 flatten Id = False flatten Unpack = False flatten Pack = False flatten Zip = False flatten Unzip = False flatten Separate = False flatten Combine = False flatten _ = True -- | Mapping is the primary way of constructing nested data parallel programs. -- It applies an (arrow) transformation to each element of an array -- uniformly. A form of flattening transformation is applied to nested -- maps (following the NESL paper). The flattening transformation converts -- two levels of 'Map' into one level. mapA :: (ArrowChoice a) => Structural a t u -> A a (ArrC t) (ArrC u) mapA (Map a) | flatten a = A (\a2 -> case a2 of Comp Unpack a3 -> Unpack . unA' (mapA a) a3 _ -> Unpack . unA' (mapA a) (unA' pack a2)) mapA (Comp a a2) = mapA a . mapA a2 mapA Id = id mapA (Product a a2) = zipA . (mapA a *** mapA a2) . unzipA mapA Unpack = A (\a -> case a of Comp Unpack a -> Unpack . (Unpack . a) Comp (Map (Comp Pack a)) a2 -> Map a . a2 Comp (Map Pack) a -> a _ -> Map Unpack . a) mapA Count = A (Comp Unpack) . arr (\(ArrC ar fr) -> ArrC (newArray $ concatMap (\(x, n) -> map ((,) x) [0..n-1]) $ elems ar) (newArray $ zipWith Node (scanl (\n (_, m) -> n + m) 0 $ elems ar) (map (\(_, n) -> newArray [Node 0 (newArray []), Node n (newArray [])]) (elems ar) ++ [newArray []]))) mapA a = A (\a2 -> case a2 of Comp (Map a2) a3 -> Comp (Map (reassociate a (Right a2))) a3 _ -> Comp (Map a) a2) instance (Category a) => Category (A a) where id = A (\a -> a) A f . A g = A (f . g) instance (ArrowChoice a) => Arrow (A a) where arr = liftA . arr A f *** A g = A (\a -> case a of Comp (Product a2 a3) a4 -> Product (f a2) (g a3) . a4 _ -> Product (f id) (g id) . a) first a = a *** id second a = id *** a instance (ArrowChoice a) => ArrowChoice (A a) where a +++ a2 = A (\a3 -> case a3 of Comp Combine a3 -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) a3 _ -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) (Separate . a3)) left a = a +++ id right a = id +++ a instance Show (Structural a t u) where showsPrec prec (Map a) = ("Map " ++) . showParen (prec==11) (showsPrec 11 a) showsPrec _ (Comp a a2) = showsPrec 11 a . (" . "++) . showsPrec 11 a2 showsPrec prec (Product a a2) = showParen (prec>=3) (showsPrec 3 a . (" *** "++) . showsPrec 3 a2) showsPrec _ Count = ("Count"++) showsPrec _ Index = ("Index"++) showsPrec _ Zip = ("Zip"++) showsPrec _ Unzip = ("Unzip"++) showsPrec _ ClearMarks = ("Clr"++) showsPrec _ Pack = ("Pk"++) showsPrec _ Unpack = ("Unpk"++) showsPrec _ Separate = ("Sep"++) showsPrec _ Combine = ("Comb"++) showsPrec _ _ = ("_"++) instance (Category a) => Category (Structural a) where id = Id (.) = Comp mirror ei = either Right Left ei -- | Supplies an array of a repeated value paired with the index of each element. countA = A (Comp Count) -- | Access one index of an array. indexA = A (Comp Index) -- | An operation analogous to 'zip'. zipA :: (Category a) => A a (ArrC t, ArrC u) (ArrC (t, u)) zipA = A (\a -> case a of Comp Unzip a2 -> a2 Comp (Product (Map a) (Map a2)) a3 -> Map (Product a a2) . unA' zipA a3 Comp (Product (Map a) Id) a3 -> Map (Product a Id) . unA' zipA a3 Comp (Product Id (Map a2)) a3 -> Map (Product Id a2) . unA' zipA a3 _ -> Zip . a) -- | 'unzipA' and 'zipA' are inverses. unzipA :: (Category a) => A a (ArrC (t, u)) (ArrC t, ArrC u) unzipA = A (\a -> case a of Comp Zip a2 -> a2 Comp (Map (Product a2 a3)) a4 -> Product (Map a2) (Map a3) . unA' unzipA a4 _ -> Unzip . a) concatA :: (Category a) => A a (ArrC (ArrC t)) (ArrC t) concatA = A (Comp ClearMarks) . pack forcePair (x, y) = x `seq` y `seq` (x, y) -- | An evaluator for 'Structural' arrows. eval0 :: (Concurrent a, Strict a, ArrowChoice a, ?seq :: Bool) => Structural a t u -> a t u eval0 Count = arr_concF id >>> arr inject eval0 Index = arr (\(ArrC ar _, i) -> ar ! i) eval0 Zip = arr (\pr@(ArrC ar _, ArrC ar2 _) -> (pr, (snd (bounds ar) `min` snd (bounds ar2)) + 1)) >>> arr_concF (arr (\((ArrC ar _, ArrC ar2 _), i) -> forcePair (ar ! i, ar2 ! i))) >>> arr inject eval0 Unzip = arr (\ar -> (fmap fst ar, fmap snd ar)) eval0 ClearMarks = arr (\(ArrC ar fr) -> ArrC ar (newArray [ Node (i + j) fr3 | Node i fr2 <- elems fr, Node j fr3 <- elems fr2 ])) eval0 (Map a) = (arr (\(ArrC ar _) -> (ar, uncurry subtract (bounds ar) + 1)) >>> arr_concF (arr (uncurry (!)) >>> eval0 a)) &&& arr (\(ArrC _ fr) -> fr) >>> arr (uncurry ArrC) eval0 Pack = arr (\(ArrC ar _) -> ArrC (newArray $ concatMap (elems . project) $ elems ar) (newArray $ zipWith Node (scanl (\i (ArrC ar _) -> i + rangeSize (bounds ar)) 0 $ elems ar) (map (\(ArrC _ fr) -> fr) (elems ar) ++ [newArray []]))) eval0 Unpack = arr (\ arc@(ArrC _ fr) -> (arc, uncurry subtract (bounds fr))) >>> arr_concF (arr (\(ArrC ar fr, index) -> let Node i fr2 = fr ! index Node j _ = fr ! (index + 1) in ArrC (ixmap (0, j-i-1) (+i) ar) fr2)) >>> arr inject eval0 Separate = arr (\ei -> ((,) $! either (\x -> inject $ newArray [x]) (\_ -> inject $ newArray []) ei) $! either (\_ -> inject $ newArray []) (\x -> inject $ newArray [x]) ei) eval0 Combine = arr (\(ar, ar2) -> let a1 = project ar a2 = project ar2 in if uncurry subtract (bounds (project ar)) == 0 then Left $! a1 ! 0 else Right $! a2 ! 0) eval0 (Comp a a2) = force (eval0 a) . eval0 a2 eval0 Id = id eval0 (Lift a) = a eval0 (Product a a2) = arr forcePair . force (second (eval0 a2)) . arr forcePair . first (eval0 a) -- | Evaluates arrows. -- -- Notes: -- -- * Effects are supported, but with much weaker semantics than the Kleisli arrows -- of the monad. In particular, the 'Map' and '***' operations are allowed to be parallelized, -- but on the other hand parallelism is not guaranteed. eval a = let ?seq = True in eval0 a instance (Concurrent a, Strict a, ArrowChoice a, ArrowApply a) => ArrowApply (A a) where app = first (arr (eval . unA)) >>> liftA app ------------------------------- checkThreats n positions = n `elem` positions -- Check if there is a piece on the row || n `elem` zipWith (-) positions [1..] -- ... the diagonal || n `elem` zipWith (+) positions [1..] -- ... or the other diagonal checkThreats2 positions = or [ checkThreats n tl | n:tl <- tails positions ] nQueensImpl :: Int -> Int -> A (->) [Int] (ArrC [Int]) nQueensImpl _ n | n <= 0 = arr (\soln -> if checkThreats2 soln then inject (newArray []) else inject (newArray [soln])) nQueensImpl m n = arr (\partialSoln -> (partialSoln, m)) >>> countA >>> mapA' (arr (uncurry (flip (:))) >>> nQueensImpl m (pred n)) >>> concatA nQueens n = arr (\() -> []) >>> nQueensImpl n n ------------------------------- sorting :: (Ord t) => Int -> A (->) (ArrC t) (ArrC t) sorting depth | depth <= 0 = arr (inject . newArray . sort . elems . project) sorting depth = arr (\x -> if uncurry subtract (bounds (project x)) <= 0 then Left x else Right x) >>> id ||| (arr (\ar -> let x:xs = elems (project ar) (bef, aft) = partition (>> first (s *** s) >>> arr (\((bef, aft), x) -> inject (newArray (elems (project bef) ++ x : elems (project aft))))) where s = sorting (pred depth) -- Memoize the answer -- In order to make this recursive function a finite structure, there is a depth limit -- parameter, beyond which the standard 'sort' takes over. ------------------------------- permute :: A (->) (ArrC Int) (ArrC Int) permute = arr (\ar -> (ar, uncurry subtract (bounds (project ar)) + 1)) >>> countA >>> mapA' indexA