module Futhark.Analysis.AlgSimplify
  ( Prod (..),
    SofP,
    simplify0,
    simplify,
    simplify',
    simplifySofP,
    simplifySofP',
    sumOfProducts,
    sumToExp,
    prodToExp,
    add,
    sub,
    negate,
    isMultipleOf,
    maybeDivide,
    removeLessThans,
    lessThanish,
    compareComplexity,
  )
where

import Data.Bits (xor)
import Data.Function ((&))
import Data.List (findIndex, intersect, partition, sort, (\\))
import Data.Maybe (mapMaybe)
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop.Names
import Futhark.IR.Syntax.Core (SubExp (..), VName)
import Futhark.Util
import Futhark.Util.Pretty
import Prelude hiding (negate)

type Exp = PrimExp VName

type TExp = TPrimExp Int64 VName

data Prod = Prod
  { Prod -> Bool
negated :: Bool,
    Prod -> [Exp]
atoms :: [Exp]
  }
  deriving (Int -> Prod -> ShowS
SofP -> ShowS
Prod -> String
(Int -> Prod -> ShowS)
-> (Prod -> String) -> (SofP -> ShowS) -> Show Prod
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Prod -> ShowS
showsPrec :: Int -> Prod -> ShowS
$cshow :: Prod -> String
show :: Prod -> String
$cshowList :: SofP -> ShowS
showList :: SofP -> ShowS
Show, Prod -> Prod -> Bool
(Prod -> Prod -> Bool) -> (Prod -> Prod -> Bool) -> Eq Prod
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Prod -> Prod -> Bool
== :: Prod -> Prod -> Bool
$c/= :: Prod -> Prod -> Bool
/= :: Prod -> Prod -> Bool
Eq, Eq Prod
Eq Prod =>
(Prod -> Prod -> Ordering)
-> (Prod -> Prod -> Bool)
-> (Prod -> Prod -> Bool)
-> (Prod -> Prod -> Bool)
-> (Prod -> Prod -> Bool)
-> (Prod -> Prod -> Prod)
-> (Prod -> Prod -> Prod)
-> Ord Prod
Prod -> Prod -> Bool
Prod -> Prod -> Ordering
Prod -> Prod -> Prod
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Prod -> Prod -> Ordering
compare :: Prod -> Prod -> Ordering
$c< :: Prod -> Prod -> Bool
< :: Prod -> Prod -> Bool
$c<= :: Prod -> Prod -> Bool
<= :: Prod -> Prod -> Bool
$c> :: Prod -> Prod -> Bool
> :: Prod -> Prod -> Bool
$c>= :: Prod -> Prod -> Bool
>= :: Prod -> Prod -> Bool
$cmax :: Prod -> Prod -> Prod
max :: Prod -> Prod -> Prod
$cmin :: Prod -> Prod -> Prod
min :: Prod -> Prod -> Prod
Ord)

type SofP = [Prod]

sumOfProducts :: Exp -> SofP
sumOfProducts :: Exp -> SofP
sumOfProducts = (Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
sortProduct (SofP -> SofP) -> (Exp -> SofP) -> Exp -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
sumOfProducts'

sortProduct :: Prod -> Prod
sortProduct :: Prod -> Prod
sortProduct (Prod Bool
n [Exp]
as) = Bool -> [Exp] -> Prod
Prod Bool
n ([Exp] -> Prod) -> [Exp] -> Prod
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. Ord a => [a] -> [a]
sort [Exp]
as

sumOfProducts' :: Exp -> SofP
sumOfProducts' :: Exp -> SofP
sumOfProducts' (BinOpExp (Add IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 SofP -> SofP -> SofP
forall a. Semigroup a => a -> a -> a
<> Exp -> SofP
sumOfProducts' Exp
e2
sumOfProducts' (BinOpExp (Sub IntType
Int64 Overflow
_) (ValueExp (IntValue (Int64Value Int64
0))) Exp
e) =
  (Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate (SofP -> SofP) -> SofP -> SofP
forall a b. (a -> b) -> a -> b
$ Exp -> SofP
sumOfProducts' Exp
e
sumOfProducts' (BinOpExp (Sub IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 SofP -> SofP -> SofP
forall a. Semigroup a => a -> a -> a
<> (Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate (Exp -> SofP
sumOfProducts' Exp
e2)
sumOfProducts' (BinOpExp (Mul IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 SofP -> SofP -> SofP
`mult` Exp -> SofP
sumOfProducts' Exp
e2
sumOfProducts' (ValueExp (IntValue (Int64Value Int64
i))) =
  [Bool -> [Exp] -> Prod
Prod (Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0) [PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (Int64 -> IntValue) -> Int64 -> IntValue
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a. Num a => a -> a
abs Int64
i]]
sumOfProducts' Exp
e = [Bool -> [Exp] -> Prod
Prod Bool
False [Exp
e]]

