{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Sugar.Shape
where
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import qualified Data.Array.Accelerate.Representation.Shape as R
import qualified Data.Array.Accelerate.Representation.Slice as R
import Data.Kind
import GHC.Generics
type DIM0 = Z
type DIM1 = DIM0 :. Int
type DIM2 = DIM1 :. Int
type DIM3 = DIM2 :. Int
type DIM4 = DIM3 :. Int
type DIM5 = DIM4 :. Int
type DIM6 = DIM5 :. Int
type DIM7 = DIM6 :. Int
type DIM8 = DIM7 :. Int
type DIM9 = DIM8 :. Int
data Z = Z
deriving (Show, Eq, Generic, Elt)
infixl 3 :.
data tail :. head = !tail :. !head
deriving (Eq, Generic)
instance (Show sh, Show sz) => Show (sh :. sz) where
showsPrec p (sh :. sz) =
showsPrec p sh . showString " :. " . showsPrec p sz
data All = All
deriving (Show, Eq, Generic, Elt)
data Any sh = Any
deriving (Show, Eq, Generic)
data Split = Split
deriving (Show, Eq)
data Divide sh = Divide
deriving (Show, Eq)
rank :: forall sh. Shape sh => Int
rank = R.rank (shapeR @sh)
size :: forall sh. Shape sh => sh -> Int
size = R.size (shapeR @sh) . fromElt
empty :: forall sh. Shape sh => sh
empty = toElt $ R.empty (shapeR @sh)
intersect :: forall sh. Shape sh => sh -> sh -> sh
intersect x y = toElt $ R.intersect (shapeR @sh) (fromElt x) (fromElt y)
union :: forall sh. Shape sh => sh -> sh -> sh
union x y = toElt $ R.union (shapeR @sh) (fromElt x) (fromElt y)
toIndex :: forall sh. Shape sh
=> sh
-> sh
-> Int
toIndex sh ix = R.toIndex (shapeR @sh) (fromElt sh) (fromElt ix)
fromIndex :: forall sh. Shape sh
=> sh
-> Int
-> sh
fromIndex sh = toElt . R.fromIndex (shapeR @sh) (fromElt sh)
iter :: forall sh e. Shape sh
=> sh
-> (sh -> e)
-> (e -> e -> e)
-> e
-> e
iter sh f = R.iter (shapeR @sh) (fromElt sh) (f . toElt)
iter1 :: forall sh e. Shape sh
=> sh
-> (sh -> e)
-> (e -> e -> e)
-> e
iter1 sh f = R.iter1 (shapeR @sh) (fromElt sh) (f . toElt)
rangeToShape :: forall sh. Shape sh => (sh, sh) -> sh
rangeToShape (u, v) = toElt $ R.rangeToShape (shapeR @sh) (fromElt u, fromElt v)
shapeToRange :: forall sh. Shape sh => sh -> (sh, sh)
shapeToRange ix =
let (u, v) = R.shapeToRange (shapeR @sh) (fromElt ix)
in (toElt u, toElt v)
shapeToList :: forall sh. Shape sh => sh -> [Int]
shapeToList = R.shapeToList (shapeR @sh) . fromElt
listToShape :: forall sh. Shape sh => [Int] -> sh
listToShape = toElt . R.listToShape (shapeR @sh)
listToShape' :: forall sh. Shape sh => [Int] -> Maybe sh
listToShape' = fmap toElt . R.listToShape' (shapeR @sh)
showShape :: Shape sh => sh -> String
showShape = foldr (\sh str -> str ++ " :. " ++ show sh) "Z" . shapeToList
sliceShape
:: forall slix co sl dim. (Shape sl, Shape dim)
=> R.SliceIndex slix (EltR sl) co (EltR dim)
-> dim
-> sl
sliceShape slix = toElt . R.sliceShape slix . fromElt
enumSlices :: forall slix co sl dim. (Elt slix, Elt dim)
=> R.SliceIndex (EltR slix) sl co (EltR dim)
-> dim
-> [slix]
enumSlices slix = map toElt . R.enumSlices slix . fromElt
class (Elt sh, Elt (Any sh), FullShape sh ~ sh, CoSliceShape sh ~ sh, SliceShape sh ~ Z)
=> Shape sh where
shapeR :: R.ShapeR (EltR sh)
sliceAnyIndex :: R.SliceIndex (EltR (Any sh)) (EltR sh) () (EltR sh)
sliceNoneIndex :: R.SliceIndex (EltR sh) () (EltR sh) (EltR sh)
class (Elt sl, Shape (SliceShape sl), Shape (CoSliceShape sl), Shape (FullShape sl))
=> Slice sl where
type SliceShape sl :: Type
type CoSliceShape sl :: Type
type FullShape sl :: Type
sliceIndex :: R.SliceIndex (EltR sl)
(EltR (SliceShape sl))
(EltR (CoSliceShape sl))
(EltR (FullShape sl))
class (Slice (DivisionSlice sl)) => Division sl where
type DivisionSlice sl :: Type
slicesIndex :: slix ~ DivisionSlice sl
=> R.SliceIndex (EltR slix)
(EltR (SliceShape slix))
(EltR (CoSliceShape slix))
(EltR (FullShape slix))
instance (Elt t, Elt h) => Elt (t :. h) where
type EltR (t :. h) = (EltR t, EltR h)
eltR = TupRpair (eltR @t) (eltR @h)
tagsR = [TagRpair t h | t <- tagsR @t, h <- tagsR @h]
fromElt (t:.h) = (fromElt t, fromElt h)
toElt (t, h) = toElt t :. toElt h
instance Elt (Any Z)
instance Shape sh => Elt (Any (sh :. Int)) where
type EltR (Any (sh :. Int)) = (EltR (Any sh), ())
eltR = TupRpair (eltR @(Any sh)) TupRunit
tagsR = [TagRpair t TagRunit | t <- tagsR @(Any sh)]
fromElt _ = (fromElt (Any :: Any sh), ())
toElt _ = Any
instance Shape Z where
shapeR = R.ShapeRz
sliceAnyIndex = R.SliceNil
sliceNoneIndex = R.SliceNil
instance Shape sh => Shape (sh:.Int) where
shapeR = R.ShapeRsnoc (shapeR @sh)
sliceAnyIndex = R.SliceAll (sliceAnyIndex @sh)
sliceNoneIndex = R.SliceFixed (sliceNoneIndex @sh)
instance Slice Z where
type SliceShape Z = Z
type CoSliceShape Z = Z
type FullShape Z = Z
sliceIndex = R.SliceNil
instance Slice sl => Slice (sl:.All) where
type SliceShape (sl:.All) = SliceShape sl :. Int
type CoSliceShape (sl:.All) = CoSliceShape sl
type FullShape (sl:.All) = FullShape sl :. Int
sliceIndex = R.SliceAll (sliceIndex @sl)
instance Slice sl => Slice (sl:.Int) where
type SliceShape (sl:.Int) = SliceShape sl
type CoSliceShape (sl:.Int) = CoSliceShape sl :. Int
type FullShape (sl:.Int) = FullShape sl :. Int
sliceIndex = R.SliceFixed (sliceIndex @sl)
instance Shape sh => Slice (Any sh) where
type SliceShape (Any sh) = sh
type CoSliceShape (Any sh) = Z
type FullShape (Any sh) = sh
sliceIndex = sliceAnyIndex @sh
instance Division Z where
type DivisionSlice Z = Z
slicesIndex = R.SliceNil
instance Division sl => Division (sl:.All) where
type DivisionSlice (sl:.All) = DivisionSlice sl :. All
slicesIndex = R.SliceAll (slicesIndex @sl)
instance Division sl => Division (sl:.Split) where
type DivisionSlice (sl:.Split) = DivisionSlice sl :. Int
slicesIndex = R.SliceFixed (slicesIndex @sl)
instance Shape sh => Division (Any sh) where
type DivisionSlice (Any sh) = Any sh
slicesIndex = sliceAnyIndex @sh
instance (Shape sh, Slice sh) => Division (Divide sh) where
type DivisionSlice (Divide sh) = sh
slicesIndex = sliceNoneIndex @sh