-- | Particularly simple simplification rules.
module Futhark.Optimise.Simplify.Rules.Simple
  ( TypeLookup,
    VarLookup,
    applySimpleRules,
  )
where

import Control.Monad
import Data.List (isSuffixOf)
import Data.List.NonEmpty qualified as NE
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR
import Futhark.Util (focusNth)

-- | A function that, given a variable name, returns its definition.
type VarLookup rep = VName -> Maybe (Exp rep, Certs)

-- | A function that, given a subexpression, returns its type.
type TypeLookup = SubExp -> Maybe Type

-- | A simple rule is a top-down rule that can be expressed as a pure
-- function.
type SimpleRule rep = VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

simplifyCmpOp :: SimpleRule rep
simplifyCmpOp :: forall rep. SimpleRule rep
simplifyCmpOp VarLookup rep
_ TypeLookup
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$
      Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
        case CmpOp
cmp of
          CmpEq {} -> Bool
True
          CmpSlt {} -> Bool
False
          CmpUlt {} -> Bool
False
          CmpSle {} -> Bool
True
          CmpUle {} -> Bool
True
          FCmpLt {} -> Bool
False
          FCmpLe {} -> Bool
True
          CmpOp
CmpLlt -> Bool
False
          CmpOp
CmpLle -> Bool
True
simplifyCmpOp VarLookup rep
_ TypeLookup
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> (Bool -> PrimValue) -> Bool -> Maybe (BasicOp, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue (Bool -> Maybe (BasicOp, Certs))
-> Maybe Bool -> Maybe (BasicOp, Certs)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2
simplifyCmpOp VarLookup rep
look TypeLookup
_ (CmpOp CmpEq {} (Constant (IntValue IntValue
x)) (Var VName
v))
  | Just (BasicOp (ConvOp BToI {} SubExp
b), Certs
cs) <- VarLookup rep
look VName
v =
      case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
        Int
1 -> (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certs
cs)
        Int
0 -> (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certs
cs)
        Int
_ -> (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certs
cs)
simplifyCmpOp VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyBinOp :: SimpleRule rep
simplifyBinOp :: forall rep. SimpleRule rep
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
  | Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
      PrimValue -> Maybe (BasicOp, Certs)
constRes PrimValue
res
-- By normalisation, constants are always on the left.
--
-- x+(y+z) = (x+y)+z (where x and y are constants).
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp BinOp
op1 (Constant PrimValue
x1) (Var VName
y1))
  | BinOp -> Bool
associativeBinOp BinOp
op1,
    Just (BasicOp (BinOp BinOp
op2 (Constant PrimValue
x2) SubExp
y2), Certs
cs) <- VarLookup rep
look VName
y1,
    BinOp
op1 BinOp -> BinOp -> Bool
forall a. Eq a => a -> a -> Bool
== BinOp
op2,
    Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op1 PrimValue
x1 PrimValue
x2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op1 (PrimValue -> SubExp
Constant PrimValue
res) SubExp
y2, Certs
cs)
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp (Add IntType
it Overflow
ovf) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  -- x+(y-x) => y
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Sub {} SubExp
e2_a SubExp
e2_b), Certs
cs) <- VarLookup rep
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certs
cs)
  -- x+(-1*y) => x-y
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Mul {} (Constant (IntValue IntValue
x)) SubExp
e3), Certs
cs) <- VarLookup rep
look VName
v2,
    IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== (-Int
1 :: Int) =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
ovf) SubExp
e1 SubExp
e3, Certs
cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FAdd {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp sub :: BinOp
sub@(Sub IntType
t Overflow
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (IntType -> Integer -> SubExp
intConst IntType
t Integer
0), Certs
forall a. Monoid a => a
mempty)
  --
  -- Below are cases for simplifying (a+b)-b and permutations.
  --
  -- (e1_a+e1_b)-e1_a == e1_b
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certs
cs) <- VarLookup rep
look VName
v1,
    SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certs
cs)
  -- (e1_a+e1_b)-e1_b == e1_a
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certs
cs) <- VarLookup rep
look VName
v1,
    SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certs
cs)
  -- e2_a-(e2_a+e2_b) == 0-e2_b
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certs
cs) <- VarLookup rep
look VName
v2,
    SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
