{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
  ( IxFun (..),
    Shape,
    LMAD (..),
    LMADDim (..),
    index,
    mkExistential,
    iota,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    expand,
    shape,
    rank,
    isDirect,
    substituteInIxFun,
    substituteInLMAD,
    existentialize,
    closeEnough,
    disjoint,
    disjoint2,
    disjoint3,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.State
import Data.Map.Strict qualified as M
import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD hiding
  ( equivalent,
    flatSlice,
    index,
    iota,
    isDirect,
    mkExistential,
    permute,
    rank,
    reshape,
    shape,
    slice,
  )
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( FlatSlice (..),
    Slice (..),
    unitSlice,
  )
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as an LMAD.
data IxFun num = IxFun
  { forall num. IxFun num -> LMAD num
ixfunLMAD :: LMAD num,
    -- | the shape of the support array, i.e., the original array
    --   that birthed (is the start point) of this index function.
    forall num. IxFun num -> Shape num
base :: Shape num
  }
  deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
showsPrec :: Int -> IxFun num -> ShowS
$cshow :: forall num. Show num => IxFun num -> String
show :: IxFun num -> String
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
showList :: [IxFun num] -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
/= :: IxFun num -> IxFun num -> Bool
Eq)

instance (Pretty num) => Pretty (IxFun num) where
  pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun LMAD num
lmad Shape num
oshp) =
    Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
braces (Doc ann -> Doc ann)
-> ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
semistack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"base:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
brackets ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (num -> Doc ann) -> Shape num -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
        Doc ann
"LMAD:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> LMAD num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. LMAD num -> Doc ann
pretty LMAD num
lmad
      ]

instance (Substitute num) => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> IxFun a -> IxFun b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance (Substitute num) => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename

instance (FreeIn num) => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall m a. Monoid m => (a -> m) -> IxFun a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance Functor IxFun where
  fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap = (a -> b) -> IxFun a -> IxFun b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable IxFun where
  foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap = (a -> m) -> IxFun a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

-- It is important that the traversal order here is the same as in
-- mkExistential.
instance Traversable IxFun where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun LMAD a
lmad Shape a
oshp) =
    LMAD b -> Shape b -> IxFun b
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD b -> Shape b -> IxFun b)
-> f (LMAD b) -> f (Shape b -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f LMAD a
lmad f (Shape b -> IxFun b) -> f (Shape b) -> f (IxFun b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse a -> f b
f Shape a
oshp

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
  (Ord a) =>
  M.Map a (TPrimExp t a) ->
  IxFun (TPrimExp t a) ->
  IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun LMAD (TPrimExp t a)
lmad Shape (TPrimExp t a)
oshp) =
  LMAD (TPrimExp t a) -> Shape (TPrimExp t a) -> IxFun (TPrimExp t a)
forall num. LMAD num -> Shape num -> IxFun num
IxFun
    (Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab LMAD (TPrimExp t a)
lmad)
    ((TPrimExp t a -> TPrimExp t a)
-> Shape (TPrimExp t a) -> Shape (TPrimExp t a)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp a -> TPrimExp t a
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
  where
    tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall a b. (a -> b) -> Map a a -> Map a b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect (IxFun (LMAD num
offset [LMADDim num]
dims) Shape num
oshp) =
  let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. HasCallStack => [a] -> [a]
tail Shape num
oshp))
   in Shape num -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        Bool -> Bool -> Bool
&& num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
        Bool -> Bool -> Bool
&& ((LMADDim num, num, num) -> Bool)
-> [(LMADDim num, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
          (\(LMADDim num
s num
n, num
d, num
se) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d)
          ([LMADDim num]
-> Shape num -> Shape num -> [(LMADDim num, num, num)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [LMADDim num]
dims Shape num
oshp Shape num
strides_expected)

-- | The index space of the index function.  This is the same as the
-- shape of arrays that the index function supports.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape = LMAD num -> [num]
forall num. LMAD num -> Shape num
LMAD.shape (LMAD num -> [num])
-> (IxFun num -> LMAD num) -> IxFun num -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index ::
  (IntegralExp num, Eq num) =>
  IxFun num ->
  Indices num ->
  num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index (LMAD num -> Indices num -> num)
-> (IxFun num -> LMAD num) -> IxFun num -> Indices num -> num
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD

-- | iota with offset.
iotaOffset :: (IntegralExp num) => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (num -> Shape num -> LMAD num
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota num
o Shape num
ns) Shape num
ns

-- | iota.
iota :: (IntegralExp num) => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = num -> Shape num -> IxFun num
forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0

-- | Create a single-LMAD index function that is
-- existential in everything, with the provided permutation.
mkExistential :: Int -> Int -> Int -> IxFun (Ext a)
mkExistential :: forall a. Int -> Int -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank Int
lmad_rank Int
start =
  LMAD (Ext a) -> Shape (Ext a) -> IxFun (Ext a)
forall num. LMAD num -> Shape num -> IxFun num
IxFun (Int -> Int -> LMAD (Ext a)
forall a. Int -> Int -> LMAD (Ext a)
LMAD.mkExistential Int
lmad_rank Int
start) Shape (Ext a)
forall {a}. [Ext a]
basis
  where
    basis :: [Ext a]
basis = Int -> [Ext a] -> [Ext a]
forall a. Int -> [a] -> [a]
take Int
basis_rank ([Ext a] -> [Ext a]) -> [Ext a] -> [Ext a]
forall a b. (a -> b) -> a -> b
$ (Int -> Ext a) -> [Int] -> [Ext a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Ext a
forall a. Int -> Ext a
Ext [Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
lmad_rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 ..]

-- | Permute dimensions.
permute ::
  (IntegralExp num) =>
  IxFun num ->
  Permutation ->
  IxFun num
permute :: forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
permute (IxFun LMAD num
lmad Shape num
oshp) [Int]
perm_new =
  LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> [Int] -> LMAD num
forall num. LMAD num -> [Int] -> LMAD num
LMAD.permute LMAD num
lmad [Int]
perm_new) Shape num
oshp

-- | Slice an index function.
slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) Shape num
oshp) (Slice [DimIndex num]
is)
  -- Avoid identity slicing.
  | [DimIndex num]
is [DimIndex num] -> [DimIndex num] -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Bool
otherwise =
      LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> Slice num -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD num
lmad ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is)) Shape num
oshp