mult :: SofP -> SofP -> SofP
mult :: SofP -> SofP -> SofP
mult SofP
xs SofP
ys = [Bool -> [Exp] -> Prod
Prod (Bool
b Bool -> Bool -> Bool
forall a. Bits a => a -> a -> a
`xor` Bool
b') ([Exp]
x [Exp] -> [Exp] -> [Exp]
forall a. Semigroup a => a -> a -> a
<> [Exp]
y) | Prod Bool
b [Exp]
x <- SofP
xs, Prod Bool
b' [Exp]
y <- SofP
ys]

negate :: Prod -> Prod
negate :: Prod -> Prod
negate Prod
p = Prod
p {negated = not $ negated p}

sumToExp :: SofP -> Exp
sumToExp :: SofP -> Exp
sumToExp [] = Int64 -> Exp
val Int64
0
sumToExp [Prod
x] = Prod -> Exp
prodToExp Prod
x
sumToExp (Prod
x : SofP
xs) =
  (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> Exp -> Exp -> Exp) -> BinOp -> Exp -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (Prod -> Exp
prodToExp Prod
x) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
    (Prod -> Exp) -> SofP -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Exp
prodToExp SofP
xs

prodToExp :: Prod -> Exp
prodToExp :: Prod -> Exp
prodToExp (Prod Bool
_ []) = Int64 -> Exp
val Int64
1
prodToExp (Prod Bool
True [ValueExp (IntValue (Int64Value Int64
i))]) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (-Int64
i)
prodToExp (Prod Bool
True [Exp]
as) =
  (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> Exp -> Exp -> Exp) -> BinOp -> Exp -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (Int64 -> Exp
val (-Int64
1)) [Exp]
as
prodToExp (Prod Bool
False (Exp
a : [Exp]
as)) =
  (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> Exp -> Exp -> Exp) -> BinOp -> Exp -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) Exp
a [Exp]
as

simplifySofP :: SofP -> SofP
simplifySofP :: SofP -> SofP
simplifySofP =
  -- TODO: Maybe 'constFoldValueExps' is not necessary after adding scaleConsts
  (SofP -> SofP) -> SofP -> SofP
forall a. Eq a => (a -> a) -> a -> a
fixPoint ((Prod -> Maybe Prod) -> SofP -> SofP
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Prod -> Maybe Prod
applyZero (Prod -> Maybe Prod) -> (Prod -> Prod) -> Prod -> Maybe Prod
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> Prod
removeOnes) (SofP -> SofP) -> (SofP -> SofP) -> SofP -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
scaleConsts (SofP -> SofP) -> (SofP -> SofP) -> SofP -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
constFoldValueExps (SofP -> SofP) -> (SofP -> SofP) -> SofP -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
removeNegations)

