module Futhark.IR.Mem.LMAD
( Shape,
LMAD (..),
LMADDim (..),
Monotonicity (..),
Permutation,
lmadShape,
lmadShapeBase,
substituteInLMAD,
permuteInv,
permuteFwd,
conservativeFlatten,
disjoint,
disjoint2,
disjoint3,
dynamicEqualsLMAD,
lmadPermutation,
makeRotIota,
invertMonotonicity,
)
where
import Control.Category
import Control.Monad
import Data.Function (on, (&))
import Data.List (elemIndex, partition, sortBy, zipWith4)
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, isNothing)
import Data.Traversable
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Prop
import Futhark.IR.Syntax (Type)
import Futhark.IR.Syntax.Core (VName (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))
type Shape num = [num]
type Permutation = [Int]
data Monotonicity
=
Inc
|
Dec
|
Unknown
deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)
data LMADDim num = LMADDim
{ forall num. LMADDim num -> num
ldStride :: num,
forall num. LMADDim num -> num
ldShape :: num,
forall num. LMADDim num -> Int
ldPerm :: Int,
forall num. LMADDim num -> Monotonicity
ldMon :: Monotonicity
}
deriving (Int -> LMADDim num -> ShowS
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)
instance Ord Monotonicity where
<= :: Monotonicity -> Monotonicity -> Bool
(<=) Monotonicity
_ Monotonicity
Inc = Bool
True
(<=) Monotonicity
Unknown Monotonicity
_ = Bool
True
(<=) Monotonicity
_ Monotonicity
Unknown = Bool
False
(<=) Monotonicity
Inc Monotonicity
Dec = Bool
False
(<=) Monotonicity
_ Monotonicity
Dec = Bool
True
instance Ord num => Ord (LMADDim num) where
(LMADDim num
s1 num
q1 Int
p1 Monotonicity
m1) <= :: LMADDim num -> LMADDim num -> Bool
<= (LMADDim num
s2 num
q2 Int
p2 Monotonicity
m2) =
([num
q1, num
s1] forall a. Ord a => a -> a -> Bool
< [num
q2, num
s2])
Bool -> Bool -> Bool
|| ( ([num
q1, num
s1] forall a. Eq a => a -> a -> Bool
== [num
q2, num
s2])
Bool -> Bool -> Bool
&& ( (Int
p1 forall a. Ord a => a -> a -> Bool
< Int
p2)
Bool -> Bool -> Bool
|| ( (Int
p1 forall a. Eq a => a -> a -> Bool
== Int
p2)
Bool -> Bool -> Bool
&& (Monotonicity
m1 forall a. Ord a => a -> a -> Bool
<= Monotonicity
m2)
)
)
)
data LMAD num = LMAD
{ forall num. LMAD num -> num
lmadOffset :: num,
forall num. LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
}
deriving (Int -> LMAD num -> ShowS
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq, LMAD num -> LMAD num -> Bool
LMAD num -> LMAD num -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMAD num)
forall num. Ord num => LMAD num -> LMAD num -> Bool
forall num. Ord num => LMAD num -> LMAD num -> Ordering
forall num. Ord num => LMAD num -> LMAD num -> LMAD num
min :: LMAD num -> LMAD num -> LMAD num
$cmin :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
max :: LMAD num -> LMAD num -> LMAD num
$cmax :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
>= :: LMAD num -> LMAD num -> Bool
$c>= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
> :: LMAD num -> LMAD num -> Bool
$c> :: forall num. Ord num => LMAD num -> LMAD num -> Bool
<= :: LMAD num -> LMAD num -> Bool
$c<= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
< :: LMAD num -> LMAD num -> Bool
$c< :: forall num. Ord num => LMAD num -> LMAD num -> Bool
compare :: LMAD num -> LMAD num -> Ordering
$ccompare :: forall num. Ord num => LMAD num -> LMAD num -> Ordering
Ord)
instance Pretty Monotonicity where
pretty :: forall ann. Monotonicity -> Doc ann
pretty = forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Show a => a -> String
show
instance Pretty num => Pretty (LMAD num) where
pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
[ Doc ann
"offset:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
group (forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
Doc ann
"strides:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldStride,
Doc ann
"shape:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldShape,
Doc ann
"permutation:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Int
ldPerm,
Doc ann
"monotonicity:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Monotonicity
ldMon
]
where
p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = forall ann. Doc ann -> Doc ann
group forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
brackets forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
align forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims
instance Substitute num => Substitute (LMAD num) where
substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Rename (LMAD num) where
rename :: LMAD num -> RenameM (LMAD num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn num => FreeIn (LMAD num) where
freeIn' :: LMAD num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'
instance FreeIn num => FreeIn (LMADDim num) where
freeIn' :: LMADDim num -> FV
freeIn' (LMADDim num
s num
n Int
_ Monotonicity
_) = forall a. FreeIn a => a -> FV
freeIn' num
s forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' num
n
instance Functor LMAD where
fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault
instance Foldable LMAD where
foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault
instance Traversable LMAD where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
where
f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n Int
p Monotonicity
m) = forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown
lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: forall num. LMAD num -> Permutation
lmadPermutation = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims
substituteInLMAD ::
Ord a =>
M.Map a (TPrimExp t a) ->
LMAD (TPrimExp t a) ->
LMAD (TPrimExp t a)
substituteInLMAD :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab (LMAD TPrimExp t a
offset [LMADDim (TPrimExp t a)]
dims) =
let offset' :: TPrimExp t a
offset' = forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
offset
dims' :: [LMADDim (TPrimExp t a)]
dims' =
forall a b. (a -> b) -> [a] -> [b]
map
( \(LMADDim TPrimExp t a
s TPrimExp t a
n Int
p Monotonicity
m) ->
forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
(forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
s)
(forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
n)
Int
p
Monotonicity
m
)
[LMADDim (TPrimExp t a)]
dims
in forall num. num -> [LMADDim num] -> LMAD num
LMAD forall {k} {t :: k}. TPrimExp t a
offset' forall {k} {t :: k}. [LMADDim (TPrimExp t a)]
dims'
where
tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
sub :: TPrimExp t a -> TPrimExp t a
sub = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
lmadShape :: LMAD num -> Shape num
lmadShape :: forall a. LMAD a -> [a]
lmadShape LMAD num
lmad = forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall a. LMAD a -> [a]
lmadShapeBase LMAD num
lmad
lmadShapeBase :: LMAD num -> Shape num
lmadShapeBase :: forall a. LMAD a -> [a]
lmadShapeBase = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems !!) Permutation
ps
permuteInv :: Permutation -> [a] -> [a]
permuteInv :: forall a. Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems
makeRotIota ::
IntegralExp num =>
Monotonicity ->
num ->
[num] ->
LMAD num
makeRotIota :: forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off [num]
ns
| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
let rk :: Int
rk = forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
ns
ss0 :: [num]
ss0 = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
rk forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [num]
ns
ss :: [num]
ss =
if Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
then [num]
ss0
else forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
ps :: Permutation
ps = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk forall a. Num a => a -> a -> a
- Int
1]
fi :: [Monotonicity]
fi = forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off forall a b. (a -> b) -> a -> b
$ forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
ns Permutation
ps [Monotonicity]
fi
| Bool
otherwise = forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan (LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)]
dims) =
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
( \LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
upper ->
let spn :: TPrimExp Int64 VName
spn = forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
* (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
in
TPrimExp Int64 VName
spn forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
upper
)
TPrimExp Int64 VName
0
[LMADDim (TPrimExp Int64 VName)]
dims
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten (LMAD TPrimExp Int64 VName
offset []) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
1 TPrimExp Int64 VName
1 Int
0 Monotonicity
Inc]
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)
_]) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure LMAD (TPrimExp Int64 VName)
l
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims) = do
TPrimExp Int64 VName
strd <-
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd
(forall num. LMADDim num -> num
ldStride forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [LMADDim (TPrimExp Int64 VName)]
dims)
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim (TPrimExp Int64 VName)]
dims
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
strd (TPrimExp Int64 VName
shp forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) Int
0 Monotonicity
Unknown]
where
shp :: TPrimExp Int64 VName
shp = LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan LMAD (TPrimExp Int64 VName)
l
gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd TPrimExp Int64 VName
x TPrimExp Int64 VName
y = forall {a}. (Eq a, Num a) => a -> a -> Maybe a
gcd' (forall a. Num a => a -> a
abs TPrimExp Int64 VName
x) (forall a. Num a => a -> a
abs TPrimExp Int64 VName
y)
where
gcd' :: a -> a -> Maybe a
gcd' a
a a
b | a
a forall a. Eq a => a -> a -> Bool
== a
b = forall a. a -> Maybe a
Just a
a
gcd' a
1 a
_ = forall a. a -> Maybe a
Just a
1
gcd' a
_ a
1 = forall a. a -> Maybe a
Just a
1
gcd' a
a a
0 = forall a. a -> Maybe a
Just a
a
gcd' a
_ a
_ = forall a. Maybe a
Nothing
disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint :: [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives (LMAD TPrimExp Int64 VName
offset1 [LMADDim (TPrimExp Int64 VName)
dim1]) (LMAD TPrimExp Int64 VName
offset2 [LMADDim (TPrimExp Int64 VName)
dim2]) =
Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1) (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)) (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset2)
Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
[(VName, PrimExp VName)]
less_thans
Names
non_negatives
(TPrimExp Int64 VName
offset2 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)
TPrimExp Int64 VName
offset1
Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
[(VName, PrimExp VName)]
less_thans
Names
non_negatives
(TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1)
TPrimExp Int64 VName
offset2
where
doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (Just TPrimExp Int64 VName
x) TPrimExp Int64 VName
y =
forall e. IntegralExp e => e -> e -> e
Futhark.Util.IntegralExp.mod TPrimExp Int64 VName
y TPrimExp Int64 VName
x
forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
forall a b. a -> (a -> b) -> b
& forall v. PrimExp v -> PrimExp v
constFoldPrimExp
forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp
forall a b. a -> (a -> b) -> b
& forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.==.) (TPrimExp Int64 VName
0 :: TPrimExp Int64 VName)
forall a b. a -> (a -> b) -> b
& TPrimExp Bool VName -> Maybe Bool
primBool
forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
not
doesNotDivide Maybe (TPrimExp Int64 VName)
_ TPrimExp Int64 VName
_ = Bool
False
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
case (LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad1, LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad2) of
(Just LMAD (TPrimExp Int64 VName)
lmad1', Just LMAD (TPrimExp Int64 VName)
lmad2') -> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1' LMAD (TPrimExp Int64 VName)
lmad2'
(Maybe (LMAD (TPrimExp Int64 VName)),
Maybe (LMAD (TPrimExp Int64 VName)))
_ -> Bool
False
disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint2 :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint2 scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
(SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
(SofP
neg_offset, SofP
pos_offset) =
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated forall a b. (a -> b) -> a -> b
$
SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
([Interval]
interval1', [Interval]
interval2') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
[Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1 [Interval]
interval2
in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
interval1',
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
interval2'
) of
(Just [Interval]
interval1'', Just [Interval]
interval2'') ->
forall a. Maybe a -> Bool
isNothing
( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval1''
)
Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isNothing
( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval2''
)
Bool -> Bool -> Bool
&& Bool -> Bool
not
( forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives))
(forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
interval1'' [Interval]
interval2'')
)
(Maybe [Interval], Maybe [Interval])
_ ->
Bool
False
disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint3 :: Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint3 Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
(SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
interval1' :: [Interval]
interval1' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval1
interval2' :: [Interval]
interval2' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval2
([Interval]
interval1'', [Interval]
interval2'') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
[Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1' [Interval]
interval2'
in Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
4 [Interval]
interval1'' [Interval]
interval2'' forall a b. (a -> b) -> a -> b
$ SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
where
disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool
disjointHelper :: Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
0 [Interval]
_ [Interval]
_ SofP
_ = Bool
False
disjointHelper Int
i [Interval]
is10 [Interval]
is20 SofP
offset =
let ([Interval]
is1, [Interval]
is2) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
[Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
is10 [Interval]
is20
(SofP
neg_offset, SofP
pos_offset) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated SofP
offset
in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
is1,
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
is2
) of
(Just [Interval]
is1', Just [Interval]
is2') -> do
let overlap1 :: Maybe Interval
overlap1 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is1'
let overlap2 :: Maybe Interval
overlap2 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is2'
case (Maybe Interval
overlap1, Maybe Interval
overlap2) of
(Maybe Interval
Nothing, Maybe Interval
Nothing) ->
case [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives of
Just Names
non_negatives' ->
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives'))
(forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
is1 [Interval]
is2)
Maybe Names
_ -> Bool
False
(Just Interval
overlapping_dim, Maybe Interval
_) ->
let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is1
splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is1'
in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(SofP
new_offset, [Interval]
new_is1) -> Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
new_is1) ([Interval] -> [Interval]
joinDims [Interval]
is2') SofP
new_offset) [(SofP, [Interval])]
splits
Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
(Maybe Interval
_, Just Interval
overlapping_dim) ->
let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is2
splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is2'
in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
( \(SofP
new_offset, [Interval]
new_is2) ->
Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
is1') ([Interval] -> [Interval]
joinDims [Interval]
new_is2) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
new_offset
)
[(SofP, [Interval])]
splits
Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
(Maybe [Interval], Maybe [Interval])
_ -> Bool
False
joinDims :: [Interval] -> [Interval]
joinDims :: [Interval] -> [Interval]
joinDims = [Interval] -> [Interval] -> [Interval]
helper []
where
helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = forall a. [a] -> [a]
reverse [Interval]
acc
helper [Interval]
acc [Interval
x] = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)
mergeDims :: [Interval] -> [Interval]
mergeDims :: [Interval] -> [Interval]
mergeDims = [Interval] -> [Interval] -> [Interval]
helper [] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse
where
helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval]
acc
helper [Interval]
acc [Interval
x] = Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)
splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])]
splitDim :: Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim0 [Interval]
is
| [Prod
st] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0,
[Prod
st1] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim,
[Prod
spn] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim,
Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0,
Just Prod
big_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
spn Prod
st,
Just Prod
small_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
st Prod
st1 =
[ ( [],
forall a. [a] -> [a]
init [Interval]
before
forall a. Semigroup a => a -> a -> a
<> [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
big_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0),
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
small_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim)
]
forall a. Semigroup a => a -> a -> a
<> [Interval]
after
)
]
| Bool
otherwise =
let shrunk_dim :: Interval
shrunk_dim = Interval
overlapping_dim {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1}
point_offset :: SofP
point_offset = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ (Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1 forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim
in [ (SofP
point_offset, [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval]
after),
([], [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval
shrunk_dim] forall a. Semigroup a => a -> a -> a
<> [Interval]
after)
]
where
([Interval]
before, Interval
overlapping_dim, [Interval]
after) =
forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Interval
overlapping_dim0 [Interval]
is
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [Interval]
is forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (forall a. Num a => a -> a -> a
+ Int
1))
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval])
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset []) = (PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset, [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
1])
lmadToIntervals lmad :: LMAD (TPrimExp Int64 VName)
lmad@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims0) =
(SofP
offset', forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Interval
helper forall a b. (a -> b) -> a -> b
$ forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD (TPrimExp Int64 VName)
lmad) [LMADDim (TPrimExp Int64 VName)]
dims0)
where
offset' :: SofP
offset' = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset
helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper (LMADDim TPrimExp Int64 VName
strd TPrimExp Int64 VName
shp Int
_ Monotonicity
_) = do
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
shp) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
strd)
dynamicEqualsLMADDim :: Eq num => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim2)
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim2)
dynamicEqualsLMAD :: Eq num => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad2
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
(forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
forall v. TPrimExp Bool v
true
(forall a b. [a] -> [b] -> [(a, b)]
zip (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad1) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad2))