{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
  ( IxFun (..),
    Shape,
    LMAD (..),
    LMADDim (..),
    Monotonicity (..),
    index,
    mkExistential,
    iota,
    iotaOffset,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    rebase,
    shape,
    lmadShape,
    rank,
    linearWithOffset,
    rearrangeWithOffset,
    isDirect,
    isLinear,
    substituteInIxFun,
    substituteInLMAD,
    existentialize,
    closeEnough,
    equivalent,
    hasOneLmad,
    permuteInv,
    conservativeFlatten,
    disjoint,
    disjoint2,
    disjoint3,
    dynamicEqualsLMAD,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.State
import Data.Function (on, (&))
import Data.List (sort, sortBy, zip4, zipWith4)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe (isJust)
import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( DimIndex (..),
    FlatDimIndex (..),
    FlatSlice (..),
    Slice (..),
    dimFix,
    flatSliceDims,
    flatSliceStrides,
    unitSlice,
  )
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

type Indices num = [num]

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as a sequence of 'LMAD's.
data IxFun num = IxFun
  { forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
    -- | the shape of the support array, i.e., the original array
    --   that birthed (is the start point) of this index function.
    forall num. IxFun num -> Shape num
base :: Shape num,
    -- | ignoring permutations, is the index function contiguous?
    forall num. IxFun num -> Bool
contiguous :: Bool
  }
  deriving (Int -> IxFun num -> ShowS
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)

instance Pretty num => Pretty (IxFun num) where
  pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
    forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"base:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
        Doc ann
"contiguous:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> if Bool
cg then Doc ann
"true" else Doc ann
"false",
        Doc ann
"LMADs:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commastack forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> [a]
NE.toList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map forall a ann. Pretty a => a -> Doc ann
pretty NonEmpty (LMAD num)
lmads)
      ]

instance Substitute num => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn num => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'

instance Functor IxFun where
  fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable IxFun where
  foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

-- It is important that the traversal order here is the same as in
-- mkExistential.
instance Traversable IxFun where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
    forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg

(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: forall a. [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
  a
e : [a]
es' -> a
e forall a. a -> [a] -> NonEmpty a
:| [a]
es' forall a. [a] -> [a] -> [a]
++ [a
ne] forall a. [a] -> [a] -> [a]
++ [a]
nes
  [] -> a
ne forall a. a -> [a] -> NonEmpty a
:| [a]
nes

(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x forall a. a -> [a] -> NonEmpty a
:| [a]
xs forall a. [a] -> [a] -> [a]
++ [a
y] forall a. [a] -> [a] -> [a]
++ [a]
ys

setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
  LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}

setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: forall num. Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
  Ord a =>
  M.Map a (TPrimExp t a) ->
  IxFun (TPrimExp t a) ->
  IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
    (forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab) NonEmpty (LMAD (TPrimExp t a))
lmads)
    (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
    Bool
cg
  where
    tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
  let strides_expected :: Shape num
strides_expected = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 (forall a. [a] -> [a]
reverse (forall a. [a] -> [a]
tail Shape num
oshp))
   in forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        Bool -> Bool -> Bool
&& num
offset forall a. Eq a => a -> a -> Bool
== num
0
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
          (\(LMADDim num
s num
n Int
p Monotonicity
_, Int
m, num
d, num
se) -> num
s forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p forall a. Eq a => a -> a -> Bool
== Int
m)
          (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False

-- | Is index function "analyzable", i.e., consists of one LMAD
hasOneLmad :: IxFun num -> Bool
hasOneLmad :: forall num. IxFun num -> Bool
hasOneLmad (IxFun (LMAD num
_ :| []) Shape num
_ Bool
_) = Bool
True
hasOneLmad IxFun num
_ = Bool
False

-- | Does the index function have an ascending permutation?
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall num. IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
   in Permutation
perm forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False

-- | The index space of the index function.  This is the same as the
-- shape of arrays that the index function supports.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) =
  forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> Shape num
lmadShapeBase LMAD num
lmad

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index ::
  (IntegralExp num, Eq num) =>
  IxFun num ->
  Indices num ->
  num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
  where
    indexFromLMADs ::
      (IntegralExp num, Eq num) =>
      NonEmpty (LMAD num) ->
      Indices num ->
      num
    indexFromLMADs :: forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
    indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
      let i_flat :: num
i_flat = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
          new_inds :: Indices num
new_inds = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
       in forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
    indexLMAD ::
      (IntegralExp num, Eq num) =>
      LMAD num ->
      Indices num ->
      num
    indexLMAD :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
      let prod :: num
prod =
            forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$
              forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim
                (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim num]
dims)
                (forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
       in num
off forall a. Num a => a -> a -> a
+ num
prod

-- | iota with offset.
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
Inc num
o Shape num
ns forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True

-- | iota.
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0

-- | Create a contiguous single-LMAD index function that is
-- existential in everything, with the provided permutation,
-- monotonicity, and contiguousness.
mkExistential :: Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential :: forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank [(Int, Monotonicity)]
perm Bool
contig Int
start =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall a. a -> NonEmpty a
NE.singleton forall {a}. LMAD (Ext a)
lmad) forall {a}. [Ext a]
basis Bool
contig
  where
    basis :: [Ext a]
basis = forall a. Int -> [a] -> [a]
take Int
basis_rank forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Int -> Ext a
Ext [Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
dims_rank forall a. Num a => a -> a -> a
* Int
2 ..]
    dims_rank :: Int
dims_rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Monotonicity)]
perm
    lmad :: LMAD (Ext a)
