{-# LANGUAGE GADTs #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Representation.Shape -- Copyright : [2008..2020] The Accelerate Team -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Representation.Shape where import Data.Array.Accelerate.Error import Data.Array.Accelerate.Type import Data.Array.Accelerate.Representation.Type import Language.Haskell.TH import Prelude hiding ( zip ) import GHC.Base ( quotInt, remInt ) -- | Shape and index representations as nested pairs -- data ShapeR sh where ShapeRz :: ShapeR () ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int) -- | Nicely format a shape as a string -- showShape :: ShapeR sh -> sh -> String showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList shr -- Synonyms for common shape types -- type DIM0 = () type DIM1 = ((), Int) type DIM2 = (((), Int), Int) type DIM3 = ((((), Int), Int), Int) dim0 :: ShapeR DIM0 dim0 = ShapeRz dim1 :: ShapeR DIM1 dim1 = ShapeRsnoc dim0 dim2 :: ShapeR DIM2 dim2 = ShapeRsnoc dim1 dim3 :: ShapeR DIM3 dim3 = ShapeRsnoc dim2 -- | Number of dimensions of a /shape/ or /index/ (>= 0) -- rank :: ShapeR sh -> Int rank ShapeRz = 0 rank (ShapeRsnoc shr) = rank shr + 1 -- | Total number of elements in an array of the given shape -- size :: ShapeR sh -> sh -> Int size ShapeRz () = 1 size (ShapeRsnoc shr) (sh, sz) | sz <= 0 = 0 | otherwise = size shr sh * sz -- | The empty shape -- empty :: ShapeR sh -> sh empty ShapeRz = () empty (ShapeRsnoc shr) = (empty shr, 0) -- | Yield the intersection of two shapes -- intersect :: ShapeR sh -> sh -> sh -> sh intersect = zip min -- | Yield the union of two shapes -- union :: ShapeR sh -> sh -> sh -> sh union = zip max zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh zip _ ShapeRz () () = () zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b) eq :: ShapeR sh -> sh -> sh -> Bool eq ShapeRz () () = True eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh' -- | Map a multi-dimensional index into one in a linear, row-major -- representation of the array (first argument is the /shape/, second -- argument is the /index/). -- toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int toIndex ShapeRz () () = 0 toIndex (ShapeRsnoc shr) (sh, sz) (ix, i) = indexCheck i sz $ toIndex shr sh ix * sz + i -- | Inverse of 'toIndex' -- fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh fromIndex ShapeRz () _ = () fromIndex (ShapeRsnoc shr) (sh, sz) i = (fromIndex shr sh (i `quotInt` sz), r) -- If we assume that the index is in range, there is no point in computing -- the remainder for the highest dimension since i < sz must hold. -- where r = case shr of -- Check if rank of shr is 0 ShapeRz -> indexCheck i sz i _ -> i `remInt` sz -- | Iterate through the entire shape, applying the function in the second -- argument; third argument combines results and fourth is an initial value -- that is combined with the results; the index space is traversed in -- row-major order -- iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a iter ShapeRz () f _ _ = f () iter (ShapeRsnoc shr) (sh, sz) f c r = iter shr sh (\ix -> iter' (ix,0)) c r where iter' (ix,i) | i >= sz = r | otherwise = f (ix,i) `c` iter' (ix,i+1) -- | Variant of 'iter' without an initial value -- iter1 :: HasCallStack => ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a iter1 ShapeRz () f _ = f () iter1 (ShapeRsnoc _ ) (_, 0) _ _ = boundsError "empty iteration space" iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c where iter1' (ix,i) | i == sz-1 = f (ix,i) | otherwise = f (ix,i) `c` iter1' (ix,i+1) -- Operations to facilitate conversion with IArray -- | Convert a minpoint-maxpoint index into a shape -- rangeToShape :: ShapeR sh -> (sh, sh) -> sh rangeToShape ShapeRz ((), ()) = () rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1) -- | Converse of 'rangeToShape' -- shapeToRange :: ShapeR sh -> sh -> (sh, sh) shapeToRange ShapeRz () = ((), ()) shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1)) -- | Convert a shape or index into its list of dimensions -- shapeToList :: ShapeR sh -> sh -> [Int] shapeToList ShapeRz () = [] shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh -- | Convert a list of dimensions into a shape -- listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh listToShape shr ds = case listToShape' shr ds of Just sh -> sh Nothing -> error "listToShape: unable to convert list to a shape at the specified type" -- | Attempt to convert a list of dimensions into a shape -- listToShape' :: ShapeR sh -> [Int] -> Maybe sh listToShape' ShapeRz [] = Just () listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs listToShape' _ _ = Nothing shapeType :: ShapeR sh -> TypeR sh shapeType ShapeRz = TupRunit shapeType (ShapeRsnoc shr) = shapeType shr `TupRpair` TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt))) rnfShape :: ShapeR sh -> sh -> () rnfShape ShapeRz () = () rnfShape (ShapeRsnoc shr) (sh, s) = s `seq` rnfShape shr sh rnfShapeR :: ShapeR sh -> () rnfShapeR ShapeRz = () rnfShapeR (ShapeRsnoc shr) = rnfShapeR shr liftShapeR :: ShapeR sh -> Q (TExp (ShapeR sh)) liftShapeR ShapeRz = [|| ShapeRz ||] liftShapeR (ShapeRsnoc sh) = [|| ShapeRsnoc $$(liftShapeR sh) ||]