sub (IntType -> Integer -> SubExp
intConst IntType
t Integer
0) SubExp
e2_b, Certs
cs)
  -- e2_b-(e2_a+e2_b) == 0-e2_a
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certs
cs) <- VarLookup rep
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
sub (IntType -> Integer -> SubExp
intConst IntType
t Integer
0) SubExp
e2_a, Certs
cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FSub {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp Mul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FMul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
look TypeLookup
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp SMod {} SubExp
_ SubExp
e4), Certs
v1_cs) <- VarLookup rep
look VName
v1,
    SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certs
v1_cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp SDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp SDivUp {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp FDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp SQuot {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2, Certs
forall a. Monoid a => a
mempty)
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp (SubExp -> Maybe (BasicOp, Certs))
-> SubExp -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
  | SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp (SubExp -> Maybe (BasicOp, Certs))
-> SubExp -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp AShr {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp (SubExp -> Maybe (BasicOp, Certs))
-> SubExp -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp (SubExp -> Maybe (BasicOp, Certs))
-> SubExp -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp Or {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
simplifyBinOp VarLookup rep
_ TypeLookup
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp (SubExp -> Maybe (BasicOp, Certs))
-> SubExp -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup rep
defOf TypeLookup
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certs
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certs
v_cs)
simplifyBinOp VarLookup rep
defOf TypeLookup
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> PrimValue -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certs
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certs
v_cs)
simplifyBinOp VarLookup rep
defOf TypeLookup
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
e1
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certs
v1_cs) <- VarLookup rep
defOf VName
v1,
    SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certs
v1_cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certs
v1_cs) <- VarLookup rep
defOf VName
v1,
    SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certs
v1_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certs
v2_cs) <- VarLookup rep
defOf VName
v2,
    SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certs
v2_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certs
v2_cs) <- VarLookup rep
defOf VName
v2,
    SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certs
v2_cs)
simplifyBinOp VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

constRes :: PrimValue -> Maybe (BasicOp, Certs)
constRes :: PrimValue -> Maybe (BasicOp, Certs)
constRes = (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just ((BasicOp, Certs) -> Maybe (BasicOp, Certs))
-> (PrimValue -> (BasicOp, Certs))
-> PrimValue
-> Maybe (BasicOp, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certs
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certs))
-> (PrimValue -> BasicOp) -> PrimValue -> (BasicOp, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant

resIsSubExp :: SubExp -> Maybe (BasicOp, Certs)
resIsSubExp :: SubExp -> Maybe (BasicOp, Certs)
resIsSubExp = (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just ((BasicOp, Certs) -> Maybe (BasicOp, Certs))
-> (SubExp -> (BasicOp, Certs)) -> SubExp -> Maybe (BasicOp, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certs
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certs))
-> (SubExp -> BasicOp) -> SubExp -> (BasicOp, Certs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

simplifyUnOp :: SimpleRule rep
simplifyUnOp :: forall rep. SimpleRule rep
simplifyUnOp VarLookup rep
_ TypeLookup
_ (UnOp UnOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> Maybe PrimValue -> Maybe (BasicOp, Certs)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup rep
defOf TypeLookup
_ (UnOp UnOp
Not (Var VName
v))
  | Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certs
v_cs) <- VarLookup rep
defOf VName
v =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certs
v_cs)
simplifyUnOp VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyConvOp :: SimpleRule rep
simplifyConvOp :: forall rep. SimpleRule rep
simplifyConvOp VarLookup rep
_ TypeLookup
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes (PrimValue -> Maybe (BasicOp, Certs))
-> Maybe PrimValue -> Maybe (BasicOp, Certs)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup rep
_ TypeLookup
_ (ConvOp ConvOp
op SubExp
se)
  | (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op,
    PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
      SubExp -> Maybe (BasicOp, Certs)
resIsSubExp SubExp
se
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
lookupVar TypeLookup
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certs
v_cs) <- VarLookup rep
lookupVar VName
v,
    FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certs
v_cs)
simplifyConvOp VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

-- If expression is true then just replace assertion.
simplifyAssert :: SimpleRule rep
simplifyAssert :: forall rep. SimpleRule rep
simplifyAssert VarLookup rep
_ TypeLookup
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
  PrimValue -> Maybe (BasicOp, Certs)
constRes PrimValue
UnitValue
simplifyAssert VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

