{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} -- -- Copyright (c) 2009-2011, ERICSSON AB -- All rights reserved. -- -- Redistribution and use in source and binary forms, with or without -- modification, are permitted provided that the following conditions are met: -- -- * Redistributions of source code must retain the above copyright notice, -- this list of conditions and the following disclaimer. -- * Redistributions in binary form must reproduce the above copyright -- notice, this list of conditions and the following disclaimer in the -- documentation and/or other materials provided with the distribution. -- * Neither the name of the ERICSSON AB nor the names of its contributors -- may be used to endorse or promote products derived from this software -- without specific prior written permission. -- -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -- module Feldspar.Repa where import qualified Prelude as P import Language.Syntactic.Sugar import Feldspar hiding (desugar,sugar,resugar) -- | * Shapes infixl 3 :. data Z = Z data tail :. head = tail :. head type DIM0 = Z type DIM1 = DIM0 :. Data Length type DIM2 = DIM1 :. Data Length type DIM3 = DIM2 :. Data Length class Shape sh where -- | Get the number of dimensions in a shape dim :: sh -> Int -- | The shape of an array of size zero, with a particular dimension zeroDim :: sh -- | The shape of an array with size one, with a particular dimension unitDim :: sh -- | Get the total number of elements in an array with this shape. size :: sh -> Data Length -- | Index into flat, linear, row-major representation toIndex :: sh -> sh -> Data Index -- | Inverse of `toIndex`. fromIndex :: sh -> Data Index -> sh -- | The intersection of two dimensions. intersectDim :: sh -> sh -> sh -- | Check whether an index is within a given shape. -- @inRange l u i@ checks that 'i' fits between 'l' and 'u'. inRange :: sh -> sh -> sh -> Data Bool -- | Turn a shape into a list. Used in the 'Syntactic' instance. toList :: sh -> [Data Length] -- | Reconstruct a shape. Used in the 'Syntactic' instance. toShape :: Int -> Data [Length] -> sh instance Shape Z where dim Z = 0 zeroDim = Z unitDim = Z size Z = 1 toIndex _ _ = 0 fromIndex _ _ = Z intersectDim _ _ = Z inRange Z Z Z = true toList _ = [] toShape _ _ = Z instance Shape sh => Shape (sh :. Data Length) where dim (sh :. _) = dim sh + 1 zeroDim = zeroDim :. 0 unitDim = unitDim :. 1 size (sh :. i) = size sh * i toIndex (sh1 :. sh2) (sh1' :. sh2') = toIndex sh1 sh1' * sh2 + sh2' fromIndex (ds :. d) ix = fromIndex ds (ix `quot` d) :. (ix `rem` d) intersectDim (sh1 :. n1) (sh2 :. n2) = (intersectDim sh1 sh2 :. (min n1 n2)) inRange (shL :. l) (shU :. u) (sh :. i) = l <= i && i < u && inRange shL shU sh toList (sh :. i) = i : toList sh toShape i arr = toShape (i+1) arr :. (arr ! (P.fromIntegral i)) -- | * Slices data All = All data Any sh = Any type family FullShape ss type instance FullShape Z = Z type instance FullShape (Any sh) = sh type instance FullShape (sl :. Data Length) = FullShape sl :. Data Length type instance FullShape (sl :. All) = FullShape sl :. Data Length type family SliceShape ss type instance SliceShape Z = Z type instance SliceShape (Any sh) = sh type instance SliceShape (sl :. Data Length) = SliceShape sl type instance SliceShape (sl :. All) = SliceShape sl :. Data Length class Slice ss where sliceOfFull :: ss -> FullShape ss -> SliceShape ss fullOfSlice :: ss -> SliceShape ss -> FullShape ss instance Slice Z where sliceOfFull Z Z = Z fullOfSlice Z Z = Z instance Slice (Any sh) where sliceOfFull Any sh = sh fullOfSlice Any sh = sh instance Slice sl => Slice (sl :. Data Length) where sliceOfFull (fsl :. _) (ssl :. _) = sliceOfFull fsl ssl fullOfSlice (fsl :. n) ssl = fullOfSlice fsl ssl :. n instance Slice sl => Slice (sl :. All) where sliceOfFull (fsl :. All) (ssl :. s) = sliceOfFull fsl ssl :. s fullOfSlice (fsl :. All) (ssl :. s) = fullOfSlice fsl ssl :. s -- | * Vectors data Vector sh a = Vector sh (sh -> a) type DVector sh a = Vector sh (Data a) instance (Shape sh, Syntax a) => Syntactic (Vector sh a) where type Domain (Vector sh a) = FeldDomain type Internal (Vector sh a) = ([Length],[Internal a]) desugar = desugar . freezeVector . map resugar sugar = map resugar . thawVector . sugar type instance Elem (Vector sh a) = a type instance CollIndex (Vector sh a) = sh type instance CollSize (Vector sh a) = sh instance Syntax a => Indexed (Vector sh a) where (Vector _ ixf) ! i = ixf i instance (Syntax a, Shape sh) => Sized (Vector sh a) where collSize = extent setCollSize = newExtent instance CollMap (Vector sh a) (Vector sh a) where collMap = map -- | * Fuctions -- | Store a vector in an array. fromVector :: (Shape sh, Type a) => DVector sh a -> Data [a] fromVector vec = parallel (size ext) (\ix -> vec !: fromIndex ext ix) where ext = extent vec -- | Restore a vector from an array toVector :: (Shape sh, Type a) => sh -> Data [a] -> DVector sh a toVector sh arr = Vector sh (\ix -> arr ! toIndex ix sh) freezeVector :: (Shape sh, Type a) => DVector sh a -> (Data [Length], Data [a]) freezeVector v = (shapeArr, fromVector v) where shapeArr = fromList (toList $ extent v) fromList :: Type a => [Data a] -> Data [a] fromList ls = loop 1 (parallel (value len) (const (P.head ls))) where loop i arr | i P.< len = loop (i+1) (setIx arr (value i) (ls P.!! (P.fromIntegral i))) | otherwise = arr len = P.fromIntegral $ P.length ls thawVector :: (Shape sh, Type a) => (Data [Length], Data [a]) -> DVector sh a thawVector (l,arr) = toVector (toShape 0 l) arr -- | Store a vector in memory. Use this function instead of 'force' if -- possible as it is both much more safe and faster. memorize :: (Shape sh, Type a) => DVector sh a -> DVector sh a memorize vec = toVector (extent vec) (fromVector vec) -- | The shape and size of the vector extent :: Vector sh a -> sh extent (Vector sh _) = sh -- | Change the extent of the vector to the supplied value. If the supplied -- extent will contain more elements than the old extent, the new elements -- will have undefined value. newExtent :: sh -> Vector sh a -> Vector sh a newExtent sh (Vector _ ixf) = Vector sh ixf -- | Change shape and transform elements of a vector. This function is the -- most general way of manipulating a vector. traverse :: (Shape sh, Shape sh') => Vector sh a -> (sh -> sh') -> ((sh -> a) -> sh' -> a') -> Vector sh' a' traverse (Vector sh ixf) shf elemf = Vector (shf sh) (elemf ixf) -- | Duplicates part of a vector along a new dimension. replicate :: (Slice sl, Shape (FullShape sl) ,Shape (SliceShape sl)) => sl -> Vector (SliceShape sl) a -> Vector (FullShape sl) a replicate sl vec = backpermute (fullOfSlice sl (extent vec)) (sliceOfFull sl) vec -- | Extracts a slice from a vector. slice :: (Slice sl ,Shape (FullShape sl) ,Shape (SliceShape sl)) => Vector (FullShape sl) a -> sl -> Vector (SliceShape sl) a slice vec sl = backpermute (sliceOfFull sl (extent vec)) (fullOfSlice sl) vec -- | Change the shape of a vector. This function is potentially unsafe, the -- new shape need to have fewer or equal number of elements compared to -- the old shape. reshape :: (Shape sh, Shape sh') => sh -> Vector sh' a -> Vector sh a reshape sh' (Vector sh ixf) = Vector sh' (ixf . fromIndex sh . toIndex sh') -- | A scalar (zero dimensional) vector unit :: a -> Vector Z a unit a = Vector Z (const a) -- | Index into a vector (!:) :: (Shape sh) => Vector sh a -> sh -> a (Vector _ ixf) !: ix = ixf ix -- | Extract the diagonal of a two dimensional vector diagonal :: Vector DIM2 a -> Vector DIM1 a diagonal vec = backpermute (Z :. width) (\ (_ :. x) -> Z :. x :. x) vec where Z :. _ :. width = extent vec -- | Change the shape of a vector. backpermute :: (Shape sh, Shape sh') => sh' -> (sh' -> sh) -> Vector sh a -> Vector sh' a backpermute sh perm vec = traverse vec (const sh) (. perm) -- | Map a function on all the elements of a vector map :: (a -> b) -> Vector sh a -> Vector sh b map f (Vector sh ixf) = Vector sh (f . ixf) -- | Combines the elements of two vectors. The size of the resulting vector -- will be the intersection of the two argument vectors. zip :: (Shape sh) => Vector sh a -> Vector sh b -> Vector sh (a,b) zip = zipWith (\a b -> (a,b)) -- | Combines the elements of two vectors pointwise using a function. -- The size of the resulting vector will be the intersection of the -- two argument vectors. zipWith :: (Shape sh) => (a -> b -> c) -> Vector sh a -> Vector sh b -> Vector sh c zipWith f arr1 arr2 = Vector (intersectDim (extent arr1) (extent arr2)) (\ix -> f (arr1 !: ix) (arr2 !: ix)) -- | Reduce a vector along its last dimension fold :: (Shape sh, Syntax a) => (a -> a -> a) -> a -> Vector (sh :. Data Length) a -> Vector sh a fold f x vec = Vector sh ixf where sh :. n = extent vec ixf i = forLoop n x (\ix s -> f s (vec !: (i :. ix))) -- Here's another version of fold which has a little bit more freedom -- when it comes to choosing the initial element when folding -- | A generalization of 'fold' which allows for different initial -- values when starting to fold. fold' :: (Shape sh, Syntax a) => (a -> a -> a) -> Vector sh a -> Vector (sh :. Data Length) a -> Vector sh a fold' f x vec = Vector sh ixf where sh :. n = extent vec ixf i = forLoop n (x!:i) (\ix s -> f s (vec !: (i :. ix))) -- | Summing a vector along its last dimension sum :: (Shape sh, Type a, Numeric a) => DVector (sh :. Data Length) a -> DVector sh a sum = fold (+) 0 -- | Enumerating a vector (...) :: Data Index -> Data Index -> DVector DIM1 Index from ... to = Vector (Z :. (to - from + 1)) (\(Z :. ix) -> ix + from) -- This one should generalize to arbitrary shapes -- Laplace stencil :: DVector DIM2 Float -> DVector DIM2 Float stencil vec = traverse vec id update where _ :. height :. width = extent vec update get d@(sh :. i :. j) = isBoundary i j ? get d $ ( get (sh :. (i-1) :. j) + get (sh :. i :. (j-1)) + get (sh :. (i+1) :. j) + get (sh :. i :. (j+1))) / 4 isBoundary i j = (i == 0) || (i >= width - 1) || (j == 0) || (j >= height - 1) laplace :: Data Length -> DVector DIM2 Float -> DVector DIM2 Float laplace steps vec = toVector (extent vec) $ forLoop steps (fromVector vec) $ const $ fromVector . stencil . toVector (extent vec) -- Matrix Multiplication transpose2D :: Vector DIM2 e -> Vector DIM2 e transpose2D vec = backpermute new_extent swp vec where swp (Z :. i :. j) = Z :. j :. i new_extent = swp (extent vec) -- | Matrix multiplication mmMult :: (Type e, Numeric e) => DVector DIM2 e -> DVector DIM2 e -> DVector DIM2 e mmMult vA vB = sum (zipWith (*) vaRepl vbRepl) where vaRepl = replicate (Z :. All :. colsB :. All) vA vbRepl = replicate (Z :. rowsA :. All :. All) vB (Z :. _ :. rowsA) = extent vA (Z :. colsB :. _ ) = extent vB -- One dimensional vectors, meant to help transitioning from -- the old vector library mapDIM1 :: (Data Index -> Data Index) -> DIM1 -> DIM1 mapDIM1 ixmap (Z :. i) = (Z :. ixmap i) indexed :: Data Length -> (Data Index -> a) -> Vector DIM1 a indexed l idxFun = Vector (Z :. l) (\ (Z :. i) -> idxFun i) length :: Vector DIM1 a -> Data Length length (Vector (Z :. l) _) = l -------------------------------------------------------------------------------- -- * Operations on one dimensional vectors -------------------------------------------------------------------------------- -- | Change the length of the vector to the supplied value. If the supplied -- length is greater than the old length, the new elements will have undefined -- value. The resulting vector has only one segment. newLen :: Syntax a => Data Length -> Vector DIM1 a -> Vector DIM1 a newLen l (Vector (Z :. _) ixf) = Vector (Z :. l) ixf (++) :: Syntax a => Vector DIM1 a -> Vector DIM1 a -> Vector DIM1 a Vector (Z :. l1) ixf1 ++ Vector (Z :. l2) ixf2 = Vector (Z :. l1 + l2) (\ (Z :. i) -> i < l1 ? ixf1 (Z :. i) $ ixf2 (Z :. (i + l1))) infixr 5 ++ take :: Data Length -> Vector DIM1 a -> Vector DIM1 a take n (Vector (Z :. l) ixf) = Vector (Z :. (min n l)) ixf drop :: Data Length -> Vector DIM1 a -> Vector DIM1 a drop n (Vector (Z :. l) ixf) = Vector (Z :. (l - min l n)) (ixf . mapDIM1 (+ n)) splitAt :: Data Index -> Vector DIM1 a -> (Vector DIM1 a, Vector DIM1 a) splitAt n vec = (take n vec, drop n vec) head :: Syntax a => Vector DIM1 a -> a head = (! (Z :. 0)) last :: Syntax a => Vector DIM1 a -> a last vec = vec ! (Z :. (length vec - 1)) tail :: Vector DIM1 a -> Vector DIM1 a tail = drop 1 init :: Vector DIM1 a -> Vector DIM1 a init vec = take (length vec - 1) vec tails :: Vector DIM1 a -> Vector DIM1 (Vector DIM1 a) tails vec = indexed (length vec + 1) (\n -> drop n vec) inits :: Vector DIM1 a -> Vector DIM1 (Vector DIM1 a) inits vec = indexed (length vec + 1) (\n -> take n vec) inits1 :: Vector DIM1 a -> Vector DIM1 (Vector DIM1 a) inits1 = tail . inits -- | Permute a vector permute :: (Data Length -> Data Index -> Data Index) -> (Vector DIM1 a -> Vector DIM1 a) permute perm (Vector s@(Z :. l) ixf) = Vector s (ixf . mapDIM1 (perm l)) reverse :: Syntax a => Vector DIM1 a -> Vector DIM1 a reverse = permute $ \l i -> l-1-i rotateVecL :: Syntax a => Data Index -> Vector DIM1 a -> Vector DIM1 a rotateVecL ix = permute $ \l i -> (i + ix) `rem` l rotateVecR :: Syntax a => Data Index -> Vector DIM1 a -> Vector DIM1 a rotateVecR ix = reverse . rotateVecL ix . reverse replicate1 :: Data Length -> a -> Vector DIM1 a replicate1 n a = Vector (Z :. n) (const a) -- | @enumFromTo m n@: Enumerate the indexes from @m@ to @n@ -- -- In order to enumerate a different type, use 'i2n', e.g: -- -- > map i2n (10...20) :: Vector1 Word8 enumFromTo :: Data Index -> Data Index -> Vector DIM1 (Data Index) enumFromTo 1 n = indexed n (+1) enumFromTo m n = indexed l (+m) where l = (n Vector DIM1 (Data Index) enumFrom = flip enumFromTo (value maxBound) unzip :: Vector DIM1 (a,b) -> (Vector DIM1 a, Vector DIM1 b) unzip (Vector l ixf) = (Vector l (fst.ixf), Vector l (snd.ixf)) -- | Corresponds to the standard 'foldl'. foldl :: (Syntax a) => (a -> b -> a) -> a -> Vector DIM1 b -> a foldl f x (Vector (Z :. l) ixf) = forLoop l x $ \ix s -> f s (ixf (Z :. ix)) -- | Corresponds to the standard 'foldl1'. fold1 :: Syntax a => (a -> a -> a) -> Vector DIM1 a -> a fold1 f a = foldl f (head a) (tail a) sum1 :: (Syntax a, Num a) => Vector DIM1 a -> a sum1 = foldl (+) 0 maximum :: Ord a => Vector DIM1 (Data a) -> Data a maximum = fold1 max minimum :: Ord a => Vector DIM1 (Data a) -> Data a minimum = fold1 min -- | Scalar product of two vectors scalarProd :: (Syntax a, Num a) => Vector DIM1 a -> Vector DIM1 a -> a scalarProd a b = sum1 (zipWith (*) a b) -------------------------------------------------------------------------------- -- Misc. -------------------------------------------------------------------------------- tVec :: Patch a a -> Patch (Vector DIM1 a) (Vector DIM1 a) tVec _ = id tVec1 :: Patch a a -> Patch (Vector DIM1 (Data a)) (Vector DIM1 (Data a)) tVec1 _ = id tVec2 :: Patch a a -> Patch (Vector DIM2 (Data a)) (Vector DIM2 (Data a)) tVec2 _ = id