simplifySofP' :: SofP -> SofP
simplifySofP' :: SofP -> SofP
simplifySofP' = (SofP -> SofP) -> SofP -> SofP
forall a. Eq a => (a -> a) -> a -> a
fixPoint ((Prod -> Maybe Prod) -> SofP -> SofP
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Prod -> Maybe Prod
applyZero (Prod -> Maybe Prod) -> (Prod -> Prod) -> Prod -> Maybe Prod
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> Prod
removeOnes) (SofP -> SofP) -> (SofP -> SofP) -> SofP -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
scaleConsts (SofP -> SofP) -> (SofP -> SofP) -> SofP -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
removeNegations)

simplify0 :: Exp -> SofP
simplify0 :: Exp -> SofP
simplify0 = SofP -> SofP
simplifySofP (SofP -> SofP) -> (Exp -> SofP) -> Exp -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
sumOfProducts

simplify :: Exp -> Exp
simplify :: Exp -> Exp
simplify = Exp -> Exp
forall v. PrimExp v -> PrimExp v
constFoldPrimExp (Exp -> Exp) -> (Exp -> Exp) -> Exp -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> Exp
sumToExp (SofP -> Exp) -> (Exp -> SofP) -> Exp -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
simplify0

simplify' :: TExp -> TExp
simplify' :: TPrimExp Int64 VName -> TPrimExp Int64 VName
simplify' = Exp -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TPrimExp Int64 VName)
-> (TPrimExp Int64 VName -> Exp)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
simplify (Exp -> Exp)
-> (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

applyZero :: Prod -> Maybe Prod
applyZero :: Prod -> Maybe Prod
applyZero p :: Prod
p@(Prod Bool
_ [Exp]
as)
  | Int64 -> Exp
val Int64
0 Exp -> [Exp] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Exp]
as = Maybe Prod
forall a. Maybe a
Nothing
  | Bool
otherwise = Prod -> Maybe Prod
forall a. a -> Maybe a
Just Prod
p

removeOnes :: Prod -> Prod
removeOnes :: Prod -> Prod
removeOnes (Prod Bool
neg [Exp]
as) =
  let as' :: [Exp]
as' = (Exp -> Bool) -> [Exp] -> [Exp]
forall a. (a -> Bool) -> [a] -> [a]
filter (Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64 -> Exp
val Int64
1) [Exp]
as
   in Bool -> [Exp] -> Prod
Prod Bool
neg ([Exp] -> Prod) -> [Exp] -> Prod
forall a b. (a -> b) -> a -> b
$ if [Exp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
as' then [PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1] else [Exp]
as'

removeNegations :: SofP -> SofP
removeNegations :: SofP -> SofP
removeNegations [] = []
removeNegations (Prod
t : SofP
ts) =
  case (Prod -> Bool) -> SofP -> (SofP, SofP)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break (Prod -> Prod -> Bool
forall a. Eq a => a -> a -> Bool
== Prod -> Prod
negate Prod
t) SofP
ts of
    (SofP
start, Prod
_ : SofP
rest) -> SofP -> SofP
removeNegations (SofP -> SofP) -> SofP -> SofP
forall a b. (a -> b) -> a -> b
$ SofP
start SofP -> SofP -> SofP
forall a. Semigroup a => a -> a -> a
<> SofP
rest
    (SofP, SofP)
_ -> Prod
t Prod -> SofP -> SofP
forall a. a -> [a] -> [a]
: SofP -> SofP
removeNegations SofP
ts

constFoldValueExps :: SofP -> SofP
constFoldValueExps :: SofP -> SofP
constFoldValueExps SofP
prods =
  let (SofP
value_exps, SofP
others) = (Prod -> Bool) -> SofP -> (SofP, SofP)
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Exp -> Bool) -> [Exp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Exp -> Bool
isPrimValue ([Exp] -> Bool) -> (Prod -> [Exp]) -> Prod -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> [Exp]
atoms) SofP
prods
      value_exps' :: SofP
