{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
  ( IxFun (..),
    Shape,
    LMAD (..),
    LMADDim (..),
    index,
    mkExistential,
    iota,
    iotaOffset,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    rebase,
    shape,
    permutation,
    rank,
    isDirect,
    substituteInIxFun,
    substituteInLMAD,
    existentialize,
    closeEnough,
    equivalent,
    permuteInv,
    disjoint,
    disjoint2,
    disjoint3,
    dynamicEqualsLMAD,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.State
import Data.List (sort, zip4)
import Data.Map.Strict qualified as M
import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD hiding
  ( flatSlice,
    index,
    iota,
    mkExistential,
    permutation,
    permute,
    reshape,
    shape,
    slice,
  )
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( FlatSlice (..),
    Slice (..),
    unitSlice,
  )
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as an LMAD.
data IxFun num = IxFun
  { forall num. IxFun num -> LMAD num
ixfunLMAD :: LMAD num,
    -- | the shape of the support array, i.e., the original array
    --   that birthed (is the start point) of this index function.
    forall num. IxFun num -> Shape num
base :: Shape num
  }
  deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
showsPrec :: Int -> IxFun num -> ShowS
$cshow :: forall num. Show num => IxFun num -> String
show :: IxFun num -> String
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
showList :: [IxFun num] -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
/= :: IxFun num -> IxFun num -> Bool
Eq)

instance Pretty num => Pretty (IxFun num) where
  pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun LMAD num
lmad Shape num
oshp) =
    Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
braces (Doc ann -> Doc ann)
-> ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
semistack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"base:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
brackets ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (num -> Doc ann) -> Shape num -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
        Doc ann
"LMAD:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> LMAD num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. LMAD num -> Doc ann
pretty LMAD num
lmad
      ]

instance Substitute num => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> IxFun a -> IxFun b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn num => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall m a. Monoid m => (a -> m) -> IxFun a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance Functor IxFun where
  fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap = (a -> b) -> IxFun a -> IxFun b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable IxFun where
  foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap = (a -> m) -> IxFun a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

-- It is important that the traversal order here is the same as in
-- mkExistential.
instance Traversable IxFun where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun LMAD a
lmad Shape a
oshp) =
    LMAD b -> Shape b -> IxFun b
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD b -> Shape b -> IxFun b)
-> f (LMAD b) -> f (Shape b -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f LMAD a
lmad f (Shape b -> IxFun b) -> f (Shape b) -> f (IxFun b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse a -> f b
f Shape a
oshp

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
  Ord a =>
  M.Map a (TPrimExp t a) ->
  IxFun (TPrimExp t a) ->
  IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun LMAD (TPrimExp t a)
lmad Shape (TPrimExp t a)
oshp) =
  LMAD (TPrimExp t a) -> Shape (TPrimExp t a) -> IxFun (TPrimExp t a)
forall num. LMAD num -> Shape num -> IxFun num
IxFun
    (Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab LMAD (TPrimExp t a)
lmad)
    ((TPrimExp t a -> TPrimExp t a)
-> Shape (TPrimExp t a) -> Shape (TPrimExp t a)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp a -> TPrimExp t a
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
  where
    tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall a b. (a -> b) -> Map a a -> Map a b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect (IxFun (LMAD num
offset [LMADDim num]
dims) Shape num
oshp) =
  let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. HasCallStack => [a] -> [a]
tail Shape num
oshp))
   in Shape num -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        Bool -> Bool -> Bool
&& num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
        Bool -> Bool -> Bool
