-- | Particularly simple simplification rules.
module Futhark.Optimise.Simplify.Rules.Simple
  ( TypeLookup,
    VarLookup,
    applySimpleRules,
  )
where

import Control.Monad
import Data.List (isSuffixOf)
import Data.List.NonEmpty qualified as NE
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR
import Futhark.Util (focusNth)

-- | A function that, given a variable name, returns its definition.
type VarLookup rep = VName -> Maybe (Exp rep, Certs)

-- | A function that, given a subexpression, returns its type.
type TypeLookup = SubExp -> Maybe Type

-- | A simple rule is a top-down rule that can be expressed as a pure
-- function.
type SimpleRule rep = VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

simplifyCmpOp :: SimpleRule rep
simplifyCmpOp :: forall {k} (rep :: k). SimpleRule rep
simplifyCmpOp VarLookup rep
_ TypeLookup
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$
      Bool -> PrimValue
BoolValue forall a b. (a -> b) -> a -> b
$
        case CmpOp
cmp of
          CmpEq {} -> Bool
True
          CmpSlt {} -> Bool
False
          CmpUlt {} -> Bool
False
          CmpSle {} -> Bool
True
          CmpUle {} -> Bool
True
          FCmpLt {} -> Bool
False
          FCmpLe {} -> Bool
True
          CmpOp
CmpLlt -> Bool
False
          CmpOp
CmpLle -> Bool
True
simplifyCmpOp VarLookup rep
_ TypeLookup
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2
simplifyCmpOp VarLookup rep
look TypeLookup
_ (CmpOp CmpEq {} (Constant (IntValue IntValue
x)) (Var VName
v))
  | Just (BasicOp (ConvOp BToI {} SubExp
b), Certs
cs) <- VarLookup rep
look VName
v =
      case forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
        Int
1 -> forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certs
cs)
        Int
0 -> forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certs
cs)
        Int
_ -> forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certs
cs)
simplifyCmpOp VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

simplifyBinOp :: SimpleRule rep
simplifyBinOp :: forall {k} (rep :: k). SimpleRule rep
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
  | Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
      PrimValue -> Maybe (BasicOp, Certs)
constRes PrimValue
res
-- By normalisation, constants are always on the left.
--
-- x+(y+z) = (x+y)+z (where x and y are constants).
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp BinOp
op1 (Constant PrimValue
x1) (Var VName
y1))
  | BinOp -> Bool
associativeBinOp BinOp
op1,
    Just (BasicOp (BinOp BinOp
op2 (Constant PrimValue
x2) SubExp
y2), Certs
cs) <- VarLookup rep
look VName
y1,
    BinOp
op1 forall a. Eq a => a -> a -> Bool
== BinOp
op2,
    Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op1 PrimValue
x1 PrimValue
x2 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op1 (PrimValue -> SubExp
Constant PrimValue
res) SubExp
y2, Certs
cs)
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp Add {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  -- x+(y-x) => y
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Sub {} SubExp
e2_a SubExp
e2_b), Certs
cs) <- VarLookup rep
look VName
v2,
    SubExp
e2_b forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certs
cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FAdd {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp sub :: BinOp
sub@(Sub IntType
t Overflow
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  -- Cases for simplifying (a+b)-b and permutations.

  -- (e1_a+e1_b)-e1_a == e1_b
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certs
cs) <- VarLookup rep
look VName
v1,
    SubExp
e1_a forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certs
cs)
  -- (e1_a+e1_b)-e1_b == e1_a
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certs
cs) <- VarLookup rep
look VName
v1,
    SubExp
e1_b forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certs
cs)
  -- e2_a-(e2_a+e2_b) == 0-e2_b
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certs
cs) <- VarLookup rep
look VName
v2,
    SubExp
e2_a forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
sub (IntType -> Integer -> SubExp
intConst IntType
t Integer
0) SubExp
e2_b, Certs
cs)
  -- e2_b-(e2_a+e2_b) == 0-e2_a
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certs
cs) <- VarLookup rep
look VName
v2,
    SubExp
