{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Array.Representation (
Shape(..), Slice(..), SliceIndex(..),
sliceShape, enumSlices,
) where
import Data.Array.Accelerate.Error
import GHC.Base ( quotInt, remInt )
class (Eq sh, Slice sh) => Shape sh where
rank :: sh -> Int
size :: sh -> Int
empty :: sh
intersect :: sh -> sh -> sh
union :: sh -> sh -> sh
ignore :: sh
toIndex :: sh -> sh -> Int
fromIndex :: sh -> Int -> sh
iter :: sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter1 :: sh -> (sh -> a) -> (a -> a -> a) -> a
rangeToShape :: (sh, sh) -> sh
shapeToRange :: sh -> (sh, sh)
shapeToList :: sh -> [Int]
listToShape :: [Int] -> sh
instance Shape () where
rank _ = 0
empty = ()
ignore = ()
() `intersect` () = ()
() `union` () = ()
size () = 1
toIndex () () = 0
fromIndex () _ = ()
iter () f _ _ = f ()
iter1 () f _ = f ()
rangeToShape ((), ()) = ()
shapeToRange () = ((), ())
shapeToList () = []
listToShape [] = ()
listToShape _ = $internalError "listToShape" "non-empty list when converting to unit"
instance Shape sh => Shape (sh, Int) where
rank _ = rank (undefined :: sh) + 1
empty = (empty, 0)
ignore = (ignore, -1)
(sh1, sz1) `intersect` (sh2, sz2) = (sh1 `intersect` sh2, sz1 `min` sz2)
(sh1, sz1) `union` (sh2, sz2) = (sh1 `union` sh2, sz1 `max` sz2)
size (sh, sz) = $boundsCheck "size" "negative shape dimension" (sz >= 0)
$ size sh * sz
toIndex (sh, sz) (ix, i) = $indexCheck "toIndex" i sz
$ toIndex sh ix * sz + i
fromIndex (sh, sz) i = (fromIndex sh (i `quotInt` sz), r)
where
r | rank sh == 0 = $indexCheck "fromIndex" i sz i
| otherwise = i `remInt` sz
iter (sh, sz) f c r = iter sh (\ix -> iter' (ix,0)) c r
where
iter' (ix,i) | i >= sz = r
| otherwise = f (ix,i) `c` iter' (ix,i+1)
iter1 (_, 0) _ _ = $boundsError "iter1" "empty iteration space"
iter1 (sh, sz) f c = iter1 sh (\ix -> iter1' (ix,0)) c
where
iter1' (ix,i) | i == sz-1 = f (ix,i)
| otherwise = f (ix,i) `c` iter1' (ix,i+1)
rangeToShape ((sh1, sz1), (sh2, sz2))
= (rangeToShape (sh1, sh2), sz2 - sz1 + 1)
shapeToRange (sh, sz)
= let (low, high) = shapeToRange sh
in
((low, 0), (high, sz - 1))
shapeToList (sh,sz) = sz : shapeToList sh
listToShape [] = $internalError "listToShape" "empty list when converting to Ix"
listToShape (x:xs) = (listToShape xs,x)
class Slice sl where
type SliceShape sl
type CoSliceShape sl
type FullShape sl
sliceIndex :: sl -> SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
instance Slice () where
type SliceShape () = ()
type CoSliceShape () = ()
type FullShape () = ()
sliceIndex _ = SliceNil
instance Slice sl => Slice (sl, ()) where
type SliceShape (sl, ()) = (SliceShape sl, Int)
type CoSliceShape (sl, ()) = CoSliceShape sl
type FullShape (sl, ()) = (FullShape sl, Int)
sliceIndex _ = SliceAll (sliceIndex (undefined::sl))
instance Slice sl => Slice (sl, Int) where
type SliceShape (sl, Int) = SliceShape sl
type CoSliceShape (sl, Int) = (CoSliceShape sl, Int)
type FullShape (sl, Int) = (FullShape sl, Int)
sliceIndex _ = SliceFixed (sliceIndex (undefined::sl))
data SliceIndex ix slice coSlice sliceDim where
SliceNil :: SliceIndex () () () ()
SliceAll ::
SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
SliceFixed ::
SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
instance Show (SliceIndex ix slice coSlice sliceDim) where
show SliceNil = "SliceNil"
show (SliceAll rest) = "SliceAll (" ++ show rest ++ ")"
show (SliceFixed rest) = "SliceFixed (" ++ show rest ++ ")"
sliceShape :: forall slix co sl dim.
SliceIndex slix sl co dim
-> dim
-> sl
sliceShape SliceNil () = ()
sliceShape (SliceAll sl) (sh, n) = (sliceShape sl sh, n)
sliceShape (SliceFixed sl) (sh, _) = sliceShape sl sh
enumSlices :: forall slix co sl dim.
SliceIndex slix sl co dim
-> dim
-> [slix]
enumSlices SliceNil () = [()]
enumSlices (SliceAll sl) (sh, _) = [ (sh', ()) | sh' <- enumSlices sl sh]
enumSlices (SliceFixed sl) (sh, n) = [ (sh', i) | sh' <- enumSlices sl sh, i <- [0..n-1]]