value_exps' = Exp -> SofP
sumOfProducts (Exp -> SofP) -> Exp -> SofP
forall a b. (a -> b) -> a -> b
$ Exp -> Exp
forall v. PrimExp v -> PrimExp v
constFoldPrimExp (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ SofP -> Exp
sumToExp SofP
value_exps
   in SofP
value_exps' SofP -> SofP -> SofP
forall a. Semigroup a => a -> a -> a
<> SofP
others

intFromExp :: Exp -> Maybe Int64
intFromExp :: Exp -> Maybe Int64
intFromExp (ValueExp (IntValue IntValue
x)) = Int64 -> Maybe Int64
forall a. a -> Maybe a
Just (Int64 -> Maybe Int64) -> Int64 -> Maybe Int64
forall a b. (a -> b) -> a -> b
$ IntValue -> Int64
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x
intFromExp Exp
_ = Maybe Int64
forall a. Maybe a
Nothing

-- | Given @-[2, x]@ returns @(-2, [x])@
prodToScale :: Prod -> (Int64, [Exp])
prodToScale :: Prod -> (Int64, [Exp])
prodToScale (Prod Bool
b [Exp]
exps) =
  let ([Int64]
scalars, [Exp]
exps') = (Exp -> Maybe Int64) -> [Exp] -> ([Int64], [Exp])
forall a b. (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe Exp -> Maybe Int64
intFromExp [Exp]
exps
   in if Bool
b
        then (-([Int64] -> Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
scalars), [Exp]
exps')
        else ([Int64] -> Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
scalars, [Exp]
exps')

-- | Given @(-2, [x])@ returns @-[1, 2, x]@
scaleToProd :: (Int64, [Exp]) -> Prod
scaleToProd :: (Int64, [Exp]) -> Prod
scaleToProd (Int64
i, [Exp]
exps) =
  Bool -> [Exp] -> Prod
Prod (Int64
i Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0) ([Exp] -> Prod) -> [Exp] -> Prod
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (Int64 -> IntValue) -> Int64 -> IntValue
forall a b. (a -> b) -> a -> b
$ Int64 -> Int64
forall a. Num a => a -> a
abs Int64
i) Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
exps

-- | Given @[[2, x], -[x]]@ returns @[[x]]@
scaleConsts :: SofP -> SofP
scaleConsts :: SofP -> SofP
scaleConsts =
  SofP -> [(Int64, [Exp])] -> SofP
helper [] ([(Int64, [Exp])] -> SofP)
-> (SofP -> [(Int64, [Exp])]) -> SofP -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Prod -> (Int64, [Exp])) -> SofP -> [(Int64, [Exp])]
forall a b. (a -> b) -> [a] -> [b]
map Prod -> (Int64, [Exp])
prodToScale
  where
    helper :: [Prod] -> [(Int64, [Exp])] -> [Prod]
    helper :: SofP -> [(Int64, [Exp])] -> SofP
helper SofP
acc [] = SofP -> SofP
forall a. [a] -> [a]
reverse SofP
acc
    helper SofP
acc ((Int64
scale, [Exp]
exps) : [(Int64, [Exp])]
rest) =
      case (Int
 -> [(Int64, [Exp])]
 -> Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])]))
-> [(Int64, [Exp])]
-> Int
-> Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])])
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int
-> [(Int64, [Exp])]
-> Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [(Int64, [Exp])]
rest (Int -> Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])]))
-> Maybe Int
-> Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((Int64, [Exp]) -> Bool) -> [(Int64, [Exp])] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex ([Exp] -> [Exp] -> Bool
forall a. Eq a => a -> a -> Bool
(==) [Exp]
exps ([Exp] -> Bool)
-> ((Int64, [Exp]) -> [Exp]) -> (Int64, [Exp]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int64, [Exp]) -> [Exp]
forall a b. (a, b) -> b
snd) [(Int64, [Exp])]
rest of
        Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])])
