{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Futhark.Representation.ExplicitMemory.IndexFunction
( IxFun(..)
, index
, iota
, offsetIndex
, strideIndex
, permute
, rotate
, reshape
, slice
, rebase
, repeat
, isContiguous
, shape
, rank
, linearWithOffset
, rearrangeWithOffset
, isDirect
, isLinear
, substituteInIxFun
)
where
import Prelude hiding (mod, repeat)
import Data.List hiding (repeat)
import qualified Data.List.NonEmpty as NE
import Data.List.NonEmpty (NonEmpty(..))
import Data.Function (on)
import Data.Maybe (isJust)
import Control.Monad.Identity
import Control.Monad.Writer
import qualified Data.Map.Strict as M
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Representation.AST.Syntax
(ShapeChange, DimChange(..), DimIndex(..), Slice, unitSlice, dimFix)
import Futhark.Representation.AST.Attributes
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Futhark.Analysis.PrimExp.Convert
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data Monotonicity = Inc | Dec | Unknown
deriving (Show, Eq)
data LMADDim num = LMADDim { ldStride :: num
, ldRotate :: num
, ldShape :: num
, ldPerm :: Int
, ldMon :: Monotonicity
}
deriving (Show, Eq)
data LMAD num = LMAD { lmadOffset :: num
, lmadDims :: [LMADDim num]
}
deriving (Show, Eq)
data IxFun num = IxFun { ixfunLMADs :: NonEmpty (LMAD num)
, base :: Shape num
, ixfunContig :: Bool
}
deriving (Show, Eq)
instance Pretty Monotonicity where
ppr = text . show
instance Pretty num => Pretty (LMAD num) where
ppr (LMAD offset dims) =
braces $ semisep [ text "offset: " <> oneLine (ppr offset)
, text "strides: " <> p ldStride
, text "rotates: " <> p ldRotate
, text "shape: " <> p ldShape
, text "permutation: " <> p ldPerm
, text "monotonicity: " <> p ldMon
]
where p f = oneLine $ brackets $ commasep $ map (ppr . f) dims
instance Pretty num => Pretty (IxFun num) where
ppr (IxFun lmads oshp cg) =
braces $ semisep [ text "base: " <> brackets (commasep $ map ppr oshp)
, text "contiguous: " <> text (show cg)
, text "LMADs: " <> brackets (commasep $ NE.toList $ NE.map ppr lmads)
]
instance Substitute num => Substitute (LMAD num) where
substituteNames substs = fmap $ substituteNames substs
instance Substitute num => Substitute (IxFun num) where
substituteNames substs = fmap $ substituteNames substs
instance Substitute num => Rename (LMAD num) where
rename = substituteRename
instance Substitute num => Rename (IxFun num) where
rename = substituteRename
instance FreeIn num => FreeIn (LMAD num) where
freeIn = foldMap freeIn
instance FreeIn num => FreeIn (IxFun num) where
freeIn = foldMap freeIn
instance Functor LMAD where
fmap f = runIdentity . traverse (return . f)
instance Functor IxFun where
fmap f = runIdentity . traverse (return . f)
instance Foldable LMAD where
foldMap f = execWriter . traverse (tell . f)
instance Foldable IxFun where
foldMap f = execWriter . traverse (tell . f)
instance Traversable LMAD where
traverse f (LMAD offset dims) =
LMAD <$> f offset <*> traverse f' dims
where f' (LMADDim s r n p m) =
LMADDim <$> f s <*> f r <*> f n <*> pure p <*> pure m
instance Traversable IxFun where
traverse f (IxFun lmads oshp cg) =
IxFun <$> traverse (traverse f) lmads <*> traverse f oshp <*> pure cg
(++@) :: [a] -> NonEmpty a -> NonEmpty a
es ++@ (ne :| nes) = case es of
e : es' -> e :| es' ++ [ne] ++ nes
[] -> ne :| nes
(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(x :| xs) @++@ (y :| ys) = x :| xs ++ [y] ++ ys
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Inc = Dec
invertMonotonicity Dec = Inc
invertMonotonicity Unknown = Unknown
lmadPermutation :: LMAD num -> Permutation
lmadPermutation = map ldPerm . lmadDims
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation perm lmad =
lmad { lmadDims = zipWith (\dim p -> dim { ldPerm = p }) (lmadDims lmad) perm }
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape shp lmad = lmad { lmadDims = zipWith (\dim s -> dim { ldShape = s }) (lmadDims lmad) shp }
substituteInLMAD :: Ord a => M.Map a (PrimExp a) -> LMAD (PrimExp a)
-> LMAD (PrimExp a)
substituteInLMAD tab (LMAD offset dims) =
let offset' = substituteInPrimExp tab offset
dims' = map (\(LMADDim s r n p m) ->
LMADDim
(substituteInPrimExp tab s)
(substituteInPrimExp tab r)
(substituteInPrimExp tab n)
p m)
dims
in LMAD offset' dims'
substituteInIxFun :: (Ord a) => M.Map a (PrimExp a) -> IxFun (PrimExp a)
-> IxFun (PrimExp a)
substituteInIxFun tab (IxFun lmads oshp cg) =
IxFun (NE.map (substituteInLMAD tab) lmads)
(map (substituteInPrimExp tab) oshp)
cg
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect ixfun@(IxFun (LMAD offset dims :| []) oshp True) =
let strides_expected = reverse $ scanl (*) 1 (reverse (tail oshp))
in hasContiguousPerm ixfun &&
length oshp == length dims &&
offset == 0 &&
all (\(LMADDim s r n p _, m, d, se) ->
s == se && r == 0 && n == d && p == m)
(zip4 dims [0..length dims - 1] oshp strides_expected)
isDirect _ = False
isContiguous :: (Eq num, IntegralExp num) => IxFun num -> Bool
isContiguous ixfun = ixfunContig ixfun && hasContiguousPerm ixfun
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm (IxFun (lmad :| []) _ _) =
let perm = lmadPermutation lmad
in perm == sort perm
hasContiguousPerm _ = False
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun (lmad :| _) _ _) = lmadShape lmad
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape lmad = permuteInv (lmadPermutation lmad) $ lmadShapeBase lmad
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase = map ldShape . lmadDims
index :: (IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num -> num
index = indexFromLMADs . ixfunLMADs
where indexFromLMADs :: (IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num -> num
indexFromLMADs (lmad :| []) inds elm_size = indexLMAD lmad inds elm_size
indexFromLMADs (lmad1 :| lmad2 : lmads) inds elm_size =
let i_flat = indexLMAD lmad1 inds 1
new_inds = unflattenIndex (permuteFwd (lmadPermutation lmad2) $ lmadShapeBase lmad2) i_flat
in indexFromLMADs (lmad2 :| lmads) new_inds elm_size
indexLMAD :: (IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num -> num
indexLMAD lmad@(LMAD off dims) inds elm_size =
let prod = sum $ zipWith flatOneDim
(map (\(LMADDim s r n _ _) -> (s, r, n)) dims)
(permuteInv (lmadPermutation lmad) inds)
ind = off + prod
in if elm_size == 1 then ind else ind * elm_size
iota :: IntegralExp num => Shape num -> IxFun num
iota ns =
let rs = replicate (length ns) 0
in IxFun (makeRotIota Inc 0 (zip rs ns) :| []) ns True
permute :: IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun (lmad :| lmads) oshp cg) perm_new =
let perm_cur = lmadPermutation lmad
perm = map (perm_cur !!) perm_new
in IxFun (setLMADPermutation perm lmad :| lmads) oshp cg
repeat :: (Eq num, IntegralExp num) =>
IxFun num -> [Shape num] -> Shape num -> IxFun num
repeat (IxFun (lmad@(LMAD off dims) :| lmads) oshp _) shps shp =
let perm = lmadPermutation lmad
lens = map (\s -> 1 + length s) shps
(shps', lens') = unzip $ permuteInv perm $ zip shps lens
scn = drop 1 $ scanl (+) 0 lens'
perm' = concatMap (\(p, l) -> map (\i-> (scn !! p) - l + i) [0..l-1])
$ zip perm lens
tmp = length perm'
perm'' = perm' ++ [tmp..tmp-1+length shp]
dims' = concatMap (\(shp_k, srnp) ->
map fakeDim shp_k ++ [srnp]
) $ zip shps' dims
lmad' = setLMADPermutation perm'' $ LMAD off (dims' ++ map fakeDim shp)
in IxFun (lmad' :| lmads) oshp False
where fakeDim x = LMADDim 0 0 x 0 Unknown
rotate :: (Eq num, IntegralExp num) =>
IxFun num -> Indices num -> IxFun num
rotate (IxFun (lmad@(LMAD off dims) :| lmads) oshp cg) offs =
let dims' = zipWith (\(LMADDim s r n p f) o ->
if s == 0 then LMADDim 0 0 n p Unknown
else LMADDim s (r + o) n p f
) dims (permuteInv (lmadPermutation lmad) offs)
in IxFun (LMAD off dims' :| lmads) oshp cg
sliceOneLMAD :: (Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad@(LMAD _ ldims) :| lmads) oshp cg) is = do
let perm = lmadPermutation lmad
is' = permuteInv perm is
cg' = cg && slicePreservesContiguous lmad is'
guard $ harmlessRotation lmad is'
let lmad' = foldl sliceOne (LMAD (lmadOffset lmad) []) $ zip is' ldims
perm' = updatePerm perm $ map fst $ filter (isJust . dimFix . snd) $
zip [0..length is' - 1] is'
return $ IxFun (setLMADPermutation perm' lmad' :| lmads) oshp cg'
where updatePerm ps inds = foldl (\acc p -> acc ++ decrease p) [] ps
where decrease p =
let d = foldl (\n i -> if i == p then -1
else if i > p
then n
else if n /= -1 then n + 1
else n
) 0 inds
in if d == -1 then [] else [p - d]
harmlessRotation' :: (Eq num, IntegralExp num) =>
LMADDim num -> DimIndex num -> Bool
harmlessRotation' _ (DimFix _) = True
harmlessRotation' (LMADDim 0 _ _ _ _) _ = True
harmlessRotation' (LMADDim _ 0 _ _ _) _ = True
harmlessRotation' (LMADDim _ _ n _ _) dslc
| dslc == DimSlice (n - 1) n (-1) ||
dslc == unitSlice 0 n = True
harmlessRotation' _ _ = False
harmlessRotation :: (Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
harmlessRotation (LMAD _ dims) iss =
and $ zipWith harmlessRotation' dims iss
sliceOne :: (Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD off dims) (DimFix i, LMADDim s r n _ _) =
LMAD (off + flatOneDim (s, r, n) i) dims
sliceOne (LMAD off dims) (DimSlice _ ne _, LMADDim 0 _ _ p _) =
LMAD off (dims ++ [LMADDim 0 0 ne p Unknown])
sliceOne (LMAD off dims) (dmind, dim@(LMADDim _ _ n _ _))
| dmind == unitSlice 0 n = LMAD off (dims ++ [dim])
sliceOne (LMAD off dims) (dmind, LMADDim s r n p m)
| dmind == DimSlice (n - 1) n (-1) =
let r' = if r == 0 then 0 else n - r
off' = off + flatOneDim (s, 0, n) (n - 1)
in LMAD off' (dims ++ [LMADDim (s * (-1)) r' n p (invertMonotonicity m)])
sliceOne (LMAD off dims) (DimSlice b ne 0, LMADDim s r n p _) =
LMAD (off + flatOneDim (s, r, n) b) (dims ++ [LMADDim 0 0 ne p Unknown])
sliceOne (LMAD off dims) (DimSlice bs ns ss, LMADDim s 0 _ p m) =
let m' = case sgn ss of
Just 1 -> m
Just (-1) -> invertMonotonicity m
_ -> Unknown
in LMAD (off + s * bs) (dims ++ [LMADDim (ss * s) 0 ns p m'])
sliceOne _ _ = error "slice: reached impossible case"
slicePreservesContiguous :: (Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD _ dims) slc =
let (dims', slc') = unzip $
filter ((/= 0) . ldStride . fst) $
zip dims $ map normIndex slc
(_, success) =
foldl (\(found, res) (slcdim, LMADDim _ r n _ _) ->
case (slcdim, found) of
(DimFix{}, True ) -> (found, False)
(DimFix{}, False) -> (found, res)
(DimSlice _ ne ds, False) ->
let res' = (r == 0 || n == ne) && (ds == 1 || ds == -1)
in (True, res && res')
(DimSlice _ ne ds, True) ->
let res' = (n == ne) && (ds == 1 || ds == -1)
in (found, res && res')
) (False, True) $ zip slc' dims'
in success
normIndex :: (Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex (DimSlice b 1 _) = DimFix b
normIndex (DimSlice b _ 0) = DimFix b
normIndex d = d
slice :: (Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice _ [] = error "slice: empty slice"
slice ixfun@(IxFun (lmad@(LMAD _ _) :| lmads) oshp cg) dim_slices
| dim_slices == map (unitSlice 0) (shape ixfun) = ixfun
| Just ixfun' <- sliceOneLMAD ixfun dim_slices = ixfun'
| otherwise =
case sliceOneLMAD (iota (lmadShape lmad)) dim_slices of
Just (IxFun (lmad' :| []) _ cg') ->
IxFun (lmad' :| lmad : lmads) oshp (cg && cg')
_ -> error "slice: reached impossible case"
reshapeCoercion :: (Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion (IxFun (lmad@(LMAD off dims) :| lmads) _ cg) newshape = do
let perm = lmadPermutation lmad
(head_coercions, reshapes, tail_coercions) <- splitCoercions newshape
let hd_len = length head_coercions
num_coercions = hd_len + length tail_coercions
dims' = permuteFwd perm dims
mid_dims = take (length dims - num_coercions) $ drop hd_len dims'
num_rshps = length reshapes
guard (num_rshps == 0 || (num_rshps == 1 && length mid_dims == 1))
let dims'' = map snd $ sortBy (compare `on` fst) $
zipWith (\ld n -> (ldPerm ld, ld { ldShape = n }))
dims' (newDims newshape)
lmad' = LMAD off dims''
return $ IxFun (lmad' :| lmads) (newDims newshape) cg
reshapeOneLMAD :: (Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD ixfun@(IxFun (lmad@(LMAD off dims) :| lmads) _ cg) newshape = do
let perm = lmadPermutation lmad
(head_coercions, reshapes, tail_coercions) <- splitCoercions newshape
let hd_len = length head_coercions
num_coercions = hd_len + length tail_coercions
dims_perm = permuteFwd perm dims
mid_dims = take (length dims - num_coercions) $ drop hd_len dims_perm
mon = ixfunMonotonicityRots True ixfun
guard $
all (\ (LMADDim s r _ _ _) -> s /= 0 && r == 0) mid_dims &&
consecutive hd_len (map ldPerm mid_dims) &&
hasContiguousPerm ixfun && cg && (mon == Inc || mon == Dec)
let rsh_len = length reshapes
diff = length newshape - length dims
iota_shape = [0..length newshape-1]
perm' = map (\i -> let ind = if i < hd_len
then i else i - diff
in if (i >= hd_len) && (i < hd_len + rsh_len)
then i
else let p = ldPerm (dims !! ind)
in if p < hd_len
then p
else p + diff
) iota_shape
(support_inds, repeat_inds) =
foldl (\(sup, rpt) (i, shpdim, ip) ->
case (i < hd_len, i >= hd_len + rsh_len, shpdim) of
(True, _, DimCoercion n) ->
case dims_perm !! i of
(LMADDim 0 _ _ _ _) -> ( sup, (ip, n) : rpt )
(LMADDim _ r _ _ _) -> ( (ip, (r, n)) : sup, rpt )
(_, True, DimCoercion n) ->
case dims_perm !! (i-diff) of
(LMADDim 0 _ _ _ _) -> ( sup, (ip, n) : rpt )
(LMADDim _ r _ _ _) -> ( (ip, (r, n)) : sup, rpt )
(False, False, _) ->
( (ip, (0, newDim shpdim)) : sup, rpt )
_ -> error "reshape: reached impossible case"
) ([], []) $ reverse $ zip3 iota_shape newshape perm'
(sup_inds, support) = unzip $ sortBy (compare `on` fst) support_inds
(rpt_inds, repeats) = unzip repeat_inds
LMAD off' dims_sup = makeRotIota mon off support
repeats' = map (\n -> LMADDim 0 0 n 0 Unknown) repeats
dims' = map snd $ sortBy (compare `on` fst)
$ zip sup_inds dims_sup ++ zip rpt_inds repeats'
lmad' = LMAD off' dims'
return $ IxFun (setLMADPermutation perm' lmad' :| lmads) (newDims newshape) cg
where consecutive _ [] = True
consecutive i [p]= i == p
consecutive i ps = and $ zipWith (==) ps [i, i+1..]
splitCoercions :: (Eq num, IntegralExp num) =>
ShapeChange num -> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions newshape' = do
let (head_coercions, newshape'') = span isCoercion newshape'
(reshapes, tail_coercions) = break isCoercion newshape''
guard (all isCoercion tail_coercions)
return (head_coercions, reshapes, tail_coercions)
where isCoercion DimCoercion{} = True
isCoercion _ = False
reshape :: (Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
reshape ixfun new_shape
| Just ixfun' <- reshapeCoercion ixfun new_shape = ixfun'
| Just ixfun' <- reshapeOneLMAD ixfun new_shape = ixfun'
reshape (IxFun (lmad0 :| lmad0s) oshp cg) new_shape =
case iota (newDims new_shape) of
IxFun (lmad :| []) _ _ -> IxFun (lmad :| lmad0 : lmad0s) oshp cg
_ -> error "reshape: reached impossible case"
rank :: IntegralExp num =>
IxFun num -> Int
rank (IxFun (LMAD _ sss :| _) _ _) = length sss
rebaseNice :: (Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
new_base@(IxFun (lmad_base :| lmads_base) _ cg_base)
ixfun@(IxFun lmads shp cg) = do
let (lmad_full :| lmads') = NE.reverse lmads
((outer_shapes, inner_shape), lmad) = shaveoffRepeats lmad_full
dims = lmadDims lmad
perm = lmadPermutation lmad
perm_base = lmadPermutation lmad_base
guard $
base ixfun == shape new_base
&& cg && all ((/= Unknown) . ldMon) dims
&& (hasContiguousPerm ixfun || hasContiguousPerm new_base)
&& (length perm == length perm_base || hasContiguousPerm ixfun)
&& and (zipWith3 (\sn ld inner -> sn == ldShape ld || (inner && ldStride ld == 1))
shp dims (replicate (length dims - 1) False ++ [True]))
let perm_base' = if hasContiguousPerm ixfun
then perm_base
else map (perm !!) perm_base
lmad_base' = setLMADPermutation perm_base' lmad_base
dims_base = lmadDims lmad_base'
n_fewer_dims = length dims_base - length dims
(dims_base', offs_contrib) = unzip $
zipWith (\(LMADDim s1 r1 n1 p1 _) (LMADDim _ r2 _ _ m2) ->
let (s', off') | m2 == Inc = (s1, 0)
| otherwise = (s1 * (-1), s1 * (n1 - 1))
r' | m2 == Inc = if r2 == 0 then r1 else r1 + r2
| r1 == 0 = r2
| r2 == 0 = n1 - r1
| otherwise = n1 - r1 + r2
in (LMADDim s' r' n1 (p1 - n_fewer_dims) Inc, off'))
(drop n_fewer_dims dims_base) dims
off_base = lmadOffset lmad_base' + sum offs_contrib
lmad_base''
| lmadOffset lmad == 0 = LMAD off_base dims_base'
| otherwise =
setLMADShape (lmadShape lmad)
(LMAD (off_base + ldStride (last dims_base) * lmadOffset lmad)
dims_base')
new_base' = IxFun (lmad_base'' :| lmads_base) shp cg_base
IxFun lmads_base' _ _ = if all null outer_shapes && null inner_shape
then new_base'
else repeat new_base' outer_shapes inner_shape
lmads'' = lmads' ++@ lmads_base'
return $ IxFun lmads'' shp (cg && cg_base)
where shaveoffRepeats :: (Eq num, IntegralExp num) =>
LMAD num -> (([Shape num], Shape num), LMAD num)
shaveoffRepeats lmad =
let perm = lmadPermutation lmad
dims = lmadDims lmad
resacc= foldl (\acc (LMADDim s _ n _ _) ->
case acc of
rpt : acc0 ->
if s == 0 then (n : rpt) : acc0
else [] : (rpt : acc0)
_ -> error "shaveoffRepeats: empty accumulator"
) [[]] $ reverse $ permuteFwd perm dims
last_shape = last resacc
shapes = take (length resacc - 1) resacc
howManyRepLT k =
foldl (\i (LMADDim s _ _ p _) ->
if s == 0 && p < k then i + 1 else i
) 0 dims
dims' = foldl (\acc (LMADDim s r n p info) ->
if s == 0 then acc
else let p' = p - howManyRepLT p
in LMADDim s r n p' info : acc
) [] $ reverse dims
lmad' = LMAD (lmadOffset lmad) dims'
in ((shapes, last_shape), lmad')
rebase :: (Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
rebase new_base@(IxFun lmads_base shp_base cg_base) ixfun@(IxFun lmads shp cg)
| Just ixfun' <- rebaseNice new_base ixfun = ixfun'
| otherwise =
let (lmads_base', shp_base') =
if base ixfun == shape new_base
then (lmads_base, shp_base)
else let IxFun lmads' shp_base'' _ = reshape new_base $ map DimCoercion shp
in (lmads', shp_base'')
in IxFun (lmads @++@ lmads_base') shp_base' (cg && cg_base)
ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity = ixfunMonotonicityRots False
offsetIndex :: (Eq num, IntegralExp num) =>
IxFun num -> num -> IxFun num
offsetIndex ixfun i | i == 0 = ixfun
offsetIndex ixfun i =
case shape ixfun of
d : ds -> slice ixfun (DimSlice i (d - i) 1 : map (unitSlice 0) ds)
[] -> error "offsetIndex: underlying index function has rank zero"
strideIndex :: (Eq num, IntegralExp num) =>
IxFun num -> num -> IxFun num
strideIndex ixfun s =
case shape ixfun of
d : ds -> slice ixfun (DimSlice 0 d s : map (unitSlice 0) ds)
[] -> error "offsetIndex: underlying index function has rank zero"
linearWithOffset :: (Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset ixfun@(IxFun (lmad :| []) _ cg) elem_size
| hasContiguousPerm ixfun && cg && ixfunMonotonicity ixfun == Inc =
Just $ lmadOffset lmad * elem_size
linearWithOffset _ _ = Nothing
rearrangeWithOffset :: (Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int,num)])
rearrangeWithOffset (IxFun (lmad :| []) oshp cg) elem_size = do
let perm = lmadPermutation lmad
perm_contig = [0..length perm-1]
offset <- linearWithOffset
(IxFun (setLMADPermutation perm_contig lmad :| []) oshp cg) elem_size
return (offset, zip perm (permuteFwd perm (lmadShapeBase lmad)))
rearrangeWithOffset _ _ = Nothing
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear = (== Just 0) . flip linearWithOffset 1
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd ps elems = map (elems !!) ps
permuteInv :: Permutation -> [a] -> [a]
permuteInv ps elems = map snd $ sortBy (compare `on` fst) $ zip ps elems
flatOneDim :: (Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (s, r, n) i
| s == 0 = 0
| r == 0 = i * s
| otherwise = ((i + r) `mod` n) * s
makeRotIota :: IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota mon off support
| mon == Inc || mon == Dec =
let rk = length support
(rs, ns) = unzip support
ss0 = reverse $ take rk $ scanl (*) 1 $ reverse ns
ss = if mon == Inc
then ss0
else map (* (-1)) ss0
ps = map fromIntegral [0..rk-1]
fi = replicate rk mon
in LMAD off $ zipWith5 LMADDim ss rs ns ps fi
| otherwise = error "makeRotIota: requires Inc or Dec"
ixfunMonotonicityRots :: (Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots ignore_rots (IxFun (lmad :| lmads) _ _) =
let mon0 = lmadMonotonicityRots lmad
in if all ((== mon0) . lmadMonotonicityRots) lmads
then mon0
else Unknown
where lmadMonotonicityRots :: (Eq num, IntegralExp num) =>
LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD _ dims)
| all (isMonDim Inc) dims = Inc
| all (isMonDim Dec) dims = Dec
| otherwise = Unknown
isMonDim :: (Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim mon (LMADDim s r _ _ ldmon) =
s == 0 || ((ignore_rots || r == 0) && mon == ldmon)