-- | This module contains a representation of linear-memory accessor
-- descriptors (LMAD); see work by Zhu, Hoeflinger and David.
module Futhark.IR.Mem.LMAD
  ( Shape,
    LMAD (..),
    LMADDim (..),
    Monotonicity (..),
    Permutation,
    lmadShape,
    lmadShapeBase,
    substituteInLMAD,
    permuteInv,
    permuteFwd,
    conservativeFlatten,
    disjoint,
    disjoint2,
    disjoint3,
    dynamicEqualsLMAD,
    lmadPermutation,
    makeRotIota,
    invertMonotonicity,
  )
where

import Control.Category
import Control.Monad
import Data.Function (on, (&))
import Data.List (elemIndex, partition, sortBy, zipWith4)
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, isNothing)
import Data.Traversable
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Prop
import Futhark.IR.Syntax (Type)
import Futhark.IR.Syntax.Core (VName (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | The shape of an index function.
type Shape num = [num]

type Permutation = [Int]

-- | The physical element ordering alongside a dimension, i.e. the
-- sign of the stride.
data Monotonicity
  = -- | Increasing.
    Inc
  | -- | Decreasing.
    Dec
  | -- | Unknown.
    Unknown
  deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)

-- | A single dimension in an 'LMAD'.
data LMADDim num = LMADDim
  { forall num. LMADDim num -> num
ldStride :: num,
    forall num. LMADDim num -> num
ldShape :: num,
    forall num. LMADDim num -> Int
ldPerm :: Int,
    forall num. LMADDim num -> Monotonicity
ldMon :: Monotonicity
  }
  deriving (Int -> LMADDim num -> ShowS
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)

instance Ord Monotonicity where
  <= :: Monotonicity -> Monotonicity -> Bool
(<=) Monotonicity
_ Monotonicity
Inc = Bool
True
  (<=) Monotonicity
Unknown Monotonicity
_ = Bool
True
  (<=) Monotonicity
_ Monotonicity
Unknown = Bool
False
  (<=) Monotonicity
Inc Monotonicity
Dec = Bool
False
  (<=) Monotonicity
_ Monotonicity
Dec = Bool
True

instance Ord num => Ord (LMADDim num) where
  (LMADDim num
s1 num
q1 Int
p1 Monotonicity
m1) <= :: LMADDim num -> LMADDim num -> Bool
<= (LMADDim num
s2 num
q2 Int
p2 Monotonicity
m2) =
    ([num
q1, num
s1] forall a. Ord a => a -> a -> Bool
< [num
q2, num
s2])
      Bool -> Bool -> Bool
|| ( ([num
q1, num
s1] forall a. Eq a => a -> a -> Bool
== [num
q2, num
s2])
             Bool -> Bool -> Bool
&& ( (Int
p1 forall a. Ord a => a -> a -> Bool
< Int
p2)
                    Bool -> Bool -> Bool
|| ( (Int
p1 forall a. Eq a => a -> a -> Bool
== Int
p2)
                           Bool -> Bool -> Bool
&& (Monotonicity
m1 forall a. Ord a => a -> a -> Bool
<= Monotonicity
m2)
                       )
                )
         )

-- | LMAD's representation consists of a general offset and for each dimension a
-- stride, number of elements (or shape), permutation, and
-- monotonicity. Note that the permutation is not strictly necessary in that the
-- permutation can be performed directly on LMAD dimensions, but then it is
-- difficult to extract the permutation back from an LMAD.
--
-- LMAD algebra is closed under composition w.r.t. operators such as
-- permute, index and slice.  However, other operations, such as
-- reshape, cannot always be represented inside the LMAD algebra.
--
-- It follows that the general representation of an index function is a list of
-- LMADS, in which each following LMAD in the list implicitly corresponds to an
-- irregular reshaping operation.
--
-- However, we expect that the common case is when the index function is one
-- LMAD -- we call this the "nice" representation.
--
-- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of
-- the original array, and a bit to indicate whether the index function is
-- contiguous, i.e., if we instantiate all the points of the current index
-- function, do we get a contiguous memory interval?
--
-- By definition, the LMAD \( \sigma + \{ (n_1, s_1), \ldots, (n_k, s_k) \} \),
-- where \(n\) and \(s\) denote the shape and stride of each dimension, denotes
-- the set of points:
--
-- \[
--    \{ ~ \sigma + i_1 * s_1 + \ldots + i_m * s_m ~ | ~ 0 \leq i_1 < n_1, \ldots, 0 \leq i_m < n_m ~ \}
-- \]
data LMAD num = LMAD
  { forall num. LMAD num -> num
lmadOffset :: num,
    forall num. LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
  }
  deriving (Int -> LMAD num -> ShowS
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq, LMAD num -> LMAD num -> Bool
LMAD num -> LMAD num -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMAD num)
forall num. Ord num => LMAD num -> LMAD num -> Bool
forall num. Ord num => LMAD num -> LMAD num -> Ordering
forall num. Ord num => LMAD num -> LMAD num -> LMAD num
min :: LMAD num -> LMAD num -> LMAD num
$cmin :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
max :: LMAD num -> LMAD num -> LMAD num
$cmax :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
>= :: LMAD num -> LMAD num -> Bool
$c>= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
> :: LMAD num -> LMAD num -> Bool
$c> :: forall num. Ord num => LMAD num -> LMAD num -> Bool
<= :: LMAD num -> LMAD num -> Bool
$c<= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
< :: LMAD num -> LMAD num -> Bool
$c< :: forall num. Ord num => LMAD num -> LMAD num -> Bool
compare :: LMAD num -> LMAD num -> Ordering
$ccompare :: forall num. Ord num => LMAD num -> LMAD num -> Ordering
Ord)