-- No-op reshape.
simplifyIdentityReshape :: SimpleRule rep
simplifyIdentityReshape :: forall rep. SimpleRule rep
simplifyIdentityReshape VarLookup rep
_ TypeLookup
seType (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    Shape
newshape Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t =
      SubExp -> Maybe (BasicOp, Certs)
resIsSubExp (SubExp -> Maybe (BasicOp, Certs))
-> SubExp -> Maybe (BasicOp, Certs)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyReshapeReshape :: SimpleRule rep
simplifyReshapeReshape :: forall rep. SimpleRule rep
simplifyReshapeReshape VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
k1 Shape
newshape VName
v)
  | Just (BasicOp (Reshape ReshapeKind
k2 Shape
_ VName
v2), Certs
v_cs) <- VarLookup rep
defOf VName
v =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (ReshapeKind -> Shape -> VName -> BasicOp
Reshape (ReshapeKind -> ReshapeKind -> ReshapeKind
forall a. Ord a => a -> a -> a
max ReshapeKind
k1 ReshapeKind
k2) Shape
newshape VName
v2, Certs
v_cs)
simplifyReshapeReshape VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyReshapeScratch :: SimpleRule rep
simplifyReshapeScratch :: forall rep. SimpleRule rep
simplifyReshapeScratch VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certs
v_cs) <- VarLookup rep
defOf VName
v =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape, Certs
v_cs)
simplifyReshapeScratch VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyReshapeReplicate :: SimpleRule rep
simplifyReshapeReplicate :: forall rep. SimpleRule rep
simplifyReshapeReplicate VarLookup rep
defOf TypeLookup
seType (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just (BasicOp (Replicate Shape
_ SubExp
se), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    Just Shape
oldshape <- Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Maybe Type -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType SubExp
se,
    Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape =
      let new :: [SubExp]
new =
            Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (Shape -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape
       in (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certs
v_cs)
simplifyReshapeReplicate VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyReshapeIota :: SimpleRule rep
simplifyReshapeIota :: forall rep. SimpleRule rep
simplifyReshapeIota VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
_ Shape
newshape VName
v)
  | Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    [SubExp
n] <- Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certs
v_cs)
simplifyReshapeIota VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simplifyReshapeConcat :: SimpleRule rep
simplifyReshapeConcat :: forall rep. SimpleRule rep
simplifyReshapeConcat VarLookup rep
defOf TypeLookup
seType (Reshape ReshapeKind
ReshapeCoerce Shape
newshape VName
v) = do
  (BasicOp (Concat Int
d NonEmpty VName
arrs SubExp
_), Certs
v_cs) <- VarLookup rep
defOf VName
v
  ([SubExp]
bef, SubExp
w', [SubExp]
aft) <- Int -> [SubExp] -> Maybe ([SubExp], SubExp, [SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d ([SubExp] -> Maybe ([SubExp], SubExp, [SubExp]))
-> [SubExp] -> Maybe ([SubExp], SubExp, [SubExp])
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape
  ([SubExp]
arr_bef, SubExp
_, [SubExp]
arr_aft) <-
    Int -> [SubExp] -> Maybe ([SubExp], SubExp, [SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d ([SubExp] -> Maybe ([SubExp], SubExp, [SubExp]))
-> (Maybe Type -> Maybe [SubExp])
-> Maybe Type
-> Maybe ([SubExp], SubExp, [SubExp])
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Maybe Type -> Maybe ([SubExp], SubExp, [SubExp]))
-> Maybe Type -> Maybe ([SubExp], SubExp, [SubExp])
forall a b. (a -> b) -> a -> b
$ TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ NonEmpty VName -> VName
forall a. NonEmpty a -> a
NE.head NonEmpty VName
arrs
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
arr_bef [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
bef
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
arr_aft [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
aft
  (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
d NonEmpty VName
arrs SubExp
w', Certs
v_cs)
simplifyReshapeConcat VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

reshapeSlice :: [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice :: forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (DimFix d
i : [DimIndex d]
slice') [d]
scs =
  d -> DimIndex d
forall d. d -> DimIndex d
DimFix d
i DimIndex d -> [DimIndex d] -> [DimIndex d]
forall a. a -> [a] -> [a]
: [DimIndex d] -> [d] -> [DimIndex d]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex d]
slice' [d]
scs
reshapeSlice (DimSlice d
x d
_ d
s : [DimIndex d]
slice') (d
d : [d]
ds') =
  d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
x d
d d
s DimIndex d -> [DimIndex d] -> [DimIndex d]
forall a. a -> [a] -> [a]
: [DimIndex d] -> [d] -> [DimIndex d]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex d]
slice' [d]
ds'
reshapeSlice [DimIndex d]
_ [d]
_ = []

-- If we are size-coercing a slice, then we might as well just use a
-- different slice instead.
simplifyReshapeIndex :: SimpleRule rep
simplifyReshapeIndex :: forall rep. SimpleRule rep
simplifyReshapeIndex VarLookup rep
defOf TypeLookup
_ (Reshape ReshapeKind
ReshapeCoerce Shape
newshape VName
v)
  | Just (BasicOp (Index VName
v' Slice SubExp
slice), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    Slice SubExp
slice' <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> [SubExp] -> [DimIndex SubExp]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice) ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
newshape,
    Slice SubExp
slice' Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= Slice SubExp
slice =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (VName -> Slice SubExp -> BasicOp
Index VName
v' Slice SubExp
slice', Certs
v_cs)
simplifyReshapeIndex VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

-- If we are updating a slice with the result of a size coercion, we
-- instead use the original array and update the slice dimensions.
simplifyUpdateReshape :: SimpleRule rep
simplifyUpdateReshape :: forall rep. SimpleRule rep
simplifyUpdateReshape VarLookup rep
defOf TypeLookup
seType (Update Safety
safety VName
dest Slice SubExp
slice (Var VName
v))
  | Just (BasicOp (Reshape ReshapeKind
ReshapeCoerce Shape
_ VName
v'), Certs
v_cs) <- VarLookup rep
defOf VName
v,
    Just [SubExp]
ds <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
v'),
    Slice SubExp
slice' <- [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> [SubExp] -> [DimIndex SubExp]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice) [SubExp]
ds,
    Slice SubExp
slice' Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= Slice SubExp
slice =
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
Just (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
dest Slice SubExp
slice' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v', Certs
v_cs)
simplifyUpdateReshape VarLookup rep
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

-- | If we are replicating a scratch array (possibly indirectly), just
-- turn it into a scratch by itself.
repScratchToScratch :: SimpleRule rep
repScratchToScratch :: forall rep. SimpleRule rep
repScratchToScratch VarLookup rep
defOf TypeLookup
seType (Replicate Shape
shape (Var VName
src)) = do
  Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
  Certs
cs <- VName -> Maybe Certs
isActuallyScratch VName
src
  (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t), Certs
cs)
  where
    isActuallyScratch :: VName -> Maybe Certs
isActuallyScratch VName
v =
      case VarLookup rep
defOf VName
v of
        Just (BasicOp Scratch {}, Certs
cs) ->
          Certs -> Maybe Certs
forall a. a -> Maybe a
Just Certs
cs
        Just (BasicOp (Rearrange [Int]
_ VName
v'), Certs
cs) ->
          (Certs
cs <>) (Certs -> Certs) -> Maybe Certs -> Maybe Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Maybe Certs
isActuallyScratch VName
v'
        Just (BasicOp (Reshape ReshapeKind
_ Shape
_ VName
v'), Certs
cs) ->
          (Certs
cs <>) (Certs -> Certs) -> Maybe Certs -> Maybe Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Maybe Certs
isActuallyScratch VName
v'
        Maybe (Exp rep, Certs)
_ -> Maybe Certs
forall a. Maybe a
Nothing
repScratchToScratch VarLookup rep
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certs)
forall a. Maybe a
Nothing

simpleRules :: [SimpleRule rep]
simpleRules :: forall rep. [SimpleRule rep]
simpleRules =
  [ SimpleRule rep
forall rep. SimpleRule rep
simplifyBinOp,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyCmpOp,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyUnOp,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyConvOp,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyAssert,
    SimpleRule rep
forall rep. SimpleRule rep
repScratchToScratch,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyIdentityReshape,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyReshapeReshape,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyReshapeScratch,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyReshapeReplicate,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyReshapeIota,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyReshapeConcat,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyReshapeIndex,
    SimpleRule rep
forall rep. SimpleRule rep
simplifyUpdateReshape
  ]

-- | Try to simplify the given t'BasicOp', returning a new t'BasicOp'
-- and certificates that it must depend on.
{-# NOINLINE applySimpleRules #-}
applySimpleRules ::
  VarLookup rep ->
  TypeLookup ->
  BasicOp ->
  Maybe (BasicOp, Certs)
applySimpleRules :: forall rep. SimpleRule rep
applySimpleRules VarLookup rep
defOf TypeLookup
seType BasicOp
op =
  [Maybe (BasicOp, Certs)] -> Maybe (BasicOp, Certs)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [SimpleRule rep
rule VarLookup rep
defOf TypeLookup
seType BasicOp
op | SimpleRule rep
rule <- [SimpleRule rep]
forall rep. [SimpleRule rep]
simpleRules]