{-# LANGUAGE Trustworthy, GADTs, Rank2Types, ImplicitParams, Arrows, DeriveFunctor #-} -- | An implementation of nested data parallelism (due to Simon Peyton Jones et al) module Control.CUtils.DataParallel (Equal(Equal), -- * Flattenable arrays ArrC, newArray, inject, project, -- * The arrows and associated operations Structural, A, unA, mapA', liftA, countA, countA', splitOff, assoc, indexA, zipA, unzipA, concatA, dupA, fstA, sndA, eval, -- * Examples nQueens, sorting, permute, dotProduct, transpose') where import qualified Data.Sequence as S import Data.Array import Data.List import Data.Monoid (Any(Any)) import Data.Foldable (toList) import Control.Parallel import Control.Parallel.Strategies import Control.Category import Control.Arrow import Control.Monad.Writer (Writer, tell, runWriter) import Control.Monad.Identity import Control.Monad import Control.CUtils.Conc import Control.CUtils.StrictArrow import Prelude hiding (id, (.)) data Tree t = Node !t !(S.Seq (Tree t)) instance Functor Tree where -- 'fmap' on trees has the recurrence: -- U(n) = U(n/2) + n/c log^2(n) [based on the lemma about 'fastConcat']. -- assuming unlimited capabilities. -- It solves as O(n/c log^3 n). fmap f (Node x sq) = let sq' = fastConcat (return . fmap f) sq in (toList sq' `using` parList rseq) `pseq` Node (f x) sq' data ArrC t = ArrC !(Array Int t) !(S.Seq (Tree Int)) deriving Functor newArray ls = listArray (0, length ls - 1) ls inject ar = ArrC (ixmap (0, uncurry subtract (bounds ar)) (subtract (fst (bounds ar))) ar) (S.fromList [Node 0 S.empty, Node (uncurry subtract (bounds ar) + 1) S.empty]) project (ArrC ar _) = ar instance Show (t -> u) where showsPrec _ _ = (""++) data Structural a t u where Map :: Structural a t u -> Structural a (ArrC t) (ArrC u) Comp :: Structural a u v -> Structural a t u -> Structural a t v Id :: Structural a t t Product :: Structural a t u -> Structural a v w -> Structural a (t, v) (u, w) Lift :: a t u -> Structural a t u Count :: Structural a (t, [Int]) (ArrC (t, [Int])) Index :: Structural a (ArrC t, Int) t Split :: Structural a (ArrC t, Array Int Int) (ArrC t) {-Zip :: Structural a (ArrC t, ArrC u) (ArrC (t, u)) Unzip :: Structural a (ArrC (t, u)) (ArrC t, ArrC u)-} ClearMarks :: Structural a (ArrC t) (ArrC t) Separate :: Structural a (Either t u) (ArrC t, ArrC u) Combine :: Structural a (ArrC t, ArrC u) (Either t u) Pack :: Structural a (ArrC (ArrC t)) (ArrC t) Unpack :: Structural a (ArrC t) (ArrC (ArrC t)) Dup :: Structural a t (t, t) Fst :: Structural a (t, u) t Snd :: Structural a (t, u) u -- | The 'A' arrow includes a set of primitives that may be executed concurrently. -- Programs are incrementally optimized as they are put together. A program may be -- optimized once, and the result saved for repeated use. -- -- Notes: -- -- * The exact output of the optimizer is subject to change. -- -- * The program must be a finite data structure, or optimization may diverge. -- Therefore recursive definitions do not work, unless something is done to -- limit the depth. data A a t u = A (forall v. Structural a v t -> Structural a v u) sHead sq = case S.viewl sq of x S.:< _ -> x sTail sq = case S.viewl sq of _ S.:< xs -> xs S.EmptyL -> S.empty sLast sq = case S.viewr sq of _ S.:> x -> x fromTo :: Int -> Int -> S.Seq t -> S.Seq t fromTo n1 n2 sq = let (sq1, _) = S.splitAt n2 sq in snd$S.splitAt n1 sq1 pairUp sq = S.zip sq (sTail sq) -- A concatenate function; it is described by the recurrence: -- -- T(n, k) = 2T(n/2, k) + log(kn/2) -- -- when running sequentially and -- -- T(n, k) = T(n/2, k) + log(kn/2) -- -- where k is the maximum length of a subentry, -- -- when running in parallel. Consider splitting an array into c pieces -- of roughly n/c each. The former recurrence solves as O(n/c log^2 (n/c)); -- the latter as O(log^2 c). Therefore the function runs in -- O(n/c log^2 n) time [provided c <= n]. fastConcat :: (t -> S.Seq u) -> S.Seq t -> S.Seq u fastConcat f sq = case S.length sq of 0 -> S.empty 1 -> f (sHead sq) n -> let (sq1, sq2) = S.splitAt (n `div` 2) sq cc1 = fastConcat f sq1 cc2 = fastConcat f sq2 in (cc1 `par` cc2) `pseq` (cc1 S.>< cc2) data Equal t u where Equal :: Equal t t reassociate :: (Category a) => Structural a u v -> Either (Equal t u) (Structural a t u) -> Structural a t v reassociate (Comp a Id) = reassociate a reassociate (Comp a a2) = reassociate a . Right . reassociate a2 reassociate a = either (\Equal -> a) (a.) -- | Obtain a 'Structural' program from an 'A' program. unA (A f) = f id -- | Obtain a 'Structural' program but postcompose with another program. unA' :: A a u v -> Structural a t u -> Structural a t v unA' (A f) = f mapA' :: (ArrowChoice a) => A a t u -> A a (ArrC t) (ArrC u) mapA' (A f) = mapA (f id) liftA :: (Category a) => a t u -> A a t u liftA a = A (\a2 -> case a2 of Comp (Lift a2) a3 -> Comp (Lift (a . a2)) a3 _ -> Lift a . a2) pack :: (Category a) => A a (ArrC (ArrC t)) (ArrC t) pack = A (\a -> case a of Comp (Map (Comp (Map a) a2)) a3 -> Map a . unA' pack (Map a2 . a3) Comp (Map (Map a)) a2 -> Map a . unA' pack a2 Comp (Map (Comp Pack a)) a2 -> unA' pack (unA' pack (Map a . a2)) Comp (Map Pack) a2 -> unA' pack (unA' pack a2) Comp Unpack a2 -> a2 _ -> Pack . a) flatten :: Structural a t u -> Bool flatten (Comp a a2) = flatten a || flatten a2 flatten Id = False flatten Unpack = False flatten Pack = False {-flatten Zip = False flatten Unzip = False-} flatten Separate = False flatten Combine = False flatten _ = True {-flatCounts :: (ArrowChoice a) => A a ((t, [Int]), [Int]) (ArrC (ArrC (t, [Int]), [Int])) flatCounts = zipA . (splitA . (mapA' (fstA . fstA &&& (arr (uncurry drop) . (sndA . fstA &&& sndA))) . countA . ((fstA . fstA &&& arr length . sndA) &&& arr (uncurry (flip (++))) . (sndA . fstA &&& sndA)) &&& arr (\(ls, ls2) -> let n = product ls in newArray [0,n..n*product ls2]) . (sndA . fstA &&& sndA)) &&& mapA' sndA . countA)-} -- | Mapping is the primary way of constructing nested data parallel programs. -- It applies an (arrow) transformation to each element of an array -- uniformly. A form of flattening transformation is applied to nested -- maps (following the NESL paper). The flattening transformation converts -- two levels of 'Map' into one level. mapA :: (ArrowChoice a) => Structural a t u -> A a (ArrC t) (ArrC u) mapA (Map a) | flatten a = A (\a2 -> case a2 of Comp Unpack a3 -> Unpack . unA' (mapA a) a3 Comp Split a3 -> Split . unA' (first (mapA' (mapA a))) a3 Comp ClearMarks a3 -> ClearMarks . unA' (mapA' (mapA a)) a3 _ -> Unpack . unA' (mapA a) (unA' pack a2)) mapA (Product a a2) = A (\a3 -> case a3 of Comp Count a4 -> Comp (Map (Product Id a2)) (unA' (countA . first (A (Comp a))) a4) Comp ClearMarks a4 -> ClearMarks . unA' (mapA (Product a a2)) a4 Comp (Map (Product a4 a5)) a6 -> unA' (mapA (Product (a . a4) (a2 . a5))) a6 Comp (Map (Comp (Product a4 a5) a6)) a7 -> unA' (mapA (Product (a . a4) (a2 . a5) . a6)) a7 _ -> Map (Product a a2) . a3) mapA (Comp a a2) = mapA a . mapA a2 mapA Id = id -- mapA (Product a a2) = zipA . (mapA a *** mapA a2) . unzipA mapA Unpack = A (\a -> case a of Comp Unpack a -> Unpack . (Unpack . a) Comp (Map (Comp Pack a)) a2 -> Map a . a2 Comp (Map Pack) a -> a Comp ClearMarks a3 -> ClearMarks . unA' (mapA Unpack) a3 _ -> Map Unpack . a) mapA a = A (\a2 -> case a2 of Comp (Map a2) a3 -> Comp (Map (reassociate a (Right a2))) a3 Comp ClearMarks a3 -> ClearMarks . unA' (mapA a) a3 _ -> Comp (Map a) a2) scrubIds (Comp Id x) = scrubIds x scrubIds x = x instance (Category a) => Category (A a) where id = A (\a -> a) A f . A g = A (f . scrubIds . g) instance (ArrowChoice a) => Arrow (A a) where arr = liftA . arr A f *** A g = A (\a -> case a of Comp (Product a2 a3) a4 -> Product (f a2) (g a3) . a4 _ -> Product (f id) (g id) . a) first a = a *** id second a = id *** a a &&& a2 = (a *** a2) . dupA instance (ArrowChoice a) => ArrowChoice (A a) where a +++ a2 = A (\a3 -> case a3 of Comp Combine a3 -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) a3 _ -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) (Separate . a3)) left a = a +++ id right a = id +++ a instance Show (Structural a t u) where showsPrec prec (Map a) = ("Map " ++) . showParen (prec==11) (showsPrec 11 a) showsPrec _ (Comp a a2) = showsPrec 11 a . (" . "++) . showsPrec 11 a2 showsPrec prec (Product a a2) = showParen (prec>=3) (showsPrec 3 a . (" *** "++) . showsPrec 3 a2) showsPrec _ Count = ("Count"++) showsPrec _ Index = ("Index"++) showsPrec _ Split = ("Split"++) showsPrec _ ClearMarks = ("Clr"++) showsPrec _ Pack = ("Pk"++) showsPrec _ Unpack = ("Unpk"++) showsPrec _ Separate = ("Sep"++) showsPrec _ Combine = ("Comb"++) showsPrec _ Dup = ("Dup"++) showsPrec _ Fst = ("Fst"++) showsPrec _ Snd = ("Snd"++) showsPrec _ Id = ("Id"++) showsPrec _ _ = ("_"++) instance (Category a) => Category (Structural a) where id = Id (.) = Comp mirror ei = either Right Left ei splitOff :: (ArrowChoice a) => A a ((t1, t2), u) ((t1, u), (t2, u)) splitOff = first fstA &&& first sndA assoc :: (ArrowChoice a) => A a ((t, u), v) (t, (u, v)) assoc = fstA . fstA &&& (sndA . fstA &&& sndA) -- | Access one index of an array. indexA :: (ArrowChoice a) => A a (ArrC t, Int) t indexA = A (\a -> case a of Comp (Product (Map a) a2) a3 -> a . unA' indexA (Product Id a2 . a3) -- Comp (Product Zip a) a2 -> unA' ((indexA *** indexA) . splitOff) (second a . a2) Comp (Product Count a) a2 -> unA' (fstA . fstA &&& arr (\(ns, i) -> snd (mapAccumL divMod i ns)) . (sndA . fstA &&& sndA)) (Product Id a . a2) _ -> Index . a) -- | An operation analogous to 'zip', 'zipA' combines two packed arrays into a single array -- element by element. zipA :: (ArrowChoice a) => A a (ArrC t, ArrC u) (ArrC (t, u)) zipA = id &&& arr (\(ar, ar2) -> (uncurry subtract (bounds (project ar)) `min` uncurry subtract (bounds (project ar2))) + 1) >>> countA' >>> mapA' (splitOff >>> indexA *** indexA) -- | 'unzipA' and 'zipA' are inverses. unzipA :: (ArrowChoice a) => A a (ArrC (t, u)) (ArrC t, ArrC u) unzipA = mapA' fstA &&& mapA' sndA -- | Concatenation flattens out nested layers of arrays. The key operation used to implement -- is erasing marks; erasing marks throws away the structure that would delineate the -- edges of arrays; effectively flattening them into one array. The operation is divided -- into packing and erasing marks, in the hope that the packing stage will fuse with an adjacent 'unpack'. concatA :: (Category a) => A a (ArrC (ArrC t)) (ArrC t) concatA = A (\a -> case a of Comp Split a2 -> unA' fstA a2 _ -> Comp ClearMarks a) . pack forcePair (x, y) = x `seq` y `seq` (x, y) -- | Supplies an array of a repeated value paired with the index of each element. -- Arguably adjacent 'countA's should fuse; however this is hard to implement, so I -- have opted to provide a more powerful 'countA' that works on arrays of indices; -- it generates arrays of indices lexicographically ordered. countA :: (ArrowChoice a) => A a (t, [Int]) (ArrC (t, [Int])) countA = A(Comp Count) {- (\a -> case a of Comp (Product Count a2) a3 -> unA' flatCounts (unA' (second (A (Comp a2))) a3) Comp (Product (Comp a2 Count) a3) a4 -> unA' (mapA' (first (A (Comp a2))) . flatCounts . second (A (Comp a3))) a4 Comp (Product (Comp Count a2) a3) a4 -> unA' flatCounts (unA' (A (Comp a2) *** A (Comp a3)) a4) Comp (Product (Comp a2 (Comp Count a3)) a4) a5 -> unA' (mapA' (first (A (Comp a2))) . flatCounts . (A (Comp a3) *** A (Comp a4))) a5 _ -> Comp Count a)-} countA' :: (ArrowChoice a) => A a (t, Int) (ArrC (t, Int)) countA' = second (arr return) >>> countA >>> mapA' (second (arr head)) -- | Replacements for common arrow functions make fusing work better. dupA :: (Category a) => A a t (t, t) dupA = A (Dup .) fstA :: (Category a) => A a (t, u) t fstA = A (\a -> case a of Comp Dup a -> a Comp (Product Id a) a2 -> Fst . (Product Id a . a2) Comp (Product a Id) a2 -> a . unA' fstA a2 Comp (Product a a2) a3 -> a . unA' fstA (Product Id a2 . a3) -- Due to effects, cannot omit to do any operations _ -> Fst . a) sndA :: (Category a) => A a (t, u) u sndA = A (\a -> case a of Comp Dup a -> a Comp (Product a Id) a2 -> Snd . (Product a Id . a2) Comp (Product Id a) a2 -> a . unA' sndA a2 Comp (Product a a2) a3 -> a2 . unA' sndA (Product a Id . a3) _ -> Snd . a) -- Runs in O(log^2(n)) time in the number of elements. binarySearch :: (Ord t) => t -> S.Seq t -> Int binarySearch x sq = recurse 0 (S.length sq) sq where recurse off sz sq = if sz <= 1 then off else let sz' = sz `div` 2 (sq1, sq2) = S.splitAt sz' sq y S.:< _ = S.viewl sq2 in if x < y then recurse off sz' sq1 else recurse (off + sz') (sz - sz') sq2 packImpl (ArrC ar fr) = ArrC (arr_concF (\(_, i) -> let j = binarySearch i fr'' i2 = S.index fr'' j ArrC ar' _ = ar ! j in ar' ! (i-i2)) ((), sz)) fr' where fr' = S.fromList $ snd $ mapAccumL (\i (ArrC ar fr) -> let j = i + rangeSize (bounds ar) in (j, Node i (fastConcat ((return $!) . fmap (+i)) fr))) 0 $ elems ar ++ [ArrC (newArray []) S.empty] fr'' = fmap (\(Node i _) -> i) fr' _ S.:> sz = S.viewr fr'' unpackImpl (ArrC ar fr) = fastConcat (\(Node j fr2, Node k _) -> return $! ArrC (ixmap (0, k-j-1) (+j) ar) (fastConcat ((return $!) . fmap (subtract j)) fr2)) (pairUp fr) -- | An evaluator for 'Structural' arrows. A structural arrow may be obtained from an 'A' arrow -- by either 'unA' or 'unA''. -- -- Discussion of complexity bounds for various operations [these are provided c <= k]: -- -- * Cost of 'ClearMarks' is O(k/c log^3(k)) in the number of subelements k. -- -- * Cost of 'Pack' and 'Unpack' is O(k/c log^3(k) + k) in the number of subelements k. -- 'Pack' is O(n) in the worst case in the number of spine elements n. -- -- * 'Map' costs O(f) assuming unlimited capabilities where 'a' runs in O(f) time. -- -- * 'Combine' and 'Separate' are both O(1). eval0 :: (Concurrent a, Strict a, ArrowChoice a, ?seq :: Bool, ?pool :: BoxedThreadPool) => Structural a t u -> a t u eval0 Count = id &&& arr(snd>>>product) >>> arr_concF (arr (\((x, ns), i) -> (x, snd (mapAccumL divMod i ns)))) >>> arr inject eval0 Index = arr (\(ArrC ar _, i) -> ar ! i) eval0 ClearMarks = arr (\(ArrC ar fr) -> ArrC ar (fastConcat id (fmap (\(Node _ fr) -> fr) fr))) eval0 (Map a) = (arr (\(ArrC ar _) -> (ar, uncurry subtract (bounds ar) + 1)) >>> arr_concF (arr (uncurry (!)) >>> eval0 a)) &&& arr (\(ArrC _ fr) -> fr) >>> arr (uncurry ArrC) eval0 Split = undefined eval0 Pack = arr packImpl eval0 Unpack = arr (inject . newArray . toList . unpackImpl) eval0 Separate = arr (\ei -> ((,) $! either (\x -> inject $ newArray [x]) (\_ -> inject $ newArray []) ei) $! either (\_ -> inject $ newArray []) (\x -> inject $ newArray [x]) ei) eval0 Combine = arr (\(ar, ar2) -> let a1 = project ar a2 = project ar2 in if uncurry subtract (bounds (project ar)) == 0 then Left $! a1 ! 0 else Right $! a2 ! 0) eval0 (Comp a a2) = force (eval0 a) . eval0 a2 eval0 Id = id eval0 (Lift a) = a eval0 (Product a a2) = arr forcePair . force (second (eval0 a2)) . arr forcePair . first (eval0 a) eval0 Dup = arr (\x -> forcePair (x, x)) eval0 Fst = arr fst eval0 Snd = arr snd -- | Evaluates arrows. eval a = let ?seq = True in eval0 a instance (Concurrent a, Strict a, ArrowChoice a, ArrowApply a) => ArrowApply (A a) where app = first (arr (eval . unA)) >>> liftA app where ?pool = BoxedThreadPool NoPool -------------------------------- -- Examples using NDP techniques checkThreats n positions = n `elem` positions -- Check if there is a piece on the row || n `elem` zipWith (-) positions [1..] -- ... the diagonal || n `elem` zipWith (+) positions [1..] -- ... or the other diagonal checkThreats2 positions = or [ checkThreats n tl | n:tl <- tails positions ] nQueensImpl :: A (->) ((), [Int]) (ArrC [Int]) nQueensImpl = countA >>> mapA' (arr (\(_, soln) -> if checkThreats2 soln then inject (newArray []) else inject (newArray [soln]))) >>> concatA nQueens n = arr (\() -> ((), replicate n n)) >>> nQueensImpl ------------------------------- sorting :: (Ord t) => Int -> A (->) (ArrC t) (ArrC t) sorting depth | depth <= 0 = arr (inject . newArray . sort . elems . project) sorting depth = arr (\x -> if uncurry subtract (bounds (project x)) <= 0 then Left x else Right x) >>> id ||| (arr (\ar -> let x:xs = elems (project ar) (bef, aft) = partition (>> first (s *** s) >>> arr (\((bef, aft), x) -> inject (newArray (elems (project bef) ++ x : elems (project aft))))) where s = sorting (pred depth) -- Memoize the answer -- In order to make this recursive function a finite data structure, there is a depth limit -- parameter, beyond which the standard 'sort' takes over. ------------------------------- permute :: A (->) (ArrC Int) (ArrC Int) permute = arr (\ar -> (ar, [uncurry subtract (bounds (project ar)) + 1])) >>> countA >>> mapA' (second (arr head) >>> indexA) ------------------------------- dotProduct :: (Num t) => A (->) (ArrC t, ArrC t) t dotProduct = proc (v1, v2) -> do vzip <- zipA -< (v1, v2) vdots <- mapA' (arr (uncurry (*))) -< vzip returnA -< sum $ elems $ project vdots transpose' :: A (->) (ArrC (ArrC t)) (ArrC (ArrC t)) transpose' = proc m -> do firstrow <- indexA -< (m, 0) rows <- countA -< (m, [uncurry subtract (bounds (project firstrow)) + 1]) -- Build skeleton of result array rowcols <- mapA' (proc (m, [ii]) -> do v <- countA -< ((m, ii), [uncurry subtract (bounds (project m)) + 1]) mapA' (proc ((m, ii), [jj]) -> returnA -< (m, (ii, jj))) -< v) -< rows -- Build result mapA' (mapA' (proc (m, (ii, jj)) -> do v <- indexA -< (m, jj) indexA -< (v, ii))) -< rowcols