instance Pretty Monotonicity where
  pretty :: forall ann. Monotonicity -> Doc ann
pretty = forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Show a => a -> String
show

instance Pretty num => Pretty (LMAD num) where
  pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
    forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"offset:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
group (forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
        Doc ann
"strides:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldStride,
        Doc ann
"shape:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldShape,
        Doc ann
"permutation:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Int
ldPerm,
        Doc ann
"monotonicity:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Monotonicity
ldMon
      ]
    where
      p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = forall ann. Doc ann -> Doc ann
group forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
brackets forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
align forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims

instance Substitute num => Substitute (LMAD num) where
  substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Rename (LMAD num) where
  rename :: LMAD num -> RenameM (LMAD num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn num => FreeIn (LMAD num) where
  freeIn' :: LMAD num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'

instance FreeIn num => FreeIn (LMADDim num) where
  freeIn' :: LMADDim num -> FV
freeIn' (LMADDim num
s num
n Int
_ Monotonicity
_) = forall a. FreeIn a => a -> FV
freeIn' num
s forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' num
n

instance Functor LMAD where
  fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable LMAD where
  foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable LMAD where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
    forall num. num -> [LMADDim num] -> LMAD num
LMAD forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
    where
      f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n Int
p Monotonicity
m) = forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m

invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown

lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: forall num. LMAD num -> Permutation
lmadPermutation = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims

-- | Substitute a name with a PrimExp in an LMAD.
substituteInLMAD ::
  Ord a =>
  M.Map a (TPrimExp t a) ->
  LMAD (TPrimExp t a) ->
  LMAD (TPrimExp t a)
substituteInLMAD :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab (LMAD TPrimExp t a
offset [LMADDim (TPrimExp t a)]
dims) =
  let offset' :: TPrimExp t a
offset' = forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
offset
      dims' :: [LMADDim (TPrimExp t a)]
dims' =
        forall a b. (a -> b) -> [a] -> [b]
map
          ( \(LMADDim TPrimExp t a
s TPrimExp t a
n Int
p Monotonicity
m) ->
              forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
                (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
s)
                (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
n)
                Int
p
                Monotonicity
m
          )
          [LMADDim (TPrimExp t a)]
dims
   in forall num. num -> [LMADDim num] -> LMAD num
