module Control.CUtils.DataParallel (Equal(Equal),
ArrC, newArray, inject, project,
Structural, A, unA, mapA', liftA, countA, countA', splitOff, assoc, indexA, zipA, unzipA, concatA, dupA, fstA, sndA, eval,
nQueens, sorting, permute, dotProduct, transpose') where
import qualified Data.Sequence as S
import Data.Array
import Data.List
import Data.Monoid (Any(Any))
import Data.Foldable (toList)
import Control.Parallel
import Control.Parallel.Strategies
import Control.Category
import Control.Arrow
import Control.Monad.Writer (Writer, tell, runWriter)
import Control.Monad.Identity
import Control.Monad
import Control.CUtils.Conc
import Control.CUtils.StrictArrow
import Prelude hiding (id, (.))
data Tree t = Node !t !(S.Seq (Tree t))
instance Functor Tree where
fmap f (Node x sq) =
let sq' = fastConcat (return . fmap f) sq in
(toList sq' `using` parList rseq) `pseq` Node (f x) sq'
data ArrC t = ArrC !(Array Int t) !(S.Seq (Tree Int)) deriving Functor
newArray ls = listArray (0, length ls 1) ls
inject ar = ArrC (ixmap (0, uncurry subtract (bounds ar)) (subtract (fst (bounds ar))) ar) (S.fromList [Node 0 S.empty, Node (uncurry subtract (bounds ar) + 1) S.empty])
project (ArrC ar _) = ar
instance Show (t -> u) where
showsPrec _ _ = ("<FUNCTION>"++)
data Structural a t u where
Map :: Structural a t u -> Structural a (ArrC t) (ArrC u)
Comp :: Structural a u v -> Structural a t u -> Structural a t v
Id :: Structural a t t
Product :: Structural a t u -> Structural a v w -> Structural a (t, v) (u, w)
Lift :: a t u -> Structural a t u
Count :: Structural a (t, [Int]) (ArrC (t, [Int]))
Index :: Structural a (ArrC t, Int) t
Split :: Structural a (ArrC t, Array Int Int) (ArrC t)
ClearMarks :: Structural a (ArrC t) (ArrC t)
Separate :: Structural a (Either t u) (ArrC t, ArrC u)
Combine :: Structural a (ArrC t, ArrC u) (Either t u)
Pack :: Structural a (ArrC (ArrC t)) (ArrC t)
Unpack :: Structural a (ArrC t) (ArrC (ArrC t))
Dup :: Structural a t (t, t)
Fst :: Structural a (t, u) t
Snd :: Structural a (t, u) u
data A a t u = A (forall v. Structural a v t -> Structural a v u)
sHead sq = case S.viewl sq of x S.:< _ -> x
sTail sq = case S.viewl sq of
_ S.:< xs -> xs
S.EmptyL -> S.empty
sLast sq = case S.viewr sq of _ S.:> x -> x
fromTo :: Int -> Int -> S.Seq t -> S.Seq t
fromTo n1 n2 sq =
let (sq1, _) = S.splitAt n2 sq in
snd$S.splitAt n1 sq1
pairUp sq = S.zip sq (sTail sq)
fastConcat :: (t -> S.Seq u) -> S.Seq t -> S.Seq u
fastConcat f sq = case S.length sq of
0 -> S.empty
1 -> f (sHead sq)
n -> let
(sq1, sq2) = S.splitAt (n `div` 2) sq
cc1 = fastConcat f sq1
cc2 = fastConcat f sq2 in
(cc1 `par` cc2) `pseq` (cc1 S.>< cc2)
data Equal t u where
Equal :: Equal t t
reassociate :: (Category a) => Structural a u v -> Either (Equal t u) (Structural a t u) -> Structural a t v
reassociate (Comp a Id) = reassociate a
reassociate (Comp a a2) = reassociate a . Right . reassociate a2
reassociate a = either (\Equal -> a) (a.)
unA (A f) = f id
unA' :: A a u v -> Structural a t u -> Structural a t v
unA' (A f) = f
mapA' :: (ArrowChoice a) => A a t u -> A a (ArrC t) (ArrC u)
mapA' (A f) = mapA (f id)
liftA :: (Category a) => a t u -> A a t u
liftA a = A (\a2 -> case a2 of
Comp (Lift a2) a3 -> Comp (Lift (a . a2)) a3
_ -> Lift a . a2)
pack :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
pack = A (\a -> case a of
Comp (Map (Comp (Map a) a2)) a3 -> Map a . unA' pack (Map a2 . a3)
Comp (Map (Map a)) a2 -> Map a . unA' pack a2
Comp (Map (Comp Pack a)) a2 -> unA' pack (unA' pack (Map a . a2))
Comp (Map Pack) a2 -> unA' pack (unA' pack a2)
Comp Unpack a2 -> a2
_ -> Pack . a)
flatten :: Structural a t u -> Bool
flatten (Comp a a2) = flatten a || flatten a2
flatten Id = False
flatten Unpack = False
flatten Pack = False
flatten Separate = False
flatten Combine = False
flatten _ = True
mapA :: (ArrowChoice a) => Structural a t u -> A a (ArrC t) (ArrC u)
mapA (Map a) | flatten a = A (\a2 -> case a2 of
Comp Unpack a3 -> Unpack . unA' (mapA a) a3
Comp Split a3 -> Split . unA' (first (mapA' (mapA a))) a3
Comp ClearMarks a3 -> ClearMarks . unA' (mapA' (mapA a)) a3
_ -> Unpack . unA' (mapA a) (unA' pack a2))
mapA (Product a a2) = A (\a3 -> case a3 of
Comp Count a4 -> Comp (Map (Product Id a2)) (unA' (countA . first (A (Comp a))) a4)
Comp ClearMarks a4 -> ClearMarks . unA' (mapA (Product a a2)) a4
Comp (Map (Product a4 a5)) a6 -> unA' (mapA (Product (a . a4) (a2 . a5))) a6
Comp (Map (Comp (Product a4 a5) a6)) a7 -> unA' (mapA (Product (a . a4) (a2 . a5) . a6)) a7
_ -> Map (Product a a2) . a3)
mapA (Comp a a2) = mapA a . mapA a2
mapA Id = id
mapA Unpack = A (\a -> case a of
Comp Unpack a -> Unpack . (Unpack . a)
Comp (Map (Comp Pack a)) a2 -> Map a . a2
Comp (Map Pack) a -> a
Comp ClearMarks a3 -> ClearMarks . unA' (mapA Unpack) a3
_ -> Map Unpack . a)
mapA a = A (\a2 -> case a2 of
Comp (Map a2) a3 -> Comp (Map (reassociate a (Right a2))) a3
Comp ClearMarks a3 -> ClearMarks . unA' (mapA a) a3
_ -> Comp (Map a) a2)
scrubIds (Comp Id x) = scrubIds x
scrubIds x = x
instance (Category a) => Category (A a) where
id = A (\a -> a)
A f . A g = A (f . scrubIds . g)
instance (ArrowChoice a) => Arrow (A a) where
arr = liftA . arr
A f *** A g = A (\a -> case a of
Comp (Product a2 a3) a4 -> Product (f a2) (g a3) . a4
_ -> Product (f id) (g id) . a)
first a = a *** id
second a = id *** a
a &&& a2 = (a *** a2) . dupA
instance (ArrowChoice a) => ArrowChoice (A a) where
a +++ a2 = A (\a3 -> case a3 of
Comp Combine a3 -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) a3
_ -> Combine . unA' (mapA (unA a) *** mapA (unA a2)) (Separate . a3))
left a = a +++ id
right a = id +++ a
instance Show (Structural a t u) where
showsPrec prec (Map a) = ("Map " ++) . showParen (prec==11) (showsPrec 11 a)
showsPrec _ (Comp a a2) = showsPrec 11 a . (" . "++) . showsPrec 11 a2
showsPrec prec (Product a a2) = showParen (prec>=3) (showsPrec 3 a . (" *** "++) . showsPrec 3 a2)
showsPrec _ Count = ("Count"++)
showsPrec _ Index = ("Index"++)
showsPrec _ Split = ("Split"++)
showsPrec _ ClearMarks = ("Clr"++)
showsPrec _ Pack = ("Pk"++)
showsPrec _ Unpack = ("Unpk"++)
showsPrec _ Separate = ("Sep"++)
showsPrec _ Combine = ("Comb"++)
showsPrec _ Dup = ("Dup"++)
showsPrec _ Fst = ("Fst"++)
showsPrec _ Snd = ("Snd"++)
showsPrec _ Id = ("Id"++)
showsPrec _ _ = ("_"++)
instance (Category a) => Category (Structural a) where
id = Id
(.) = Comp
mirror ei = either Right Left ei
splitOff :: (ArrowChoice a) => A a ((t1, t2), u) ((t1, u), (t2, u))
splitOff = first fstA &&& first sndA
assoc :: (ArrowChoice a) => A a ((t, u), v) (t, (u, v))
assoc = fstA . fstA &&& (sndA . fstA &&& sndA)
indexA :: (ArrowChoice a) => A a (ArrC t, Int) t
indexA = A (\a -> case a of
Comp (Product (Map a) a2) a3 -> a . unA' indexA (Product Id a2 . a3)
Comp (Product Count a) a2 -> unA' (fstA . fstA &&& arr (\(ns, i) -> snd (mapAccumL divMod i ns)) . (sndA . fstA &&& sndA)) (Product Id a . a2)
_ -> Index . a)
zipA :: (ArrowChoice a) => A a (ArrC t, ArrC u) (ArrC (t, u))
zipA = id &&& arr (\(ar, ar2) -> (uncurry subtract (bounds (project ar)) `min` uncurry subtract (bounds (project ar2))) + 1)
>>> countA'
>>> mapA' (splitOff >>> indexA *** indexA)
unzipA :: (ArrowChoice a) => A a (ArrC (t, u)) (ArrC t, ArrC u)
unzipA = mapA' fstA &&& mapA' sndA
concatA :: (Category a) => A a (ArrC (ArrC t)) (ArrC t)
concatA = A (\a -> case a of
Comp Split a2 -> unA' fstA a2
_ -> Comp ClearMarks a) . pack
forcePair (x, y) = x `seq` y `seq` (x, y)
countA :: (ArrowChoice a) => A a (t, [Int]) (ArrC (t, [Int]))
countA = A(Comp Count)
countA' :: (ArrowChoice a) => A a (t, Int) (ArrC (t, Int))
countA' = second (arr return) >>> countA >>> mapA' (second (arr head))
dupA :: (Category a) => A a t (t, t)
dupA = A (Dup .)
fstA :: (Category a) => A a (t, u) t
fstA = A (\a -> case a of
Comp Dup a -> a
Comp (Product Id a) a2 -> Fst . (Product Id a . a2)
Comp (Product a Id) a2 -> a . unA' fstA a2
Comp (Product a a2) a3 -> a . unA' fstA (Product Id a2 . a3)
_ -> Fst . a)
sndA :: (Category a) => A a (t, u) u
sndA = A (\a -> case a of
Comp Dup a -> a
Comp (Product a Id) a2 -> Snd . (Product a Id . a2)
Comp (Product Id a) a2 -> a . unA' sndA a2
Comp (Product a a2) a3 -> a2 . unA' sndA (Product a Id . a3)
_ -> Snd . a)
binarySearch :: (Ord t) => t -> S.Seq t -> Int
binarySearch x sq = recurse 0 (S.length sq) sq where
recurse off sz sq = if sz <= 1 then
off
else let
sz' = sz `div` 2
(sq1, sq2) = S.splitAt sz' sq
y S.:< _ = S.viewl sq2 in
if x < y then
recurse off sz' sq1
else
recurse (off + sz') (sz sz') sq2
packImpl (ArrC ar fr) = ArrC
(arr_concF (\(_, i) -> let
j = binarySearch i fr''
i2 = S.index fr'' j
ArrC ar' _ = ar ! j in
ar' ! (ii2))
((), sz))
fr'
where
fr' = S.fromList $ snd $ mapAccumL (\i (ArrC ar fr) -> let j = i + rangeSize (bounds ar) in (j, Node i (fastConcat ((return $!) . fmap (+i)) fr))) 0 $ elems ar ++ [ArrC (newArray []) S.empty]
fr'' = fmap (\(Node i _) -> i) fr'
_ S.:> sz = S.viewr fr''
unpackImpl (ArrC ar fr) = fastConcat
(\(Node j fr2, Node k _) -> return $! ArrC (ixmap (0, kj1) (+j) ar) (fastConcat ((return $!) . fmap (subtract j)) fr2))
(pairUp fr)
eval0 :: (Concurrent a, Strict a, ArrowChoice a, ?seq :: Bool, ?pool :: BoxedThreadPool) => Structural a t u -> a t u
eval0 Count = id &&& arr(snd>>>product) >>> arr_concF (arr (\((x, ns), i) -> (x, snd (mapAccumL divMod i ns)))) >>> arr inject
eval0 Index = arr (\(ArrC ar _, i) -> ar ! i)
eval0 ClearMarks =
arr (\(ArrC ar fr) ->
ArrC ar (fastConcat id (fmap (\(Node _ fr) -> fr) fr)))
eval0 (Map a) = (arr (\(ArrC ar _) -> (ar, uncurry subtract (bounds ar) + 1)) >>> arr_concF (arr (uncurry (!)) >>> eval0 a)) &&& arr (\(ArrC _ fr) -> fr) >>> arr (uncurry ArrC)
eval0 Split = undefined
eval0 Pack = arr packImpl
eval0 Unpack = arr (inject . newArray . toList . unpackImpl)
eval0 Separate = arr (\ei -> ((,) $! either (\x -> inject $ newArray [x]) (\_ -> inject $ newArray []) ei) $! either (\_ -> inject $ newArray []) (\x -> inject $ newArray [x]) ei)
eval0 Combine = arr (\(ar, ar2) -> let
a1 = project ar
a2 = project ar2 in
if uncurry subtract (bounds (project ar)) == 0 then Left $! a1 ! 0 else Right $! a2 ! 0)
eval0 (Comp a a2) = force (eval0 a) . eval0 a2
eval0 Id = id
eval0 (Lift a) = a
eval0 (Product a a2) = arr forcePair . force (second (eval0 a2)) . arr forcePair . first (eval0 a)
eval0 Dup = arr (\x -> forcePair (x, x))
eval0 Fst = arr fst
eval0 Snd = arr snd
eval a = let ?seq = True in eval0 a
instance (Concurrent a, Strict a, ArrowChoice a, ArrowApply a) => ArrowApply (A a) where
app = first (arr (eval . unA)) >>> liftA app where
?pool = BoxedThreadPool NoPool
checkThreats n positions = n `elem` positions
|| n `elem` zipWith () positions [1..]
|| n `elem` zipWith (+) positions [1..]
checkThreats2 positions = or [ checkThreats n tl | n:tl <- tails positions ]
nQueensImpl :: A (->) ((), [Int]) (ArrC [Int])
nQueensImpl = countA >>> mapA' (arr (\(_, soln) -> if checkThreats2 soln then inject (newArray []) else inject (newArray [soln])))
>>> concatA
nQueens n = arr (\() -> ((), replicate n n)) >>> nQueensImpl
sorting :: (Ord t) => Int -> A (->) (ArrC t) (ArrC t)
sorting depth | depth <= 0 = arr (inject . newArray . sort . elems . project)
sorting depth = arr (\x -> if uncurry subtract (bounds (project x)) <= 0 then Left x else Right x)
>>> id
||| (arr (\ar -> let
x:xs = elems (project ar)
(bef, aft) = partition (<x) xs in
((inject (newArray bef), inject (newArray aft)), x))
>>> first (s *** s)
>>> arr (\((bef, aft), x) -> inject (newArray (elems (project bef) ++ x : elems (project aft)))))
where s = sorting (pred depth)
permute :: A (->) (ArrC Int) (ArrC Int)
permute = arr (\ar -> (ar, [uncurry subtract (bounds (project ar)) + 1])) >>> countA >>> mapA' (second (arr head) >>> indexA)
dotProduct :: (Num t) => A (->) (ArrC t, ArrC t) t
dotProduct = proc (v1, v2) -> do
vzip <- zipA -< (v1, v2)
vdots <- mapA' (arr (uncurry (*))) -< vzip
returnA -< sum $ elems $ project vdots
transpose' :: A (->) (ArrC (ArrC t)) (ArrC (ArrC t))
transpose' = proc m -> do
firstrow <- indexA -< (m, 0)
rows <- countA -< (m, [uncurry subtract (bounds (project firstrow)) + 1])
rowcols <- mapA' (proc (m, [ii]) -> do
v <- countA -< ((m, ii), [uncurry subtract (bounds (project m)) + 1])
mapA' (proc ((m, ii), [jj]) -> returnA -< (m, (ii, jj))) -< v) -< rows
mapA' (mapA' (proc (m, (ii, jj)) -> do
v <- indexA -< (m, jj)
indexA -< (v, ii)))
-< rowcols