Nothing -> SofP -> [(Int64, [Exp])] -> SofP
helper ((Int64, [Exp]) -> Prod
scaleToProd (Int64
scale, [Exp]
exps) Prod -> SofP -> SofP
forall a. a -> [a] -> [a]
: SofP
acc) [(Int64, [Exp])]
rest
        Just ([(Int64, [Exp])]
before, (Int64
scale', [Exp]
_), [(Int64, [Exp])]
after) ->
          SofP -> [(Int64, [Exp])] -> SofP
helper SofP
acc ([(Int64, [Exp])] -> SofP) -> [(Int64, [Exp])] -> SofP
forall a b. (a -> b) -> a -> b
$ (Int64
scale Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
scale', [Exp]
exps) (Int64, [Exp]) -> [(Int64, [Exp])] -> [(Int64, [Exp])]
forall a. a -> [a] -> [a]
: ([(Int64, [Exp])]
before [(Int64, [Exp])] -> [(Int64, [Exp])] -> [(Int64, [Exp])]
forall a. Semigroup a => a -> a -> a
<> [(Int64, [Exp])]
after)

isPrimValue :: Exp -> Bool
isPrimValue :: Exp -> Bool
isPrimValue (ValueExp PrimValue
_) = Bool
True
isPrimValue Exp
_ = Bool
False

val :: Int64 -> Exp
val :: Int64 -> Exp
val = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> (Int64 -> PrimValue) -> Int64 -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue (IntValue -> PrimValue)
-> (Int64 -> IntValue) -> Int64 -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IntValue
Int64Value

add :: SofP -> SofP -> SofP
add :: SofP -> SofP -> SofP
add SofP
ps1 SofP
ps2 = SofP -> SofP
simplifySofP (SofP -> SofP) -> SofP -> SofP
forall a b. (a -> b) -> a -> b
$ SofP
ps1 SofP -> SofP -> SofP
forall a. Semigroup a => a -> a -> a
<> SofP
ps2

sub :: SofP -> SofP -> SofP
sub :: SofP -> SofP -> SofP
sub SofP
ps1 SofP
ps2 = SofP -> SofP -> SofP
add SofP
ps1 (SofP -> SofP) -> SofP -> SofP
forall a b. (a -> b) -> a -> b
$ (Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate SofP
ps2

isMultipleOf :: Prod -> [Exp] -> Bool
isMultipleOf :: Prod -> [Exp] -> Bool
isMultipleOf (Prod Bool
_ [Exp]
as) [Exp]
term =
  let quotient :: [Exp]
quotient = [Exp]
as [Exp] -> [Exp] -> [Exp]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
term
   in [Exp] -> [Exp]
forall a. Ord a => [a] -> [a]
sort ([Exp]
quotient [Exp] -> [Exp] -> [Exp]
forall a. Semigroup a => a -> a -> a
<> [Exp]
term) [Exp] -> [Exp] -> Bool
forall a. Eq a => a -> a -> Bool
== [Exp] -> [Exp]
forall a. Ord a => [a] -> [a]
sort [Exp]
as

maybeDivide :: Prod -> Prod -> Maybe Prod
maybeDivide :: Prod -> Prod -> Maybe Prod
maybeDivide Prod
dividend Prod
divisor
  | Prod Bool
dividend_b [Exp]
dividend_factors <- Prod
dividend,
    Prod Bool
divisor_b [Exp]
divisor_factors <- Prod
divisor,
    [Exp]
quotient <- [Exp]
dividend_factors [Exp] -> [Exp] -> [Exp]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
divisor_factors,
    [Exp] -> [Exp]
forall a. Ord a => [a] -> [a]
sort ([Exp]
quotient [Exp] -> [Exp] -> [Exp]
forall a. Semigroup a => a -> a -> a
<> [Exp]
divisor_factors) [Exp] -> [Exp] -> Bool
forall a. Eq a => a -> a -> Bool
== [Exp] -> [Exp]
forall a. Ord a => [a] -> [a]
sort [Exp]
dividend_factors =
      Prod -> Maybe Prod
forall a. a -> Maybe a
Just (Prod -> Maybe Prod) -> Prod -> Maybe Prod
forall a b. (a -> b) -> a -> b
$ Bool -> [Exp] -> Prod
Prod (Bool
dividend_b Bool -> Bool -> Bool
forall a. Bits a => a -> a -> a
`xor` Bool
divisor_b) [Exp]
quotient
  | (Int64
dividend_scale, [Exp]
dividend_rest) <- Prod -> (Int64, [Exp])
prodToScale Prod
dividend,
    (Int64
divisor_scale, [Exp]
divisor_rest) <- Prod -> (Int64, [Exp])
prodToScale Prod
divisor,
    Int64
dividend_scale Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`mod` Int64
divisor_scale Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== Int64
0,
    [Exp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Exp] -> Bool) -> [Exp] -> Bool