-- | Flat-slice an index function.
flatSlice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  FlatSlice num ->
  IxFun num
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice (IxFun LMAD num
lmad Shape num
oshp) FlatSlice num
s = LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> FlatSlice num -> LMAD num
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
LMAD.flatSlice LMAD num
lmad FlatSlice num
s) Shape num
oshp

-- | Reshape an index function.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--
-- If any of these conditions do not hold, then the reshape operation
-- will conservatively add a new LMAD to the list, leading to a
-- representation that provides less opportunities for further
-- analysis
reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  Maybe (IxFun num)
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshape (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
  LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> Shape num -> IxFun num)
-> Maybe (LMAD num) -> Maybe (Shape num -> IxFun num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD num -> Shape num -> Maybe (LMAD num)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape LMAD num
lmad Shape num
new_shape Maybe (Shape num -> IxFun num)
-> Maybe (Shape num) -> Maybe (IxFun num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape num -> Maybe (Shape num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Shape num
new_shape

-- | Coerce an index function to look like it has a new shape.
-- Dynamically the shape must be the same.
coerce ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
  LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad) Shape num
new_shape
  where
    onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> num -> LMADDim num
forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
    onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}

-- | The number of dimensions in the domain of the input function.
rank :: (IntegralExp num) => IxFun num -> Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss) Shape num
_) = [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Conceptually expand index function to be a particular slice of
-- another by adjusting the offset and strides.  Used for memory
-- expansion.
expand ::
  (Eq num, IntegralExp num) => num -> num -> IxFun num -> Maybe (IxFun num)
expand :: forall num.
(Eq num, IntegralExp num) =>
num -> num -> IxFun num -> Maybe (IxFun num)
expand num
o num
p (IxFun LMAD num
lmad Shape num
base) =
  let onDim :: LMADDim num -> LMADDim num
onDim LMADDim num
ld = LMADDim num
ld {ldStride :: num
LMAD.ldStride = num
p num -> num -> num
forall a. Num a => a -> a -> a
* LMADDim num -> num
forall num. LMADDim num -> num
LMAD.ldStride LMADDim num
ld}
      lmad' :: LMAD num
lmad' =
        num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
          (num
o num -> num -> num
forall a. Num a => a -> a -> a
+ num
p num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
LMAD.offset LMAD num
lmad)
          ((LMADDim num -> LMADDim num) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> LMADDim num
onDim (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad))
   in IxFun num -> Maybe (IxFun num)
forall a. a -> Maybe a
Just (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun LMAD num
lmad' Shape num
base

-- | Turn all the leaves of the index function into 'Ext's.  We
--  require that there's only one LMAD, that the index function is
--  contiguous, and the base shape has only one dimension.
existentialize ::
  IxFun (TPrimExp Int64 a) ->
  IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = State Int (IxFun (TPrimExp Int64 (Ext b)))
-> Int -> IxFun (TPrimExp Int64 (Ext b))
forall s a. State s a -> s -> a
evalState ((TPrimExp Int64 a -> StateT Int Identity (TPrimExp Int64 (Ext b)))
-> IxFun (TPrimExp Int64 a)
-> State Int (IxFun (TPrimExp Int64 (Ext b)))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse (StateT Int Identity (TPrimExp Int64 (Ext b))
-> TPrimExp Int64 a -> StateT Int Identity (TPrimExp Int64 (Ext b))
forall a b. a -> b -> a
const StateT Int Identity (TPrimExp Int64 (Ext b))
forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
  where
    mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
      Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get
      Int -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int -> StateT Int Identity ()) -> Int -> StateT Int Identity ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
      TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a))
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a)))
-> TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a))
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext a) -> TPrimExp t (Ext a)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp (Ext a) -> TPrimExp t (Ext a))
-> PrimExp (Ext a) -> TPrimExp t (Ext a)
forall a b. (a -> b) -> a -> b
$ Ext a -> PrimType -> PrimExp (Ext a)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext a
forall a. Int -> Ext a
Ext Int
i) PrimType
int64

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  ([num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf2))
    Bool -> Bool -> Bool
&& LMAD num -> LMAD num -> Bool
forall {num} {num}. LMAD num -> LMAD num -> Bool
closeEnoughLMADs (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf1) (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf2)
  where
    closeEnoughLMADs :: LMAD num -> LMAD num -> Bool
closeEnoughLMADs LMAD num
lmad1 LMAD num
lmad2 =
      [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad2)