{-# LANGUAGE OverloadedStrings #-}
module Futhark.IR.Mem.Interval
( Interval (..),
distributeOffset,
expandOffset,
intervalOverlap,
selfOverlap,
primBool,
intervalPairs,
justLeafExp,
)
where
import Data.Function (on)
import Data.List (maximumBy, minimumBy, (\\))
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop
import Futhark.IR.Syntax hiding (Result)
import Futhark.Util
data Interval = Interval
{ Interval -> TPrimExp Int64 VName
lowerBound :: TPrimExp Int64 VName,
Interval -> TPrimExp Int64 VName
numElements :: TPrimExp Int64 VName,
Interval -> TPrimExp Int64 VName
stride :: TPrimExp Int64 VName
}
deriving (Int -> Interval -> ShowS
[Interval] -> ShowS
Interval -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Interval] -> ShowS
$cshowList :: [Interval] -> ShowS
show :: Interval -> String
$cshow :: Interval -> String
showsPrec :: Int -> Interval -> ShowS
$cshowsPrec :: Int -> Interval -> ShowS
Show, Interval -> Interval -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Interval -> Interval -> Bool
$c/= :: Interval -> Interval -> Bool
== :: Interval -> Interval -> Bool
$c== :: Interval -> Interval -> Bool
Eq)
instance FreeIn Interval where
freeIn' :: Interval -> FV
freeIn' (Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st) = forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
lb forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
ne forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
st
distributeOffset :: (MonadFail m) => AlgSimplify.SofP -> [Interval] -> m [Interval]
distributeOffset :: forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset [] [Interval]
interval = forall (f :: * -> *) a. Applicative f => a -> f a
pure [Interval]
interval
distributeOffset [Prod]
offset [] = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Cannot distribute offset " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Prod]
offset forall a. Semigroup a => a -> a -> a
<> String
" across empty interval"
distributeOffset [Prod]
offset [Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
1] = forall (f :: * -> *) a. Applicative f => a -> f a
pure [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod]
offset)) TPrimExp Int64 VName
ne TPrimExp Int64 VName
1]
distributeOffset [Prod]
offset (Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 : [Interval]
is)
| Prod
st <- Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
False [forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
st0],
Just ([Prod]
before, Prod
quotient, [Prod]
after) <- forall a b. (a -> Maybe b) -> [a] -> Maybe ([a], b, [a])
focusMaybe (Prod -> Prod -> Maybe Prod
`AlgSimplify.maybeDivide` Prod
st) [Prod]
offset =
forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset ([Prod]
before forall a. Semigroup a => a -> a -> a
<> [Prod]
after) forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod
quotient])) TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 forall a. a -> [a] -> [a]
: [Interval]
is
| [Prod
st] <- PrimExp VName -> [Prod]
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
st0,
Just ([Prod]
before, Prod
quotient, [Prod]
after) <- forall a b. (a -> Maybe b) -> [a] -> Maybe ([a], b, [a])
focusMaybe (Prod -> Prod -> Maybe Prod
`AlgSimplify.maybeDivide` Prod
st) [Prod]
offset =
forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset ([Prod]
before forall a. Semigroup a => a -> a -> a
<> [Prod]
after) forall a b. (a -> b) -> a -> b
$
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod
quotient])) TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 forall a. a -> [a] -> [a]
: [Interval]
is
| Bool
otherwise = do
[Interval]
rest <- forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset [Prod]
offset [Interval]
is
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 forall a. a -> [a] -> [a]
: [Interval]
rest
findMostComplexTerm :: AlgSimplify.SofP -> (AlgSimplify.Prod, AlgSimplify.SofP)
findMostComplexTerm :: [Prod] -> (Prod, [Prod])
findMostComplexTerm [Prod]
prods =
let max_prod :: Prod
max_prod = forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> [PrimExp VName]
AlgSimplify.atoms)) [Prod]
prods
in (Prod
max_prod, [Prod]
prods forall a. Eq a => [a] -> [a] -> [a]
\\ [Prod
max_prod])
findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride [PrimExp VName]
offset_term [Interval]
is =
let strides :: [PrimExp VName]
strides = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall b c a. (b -> c) -> (a -> b) -> a -> c
. Interval -> TPrimExp Int64 VName
stride) [Interval]
is
p :: PrimExp VName
p =
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy
( forall a. Ord a => a -> a -> Ordering
compare
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ( Prod -> Int
termDifferenceLength
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` \Prod
s -> forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term forall a. Eq a => [a] -> [a] -> [a]
\\ Prod -> [PrimExp VName]
AlgSimplify.atoms Prod
s))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> [Prod]
AlgSimplify.simplify0
)
)
[PrimExp VName]
strides
in ( PrimExp VName
p,
([PrimExp VName]
offset_term \\) forall a b. (a -> b) -> a -> b
$
Prod -> [PrimExp VName]
AlgSimplify.atoms forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` \Prod
s -> forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term forall a. Eq a => [a] -> [a] -> [a]
\\ Prod -> [PrimExp VName]
AlgSimplify.atoms Prod
s)) forall a b. (a -> b) -> a -> b
$
PrimExp VName -> [Prod]
AlgSimplify.simplify0 PrimExp VName
p
)
where
termDifferenceLength :: Prod -> Int
termDifferenceLength (AlgSimplify.Prod Bool
_ [PrimExp VName]
xs) = forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term forall a. Eq a => [a] -> [a] -> [a]
\\ [PrimExp VName]
xs)
expandOffset :: AlgSimplify.SofP -> [Interval] -> Maybe AlgSimplify.SofP
expandOffset :: [Prod] -> [Interval] -> Maybe [Prod]
expandOffset [] [Interval]
_ = forall a. Maybe a
Nothing
expandOffset [Prod]
offset [Interval]
i1
| (AlgSimplify.Prod Bool
b [PrimExp VName]
term_to_add, [Prod]
offset_rest) <- [Prod] -> (Prod, [Prod])
findMostComplexTerm [Prod]
offset,
(PrimExp VName
closest_stride, [PrimExp VName]
first_term_divisor) <- [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride [PrimExp VName]
term_to_add [Interval]
i1,
[Prod]
target <- [Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
b forall a b. (a -> b) -> a -> b
$ PrimExp VName
closest_stride forall a. a -> [a] -> [a]
: [PrimExp VName]
first_term_divisor],
[Prod]
diff <- PrimExp VName -> [Prod]
AlgSimplify.sumOfProducts forall a b. (a -> b) -> a -> b
$ [Prod] -> PrimExp VName
AlgSimplify.sumToExp forall a b. (a -> b) -> a -> b
$ Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
b [PrimExp VName]
term_to_add forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate [Prod]
target,
[Prod]
replacement <- [Prod]
target forall a. Semigroup a => a -> a -> a
<> [Prod]
diff
=
forall a. a -> Maybe a
Just ([Prod]
replacement forall a. Semigroup a => a -> a -> a
<> [Prod]
offset_rest)
intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives (Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
ne1 TPrimExp Int64 VName
st1) (Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
ne2 TPrimExp Int64 VName
st2)
| TPrimExp Int64 VName
st1 forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2,
[(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
lb2,
[(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName
lb1 forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ne1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
lb2 =
Bool
False
| TPrimExp Int64 VName
st1 forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2,
[(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
lb1,
[(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName
lb2 forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ne2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
lb1 =
Bool
False
| Bool
otherwise = Bool
True
primBool :: TPrimExp Bool VName -> Maybe Bool
primBool :: TPrimExp Bool VName -> Maybe Bool
primBool TPrimExp Bool VName
p
| Just (BoolValue Bool
b) <- forall v (m :: * -> *).
(Pretty v, MonadFail m) =>
(v -> m PrimValue) -> PrimExp v -> m PrimValue
evalPrimExp (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Bool VName
p = forall a. a -> Maybe a
Just Bool
b
| Bool
otherwise = forall a. Maybe a
Nothing
intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' []
where
intervalPairs' :: [(Interval, Interval)] -> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' :: [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' [(Interval, Interval)]
acc [] [] = forall a. [a] -> [a]
reverse [(Interval, Interval)]
acc
intervalPairs' [(Interval, Interval)]
acc (i :: Interval
i@(Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
_ TPrimExp Int64 VName
st) : [Interval]
is) [] = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
1 TPrimExp Int64 VName
st) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is []
intervalPairs' [(Interval, Interval)]
acc [] (i :: Interval
i@(Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
_ TPrimExp Int64 VName
st) : [Interval]
is) = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
1 TPrimExp Int64 VName
st, Interval
i) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [] [Interval]
is
intervalPairs' [(Interval, Interval)]
acc (i1 :: Interval
i1@(Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
_ TPrimExp Int64 VName
st1) : [Interval]
is1) (i2 :: Interval
i2@(Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
_ TPrimExp Int64 VName
st2) : [Interval]
is2)
| TPrimExp Int64 VName
st1 forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i1, Interval
i2) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is1 [Interval]
is2
| Bool
otherwise =
let res1 :: [(Interval, Interval)]
res1 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i1, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
1 TPrimExp Int64 VName
st1) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is1 (Interval
i2 forall a. a -> [a] -> [a]
: [Interval]
is2)
res2 :: [(Interval, Interval)]
res2 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
1 TPrimExp Int64 VName
st2, Interval
i2) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) (Interval
i1 forall a. a -> [a] -> [a]
: [Interval]
is1) [Interval]
is2
in if forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Interval, Interval)]
res1 forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Interval, Interval)]
res2
then [(Interval, Interval)]
res1
else [(Interval, Interval)]
res2
selfOverlap :: scope -> asserts -> [(VName, PrimExp VName)] -> [PrimExp VName] -> [Interval] -> Maybe Interval
selfOverlap :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ [Interval
_] = forall a. Maybe a
Nothing
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives' [Interval]
is
| Just Names
non_negatives <- [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives' =
let selfOverlap' :: TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' TPrimExp Int64 VName
acc (Interval
x : [Interval]
xs) =
let interval_span :: TPrimExp Int64 VName
interval_span = (Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
x
res :: Bool
res = [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
acc) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
x)
in if Bool
res then TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' (TPrimExp Int64 VName
acc forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
interval_span) [Interval]
xs else forall a. a -> Maybe a
Just Interval
x
selfOverlap' TPrimExp Int64 VName
_ [] = forall a. Maybe a
Nothing
in TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' TPrimExp Int64 VName
0 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Interval]
is
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ (Interval
x : [Interval]
_) = forall a. a -> Maybe a
Just Interval
x
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ [] = forall a. Maybe a
Nothing
justLeafExp :: PrimExp VName -> Maybe VName
justLeafExp :: PrimExp VName -> Maybe VName
justLeafExp (LeafExp VName
v PrimType
_) = forall a. a -> Maybe a
Just VName
v
justLeafExp PrimExp VName
_ = forall a. Maybe a
Nothing