LMAD forall {k} {t :: k}. TPrimExp t a
offset' forall {k} {t :: k}. [LMADDim (TPrimExp t a)]
dims'
  where
    tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
    sub :: TPrimExp t a -> TPrimExp t a
sub = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Shape of an LMAD.
lmadShape :: LMAD num -> Shape num
lmadShape :: forall a. LMAD a -> [a]
lmadShape LMAD num
lmad = forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall a. LMAD a -> [a]
lmadShapeBase LMAD num
lmad

-- | Shape of an LMAD, ignoring permutations.
lmadShapeBase :: LMAD num -> Shape num
lmadShapeBase :: forall a. LMAD a -> [a]
lmadShapeBase = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims

permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems !!) Permutation
ps

permuteInv :: Permutation -> [a] -> [a]
permuteInv :: forall a. Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems

-- | Generalised iota with user-specified offset and rotates.
makeRotIota ::
  IntegralExp num =>
  Monotonicity ->
  -- | Offset
  num ->
  -- | Shape
  [num] ->
  LMAD num
makeRotIota :: forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off [num]
ns
  | Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
      let rk :: Int
rk = forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
ns
          ss0 :: [num]
ss0 = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
rk forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [num]
ns
          ss :: [num]
ss =
            if Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
              then [num]
ss0
              else forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
          ps :: Permutation
ps = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk forall a. Num a => a -> a -> a
- Int
1]
          fi :: [Monotonicity]
fi = forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
       in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off forall a b. (a -> b) -> a -> b
$ forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
ns Permutation
ps [Monotonicity]
fi
  | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"

-- | Computes the maximum span of an 'LMAD'. The result is the lowest and
-- highest flat values representable by that 'LMAD'.
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan (LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)]
dims) =
  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    ( \LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
upper ->
        let spn :: TPrimExp Int64 VName
spn = forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
* (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
         in -- If you've gotten this far, you've already lost
            TPrimExp Int64 VName
spn forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
upper
    )
    TPrimExp Int64 VName
0
    [LMADDim (TPrimExp Int64 VName)]
dims

-- | Conservatively flatten a list of LMAD dimensions
--
-- Since not all LMADs can actually be flattened, we try to overestimate the
-- flattened array instead. This means that any "holes" in betwen dimensions
-- will get filled out.
-- conservativeFlatten :: (IntegralExp e, Ord e, Pretty e) => LMAD e -> LMAD e
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten (LMAD TPrimExp Int64 VName
offset []) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
1 TPrimExp Int64 VName
1 Int
0 Monotonicity
Inc]
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)
_]) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure LMAD (TPrimExp Int64 VName)
l
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims) = do
  TPrimExp Int64 VName
strd <-
    forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd
      (forall num. LMADDim num -> num
ldStride forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [LMADDim (TPrimExp Int64 VName)]
dims)
      forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim (TPrimExp Int64 VName)]
dims
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
strd (TPrimExp Int64 VName
shp forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) Int
0 Monotonicity
Unknown]
  where
    shp :: TPrimExp Int64 VName
shp = LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan LMAD (TPrimExp Int64 VName)
l

-- | Very conservative GCD calculation. Returns 'Nothing' if the result cannot
-- be immediately determined. Does not recurse at all.
gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd TPrimExp Int64 VName
x TPrimExp Int64 VName
y = forall {a}. (Eq a, Num a) => a -> a -> Maybe a
gcd' (forall a. Num a => a -> a
abs TPrimExp Int64 VName
x) (forall a. Num a => a -> a
abs TPrimExp Int64 VName
y)
  where
    gcd' :: a -> a -> Maybe a
gcd' a
a a
b | a
a forall a. Eq a => a -> a -> Bool
== a
b = forall a. a -> Maybe a
Just a
a
    gcd' a
1 a
_ = forall a. a -> Maybe a
Just a
1
    gcd' a
_ a
1 = forall a. a -> Maybe a
Just a
1
    gcd' a