&& ((LMADDim num, Int, num, num) -> Bool)
-> [(LMADDim num, Int, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
          (\(LMADDim num
s num
n Int
p, Int
m, num
d, num
se) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m)
          ([LMADDim num]
-> [Int]
-> Shape num
-> Shape num
-> [(LMADDim num, Int, num, num)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)

-- | Does the index function have an ascending permutation?
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall a. IxFun a -> Bool
hasContiguousPerm (IxFun LMAD num
lmad Shape num
_) =
  let perm :: [Int]
perm = LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad
   in [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm

-- | The index space of the index function.  This is the same as the
-- shape of arrays that the index function supports.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun LMAD num
lmad Shape num
_) =
  [Int] -> Shape num -> Shape num
forall a. [Int] -> [a] -> [a]
permuteFwd (LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad) (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num
forall num. LMAD num -> Shape num
LMAD.shapeBase LMAD num
lmad

-- | The permutation of the first LMAD of the index function.
permutation :: IxFun num -> Permutation
permutation :: forall num. IxFun num -> [Int]
permutation = (LMADDim num -> Int) -> [LMADDim num] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
LMAD.ldPerm ([LMADDim num] -> [Int])
-> (IxFun num -> [LMADDim num]) -> IxFun num -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims (LMAD num -> [LMADDim num])
-> (IxFun num -> LMAD num) -> IxFun num -> [LMADDim num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index ::
  (IntegralExp num, Eq num) =>
  IxFun num ->
  Indices num ->
  num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index (LMAD num -> Indices num -> num)
-> (IxFun num -> LMAD num) -> IxFun num -> Indices num -> num
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD

-- | iota with offset.
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (num -> Shape num -> LMAD num
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota num
o Shape num
ns) Shape num
ns

-- | iota.
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = num -> Shape num -> IxFun num
forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0

-- | Create a single-LMAD index function that is
-- existential in everything, with the provided permutation.
mkExistential :: Int -> [Int] -> Int -> IxFun (Ext a)
mkExistential :: forall a. Int -> [Int] -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank [Int]
perm Int
start =
  LMAD (Ext a) -> Shape (Ext a) -> IxFun (Ext a)
forall num. LMAD num -> Shape num -> IxFun num
IxFun ([Int] -> Int -> LMAD (Ext a)
forall a. [Int] -> Int -> LMAD (Ext a)
LMAD.mkExistential [Int]
perm Int
start) Shape (Ext a)
forall {a}. [Ext a]
basis
  where
    basis :: [Ext a]
basis = Int -> [Ext a] -> [Ext a]
forall a. Int -> [a] -> [a]
take Int
basis_rank ([Ext a] -> [Ext a]) -> [Ext a] -> [Ext a]
forall a b. (a -> b) -> a -> b
$ (Int -> Ext a) -> [Int] -> [Ext a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Ext a
forall a. Int -> Ext a
Ext [Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dims_rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 ..]
    dims_rank :: Int
dims_rank = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm

-- | Permute dimensions.
permute ::
  IntegralExp num =>
  IxFun num ->
  Permutation ->
  IxFun num
permute :: forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
permute (IxFun LMAD num
lmad Shape num
oshp) [Int]
perm_new =
  LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> [Int] -> LMAD num
forall num. LMAD num -> [Int] -> LMAD num
LMAD.permute LMAD num
lmad [Int]
perm_new) Shape num
oshp

-- | Slice an index function.
slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) Shape num
oshp) (Slice [DimIndex num]
is)
  -- Avoid identity slicing.
  | [DimIndex num]
is [DimIndex num] -> [DimIndex num] -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Bool
otherwise =
      LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> Slice num -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD num
lmad ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is)) Shape num
oshp

-- | Flat-slice an index function.
flatSlice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  FlatSlice num ->
  Maybe (IxFun num)
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> Maybe (IxFun num)
flatSlice (IxFun LMAD num
lmad Shape num
oshp) FlatSlice num
s = do
  LMAD num
lmad' <- LMAD num -> FlatSlice num -> Maybe (LMAD num)
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> Maybe (LMAD num)
LMAD.flatSlice LMAD num
lmad FlatSlice num
s
  IxFun num -> Maybe (IxFun num)
forall a. a -> Maybe a
Just (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun LMAD num
lmad' Shape num
oshp

-- | Reshape an index function.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--
-- If any of these conditions do not hold, then the reshape operation
-- will conservatively add a new LMAD to the list, leading to a
-- representation that provides less opportunities for further
-- analysis
reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  Maybe (IxFun num)
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshape (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
  LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> Shape num -> IxFun num)
-> Maybe (LMAD num) -> Maybe (Shape num -> IxFun num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD num -> Shape num -> Maybe (LMAD num)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape LMAD num
lmad Shape num
new_shape Maybe (Shape num -> IxFun num)
-> Maybe (Shape num) -> Maybe (IxFun num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape num -> Maybe (Shape num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Shape num
new_shape

-- | Coerce an index function to look like it has a new shape.
-- Dynamically the shape must be the same.
coerce ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
  LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad) Shape num
new_shape
  where
    onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> num -> LMADDim num
forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
    onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}

-- | The number of dimensions in the domain of the input function.
rank :: IntegralExp num => IxFun num -> Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss) Shape num
_) = [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Essentially @rebase new_base ixfun = ixfun o new_base@
-- Core soundness condition: @base ixfun == shape new_base@
-- Handles the case where a rebase operation can stay within m + n - 1 LMADs,
-- where m is the number of LMADs in the index function, and n is the number of
-- LMADs in the new base.  If both index function have only on LMAD, this means
-- that we stay within the single-LMAD domain.
--
-- We can often stay in that domain if the original ixfun is essentially a
-- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`.
--
-- However, I strongly suspect that for in-place update what we need is actually
-- the INVERSE of the rebase function, i.e., given an index function new-base
-- and another one orig, compute the index function ixfun0 such that:
--
--   new-base == rebase ixfun0 ixfun, or equivalently:
--   new-base == ixfun o ixfun0
--
-- because then I can go bottom up and compose with ixfun0 all the index
-- functions corresponding to the memory block associated with ixfun.
rebase ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  Maybe (IxFun num)
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebase new_base :: IxFun num
new_base@(IxFun LMAD num
lmad_base Shape num
_) ixfun :: IxFun num
ixfun@(IxFun LMAD num
lmad Shape num
shp) = do
  let dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad
      perm :: [Int]
perm = LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad
      perm_base :: [Int]
perm_base = LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad_base

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
    -- Core rebase condition.
    IxFun num -> Shape num
forall num. IxFun num -> Shape num
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
      -- XXX: We should be able to handle some basic cases where both index
      -- functions have non-trivial permutations.
      Bool -> Bool -> Bool
&& (IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
new_base)
      -- We need the permutations to be of the same size if we want to compose
      -- them.  They don't have to be of the same size if the ixfun has a trivial
      -- permutation.  Supporting this latter case allows us to rebase when ixfun
      -- has been created by slicing with fixed dimensions.
      Bool -> Bool -> Bool
&& ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm_base Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun)
      -- To not have to worry about ixfun having non-1 strides, we also check that
      -- it is a row-major array (modulo permutation, which is handled
      -- separately).  Accept a non-full outermost dimension.  XXX: Maybe this can
      -- be less conservative?
      Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
        ( (num -> LMADDim num -> Bool -> Bool)
-> Shape num -> [LMADDim num] -> [Bool] -> [Bool]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
            (\num
sn LMADDim num
ld Bool
inner -> Bool
inner Bool -> Bool -> Bool
|| num
sn num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim num -> num
forall num. LMADDim num -> num
ldShape LMADDim num
ld)
            Shape num
shp
            [LMADDim num]
dims
            (Bool
True Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
replicate ([LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool
False)
        )

  -- Compose permutations, reverse strides and adjust offset if necessary.
  let perm_base' :: [Int]
perm_base' =
        if IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
          then [Int]
perm_base
          else (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Int]
perm !!) [Int]
perm_base
      lmad_base' :: LMAD num
lmad_base' = [Int] -> LMAD num -> LMAD num
forall num. [Int] -> LMAD num -> LMAD num
LMAD.setPermutation [Int]
perm_base' LMAD num
lmad_base
      dims_base :: [LMADDim num]
dims_base = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad_base'
      n_fewer_dims :: Int
n_fewer_dims = [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
      ([LMADDim num]
dims_base', Shape num
offs_contrib) =
        [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, num)] -> ([LMADDim num], Shape num))
-> [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. (a -> b) -> a -> b
$
          (LMADDim num -> LMADDim num -> (LMADDim num, num))
-> [LMADDim num] -> [LMADDim num] -> [(LMADDim num, num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            ( \(LMADDim num
s1 num
n1 Int
p1) (LMADDim {}) ->
                let (num
s', num
off') = (num
s1, num
0)
                 in (num -> num -> Int -> LMADDim num
forall num. num -> num -> Int -> LMADDim num
LMADDim num
s' num
n1 (Int
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_fewer_dims), num
off')
            )
            -- If @dims@ is morally a slice, it might have fewer dimensions than
            -- @dims_base@.  Drop extraneous outer dimensions.
            (Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
            [LMADDim num]
dims
      off_base :: num
off_base = LMAD num -> num
forall num. LMAD num -> num
LMAD.offset LMAD num
lmad_base' num -> num -> num
forall a. Num a => a -> a -> a
+ Shape num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
      lmad_base'' :: LMAD num
lmad_base'' =
        Shape num -> LMAD num -> LMAD num
forall num. Shape num -> LMAD num -> LMAD num
LMAD.setShape
          (LMAD num -> Shape num
forall num. LMAD num -> Shape num
LMAD.shape LMAD num
lmad)
          ( num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
              (num
off_base num -> num -> num
forall a. Num a => a -> a -> a
+ LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. HasCallStack => [a] -> a
last [LMADDim num]
dims_base) num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
LMAD.offset LMAD num
lmad)
              [LMADDim num]
dims_base'
          )
  IxFun num -> Maybe (IxFun num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun LMAD num
lmad_base'' Shape num
shp

-- | Turn all the leaves of the index function into 'Ext's.  We
--  require that there's only one LMAD, that the index function is
--  contiguous, and the base shape has only one dimension.
existentialize ::
  IxFun (TPrimExp Int64 a) ->
  IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = State Int (IxFun (TPrimExp Int64 (Ext b)))
-> Int -> IxFun (TPrimExp Int64 (Ext b))
forall s a. State s a -> s -> a
evalState ((TPrimExp Int64 a -> StateT Int Identity (TPrimExp Int64 (Ext b)))
-> IxFun (TPrimExp Int64 a)
-> State Int (IxFun (TPrimExp Int64 (Ext b)))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse (StateT Int Identity (TPrimExp Int64 (Ext b))
-> TPrimExp Int64 a -> StateT Int Identity (TPrimExp Int64 (Ext b))
forall a b. a -> b -> a
const StateT Int Identity (TPrimExp Int64 (Ext b))
forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
  where
    mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
      Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get
      Int -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int -> StateT Int Identity ()) -> Int -> StateT Int Identity ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
      TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a))
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a)))
-> TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a))
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext a) -> TPrimExp t (Ext a)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp (Ext a) -> TPrimExp t (Ext a))
-> PrimExp (Ext a) -> TPrimExp t (Ext a)
forall a b. (a -> b) -> a -> b
$ Ext a -> PrimType -> PrimExp (Ext a)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext a
forall a. Int -> Ext a
Ext Int
i) PrimType
int64

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  ([num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf2))
    Bool -> Bool -> Bool
&& LMAD num -> LMAD num -> Bool
forall {num} {num}. LMAD num -> LMAD num -> Bool
closeEnoughLMADs (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf1) (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf2)
  where
    closeEnoughLMADs :: LMAD num -> LMAD num -> Bool
closeEnoughLMADs LMAD num
lmad1 LMAD num
lmad2 =
      [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad2)
        Bool -> Bool -> Bool
&& (LMADDim num -> Int) -> [LMADDim num] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad1) [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim num -> Int) -> [LMADDim num] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad2)

-- | Returns true if two 'IxFun's are equivalent.
--
-- Equivalence in this case is defined as having the same number of LMADs, with
-- each pair of LMADs matching in permutation, offsets, and strides.
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: forall num. Eq num => IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
  LMAD num -> LMAD num -> Bool
forall {b}. Eq b => LMAD b -> LMAD b -> Bool
equivalentLMADs (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf1) (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf2)
  where
    equivalentLMADs :: LMAD b -> LMAD b -> Bool
equivalentLMADs LMAD b
lmad1 LMAD b
lmad2 =
      [LMADDim b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad2)
        Bool -> Bool -> Bool
&& (LMADDim b -> Int) -> [LMADDim b] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad1) [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> Int) -> [LMADDim b] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad2)
        Bool -> Bool -> Bool
&& LMAD b -> b
forall num. LMAD num -> num
LMAD.offset LMAD b
lmad1 b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD b -> b
forall num. LMAD num -> num
LMAD.offset LMAD b
lmad2
        Bool -> Bool -> Bool
&& (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad1) [b] -> [b] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad2)