{-# LANGUAGE GADTs #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Representation.Shape
where
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Type
import Language.Haskell.TH
import Prelude hiding ( zip )
import GHC.Base ( quotInt, remInt )
data ShapeR sh where
ShapeRz :: ShapeR ()
ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int)
showShape :: ShapeR sh -> sh -> String
showShape shr = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList shr
type DIM0 = ()
type DIM1 = ((), Int)
type DIM2 = (((), Int), Int)
type DIM3 = ((((), Int), Int), Int)
dim0 :: ShapeR DIM0
dim0 = ShapeRz
dim1 :: ShapeR DIM1
dim1 = ShapeRsnoc dim0
dim2 :: ShapeR DIM2
dim2 = ShapeRsnoc dim1
dim3 :: ShapeR DIM3
dim3 = ShapeRsnoc dim2
rank :: ShapeR sh -> Int
rank ShapeRz = 0
rank (ShapeRsnoc shr) = rank shr + 1
size :: ShapeR sh -> sh -> Int
size ShapeRz () = 1
size (ShapeRsnoc shr) (sh, sz)
| sz <= 0 = 0
| otherwise = size shr sh * sz
empty :: ShapeR sh -> sh
empty ShapeRz = ()
empty (ShapeRsnoc shr) = (empty shr, 0)
intersect :: ShapeR sh -> sh -> sh -> sh
intersect = zip min
union :: ShapeR sh -> sh -> sh -> sh
union = zip max
zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip _ ShapeRz () () = ()
zip f (ShapeRsnoc shr) (as, a) (bs, b) = (zip f shr as bs, f a b)
eq :: ShapeR sh -> sh -> sh -> Bool
eq ShapeRz () () = True
eq (ShapeRsnoc shr) (sh, i) (sh', i') = i == i' && eq shr sh sh'
toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex ShapeRz () () = 0
toIndex (ShapeRsnoc shr) (sh, sz) (ix, i)
= indexCheck i sz
$ toIndex shr sh ix * sz + i
fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex ShapeRz () _ = ()
fromIndex (ShapeRsnoc shr) (sh, sz) i
= (fromIndex shr sh (i `quotInt` sz), r)
where
r = case shr of
ShapeRz -> indexCheck i sz i
_ -> i `remInt` sz
iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter ShapeRz () f _ _ = f ()
iter (ShapeRsnoc shr) (sh, sz) f c r = iter shr sh (\ix -> iter' (ix,0)) c r
where
iter' (ix,i) | i >= sz = r
| otherwise = f (ix,i) `c` iter' (ix,i+1)
iter1 :: HasCallStack => ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 ShapeRz () f _ = f ()
iter1 (ShapeRsnoc _ ) (_, 0) _ _ = boundsError "empty iteration space"
iter1 (ShapeRsnoc shr) (sh, sz) f c = iter1 shr sh (\ix -> iter1' (ix,0)) c
where
iter1' (ix,i) | i == sz-1 = f (ix,i)
| otherwise = f (ix,i) `c` iter1' (ix,i+1)
rangeToShape :: ShapeR sh -> (sh, sh) -> sh
rangeToShape ShapeRz ((), ()) = ()
rangeToShape (ShapeRsnoc shr) ((sh1, sz1), (sh2, sz2)) = (rangeToShape shr (sh1, sh2), sz2 - sz1 + 1)
shapeToRange :: ShapeR sh -> sh -> (sh, sh)
shapeToRange ShapeRz () = ((), ())
shapeToRange (ShapeRsnoc shr) (sh, sz) = let (low, high) = shapeToRange shr sh in ((low, 0), (high, sz - 1))
shapeToList :: ShapeR sh -> sh -> [Int]
shapeToList ShapeRz () = []
shapeToList (ShapeRsnoc shr) (sh,sz) = sz : shapeToList shr sh
listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh
listToShape shr ds =
case listToShape' shr ds of
Just sh -> sh
Nothing -> error "listToShape: unable to convert list to a shape at the specified type"
listToShape' :: ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeRz [] = Just ()
listToShape' (ShapeRsnoc shr) (x:xs) = (, x) <$> listToShape' shr xs
listToShape' _ _ = Nothing
shapeType :: ShapeR sh -> TypeR sh
shapeType ShapeRz = TupRunit
shapeType (ShapeRsnoc shr) =
shapeType shr
`TupRpair`
TupRsingle (SingleScalarType (NumSingleType (IntegralNumType TypeInt)))
rnfShape :: ShapeR sh -> sh -> ()
rnfShape ShapeRz () = ()
rnfShape (ShapeRsnoc shr) (sh, s) = s `seq` rnfShape shr sh
rnfShapeR :: ShapeR sh -> ()
rnfShapeR ShapeRz = ()
rnfShapeR (ShapeRsnoc shr) = rnfShapeR shr
liftShapeR :: ShapeR sh -> Q (TExp (ShapeR sh))
liftShapeR ShapeRz = [|| ShapeRz ||]
liftShapeR (ShapeRsnoc sh) = [|| ShapeRsnoc $$(liftShapeR sh) ||]