lmad = forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall a. Int -> Ext a
Ext Int
start) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim [(Int, Monotonicity)]
perm [Int
0 ..]
    onDim :: (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim (Int
p, Monotonicity
mon) Int
i =
      forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) Int
p Monotonicity
mon

-- | Permute dimensions.
permute ::
  IntegralExp num =>
  IxFun num ->
  Permutation ->
  IxFun num
permute :: forall num.
IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
  let perm_cur :: Permutation
perm_cur = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm :: Permutation
perm = forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur !!) Permutation
perm_new
   in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg

-- | Handle the case where a slice can stay within a single LMAD.
sliceOneLMAD ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  Maybe (IxFun num)
sliceOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (Slice [DimIndex num]
is) = do
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      is' :: [DimIndex num]
is' = forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm [DimIndex num]
is
      cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad (forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is')
  let lmad' :: LMAD num
lmad' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is' [LMADDim num]
ldims
      -- need to remove the fixed dims from the permutation
      perm' :: Permutation
perm' =
        forall {t :: * -> *} {b} {t :: * -> *}.
(Foldable t, Ord b, Foldable t, Num b) =>
t b -> t b -> [b]
updatePerm Permutation
perm forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
            forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d. DimIndex d -> Maybe d
dimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
              forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex num]
is' forall a. Num a => a -> a -> a
- Int
1] [DimIndex num]
is'

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
  where
    updatePerm :: t b -> t b -> [b]
updatePerm t b
ps t b
inds = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap b -> [b]
decrease t b
ps
      where
        decrease :: b -> [b]
decrease b
p =
          let f :: a -> b -> a
f a
n b
i
                | b
i forall a. Eq a => a -> a -> Bool
== b
p = -a
1
                | b
i forall a. Ord a => a -> a -> Bool
> b
p = a
n
                | a
n forall a. Eq a => a -> a -> Bool
/= -a
1 = a
n forall a. Num a => a -> a -> a
+ a
1
                | Bool
otherwise = a
n
              d :: b
d = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. (Num a, Eq a) => a -> b -> a
f b
0 t b
inds
           in [b
p forall a. Num a => a -> a -> a
- b
d | b
d forall a. Eq a => a -> a -> Bool
/= -b
1]

    -- XXX: TODO: what happens to r on a negative-stride slice; is there
    -- such a case?
    sliceOne ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      (DimIndex num, LMADDim num) ->
      LMAD num
    sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
_x Int
_ Monotonicity
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i) [LMADDim num]
dims
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ Int
p Monotonicity
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
n Int
_ Monotonicity
_))
      | DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
n Int
p Monotonicity
m)
      | DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. d -> d -> d -> DimIndex d
DimSlice (num
n forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
          let off' :: num
off' = num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s (num
n forall a. Num a => a -> a -> a
- num
1)
           in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s forall a. Num a => a -> a -> a
* (-num
1)) num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
_ Int
p Monotonicity
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
b) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
_ Int
p Monotonicity
m) =
      let m' :: Monotonicity
m' = case forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
            Just Int