a a
0 = forall a. a -> Maybe a
Just a
a
    gcd' a
_ a
_ = forall a. Maybe a
Nothing -- gcd' b (a `Futhark.Util.IntegralExp.rem` b)

-- | Returns @True@ if the two 'LMAD's could be proven disjoint.
--
-- Uses some best-approximation heuristics to determine disjointness. For two
-- 1-dimensional arrays, we can guarantee whether or not they are disjoint, but
-- as soon as more than one dimension is involved, things get more
-- tricky. Currently, we try to 'conservativelyFlatten' any LMAD with more than
-- one dimension.
disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint :: [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives (LMAD TPrimExp Int64 VName
offset1 [LMADDim (TPrimExp Int64 VName)
dim1]) (LMAD TPrimExp Int64 VName
offset2 [LMADDim (TPrimExp Int64 VName)
dim2]) =
  Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1) (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)) (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset2)
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset2 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)
      TPrimExp Int64 VName
offset1
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1)
      TPrimExp Int64 VName
offset2
  where
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (Just TPrimExp Int64 VName
x) TPrimExp Int64 VName
y =
      forall e. IntegralExp e => e -> e -> e
Futhark.Util.IntegralExp.mod TPrimExp Int64 VName
y TPrimExp Int64 VName
x
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
        forall a b. a -> (a -> b) -> b
& forall v. PrimExp v -> PrimExp v
constFoldPrimExp
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp
        forall a b. a -> (a -> b) -> b
& forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.==.) (TPrimExp Int64 VName
0 :: TPrimExp Int64 VName)
        forall a b. a -> (a -> b) -> b
& TPrimExp Bool VName -> Maybe Bool
primBool
        forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
not
    doesNotDivide Maybe (TPrimExp Int64 VName)
_ TPrimExp Int64 VName
_ = Bool
False
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  case (LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad1, LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad2) of
    (Just LMAD (TPrimExp Int64 VName)
lmad1', Just LMAD (TPrimExp Int64 VName)
lmad2') -> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1' LMAD (TPrimExp Int64 VName)
lmad2'
    (Maybe (LMAD (TPrimExp Int64 VName)),
 Maybe (LMAD (TPrimExp Int64 VName)))
_ -> Bool
False

disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint2 :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint2 scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      (SofP
neg_offset, SofP
pos_offset) =
        forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated forall a b. (a -> b) -> a -> b
$
          SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
      ([Interval]
interval1', [Interval]
interval2') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1 [Interval]
interval2
   in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
interval1',
             forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
interval2'
           ) of
        (Just [Interval]
interval1'', Just [Interval]
interval2'') ->
          forall a. Maybe a -> Bool
isNothing
            ( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval1''
            )
            Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isNothing
              ( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval2''
              )
            Bool -> Bool -> Bool
&& Bool -> Bool
not
              ( forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                  (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives))
                  (forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
interval1'' [Interval]
interval2'')
              )
        (Maybe [Interval], Maybe [Interval])
_ ->
          Bool
False

disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint3 :: Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint3 Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      interval1' :: [Interval]
interval1' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval1
      interval2' :: [Interval]
interval2' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval2
      ([Interval]
interval1'', [Interval]
interval2'') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1' [Interval]
interval2'
   in Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
4 [Interval]
interval1'' [Interval]
interval2'' forall a b. (a -> b) -> a -> b
$ SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
  where
    disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool
    disjointHelper :: Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
0 [Interval]
_ [Interval]
_ SofP
_ = Bool
False
    disjointHelper Int