forall a b. (a -> b) -> a -> b
$ [Exp]
divisor_rest [Exp] -> [Exp] -> [Exp]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
dividend_rest =
      Prod -> Maybe Prod
forall a. a -> Maybe a
Just (Prod -> Maybe Prod) -> Prod -> Maybe Prod
forall a b. (a -> b) -> a -> b
$
        Bool -> [Exp] -> Prod
Prod
          (Int64 -> Int64
forall a. Num a => a -> a
signum (Int64
dividend_scale Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`div` Int64
divisor_scale) Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0)
          ( PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (Int64 -> IntValue) -> Int64 -> IntValue
forall a b. (a -> b) -> a -> b
$ Int64
dividend_scale Int64 -> Int64 -> Int64
forall a. Integral a => a -> a -> a
`div` Int64
divisor_scale)
              Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: ([Exp]
dividend_rest [Exp] -> [Exp] -> [Exp]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
divisor_rest)
          )
  | Bool
otherwise = Maybe Prod
forall a. Maybe a
Nothing

-- | Given a list of 'Names' that we know are non-negative (>= 0), determine
-- whether we can say for sure that the given 'AlgSimplify.SofP' is
-- non-negative. Conservatively returns 'False' if there is any doubt.
--
-- TODO: We need to expand this to be able to handle cases such as @i*n + g < (i
-- + 1) * n@, if it is known that @g < n@, eg. from a 'SegSpace' or a loop form.
nonNegativeish :: Names -> SofP -> Bool
nonNegativeish :: Names -> SofP -> Bool
nonNegativeish Names
non_negatives = (Prod -> Bool) -> SofP -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> Prod -> Bool
nonNegativeishProd Names
non_negatives)

nonNegativeishProd :: Names -> Prod -> Bool
nonNegativeishProd :: Names -> Prod -> Bool
nonNegativeishProd Names
_ (Prod Bool
True [Exp]
_) = Bool
False
nonNegativeishProd Names
non_negatives (Prod Bool
False [Exp]
as) =
  (Exp -> Bool) -> [Exp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> Exp -> Bool
nonNegativeishExp Names
non_negatives) [Exp]
as