1 -> Monotonicity
m
            Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
            Maybe Int
_ -> Monotonicity
Unknown
       in forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ num
s forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss forall a. Num a => a -> a -> a
* num
s) num
ns Int
p Monotonicity
m'])

    slicePreservesContiguous ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Slice num ->
      Bool
    slicePreservesContiguous :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) (Slice [DimIndex num]
slc) =
      -- remove from the slice the LMAD dimensions that have stride 0.
      -- If the LMAD was contiguous in mem, then these dims will not
      -- influence the contiguousness of the result.
      -- Also normalize the input slice, i.e., 0-stride and size-1
      -- slices are rewritten as DimFixed.
      let ([LMADDim num]
dims', [DimIndex num]
slc') =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
              forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> num
ldStride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
                forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims forall a b. (a -> b) -> a -> b
$
                  forall a b. (a -> b) -> [a] -> [b]
map forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex [DimIndex num]
slc
          -- Check that:
          -- 1. a clean split point exists between Fixed and Sliced dims
          -- 2. the outermost sliced dim has +/- 1 stride.
          -- 3. the rest of inner sliced dims are full.
          (Bool
_, Bool
success) =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
n Int
_ Monotonicity
_) ->
                  case (DimIndex num
slcdim, Bool
found) of
                    (DimFix {}, Bool
True) -> (Bool
found, Bool
False)
                    (DimFix {}, Bool
False) -> (Bool
found, Bool
res)
                    (DimSlice num
_ num
_ num
ds, Bool
False) ->
                      -- outermost sliced dim: +/-1 stride
                      let res' :: Bool
res' = (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
                       in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
                    (DimSlice num
_ num
ne num
ds, Bool
True) ->
                      -- inner sliced dim: needs to be full
                      let res' :: Bool
res' = (num
n forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
                       in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
              )
              (Bool
False, Bool
True)
              forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
slc' [LMADDim num]
dims'
       in Bool
success

    normIndex ::
      (Eq num, IntegralExp num) =>
      DimIndex num ->
      DimIndex num
    normIndex :: forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = forall d. d -> DimIndex d
DimFix num
b
    normIndex (DimSlice num
b num
_ num
0) = forall d. d -> DimIndex d
DimFix num
b
    normIndex DimIndex num
d = DimIndex num
d

-- | Slice an index function.
slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
  -- Avoid identity slicing.
  | forall d. Slice d -> [DimIndex d]
unSlice Slice num
dim_slices forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
  | Bool
otherwise =
      case forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (forall num. IntegralExp num => Shape num -> IxFun num
iota (forall num. LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
        Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
          forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
        Maybe (IxFun num)
_ -> forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"

-- | Flat-slice an index function.
flatSlice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  FlatSlice num ->
  IxFun num
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (FlatSlice num
new_offset [FlatDimIndex num]
is)
  | forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun =
      let lmad :: LMAD num
lmad =
            forall num. num -> [LMADDim num] -> LMAD num
LMAD
              (num
offset forall a. Num a => a -> a -> a
+ num
new_offset forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim num
dim)
              (forall a b. (a -> b) -> [a] -> [b]
map (forall {num}.
(Eq num, Num num) =>
num -> FlatDimIndex num -> LMADDim num
helper forall a b. (a -> b) -> a -> b
$ forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims)
              forall a b. a -> (a -> b) -> b
& forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation [Int
0 ..]
       in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
  where
    helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) =
      let new_mon :: Monotonicity
new_mon = if num
s0 forall a. Num a => a -> a -> a
* num
s forall a. Eq a => a -> a -> Bool
== num
1 then Monotonicity
Inc else Monotonicity
Unknown
       in forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s0 forall a. Num a => a -> a -> a
* num
s) num
n Int
0 Monotonicity
new_mon
flatSlice (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) s :: FlatSlice num
s@(FlatSlice num
new_offset [FlatDimIndex num]
_) =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
new_offset forall a. Num a => a -> a -> a
* num
base_stride) ([LMADDim num]
new_dims forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
tail_dims) forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp Bool
cg
  where
    tail_shapes :: Shape num
tail_shapes = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> Shape num
lmadShape LMAD num
lmad
    base_stride :: num
base_stride = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape num
tail_shapes
    tail_strides :: Shape num
tail_strides = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) num
1 Shape num
tail_shapes
    tail_dims :: [LMADDim num]