e2_b forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
sub (IntType -> Integer -> SubExp
intConst IntType
t Integer
0) SubExp
e2_a, Certs
cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FSub {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp Mul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FMul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp SMod {} SubExp
_ SubExp
e4), Certs
v1_cs) <- VarLookup rep
look VName
v1,
    SubExp
e4 forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certs
v1_cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp SDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp SDivUp {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp SQuot {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2, forall a. Monoid a => a
mempty)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
  | SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp AShr {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp Or {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup rep
defOf TypeLookup
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e1' forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certs
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e2' forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certs
v_cs)
simplifyBinOp VarLookup rep
defOf TypeLookup
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e1' forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certs
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e2' forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certs
v_cs)
simplifyBinOp VarLookup rep
defOf TypeLookup
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
  | SubExp
e1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certs
v1_cs) <- VarLookup rep
defOf VName
v1,
    SubExp
e1_1 forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certs
v1_cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certs
v1_cs) <- VarLookup rep
defOf VName
v1,
    SubExp
e1_2 forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certs
v1_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certs
v2_cs) <- VarLookup rep
defOf VName
v2,
    SubExp
e2_1 forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certs
v2_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certs
v2_cs) <- VarLookup rep
defOf VName
v2,
    SubExp
e2_2 forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certs
v2_cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

constRes :: PrimValue -> Maybe (BasicOp, Certs)
constRes :: PrimValue -> Maybe (BasicOp, Certs)
constRes = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,forall a. Monoid a => a
mempty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant

resIsSubExp :: SubExp -> Maybe (BasicOp, Certs)
resIsSubExp :: SubExp -> Maybe (BasicOp, Certs)
resIsSubExp = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,forall a. Monoid a => a
mempty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

simplifyUnOp :: SimpleRule rep
simplifyUnOp :: forall {k} (rep :: k). SimpleRule rep
simplifyUnOp VarLookup rep
_ TypeLookup
_ (UnOp UnOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup rep
defOf TypeLookup
_ (UnOp UnOp
Not (Var VName
v))
  | Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certs
v_cs) <- VarLookup rep
defOf VName
v =
      forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certs
v_cs)
simplifyUnOp VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  forall a. Maybe a
Nothing

simplifyConvOp :: SimpleRule rep
simplifyConvOp :: forall {k} (rep :: k). SimpleRule rep
simplifyConvOp VarLookup rep
_ TypeLookup
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup rep
_ TypeLookup
_ (ConvOp ConvOp
op SubExp
se)
  | (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op,
    PrimType
from forall a. Eq a => a -> a -> Bool
== PrimType
to =
      SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
se
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    FloatType
t2 forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
      forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  forall a. Maybe a
Nothing

-- If expression is true then just replace assertion.
simplifyAssert :: SimpleRule rep
simplifyAssert :: forall {k} (rep :: k). SimpleRule rep
simplifyAssert VarLookup rep
_ TypeLookup
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes PrimValue
UnitValue
simplifyAssert VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  forall a. Maybe a
Nothing

-- No-op reshape.
simplifyIdentityReshape :: SimpleRule rep
simplifyIdentityReshape :: forall {k} (rep :: k). SimpleRule rep
simplifyIdentityReshape VarLookup rep
_ TypeLookup
seType (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just Type
t <- TypeLookup
seType forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    Shape
newshape forall a. Eq a => a -> a -> Bool
== forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t =
      SubExp -> Maybe (BasicOp, Certs)
resIsSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

simplifyReshapeReshape :: SimpleRule rep
simplifyReshapeReshape :: forall {k} (rep :: k). SimpleRule rep
simplifyReshapeReshape VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
k1 Shape
newshape VName
v)
  | Just (BasicOp (Reshape ReshapeKind
k2 Shape
_ VName
v2), Certs
v_cs) <- VarLookup rep
defOf VName
v =
      forall a. a -> Maybe a
Just (ReshapeKind -> Shape -> VName -> BasicOp
Reshape (forall a. Ord a => a -> a -> a
max ReshapeKind
k1 ReshapeKind
k2) Shape
newshape VName
v2, Certs
v_cs)
simplifyReshapeReshape VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

simplifyReshapeScratch :: SimpleRule rep
simplifyReshapeScratch :: forall {k} (rep :: k). SimpleRule rep
simplifyReshapeScratch VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certs
v_cs) <- VarLookup rep
defOf VName
v =
      forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