nonNegativeishExp :: Names -> PrimExp VName -> Bool
nonNegativeishExp :: Names -> Exp -> Bool
nonNegativeishExp Names
_ (ValueExp PrimValue
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
v
nonNegativeishExp Names
non_negatives (LeafExp VName
vname PrimType
_) = VName
vname VName -> Names -> Bool
`nameIn` Names
non_negatives
nonNegativeishExp Names
_ Exp
_ = Bool
False

-- | Is e1 symbolically less than or equal to e2?
lessThanOrEqualish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish :: [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish [(VName, Exp)]
less_thans0 Names
non_negatives TPrimExp Int64 VName
e1 TPrimExp Int64 VName
e2 =
  case TPrimExp Int64 VName
e2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
e1 TPrimExp Int64 VName -> (TPrimExp Int64 VName -> Exp) -> Exp
forall a b. a -> (a -> b) -> b
& TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Exp -> (Exp -> SofP) -> SofP
forall a b. a -> (a -> b) -> b
& Exp -> SofP
simplify0 of
    [] -> Bool
True
    SofP
simplified ->
      Names -> SofP -> Bool
nonNegativeish Names
non_negatives (SofP -> Bool) -> SofP -> Bool
forall a b. (a -> b) -> a -> b
$
        (SofP -> SofP) -> SofP -> SofP
forall a. Eq a => (a -> a) -> a -> a
fixPoint (SofP -> [(SubExp, Exp)] -> SofP
`removeLessThans` [(SubExp, Exp)]
less_thans) SofP
simplified
  where
    less_thans :: [(SubExp, Exp)]
less_thans =
      ((VName, Exp) -> [(SubExp, Exp)])
-> [(VName, Exp)] -> [(SubExp, Exp)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
        (\(VName
i, Exp
bound) -> [(VName -> SubExp
Var VName
i, Exp
bound), (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0, Exp
bound)])
        [(VName, Exp)]
less_thans0

lessThanish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanish :: [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanish [(VName, Exp)]
less_thans Names
non_negatives TPrimExp Int64 VName
e1 =
  [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish [(VName, Exp)]
less_thans Names
non_negatives (TPrimExp Int64 VName
e1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)

removeLessThans :: SofP -> [(SubExp, PrimExp VName)] -> SofP
removeLessThans :: SofP -> [(SubExp, Exp)] -> SofP
removeLessThans =
  (SofP -> (SubExp, Exp) -> SofP) -> SofP -> [(SubExp, Exp)] -> SofP
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
    ( \SofP
sofp (SubExp
i, Exp
bound) ->
        let to_remove :: SofP
to_remove =
              SofP -> SofP
simplifySofP (SofP -> SofP) -> SofP -> SofP
forall a b. (a -> b) -> a -> b
$
                Bool -> [Exp] -> Prod
Prod Bool
True [PrimType -> SubExp -> Exp
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
i]
                  Prod -> SofP -> SofP
forall a. a -> [a] -> [a]
: Exp -> SofP
simplify0 Exp
bound
         in case SofP
to_remove SofP -> SofP -> SofP
forall a. Eq a => [a] -> [a] -> [a]
`intersect` SofP
sofp of
              SofP
to_remove' | SofP
to_remove' SofP -> SofP -> Bool
forall a. Eq a => a -> a -> Bool
== SofP
to_remove -> SofP
sofp SofP -> SofP -> SofP
forall a. Eq a => [a] -> [a] -> [a]
\\ SofP
to_remove
              SofP
_ -> SofP
sofp
    )

compareComplexity :: SofP -> SofP -> Ordering
compareComplexity :: SofP -> SofP -> Ordering
compareComplexity SofP
xs0 SofP
ys0 =
  case SofP -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length SofP
xs0 Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` SofP -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length SofP
ys0 of
    Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs0 SofP
ys0
    Ordering
c -> Ordering
c
  where
    helper :: SofP -> SofP -> Ordering
helper [] [] = Ordering
EQ
    helper [] SofP
_ = Ordering
LT
    helper SofP
_ [] = Ordering
GT
    helper (Prod
px : SofP
xs) (Prod
py : SofP
ys) =
      case (Prod -> (Int64, [Exp])
prodToScale Prod
px, Prod -> (Int64, [Exp])
prodToScale Prod
py) of
        ((Int64
ix, []), (Int64
iy, [])) -> case Int64
ix Int64 -> Int64 -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int64
iy of
          Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs SofP
ys
          Ordering
c -> Ordering
c
        ((Int64
_, []), (Int64
_, [Exp]
_)) -> Ordering
LT
        ((Int64
_, [Exp]
_), (Int64
_, [])) -> Ordering
GT
        ((Int64
_, [Exp]
x), (Int64
_, [Exp]
y)) -> case [Exp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
x Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` [Exp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
y of
          Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs SofP
ys
          Ordering
c -> Ordering
c