tail_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
tail_strides Shape num
tail_shapes [forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
new_shapes ..] (forall a. a -> [a]
repeat Monotonicity
Inc)
    new_shapes :: Shape num
new_shapes = forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
s
    new_strides :: Shape num
new_strides = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* num
base_stride) forall a b. (a -> b) -> a -> b
$ forall d. FlatSlice d -> [d]
flatSliceStrides FlatSlice num
s
    new_dims :: [LMADDim num]
new_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
new_strides Shape num
new_shapes [Int
0 ..] (forall a. a -> [a]
repeat Monotonicity
Inc)

-- | Handle the case where a reshape operation can stay inside a single LMAD.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--   (3) finally, the underlying memory is contiguous (and monotonous).
--
-- If any of these conditions do not hold, then the reshape operation will
-- conservatively add a new LMAD to the list, leading to a representation that
-- provides less opportunities for further analysis.
reshapeOneLMAD ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  Maybe (IxFun num)
reshapeOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) Shape num
newshape = do
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      dims_perm :: [LMADDim num]
dims_perm = forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
      mid_dims :: [LMADDim num]
mid_dims = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims) [LMADDim num]
dims_perm
      mon :: Monotonicity
mon = forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun

  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
    -- checking conditions (2)
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
_ Int
_ Monotonicity
_) -> num
s forall a. Eq a => a -> a -> Bool
/= num
0) [LMADDim num]
mid_dims
      Bool -> Bool -> Bool
&&
      -- checking condition (1)
      forall {b}. (Eq b, Num b, Enum b) => b -> [b] -> Bool
consecutive Int
0 (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
      Bool -> Bool -> Bool
&&
      -- checking condition (3)
      forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
      Bool -> Bool -> Bool
&& Bool
cg
      Bool -> Bool -> Bool
&& (Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)

  -- make new permutation
  let rsh_len :: Int
rsh_len = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape
      diff :: Int
diff = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
      iota_shape :: Permutation
iota_shape = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- Int
1]
      perm' :: Permutation
perm' =
        forall a b. (a -> b) -> [a] -> [b]
map
          ( \Int
i ->
              let ind :: Int
ind = Int
i forall a. Num a => a -> a -> a
- Int
diff
               in if (Int
i forall a. Ord a => a -> a -> Bool
>= Int
0) Bool -> Bool -> Bool
&& (Int
i forall a. Ord a => a -> a -> Bool
< Int
rsh_len)
                    then Int
i -- already checked mid_dims not affected
                    else forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims forall a. [a] -> Int -> a
!! Int
ind) forall a. Num a => a -> a -> a
+ Int
diff
          )
          Permutation
iota_shape
      -- split the dimensions
      ([(Int, num)]
support_inds, [a]
repeat_inds) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          (\([(Int, num)]
sup, [a]
rpt) (num
shpdim, Int
ip) -> ((Int
ip, num
shpdim) forall a. a -> [a] -> [a]
: [(Int, num)]
sup, [a]
rpt))
          ([], [])
          forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse
          forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
newshape Permutation
perm'

      (Permutation
sup_inds, Shape num
support) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) [(Int, num)]
support_inds
      ([a]
rpt_inds, [b]
repeats) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a. [a]
repeat_inds
      LMAD num
off' [LMADDim num]
dims_sup = forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off Shape num
support
      repeats' :: [LMADDim num]
repeats' = forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
n Int
0 Monotonicity
Unknown) forall a. [a]
repeats
      dims' :: [LMADDim num]
dims' =
        forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
            forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip forall a. [a]
rpt_inds [LMADDim num]
repeats'
      lmad' :: LMAD num
lmad' = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
  where
    consecutive :: b -> [b] -> Bool
consecutive b
_ [] = Bool
True
    consecutive b
i [b
p] = b
i forall a. Eq a => a -> a -> Bool
== b
p
    consecutive b
i [b]
ps = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [b]
ps [b
i, b
i forall a. Num a => a -> a -> a
+ b
1 ..]

-- | Reshape an index function.
reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
ixfun Shape num
new_shape
  | Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun Shape num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) Shape num
new_shape =
  case forall num. IntegralExp num => Shape num -> IxFun num
iota Shape num
new_shape of
    IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
    IxFun num
_ -> forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"

