module Feldspar.Repa where
import qualified Prelude as P
import Language.Syntactic.Sugar
import Feldspar hiding (desugar,sugar,resugar)
infixl 3 :.
data Z = Z
data tail :. head = tail :. head
type DIM0 = Z
type DIM1 = DIM0 :. Data Length
type DIM2 = DIM1 :. Data Length
type DIM3 = DIM2 :. Data Length
class Shape sh where
dim :: sh -> Int
zeroDim :: sh
unitDim :: sh
size :: sh -> Data Length
toIndex :: sh -> sh -> Data Index
fromIndex :: sh -> Data Index -> sh
intersectDim :: sh -> sh -> sh
inRange :: sh -> sh -> sh -> Data Bool
toList :: sh -> [Data Length]
toShape :: Int -> Data [Length] -> sh
instance Shape Z where
dim Z = 0
zeroDim = Z
unitDim = Z
size Z = 1
toIndex _ _ = 0
fromIndex _ _ = Z
intersectDim _ _ = Z
inRange Z Z Z = true
toList _ = []
toShape _ _ = Z
instance Shape sh => Shape (sh :. Data Length) where
dim (sh :. _) = dim sh + 1
zeroDim = zeroDim :. 0
unitDim = unitDim :. 1
size (sh :. i) = size sh * i
toIndex (sh1 :. sh2) (sh1' :. sh2') = toIndex sh1 sh1' * sh2 + sh2'
fromIndex (ds :. d) ix
= fromIndex ds (ix `quot` d) :. (ix `rem` d)
intersectDim (sh1 :. n1) (sh2 :. n2)
= (intersectDim sh1 sh2 :. (min n1 n2))
inRange (shL :. l) (shU :. u) (sh :. i)
= l <= i && i < u && inRange shL shU sh
toList (sh :. i) = i : toList sh
toShape i arr
= toShape (i+1) arr :. (arr ! (P.fromIntegral i))
data All = All
data Any sh = Any
type family FullShape ss
type instance FullShape Z = Z
type instance FullShape (Any sh) = sh
type instance FullShape (sl :. Data Length) = FullShape sl :. Data Length
type instance FullShape (sl :. All) = FullShape sl :. Data Length
type family SliceShape ss
type instance SliceShape Z = Z
type instance SliceShape (Any sh) = sh
type instance SliceShape (sl :. Data Length) = SliceShape sl
type instance SliceShape (sl :. All) = SliceShape sl :. Data Length
class Slice ss where
sliceOfFull :: ss -> FullShape ss -> SliceShape ss
fullOfSlice :: ss -> SliceShape ss -> FullShape ss
instance Slice Z where
sliceOfFull Z Z = Z
fullOfSlice Z Z = Z
instance Slice (Any sh) where
sliceOfFull Any sh = sh
fullOfSlice Any sh = sh
instance Slice sl => Slice (sl :. Data Length) where
sliceOfFull (fsl :. _) (ssl :. _) = sliceOfFull fsl ssl
fullOfSlice (fsl :. n) ssl = fullOfSlice fsl ssl :. n
instance Slice sl => Slice (sl :. All) where
sliceOfFull (fsl :. All) (ssl :. s)
= sliceOfFull fsl ssl :. s
fullOfSlice (fsl :. All) (ssl :. s)
= fullOfSlice fsl ssl :. s
data Vector sh a = Vector sh (sh -> a)
type DVector sh a = Vector sh (Data a)
instance (Shape sh, Syntax a) => Syntactic (Vector sh a)
where
type Domain (Vector sh a) = FeldDomain
type Internal (Vector sh a) = ([Length],[Internal a])
desugar = desugar . freezeVector . map resugar
sugar = map resugar . thawVector . sugar
type instance Elem (Vector sh a) = a
type instance CollIndex (Vector sh a) = sh
type instance CollSize (Vector sh a) = sh
instance Syntax a => Indexed (Vector sh a)
where
(Vector _ ixf) ! i = ixf i
instance (Syntax a, Shape sh) => Sized (Vector sh a)
where
collSize = extent
setCollSize = newExtent
instance CollMap (Vector sh a) (Vector sh a)
where
collMap = map
fromVector :: (Shape sh, Type a) => DVector sh a -> Data [a]
fromVector vec = parallel (size ext) (\ix -> vec !: fromIndex ext ix)
where ext = extent vec
toVector :: (Shape sh, Type a) => sh -> Data [a] -> DVector sh a
toVector sh arr = Vector sh (\ix -> arr ! toIndex ix sh)
freezeVector :: (Shape sh, Type a) => DVector sh a -> (Data [Length], Data [a])
freezeVector v = (shapeArr, fromVector v)
where shapeArr = fromList (toList $ extent v)
fromList :: Type a => [Data a] -> Data [a]
fromList ls = loop 1 (parallel (value len) (const (P.head ls)))
where loop i arr
| i P.< len = loop (i+1) (setIx arr (value i) (ls P.!! (P.fromIntegral i)))
| otherwise = arr
len = P.fromIntegral $ P.length ls
thawVector :: (Shape sh, Type a) => (Data [Length], Data [a]) -> DVector sh a
thawVector (l,arr) = toVector (toShape 0 l) arr
memorize :: (Shape sh, Type a) => DVector sh a -> DVector sh a
memorize vec = toVector (extent vec) (fromVector vec)
extent :: Vector sh a -> sh
extent (Vector sh _) = sh
newExtent :: sh -> Vector sh a -> Vector sh a
newExtent sh (Vector _ ixf) = Vector sh ixf
traverse :: (Shape sh, Shape sh') =>
Vector sh a -> (sh -> sh') -> ((sh -> a) -> sh' -> a')
-> Vector sh' a'
traverse (Vector sh ixf) shf elemf
= Vector (shf sh) (elemf ixf)
replicate :: (Slice sl, Shape (FullShape sl)
,Shape (SliceShape sl))
=> sl -> Vector (SliceShape sl) a
-> Vector (FullShape sl) a
replicate sl vec
= backpermute (fullOfSlice sl (extent vec))
(sliceOfFull sl) vec
slice :: (Slice sl
,Shape (FullShape sl)
,Shape (SliceShape sl))
=> Vector (FullShape sl) a
-> sl -> Vector (SliceShape sl) a
slice vec sl
= backpermute (sliceOfFull sl (extent vec))
(fullOfSlice sl) vec
reshape :: (Shape sh, Shape sh') => sh -> Vector sh' a -> Vector sh a
reshape sh' (Vector sh ixf)
= Vector sh' (ixf . fromIndex sh . toIndex sh')
unit :: a -> Vector Z a
unit a = Vector Z (const a)
(!:) :: (Shape sh) => Vector sh a -> sh -> a
(Vector _ ixf) !: ix = ixf ix
diagonal :: Vector DIM2 a -> Vector DIM1 a
diagonal vec = backpermute (Z :. width) (\ (_ :. x) -> Z :. x :. x) vec
where Z :. _ :. width = extent vec
backpermute :: (Shape sh, Shape sh') =>
sh' -> (sh' -> sh) -> Vector sh a -> Vector sh' a
backpermute sh perm vec = traverse vec (const sh) (. perm)
map :: (a -> b) -> Vector sh a -> Vector sh b
map f (Vector sh ixf) = Vector sh (f . ixf)
zip :: (Shape sh) => Vector sh a -> Vector sh b -> Vector sh (a,b)
zip = zipWith (\a b -> (a,b))
zipWith :: (Shape sh) =>
(a -> b -> c) -> Vector sh a -> Vector sh b -> Vector sh c
zipWith f arr1 arr2 = Vector (intersectDim (extent arr1) (extent arr2))
(\ix -> f (arr1 !: ix) (arr2 !: ix))
fold :: (Shape sh, Syntax a) =>
(a -> a -> a)
-> a
-> Vector (sh :. Data Length) a
-> Vector sh a
fold f x vec = Vector sh ixf
where sh :. n = extent vec
ixf i = forLoop n x (\ix s -> f s (vec !: (i :. ix)))
fold' :: (Shape sh, Syntax a) =>
(a -> a -> a)
-> Vector sh a
-> Vector (sh :. Data Length) a
-> Vector sh a
fold' f x vec = Vector sh ixf
where sh :. n = extent vec
ixf i = forLoop n (x!:i) (\ix s -> f s (vec !: (i :. ix)))
sum :: (Shape sh, Type a, Numeric a) =>
DVector (sh :. Data Length) a -> DVector sh a
sum = fold (+) 0
(...) :: Data Index -> Data Index -> DVector DIM1 Index
from ... to = Vector (Z :. (to from + 1)) (\(Z :. ix) -> ix + from)
stencil :: DVector DIM2 Float -> DVector DIM2 Float
stencil vec
= traverse vec id update
where
_ :. height :. width = extent vec
update get d@(sh :. i :. j)
= isBoundary i j
? get d
$ ( get (sh :. (i1) :. j)
+ get (sh :. i :. (j1))
+ get (sh :. (i+1) :. j)
+ get (sh :. i :. (j+1))) / 4
isBoundary i j
= (i == 0) || (i >= width 1)
|| (j == 0) || (j >= height 1)
laplace :: Data Length -> DVector DIM2 Float -> DVector DIM2 Float
laplace steps vec = toVector (extent vec) $
forLoop steps (fromVector vec) $
const $ fromVector . stencil . toVector (extent vec)
transpose2D :: Vector DIM2 e -> Vector DIM2 e
transpose2D vec
= backpermute new_extent swp vec
where swp (Z :. i :. j) = Z :. j :. i
new_extent = swp (extent vec)
mmMult :: (Type e, Numeric e) =>
DVector DIM2 e -> DVector DIM2 e
-> DVector DIM2 e
mmMult vA vB
= sum (zipWith (*) vaRepl vbRepl)
where
vaRepl = replicate (Z :. All :. colsB :. All) vA
vbRepl = replicate (Z :. rowsA :. All :. All) vB
(Z :. _ :. rowsA) = extent vA
(Z :. colsB :. _ ) = extent vB
mapDIM1 :: (Data Index -> Data Index) -> DIM1 -> DIM1
mapDIM1 ixmap (Z :. i) = (Z :. ixmap i)
indexed :: Data Length -> (Data Index -> a) -> Vector DIM1 a
indexed l idxFun = Vector (Z :. l) (\ (Z :. i) -> idxFun i)
length :: Vector DIM1 a -> Data Length
length (Vector (Z :. l) _) = l
newLen :: Syntax a => Data Length -> Vector DIM1 a -> Vector DIM1 a
newLen l (Vector (Z :. _) ixf) = Vector (Z :. l) ixf
(++) :: Syntax a => Vector DIM1 a -> Vector DIM1 a -> Vector DIM1 a
Vector (Z :. l1) ixf1 ++ Vector (Z :. l2) ixf2
= Vector (Z :. l1 + l2) (\ (Z :. i) -> i < l1 ? ixf1 (Z :. i)
$ ixf2 (Z :. (i + l1)))
infixr 5 ++
take :: Data Length -> Vector DIM1 a -> Vector DIM1 a
take n (Vector (Z :. l) ixf) = Vector (Z :. (min n l)) ixf
drop :: Data Length -> Vector DIM1 a -> Vector DIM1 a
drop n (Vector (Z :. l) ixf) = Vector (Z :. (l min l n)) (ixf . mapDIM1 (+ n))
splitAt :: Data Index -> Vector DIM1 a -> (Vector DIM1 a, Vector DIM1 a)
splitAt n vec = (take n vec, drop n vec)
head :: Syntax a => Vector DIM1 a -> a
head = (! (Z :. 0))
last :: Syntax a => Vector DIM1 a -> a
last vec = vec ! (Z :. (length vec 1))
tail :: Vector DIM1 a -> Vector DIM1 a
tail = drop 1
init :: Vector DIM1 a -> Vector DIM1 a
init vec = take (length vec 1) vec
tails :: Vector DIM1 a -> Vector DIM1 (Vector DIM1 a)
tails vec = indexed (length vec + 1) (\n -> drop n vec)
inits :: Vector DIM1 a -> Vector DIM1 (Vector DIM1 a)
inits vec = indexed (length vec + 1) (\n -> take n vec)
inits1 :: Vector DIM1 a -> Vector DIM1 (Vector DIM1 a)
inits1 = tail . inits
permute :: (Data Length -> Data Index -> Data Index) -> (Vector DIM1 a -> Vector DIM1 a)
permute perm (Vector s@(Z :. l) ixf) = Vector s (ixf . mapDIM1 (perm l))
reverse :: Syntax a => Vector DIM1 a -> Vector DIM1 a
reverse = permute $ \l i -> l1i
rotateVecL :: Syntax a => Data Index -> Vector DIM1 a -> Vector DIM1 a
rotateVecL ix = permute $ \l i -> (i + ix) `rem` l
rotateVecR :: Syntax a => Data Index -> Vector DIM1 a -> Vector DIM1 a
rotateVecR ix = reverse . rotateVecL ix . reverse
replicate1 :: Data Length -> a -> Vector DIM1 a
replicate1 n a = Vector (Z :. n) (const a)
enumFromTo :: Data Index -> Data Index -> Vector DIM1 (Data Index)
enumFromTo 1 n = indexed n (+1)
enumFromTo m n = indexed l (+m)
where
l = (n<m) ? 0 $ (nm+1)
enumFrom :: Data Index -> Vector DIM1 (Data Index)
enumFrom = flip enumFromTo (value maxBound)
unzip :: Vector DIM1 (a,b) -> (Vector DIM1 a, Vector DIM1 b)
unzip (Vector l ixf) = (Vector l (fst.ixf), Vector l (snd.ixf))
foldl :: (Syntax a) => (a -> b -> a) -> a -> Vector DIM1 b -> a
foldl f x (Vector (Z :. l) ixf) = forLoop l x $ \ix s -> f s (ixf (Z :. ix))
fold1 :: Syntax a => (a -> a -> a) -> Vector DIM1 a -> a
fold1 f a = foldl f (head a) (tail a)
sum1 :: (Syntax a, Num a) => Vector DIM1 a -> a
sum1 = foldl (+) 0
maximum :: Ord a => Vector DIM1 (Data a) -> Data a
maximum = fold1 max
minimum :: Ord a => Vector DIM1 (Data a) -> Data a
minimum = fold1 min
scalarProd :: (Syntax a, Num a) => Vector DIM1 a -> Vector DIM1 a -> a
scalarProd a b = sum1 (zipWith (*) a b)
tVec :: Patch a a -> Patch (Vector DIM1 a) (Vector DIM1 a)
tVec _ = id
tVec1 :: Patch a a -> Patch (Vector DIM1 (Data a)) (Vector DIM1 (Data a))
tVec1 _ = id
tVec2 :: Patch a a -> Patch (Vector DIM2 (Data a)) (Vector DIM2 (Data a))
tVec2 _ = id