{-# LANGUAGE TypeOperators, FlexibleInstances, ScopedTypeVariables #-}
module Data.Array.Repa.Index
(
Z (..)
, (:.) (..)
, DIM0, DIM1, DIM2, DIM3, DIM4, DIM5
, ix1, ix2, ix3, ix4, ix5)
where
import Data.Array.Repa.Shape
import GHC.Base (quotInt, remInt)
stage = "Data.Array.Repa.Index"
data Z = Z
deriving (Show, Read, Eq, Ord)
infixl 3 :.
data tail :. head
= !tail :. !head
deriving (Show, Read, Eq, Ord)
type DIM0 = Z
type DIM1 = DIM0 :. Int
type DIM2 = DIM1 :. Int
type DIM3 = DIM2 :. Int
type DIM4 = DIM3 :. Int
type DIM5 = DIM4 :. Int
ix1 :: Int -> DIM1
ix1 x = Z :. x
{-# INLINE ix1 #-}
ix2 :: Int -> Int -> DIM2
ix2 y x = Z :. y :. x
{-# INLINE ix2 #-}
ix3 :: Int -> Int -> Int -> DIM3
ix3 z y x = Z :. z :. y :. x
{-# INLINE ix3 #-}
ix4 :: Int -> Int -> Int -> Int -> DIM4
ix4 a z y x = Z :. a :. z :. y :. x
{-# INLINE ix4 #-}
ix5 :: Int -> Int -> Int -> Int -> Int -> DIM5
ix5 b a z y x = Z :. b :. a :. z :. y :. x
{-# INLINE ix5 #-}
instance Shape Z where
{-# INLINE [1] rank #-}
rank _ = 0
{-# INLINE [1] zeroDim #-}
zeroDim = Z
{-# INLINE [1] unitDim #-}
unitDim = Z
{-# INLINE [1] intersectDim #-}
intersectDim _ _ = Z
{-# INLINE [1] addDim #-}
addDim _ _ = Z
{-# INLINE [1] size #-}
size _ = 1
{-# INLINE [1] sizeIsValid #-}
sizeIsValid _ = True
{-# INLINE [1] toIndex #-}
toIndex _ _ = 0
{-# INLINE [1] fromIndex #-}
fromIndex _ _ = Z
{-# INLINE [1] inShapeRange #-}
inShapeRange Z Z Z = True
{-# NOINLINE listOfShape #-}
listOfShape _ = []
{-# NOINLINE shapeOfList #-}
shapeOfList [] = Z
shapeOfList _ = error $ stage ++ ".fromList: non-empty list when converting to Z."
{-# INLINE deepSeq #-}
deepSeq Z x = x
instance Shape sh => Shape (sh :. Int) where
{-# INLINE [1] rank #-}
rank (sh :. _)
= rank sh + 1
{-# INLINE [1] zeroDim #-}
zeroDim = zeroDim :. 0
{-# INLINE [1] unitDim #-}
unitDim = unitDim :. 1
{-# INLINE [1] intersectDim #-}
intersectDim (sh1 :. n1) (sh2 :. n2)
= (intersectDim sh1 sh2 :. (min n1 n2))
{-# INLINE [1] addDim #-}
addDim (sh1 :. n1) (sh2 :. n2)
= addDim sh1 sh2 :. (n1 + n2)
{-# INLINE [1] size #-}
size (sh1 :. n)
= size sh1 * n
{-# INLINE [1] sizeIsValid #-}
sizeIsValid (sh1 :. n)
| size sh1 > 0
= n <= maxBound `div` size sh1
| otherwise
= False
{-# INLINE [1] toIndex #-}
toIndex (sh1 :. sh2) (sh1' :. sh2')
= toIndex sh1 sh1' * sh2 + sh2'
{-# INLINE [1] fromIndex #-}
fromIndex (ds :. d) n
= fromIndex ds (n `quotInt` d) :. r
where
r | rank ds == 0 = n
| otherwise = n `remInt` d
{-# INLINE [1] inShapeRange #-}
inShapeRange (zs :. z) (sh1 :. n1) (sh2 :. n2)
= (n2 >= z) && (n2 < n1) && (inShapeRange zs sh1 sh2)
{-# NOINLINE listOfShape #-}
listOfShape (sh :. n)
= n : listOfShape sh
{-# NOINLINE shapeOfList #-}
shapeOfList xx
= case xx of
[] -> error $ stage ++ ".toList: empty list when converting to (_ :. Int)"
x:xs -> shapeOfList xs :. x
{-# INLINE deepSeq #-}
deepSeq (sh :. n) x = deepSeq sh (n `seq` x)