-- | Coerce an index function to look like it has a new shape.
-- Dynamically the shape must be the same.
coerce ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Shape num
new_shape =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
  where
    onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
    onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}

-- | The number of dimensions in the domain of the input function.
rank ::
  IntegralExp num =>
  IxFun num ->
  Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Essentially @rebase new_base ixfun = ixfun o new_base@
-- Core soundness condition: @base ixfun == shape new_base@
-- Handles the case where a rebase operation can stay within m + n - 1 LMADs,
-- where m is the number of LMADs in the index function, and n is the number of
-- LMADs in the new base.  If both index function have only on LMAD, this means
-- that we stay within the single-LMAD domain.
--
-- We can often stay in that domain if the original ixfun is essentially a
-- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`.
--
-- XXX: TODO: handle repetitions in both lmads.
--
-- How to handle repeated dimensions in the original?
--
--   (a) Shave them off of the last lmad of original
--   (b) Compose the result from (a) with the first
--       lmad of the new base
--   (c) apply a repeat operation on the result of (b).
--
-- However, I strongly suspect that for in-place update what we need is actually
-- the INVERSE of the rebase function, i.e., given an index function new-base
-- and another one orig, compute the index function ixfun0 such that:
--
--   new-base == rebase ixfun0 ixfun, or equivalently:
--   new-base == ixfun o ixfun0
--
-- because then I can go bottom up and compose with ixfun0 all the index
-- functions corresponding to the memory block associated with ixfun.
rebaseNice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  Maybe (IxFun num)
rebaseNice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
  new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
  ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
    let (LMAD num
lmad :| [LMAD num]
lmads') = forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
        dims :: [LMADDim num]
dims = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
        perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
        perm_base :: Permutation
perm_base = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base

    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
      -- Core rebase condition.
      forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
        -- Conservative safety conditions: ixfun is contiguous and has known
        -- monotonicity for all dimensions.
        Bool -> Bool -> Bool
&& Bool
cg
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
        -- XXX: We should be able to handle some basic cases where both index
        -- functions have non-trivial permutations.
        Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
new_base)
        -- We need the permutations to be of the same size if we want to compose
        -- them.  They don't have to be of the same size if the ixfun has a trivial
        -- permutation.  Supporting this latter case allows us to rebase when ixfun
        -- has been created by slicing with fixed dimensions.
        Bool -> Bool -> Bool
&& (forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun)
        -- To not have to worry about ixfun having non-1 strides, we also check that
        -- it is a row-major array (modulo permutation, which is handled
        -- separately).  Accept a non-full innermost dimension.  XXX: Maybe this can
        -- be less conservative?
        Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and
          ( forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
              (\num
sn LMADDim num
ld Bool
inner -> num
sn forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& forall num. LMADDim num -> num
ldStride LMADDim num
ld forall a. Eq a => a -> a -> Bool
== num
1))
              Shape num
shp
              [LMADDim num]
dims
              (forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1) Bool
False forall a. [a] -> [a] -> [a]
++ [Bool
True])
          )

    -- Compose permutations, reverse strides and adjust offset if necessary.
    let perm_base' :: Permutation
perm_base' =
          if forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
            then Permutation
perm_base
            else forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm !!) Permutation
perm_base
        lmad_base' :: LMAD num
lmad_base' = forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
        dims_base :: [LMADDim num]
dims_base = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
        n_fewer_dims :: Int
n_fewer_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        ([LMADDim num]
dims_base', Shape num
offs_contrib) =
          forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
            forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \(LMADDim num
s1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
_ Int
_ Monotonicity
m2) ->
                  let (num
s', num
off')
                        | Monotonicity
m2 forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
                        | Bool
otherwise = (num
s1 forall a. Num a => a -> a -> a
* (-num
1), num
s1 forall a. Num a => a -> a -> a
* (num
n1 forall a. Num a => a -> a -> a
- num
1))
                   in (forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
n1 (Int
p1 forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
              )
              -- If @dims@ is morally a slice, it might have fewer dimensions than
              -- @dims_base@.  Drop extraneous outer dimensions.
              (forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
              [LMADDim num]
dims
        off_base :: num
off_base = forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
        lmad_base'' :: LMAD num
lmad_base''
          | forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Eq a => a -> a -> Bool
== num
0 = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
          | Bool
otherwise =
              -- If the innermost dimension of the ixfun was not full (but still
              -- had a stride of 1), add its offset relative to the new base.
              forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
                (forall num. LMAD num -> Shape num
lmadShape LMAD num
lmad)
                ( forall num. num -> [LMADDim num] -> LMAD num
LMAD
                    (num
off_base forall a. Num a => a -> a -> a
+ forall num. LMADDim num -> num
ldStride (forall a. [a] -> a
last [LMADDim num]
dims_base) forall a. Num a => a -> a -> a
* forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
                    [LMADDim num]
dims_base'
                )
        new_base' :: IxFun num
new_base' = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
        IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
        lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

-- | Rebase an index function on top of a new base.
rebase ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  IxFun num
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
  | Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
  -- In the general case just concatenate LMADs since this refers to index
  -- function composition, which is always safe.
  | Bool
otherwise =
      let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
            if forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
              then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
              else
                let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
new_base Shape num
shp
                 in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
       in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

-- | If the memory support of the index function is contiguous and row-major
-- (i.e., no transpositions, repetitions, rotates, etc.), then this should
-- return the offset from which the memory-support of this index function
-- starts.
linearWithOffset ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  num ->
  Maybe num
linearWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
  | forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing

-- | Similar restrictions to @linearWithOffset@ except for transpositions, which
-- are returned together with the offset.
rearrangeWithOffset ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  num ->
  Maybe (num, [(Int, num)])
rearrangeWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
  -- Note that @cg@ describes whether the index function is
  -- contiguous, *ignoring permutations*.  This function requires that
  -- functionality.
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm_contig :: Permutation
perm_contig = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Num a => a -> a -> a
- Int
1]
  num
offset <-
    forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
      (forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
      num
elem_size
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (num
offset, forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (forall num. LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing

-- | Is this a row-major array starting at offset zero?
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear = (forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1

flatOneDim ::
  (Eq num, IntegralExp num) =>
  num ->
  num ->
  num
flatOneDim :: forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i
  | num
s forall a. Eq a => a -> a -> Bool
== num
0 = num
0
  | Bool
otherwise = num
i forall a. Num a => a -> a -> a
* num
s

-- | Check monotonicity of an index function.
ixfunMonotonicity ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Monotonicity
ixfunMonotonicity :: forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
  let mon0 :: Monotonicity
mon0 = forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
   in if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
        then Monotonicity
mon0
        else Monotonicity
Unknown
  where
    lmadMonotonicityRots ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Monotonicity
    lmadMonotonicityRots :: forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
      | Bool
otherwise = Monotonicity
Unknown

    isMonDim ::
      (Eq num, IntegralExp num) =>
      Monotonicity ->
      LMADDim num ->
      Bool
    isMonDim :: forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
_ Int
_ Monotonicity
ldmon) =
      num
s forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon

-- | Turn all the leaves of the index function into 'Ext's.  We
--  require that there's only one LMAD, that the index function is
--  contiguous, and the base shape has only one dimension.
existentialize ::
  IxFun (TPrimExp Int64 a) ->
  IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
  where
    mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
      Int
i <- forall s (m :: * -> *). MonadState s m => m s
get
      forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
1
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf2))
    Bool -> Bool -> Bool
&& (forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
    -- This treats ixf1 as the "declared type" that we are matching against.
    Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
contiguous IxFun num
ixf1 forall a. Ord a => a -> a -> Bool
<= forall num. IxFun num -> Bool
contiguous IxFun num
ixf2)
  where
    closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
    closeEnoughLMADs :: forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
      forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
        Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
          forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)

-- | Returns true if two 'IxFun's are equivalent.
--
-- Equivalence in this case is defined as having the same number of LMADs, with
-- each pair of LMADs matching in permutation, offsets, strides and rotations.
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: forall num. Eq num => IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
  forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2)
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {b}. Eq b => (LMAD b, LMAD b) -> Bool
equivalentLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
  where
    equivalentLMADs :: (LMAD b, LMAD b) -> Bool
equivalentLMADs (LMAD b
lmad1, LMAD b
lmad2) =
      forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
        Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
          forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
        Bool -> Bool -> Bool
&& forall num. LMAD num -> num
lmadOffset LMAD b
lmad1
          forall a. Eq a => a -> a -> Bool
== forall num. LMAD num -> num
lmadOffset LMAD b
lmad2
        Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
          forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)