i [Interval]
is10 [Interval]
is20 SofP
offset =
      let ([Interval]
is1, [Interval]
is2) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
              forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
                [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
is10 [Interval]
is20
          (SofP
neg_offset, SofP
pos_offset) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated SofP
offset
       in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
is1,
                 forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
is2
               ) of
            (Just [Interval]
is1', Just [Interval]
is2') -> do
              let overlap1 :: Maybe Interval
overlap1 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is1'
              let overlap2 :: Maybe Interval
overlap2 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is2'
              case (Maybe Interval
overlap1, Maybe Interval
overlap2) of
                (Maybe Interval
Nothing, Maybe Interval
Nothing) ->
                  case [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives of
                    Just Names
non_negatives' ->
                      Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
                        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                          (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives'))
                          (forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
is1 [Interval]
is2)
                    Maybe Names
_ -> Bool
False
                (Just Interval
overlapping_dim, Maybe Interval
_) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is1
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is1'
                   in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(SofP
new_offset, [Interval]
new_is1) -> Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
new_is1) ([Interval] -> [Interval]
joinDims [Interval]
is2') SofP
new_offset) [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
                (Maybe Interval
_, Just Interval
overlapping_dim) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is2
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is2'
                   in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                        ( \(SofP
new_offset, [Interval]
new_is2) ->
                            Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
is1') ([Interval] -> [Interval]
joinDims [Interval]
new_is2) forall a b. (a -> b) -> a -> b
$
                              forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
new_offset
                        )
                        [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
            (Maybe [Interval], Maybe [Interval])
_ -> Bool
False

joinDims :: [Interval] -> [Interval]
joinDims :: [Interval] -> [Interval]
joinDims = [Interval] -> [Interval] -> [Interval]
helper []
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = forall a. [a] -> [a]
reverse [Interval]
acc
    helper [Interval]
acc [Interval
x] = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)

mergeDims :: [Interval] -> [Interval]
mergeDims :: [Interval] -> [Interval]
mergeDims = [Interval] -> [Interval] -> [Interval]
helper [] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval]
acc
    helper [Interval]
acc [Interval
x] = Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)

splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])]
splitDim :: Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim0 [Interval]
is
  | [Prod
st] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0,
    [Prod
st1] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim,
    [Prod
spn] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim,
    Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0,
    Just Prod
big_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
spn Prod
st,
    Just Prod
small_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
st Prod
st1 =
      [ ( [],
          forall a. [a] -> [a]
init [Interval]
before
            forall a. Semigroup a => a -> a -> a
<> [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
big_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0),
                 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
small_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim)
               ]
            forall a. Semigroup a => a -> a -> a
<> [Interval]
after
        )
      ]
  | Bool
otherwise =
      let shrunk_dim :: Interval
shrunk_dim = Interval
overlapping_dim {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1}
          point_offset :: SofP
point_offset = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ (Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1 forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim
       in [ (SofP
point_offset, [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval]
after),
            ([], [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval
shrunk_dim] forall a. Semigroup a => a -> a -> a
<> [Interval]
after)
          ]
  where
    ([Interval]
before, Interval
overlapping_dim, [Interval]
after) =
      forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$
        forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Interval
overlapping_dim0 [Interval]
is
          forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [Interval]
is forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (forall a. Num a => a -> a -> a
+ Int
1))

lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval])
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset []) = (PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset, [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
1])
lmadToIntervals lmad :: LMAD (TPrimExp Int64 VName)
lmad@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims0) =
  (SofP
offset', forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Interval
helper forall a b. (a -> b) -> a -> b
$ forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD (TPrimExp Int64 VName)
lmad) [LMADDim (TPrimExp Int64 VName)]
dims0)
  where
    offset' :: SofP
offset' = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset

    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper (LMADDim TPrimExp Int64 VName
strd TPrimExp Int64 VName
shp Int
_ Monotonicity
_) = do
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
shp) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
strd)

-- | Dynamically determine if two 'LMADDim' are equal.
--
-- True if the dynamic values of their constituents are equal.
dynamicEqualsLMADDim :: Eq num => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
  forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2
    forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2
    forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim2)
    forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim2)

-- | Dynamically determine if two 'LMAD' are equal.
--
-- True if offset and constituent 'LMADDim' are equal.
dynamicEqualsLMAD :: Eq num => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
  forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad2
    forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      (forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
      forall v. TPrimExp Bool v
true
      (forall a b. [a] -> [b] -> [(a, b)]
zip (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad1) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad2))