newshape, Certs
v_cs)
simplifyReshapeScratch VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

simplifyReshapeReplicate :: SimpleRule rep
simplifyReshapeReplicate :: forall {k} (rep :: k). SimpleRule rep
simplifyReshapeReplicate VarLookup rep
defOf TypeLookup
seType (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just (BasicOp (Replicate Shape
_ SubExp
se), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    Just Shape
oldshape <- forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType SubExp
se,
    forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` forall d. ShapeBase d -> [d]
shapeDims Shape
newshape =
      let new :: [SubExp]
new =
            forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
newshape forall a. Num a => a -> a -> a
- forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) forall a b. (a -> b) -> a -> b
$
              forall d. ShapeBase d -> [d]
shapeDims Shape
newshape
       in forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certs
v_cs)
simplifyReshapeReplicate VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

simplifyReshapeIota :: SimpleRule rep
simplifyReshapeIota :: forall {k} (rep :: k). SimpleRule rep
simplifyReshapeIota VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    [SubExp
n] <- forall d. ShapeBase d -> [d]
shapeDims Shape
newshape =
      forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certs
v_cs)
simplifyReshapeIota VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

simplifyReshapeConcat :: SimpleRule rep
simplifyReshapeConcat :: forall {k} (rep :: k). SimpleRule rep
simplifyReshapeConcat VarLookup rep
defOf TypeLookup
seType (Reshape ReshapeKind
ReshapeCoerce Shape
newshape VName
v) = do
  (BasicOp (Concat Int
d NonEmpty VName
arrs SubExp
_), Certs
v_cs) <- VarLookup rep
defOf VName
v
  ([SubExp]
bef, SubExp
w', [SubExp]
aft) <- forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
newshape
  ([SubExp]
arr_bef, SubExp
_, [SubExp]
arr_aft) <-
    forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ TypeLookup
seType forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> a
NE.head NonEmpty VName
arrs
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ [SubExp]
arr_bef forall a. Eq a => a -> a -> Bool
== [SubExp]
bef
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ [SubExp]
arr_aft forall a. Eq a => a -> a -> Bool
== [SubExp]
aft
  forall a. a -> Maybe a
Just (Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d NonEmpty VName
arrs SubExp
w', Certs
v_cs)
simplifyReshapeConcat VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

reshapeSlice :: [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice :: forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (DimFix d
i : [DimIndex d]
slice') [d]
scs =
  forall d. d -> DimIndex d
DimFix d
i forall a. a -> [a] -> [a]
: forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex d]
slice' [d]
scs
reshapeSlice (DimSlice d
x d
_ d
s : [DimIndex d]
slice') (d
d : [d]
ds') =
  forall d. d -> d -> d -> DimIndex d
DimSlice d
x d
d d
s forall a. a -> [a] -> [a]
: forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex d]
slice' [d]
ds'
reshapeSlice [DimIndex d]
_ [d]
_ = []

-- If we are size-coercing a slice, then we might as well just use a
-- different slice instead.
simplifyReshapeIndex :: SimpleRule rep
simplifyReshapeIndex :: forall {k} (rep :: k). SimpleRule rep
simplifyReshapeIndex VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
ReshapeCoerce Shape
newshape VName
v)
  | Just (BasicOp (Index VName
v' Slice SubExp
slice), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    Slice SubExp
slice' <- forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice) forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
newshape,
    Slice SubExp
slice' forall a. Eq a => a -> a -> Bool
/= Slice SubExp
slice =
      forall a. a -> Maybe a
Just (VName -> Slice SubExp -> BasicOp
Index VName
v' Slice SubExp
slice', Certs
v_cs)
simplifyReshapeIndex VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

-- If we are updating a slice with the result of a size coercion, we
-- instead use the original array and update the slice dimensions.
simplifyUpdateReshape :: SimpleRule rep
simplifyUpdateReshape :: forall {k} (rep :: k). SimpleRule rep
simplifyUpdateReshape VarLookup rep
defOf TypeLookup
seType (Update Safety
safety VName
dest Slice SubExp
slice (Var VName
v))
  | Just (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
_ VName
v'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    Just [SubExp]
ds <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
v'),
    Slice SubExp
slice' <- forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice) [SubExp]
ds,
    Slice SubExp
slice' forall a. Eq a => a -> a -> Bool
/= Slice SubExp
slice =
      forall a. a -> Maybe a
Just (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
dest Slice SubExp
slice' forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v', Certs
v_cs)
simplifyUpdateReshape VarLookup rep
_ TypeLookup
_ BasicOp
_ = forall a. Maybe a
Nothing

-- | If we are copying a scratch array (possibly indirectly), just turn it into a scratch by
-- itself.
copyScratchToScratch :: SimpleRule rep
copyScratchToScratch :: forall {k} (rep :: k). SimpleRule rep
copyScratchToScratch VarLookup rep
defOf TypeLookup
seType (Copy VName
src) = do
  Type
t <- TypeLookup
seType forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
  if VName -> Bool
isActuallyScratch VName
src
    then forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t), forall a. Monoid a => a
mempty)
    else forall a. Maybe a
Nothing
  where
    isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
      case forall {k} (rep :: k). Exp rep -> Maybe BasicOp
asBasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup rep
defOf VName
v of
        Just Scratch {} -> Bool
True
        Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
        Just (Reshape ReshapeKind
_ Shape
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
        Maybe BasicOp
_ -> Bool
False
copyScratchToScratch VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  forall a. Maybe a
Nothing

simpleRules :: [SimpleRule rep]
simpleRules :: forall {k} (rep :: k). [SimpleRule rep]
simpleRules =
  [ forall {k} (rep :: k). SimpleRule rep
simplifyBinOp,
    forall {k} (rep :: k). SimpleRule rep
simplifyCmpOp,
    forall {k} (rep :: k). SimpleRule rep
simplifyUnOp,
    forall {k} (rep :: k). SimpleRule rep
simplifyConvOp,
    forall {k} (rep :: k). SimpleRule rep
simplifyAssert,
    forall {k} (rep :: k). SimpleRule rep
copyScratchToScratch,
    forall {k} (rep :: k). SimpleRule rep
simplifyIdentityReshape,
    forall {k} (rep :: k). SimpleRule rep
simplifyReshapeReshape,
    forall {k} (rep :: k). SimpleRule rep
simplifyReshapeScratch,
    forall {k} (rep :: k). SimpleRule rep
simplifyReshapeReplicate,
    forall {k} (rep :: k). SimpleRule rep
simplifyReshapeIota,
    forall {k} (rep :: k). SimpleRule rep
simplifyReshapeConcat,
    forall {k} (rep :: k). SimpleRule rep
simplifyReshapeIndex,
    forall {k} (rep :: k). SimpleRule rep
simplifyUpdateReshape
  ]

-- | Try to simplify the given t'BasicOp', returning a new t'BasicOp'
-- and certificates that it must depend on.
{-# NOINLINE applySimpleRules #-}
applySimpleRules ::
  VarLookup rep ->
  TypeLookup ->
  BasicOp ->
  Maybe (BasicOp, Certs)
applySimpleRules :: forall {k} (rep :: k). SimpleRule rep
applySimpleRules VarLookup rep
defOf TypeLookup
seType BasicOp
op =
  forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [SimpleRule rep
rule VarLookup rep
defOf TypeLookup
seType BasicOp
op | SimpleRule rep
rule <- forall {k} (rep :: k). [SimpleRule rep]
simpleRules]