{-# LANGUAGE TypeFamilies #-}

-- | Match simplification rules.
module Futhark.Optimise.Simplify.Rules.Match (matchRules) where

import Control.Monad
import Data.Either
import Data.List (partition, transpose, unzip4, zip5)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Util

-- Does this case always match the scrutinees?
caseAlwaysMatches :: [SubExp] -> Case a -> Bool
caseAlwaysMatches :: forall a. [SubExp] -> Case a -> Bool
caseAlwaysMatches [SubExp]
ses = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> (Case a -> [Bool]) -> Case a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Maybe PrimValue -> Bool)
-> [SubExp] -> [Maybe PrimValue] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> Bool
match [SubExp]
ses ([Maybe PrimValue] -> [Bool])
-> (Case a -> [Maybe PrimValue]) -> Case a -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case a -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat
  where
    match :: SubExp -> Maybe PrimValue -> Bool
match SubExp
se (Just PrimValue
v) = SubExp
se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> SubExp
Constant PrimValue
v
    match SubExp
_ Maybe PrimValue
Nothing = Bool
True

-- Can this case never match the scrutinees?
caseNeverMatches :: [SubExp] -> Case a -> Bool
caseNeverMatches :: forall a. [SubExp] -> Case a -> Bool
caseNeverMatches [SubExp]
ses = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> (Case a -> [Bool]) -> Case a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Maybe PrimValue -> Bool)
-> [SubExp] -> [Maybe PrimValue] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> Bool
impossible [SubExp]
ses ([Maybe PrimValue] -> [Bool])
-> (Case a -> [Maybe PrimValue]) -> Case a -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case a -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat
  where
    impossible :: SubExp -> Maybe PrimValue -> Bool
impossible (Constant PrimValue
v1) (Just PrimValue
v2) = PrimValue
v1 PrimValue -> PrimValue -> Bool
forall a. Eq a => a -> a -> Bool
/= PrimValue
v2
    impossible SubExp
_ Maybe PrimValue
_ = Bool
False

ruleMatch :: (BuilderOps rep) => TopDownRuleMatch rep
-- Remove impossible cases.
ruleMatch :: forall rep. BuilderOps rep => TopDownRuleMatch rep
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec (BranchType rep)
ifdec)
  | ([Case (Body rep)]
impossible, [Case (Body rep)]
cases') <- (Case (Body rep) -> Bool)
-> [Case (Body rep)] -> ([Case (Body rep)], [Case (Body rep)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ([SubExp] -> Case (Body rep) -> Bool
forall a. [SubExp] -> Case a -> Bool
caseNeverMatches [SubExp]
cond) [Case (Body rep)]
cases,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Case (Body rep)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Case (Body rep)]
impossible =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' Body rep
defbody MatchDec (BranchType rep)
ifdec
-- Find new default case.
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
_, MatchDec (BranchType rep)
ifdec)
  | ([Case (Body rep)]
always_matches, [Case (Body rep)]
cases') <- (Case (Body rep) -> Bool)
-> [Case (Body rep)] -> ([Case (Body rep)], [Case (Body rep)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ([SubExp] -> Case (Body rep) -> Bool
forall a. [SubExp] -> Case a -> Bool
caseAlwaysMatches [SubExp]
cond) [Case (Body rep)]
cases,
    Case (Body rep)
new_default : [Case (Body rep)]
_ <- [Case (Body rep)] -> [Case (Body rep)]
forall a. [a] -> [a]
reverse [Case (Body rep)]
always_matches =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases' (Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody Case (Body rep)
new_default) MatchDec (BranchType rep)
ifdec
-- Remove caseless match.
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
_ ExpDec rep
_) ([SubExp]
_, [], Body rep
defbody, MatchDec (BranchType rep)
_) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  Result
defbody_res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
Body (Rep (RuleM rep))
defbody
  Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(PatElem (LetDec rep), SubExpRes)]
-> ((PatElem (LetDec rep), SubExpRes) -> RuleM rep ())
-> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
defbody_res) (((PatElem (LetDec rep), SubExpRes) -> RuleM rep ())
 -> RuleM rep ())
-> ((PatElem (LetDec rep), SubExpRes) -> RuleM rep ())
-> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(PatElem (LetDec rep)
pe, SubExpRes
res) ->
    Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (SubExpRes -> Certs
resCerts SubExpRes
res) (RuleM rep () -> RuleM rep ())
-> (Exp rep -> RuleM rep ()) -> Exp rep -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) (Exp rep -> RuleM rep ()) -> Exp rep -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExpRes -> SubExp
resSubExp SubExpRes
res)
-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleMatch
  TopDown rep
_
  Pat (LetDec rep)
pat
  StmAux (ExpDec rep)
_
  ( [SubExp
cond],
    [ Case
        [Just (BoolValue Bool
True)]
        (Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
tcs (Constant (BoolValue Bool
True))])
      ],
    Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
fcs SubExp
se],
    MatchDec [BranchType rep]
ts MatchSort
_
    )
    | Seq (Stm rep) -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
tstms,
      Seq (Stm rep) -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
fstms,
      [Prim PrimType
Bool] <- (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> [BranchType rep] -> [TypeBase ExtShape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ts =
        RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp
cond], [Case [Just (BoolValue Bool
True)] Body rep
tb], Body rep
fb, MatchDec [BranchType rep]
ts MatchSort
_)
  | Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
tcs SubExp
tres] <- Body rep
tb,
    Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
fcs SubExp
fres] <- Body rep
fb,
    (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ Seq (Stm rep)
tstms Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> Seq (Stm rep)
fstms,
    (BranchType rep -> Bool) -> [BranchType rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((TypeBase ExtShape NoUniqueness
-> TypeBase ExtShape NoUniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> TypeBase ExtShape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (TypeBase ExtShape NoUniqueness -> Bool)
-> (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> BranchType rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf) [BranchType rep]
ts = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm rep)
Stms (Rep (RuleM rep))
tstms
      Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Seq (Stm rep)
Stms (Rep (RuleM rep))
fstms
      Exp rep
e <-
        BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
          BinOp
LogOr
          (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep))))
-> Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
          ( BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              BinOp
LogAnd
              (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep))))
-> Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
              (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep))))
-> Exp (Rep (RuleM rep)) -> RuleM rep (Exp (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
          )
      Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
tcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
fcs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat Exp rep
Exp (Rep (RuleM rep))
e
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
_, [Case [Maybe PrimValue]
_ Body rep
tbranch], Body rep
_, MatchDec [BranchType rep]
_ MatchSort
MatchFallback)
  | (Stm rep -> Bool) -> Seq (Stm rep) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp rep -> Bool
forall rep. ASTRep rep => Exp rep -> Bool
safeExp (Exp rep -> Bool) -> (Stm rep -> Exp rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp) (Seq (Stm rep) -> Bool) -> Seq (Stm rep) -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      let ses :: Result
ses = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
tbranch
      Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
tbranch
      [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
        [ Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
          | (PatElem (LetDec rep)
p, SubExpRes Certs
cs SubExp
se) <- [PatElem (LetDec rep)]
-> Result -> [(PatElem (LetDec rep), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) Result
ses
        ]
ruleMatch TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp
cond], [Case [Just (BoolValue Bool
True)] Body rep
tb], Body rep
fb, MatchDec (BranchType rep)
_)
  | Body BodyDec rep
_ Seq (Stm rep)
_ [SubExpRes Certs
tcs (Constant (IntValue IntValue
t))] <- Body rep
tb,
    Body BodyDec rep
_ Seq (Stm rep)
_ [SubExpRes Certs
fcs (Constant (IntValue IntValue
f))] <- Body rep
fb =
      if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f Bool -> Bool -> Bool
&& Certs
tcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty Bool -> Bool -> Bool
&& Certs
fcs Certs -> Certs -> Bool
forall a. Eq a => a -> a -> Bool
== Certs
forall a. Monoid a => a
mempty
        then
          RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (BasicOp -> RuleM rep ()) -> BasicOp -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule rep) -> BasicOp -> Rule rep
forall a b. (a -> b) -> a -> b
$
            ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
        else
          if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
            then RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
              SubExp
cond_neg <- String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
              Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
            else Rule rep
forall rep. Rule rep
Skip
-- Simplify
--
--   let z = if c then x else y
--
-- to
--
--   let z = y
--
-- in the case where 'x' is a loop parameter with initial value 'y'
-- and the new value of the loop parameter is 'z'.  ('x' and 'y' can
-- be flipped.)
ruleMatch TopDown rep
vtable (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
aux ([SubExp]
_c, [Case [Maybe PrimValue]
_ Body rep
tb], Body rep
fb, MatchDec [BranchType rep
_] MatchSort
_)
  | Body BodyDec rep
_ Seq (Stm rep)
tstms [SubExpRes Certs
xcs SubExp
x] <- Body rep
tb,
    Seq (Stm rep) -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
tstms,
    Body BodyDec rep
_ Seq (Stm rep)
fstms [SubExpRes Certs
ycs SubExp
y] <- Body rep
fb,
    Seq (Stm rep) -> Bool
forall a. Seq a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm rep)
fstms,
    SubExp -> SubExp -> Bool
matches SubExp
x SubExp
y Bool -> Bool -> Bool
|| SubExp -> SubExp -> Bool
matches SubExp
y SubExp
x =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
xcs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
ycs) (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]) (BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
y)
  where
    z :: VName
z = PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe
    matches :: SubExp -> SubExp -> Bool
matches (Var VName
x) SubExp
y
      | Just (SubExp
initial, SubExp
res) <- VName -> TopDown rep -> Maybe (SubExp, SubExp)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
x TopDown rep
vtable =
          SubExp
initial SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y Bool -> Bool -> Bool
&& SubExp
res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
z
    matches SubExp
_ SubExp
_ = Bool
False
ruleMatch TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([SubExp], [Case (Body rep)], Body rep, MatchDec (BranchType rep))
_ = Rule rep
forall rep. Rule rep
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results used for
-- existentials), or the same in both branches.
hoistBranchInvariant :: (BuilderOps rep) => TopDownRuleMatch rep
hoistBranchInvariant :: forall rep. BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec [BranchType rep]
ret MatchSort
ifsort) =
  let case_reses :: [Result]
case_reses = (Case (Body rep) -> Result) -> [Case (Body rep)] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Body rep -> Result)
-> (Case (Body rep) -> Body rep) -> Case (Body rep) -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
      defbody_res :: Result
defbody_res = Body rep -> Result
forall rep. Body rep -> Result
bodyResult Body rep
defbody
      ([RuleM rep (Int, SubExp)]
hoistings, ([PatElem (LetDec rep)]
pes, [BranchType rep]
ts, [Result]
case_reses_tr, Result
defbody_res')) =
        (([(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
 -> ([PatElem (LetDec rep)], [BranchType rep], [Result], Result))
-> ([RuleM rep (Int, SubExp)],
    [(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)])
-> ([RuleM rep (Int, SubExp)],
    ([PatElem (LetDec rep)], [BranchType rep], [Result], Result))
forall a b.
(a -> b)
-> ([RuleM rep (Int, SubExp)], a) -> ([RuleM rep (Int, SubExp)], b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
-> ([PatElem (LetDec rep)], [BranchType rep], [Result], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 (([RuleM rep (Int, SubExp)],
  [(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)])
 -> ([RuleM rep (Int, SubExp)],
     ([PatElem (LetDec rep)], [BranchType rep], [Result], Result)))
-> ([Either
       (RuleM rep (Int, SubExp))
       (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
    -> ([RuleM rep (Int, SubExp)],
        [(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]))
-> [Either
      (RuleM rep (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
-> ([RuleM rep (Int, SubExp)],
    ([PatElem (LetDec rep)], [BranchType rep], [Result], Result))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (RuleM rep (Int, SubExp))
   (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
-> ([RuleM rep (Int, SubExp)],
    [(PatElem (LetDec rep), BranchType rep, Result, SubExpRes)])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) ([Either
    (RuleM rep (Int, SubExp))
    (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
 -> ([RuleM rep (Int, SubExp)],
     ([PatElem (LetDec rep)], [BranchType rep], [Result], Result)))
-> ([(Int, PatElem (LetDec rep), BranchType rep, Result,
      SubExpRes)]
    -> [Either
          (RuleM rep (Int, SubExp))
          (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)])
-> [(Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
-> ([RuleM rep (Int, SubExp)],
    ([PatElem (LetDec rep)], [BranchType rep], [Result], Result))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
 -> Either
      (RuleM rep (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, Result, SubExpRes))
-> [(Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
-> [Either
      (RuleM rep (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
forall a b. (a -> b) -> [a] -> [b]
map (Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
branchInvariant ([(Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
 -> ([RuleM rep (Int, SubExp)],
     ([PatElem (LetDec rep)], [BranchType rep], [Result], Result)))
-> [(Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
-> ([RuleM rep (Int, SubExp)],
    ([PatElem (LetDec rep)], [BranchType rep], [Result], Result))
forall a b. (a -> b) -> a -> b
$
          [Int]
-> [PatElem (LetDec rep)]
-> [BranchType rep]
-> [Result]
-> Result
-> [(Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [Int
0 ..] (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) [BranchType rep]
ret ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
case_reses) Result
defbody_res
   in if [RuleM rep (Int, SubExp)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [RuleM rep (Int, SubExp)]
hoistings
        then Rule rep
forall rep. Rule rep
Skip
        else RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
          [(Int, SubExp)]
ctx_fixes <- [RuleM rep (Int, SubExp)] -> RuleM rep [(Int, SubExp)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [RuleM rep (Int, SubExp)]
hoistings
          let onCase :: Case (Body rep) -> Result -> Case (Body rep)
onCase (Case [Maybe PrimValue]
vs Body rep
body) Result
case_res = [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body rep -> Case (Body rep)) -> Body rep -> Case (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyResult :: Result
bodyResult = Result
case_res}
              cases' :: [Case (Body rep)]
cases' = (Case (Body rep) -> Result -> Case (Body rep))
-> [Case (Body rep)] -> [Result] -> [Case (Body rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Case (Body rep) -> Result -> Case (Body rep)
forall {rep}. Case (Body rep) -> Result -> Case (Body rep)
onCase [Case (Body rep)]
cases ([Result] -> [Case (Body rep)]) -> [Result] -> [Case (Body rep)]
forall a b. (a -> b) -> a -> b
$ [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
case_reses_tr
              defbody' :: Body rep
defbody' = Body rep
defbody {bodyResult :: Result
bodyResult = Result
defbody_res'}
              ret' :: [BranchType rep]
ret' = ((Int, SubExp) -> [BranchType rep] -> [BranchType rep])
-> [BranchType rep] -> [(Int, SubExp)] -> [BranchType rep]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType rep] -> [BranchType rep])
-> (Int, SubExp) -> [BranchType rep] -> [BranchType rep]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType rep] -> [BranchType rep]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [BranchType rep]
ts [(Int, SubExp)]
ctx_fixes
          -- We may have to add some reshapes if we made the type
          -- less existential.
          [Case (Body rep)]
cases'' <- (Case (Body rep) -> RuleM rep (Case (Body rep)))
-> [Case (Body rep)] -> RuleM rep [Case (Body rep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body rep -> RuleM rep (Body rep))
-> Case (Body rep) -> RuleM rep (Case (Body rep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body rep -> RuleM rep (Body rep))
 -> Case (Body rep) -> RuleM rep (Case (Body rep)))
-> (Body rep -> RuleM rep (Body rep))
-> Case (Body rep)
-> RuleM rep (Case (Body rep))
forall a b. (a -> b) -> a -> b
$ [TypeBase ExtShape NoUniqueness]
-> Body (Rep (RuleM rep)) -> RuleM rep (Body (Rep (RuleM rep)))
forall {m :: * -> *}.
MonadBuilder m =>
[TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults ([TypeBase ExtShape NoUniqueness]
 -> Body (Rep (RuleM rep)) -> RuleM rep (Body (Rep (RuleM rep))))
-> [TypeBase ExtShape NoUniqueness]
-> Body (Rep (RuleM rep))
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> [BranchType rep] -> [TypeBase ExtShape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret') [Case (Body rep)]
cases'
          Body rep
defbody'' <- [TypeBase ExtShape NoUniqueness]
-> Body (Rep (RuleM rep)) -> RuleM rep (Body (Rep (RuleM rep)))
forall {m :: * -> *}.
MonadBuilder m =>
[TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults ((BranchType rep -> TypeBase ExtShape NoUniqueness)
-> [BranchType rep] -> [TypeBase ExtShape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret') Body rep
Body (Rep (RuleM rep))
defbody'
          Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pes) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
cases'' Body rep
defbody'' ([BranchType rep] -> MatchSort -> MatchDec (BranchType rep)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
ret' MatchSort
ifsort)
  where
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names)
-> (Seq (Stm rep) -> [VName]) -> Seq (Stm rep) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> [VName]) -> Seq (Stm rep) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) (Seq (Stm rep) -> Names) -> Seq (Stm rep) -> Names
forall a b. (a -> b) -> a -> b
$
        (Case (Body rep) -> Seq (Stm rep))
-> [Case (Body rep)] -> Seq (Stm rep)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Seq (Stm rep))
-> (Case (Body rep) -> Body rep)
-> Case (Body rep)
-> Seq (Stm rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases Seq (Stm rep) -> Seq (Stm rep) -> Seq (Stm rep)
forall a. Semigroup a => a -> a -> a
<> Body rep -> Seq (Stm rep)
forall rep. Body rep -> Stms rep
bodyStms Body rep
defbody

    branchInvariant :: (Int, PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
branchInvariant (Int
i, PatElem (LetDec rep)
pe, BranchType rep
t, Result
case_reses, SubExpRes
defres)
      -- If just one branch has a variant result, then we give up.
      | Names -> Names -> Bool
namesIntersect Names
bound_in_branches (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ SubExpRes
defres SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
case_reses =
          Either
  (RuleM rep (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting
      -- Do all branches return the same value?
      | (SubExpRes -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
defres) (SubExp -> Bool) -> (SubExpRes -> SubExp) -> SubExpRes -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses = RuleM rep (Int, SubExp)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
forall a b. a -> Either a b
Left (RuleM rep (Int, SubExp)
 -> Either
      (RuleM rep (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, Result, SubExpRes))
-> RuleM rep (Int, SubExp)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
forall a b. (a -> b) -> a -> b
$ do
          Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts Result
case_reses Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Certs
resCerts SubExpRes
defres) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe] (Exp rep -> RuleM rep ())
-> (SubExp -> Exp rep) -> SubExp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> (SubExp -> BasicOp) -> SubExp -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> RuleM rep ()) -> SubExp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              SubExpRes -> SubExp
resSubExp SubExpRes
defres
          Int -> PatElem (LetDec rep) -> RuleM rep (Int, SubExp)
forall {f :: * -> *} {a} {dec}.
Applicative f =>
a -> PatElem dec -> f (a, SubExp)
hoisted Int
i PatElem (LetDec rep)
pe

      -- Do all branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> Names -> Bool
namesIntersect Names
bound_in_branches (Names -> Bool) -> Names -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ SubExpRes
defres SubExpRes -> Result -> Result
forall a. a -> [a] -> [a]
: Result
case_reses,
        Pat (LetDec rep) -> Int
forall dec. Pat dec -> Int
patSize Pat (LetDec rep)
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- PatElem (LetDec rep) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec rep)
pe = RuleM rep (Int, SubExp)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
forall a b. a -> Either a b
Left (RuleM rep (Int, SubExp)
 -> Either
      (RuleM rep (Int, SubExp))
      (PatElem (LetDec rep), BranchType rep, Result, SubExpRes))
-> RuleM rep (Int, SubExp)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
forall a b. (a -> b) -> a -> b
$ do
          [BranchType rep]
bt <- Pat (LetDec rep) -> RuleM rep [BranchType rep]
forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pat (LetDec rep) -> m [BranchType rep]
forall (m :: * -> *).
(HasScope rep m, Monad m) =>
Pat (LetDec rep) -> m [BranchType rep]
expTypesFromPat (Pat (LetDec rep) -> RuleM rep [BranchType rep])
-> Pat (LetDec rep) -> RuleM rep [BranchType rep]
forall a b. (a -> b) -> a -> b
$ [PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)
pe]
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe]
            (Exp rep -> RuleM rep ()) -> RuleM rep (Exp rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( [SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond
                    ([Case (Body rep)]
 -> Body rep -> MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep [Case (Body rep)]
-> RuleM rep (Body rep -> MatchDec (BranchType rep) -> Exp rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( ([Maybe PrimValue] -> Body rep -> Case (Body rep))
-> [[Maybe PrimValue]] -> [Body rep] -> [Case (Body rep)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case ((Case (Body rep) -> [Maybe PrimValue])
-> [Case (Body rep)] -> [[Maybe PrimValue]]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body rep) -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body rep)]
cases)
                            ([Body rep] -> [Case (Body rep)])
-> RuleM rep [Body rep] -> RuleM rep [Case (Body rep)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> RuleM rep (Body rep))
-> Result -> RuleM rep [Body rep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([SubExp] -> RuleM rep (Body rep)
[SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> RuleM rep (Body rep))
-> (SubExpRes -> [SubExp]) -> SubExpRes -> RuleM rep (Body rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp])
-> (SubExpRes -> SubExp) -> SubExpRes -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
case_reses
                        )
                    RuleM rep (Body rep -> MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep (Body rep)
-> RuleM rep (MatchDec (BranchType rep) -> Exp rep)
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExpRes -> SubExp
resSubExp SubExpRes
defres]
                    RuleM rep (MatchDec (BranchType rep) -> Exp rep)
-> RuleM rep (MatchDec (BranchType rep)) -> RuleM rep (Exp rep)
forall a b. RuleM rep (a -> b) -> RuleM rep a -> RuleM rep b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MatchDec (BranchType rep) -> RuleM rep (MatchDec (BranchType rep))
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> MatchSort -> MatchDec (BranchType rep)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType rep]
bt MatchSort
ifsort)
                )
          Int -> PatElem (LetDec rep) -> RuleM rep (Int, SubExp)
forall {f :: * -> *} {a} {dec}.
Applicative f =>
a -> PatElem dec -> f (a, SubExp)
hoisted Int
i PatElem (LetDec rep)
pe
      | Bool
otherwise = Either
  (RuleM rep (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting
      where
        noHoisting :: Either
  (RuleM rep (Int, SubExp))
  (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
noHoisting = (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
-> Either
     (RuleM rep (Int, SubExp))
     (PatElem (LetDec rep), BranchType rep, Result, SubExpRes)
forall a b. b -> Either a b
Right (PatElem (LetDec rep)
pe, BranchType rep
t, Result
case_reses, SubExpRes
defres)

    hoisted :: a -> PatElem dec -> f (a, SubExp)
hoisted a
i PatElem dec
pe = (a, SubExp) -> f (a, SubExp)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe)

    reshapeBodyResults :: [TypeBase ExtShape NoUniqueness]
-> Body (Rep m) -> m (Body (Rep m))
reshapeBodyResults [TypeBase ExtShape NoUniqueness]
rets Body (Rep m)
body = m Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (Body (Rep m))) -> m Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      Result
ses <- Body (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body
      let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([TypeBase ExtShape NoUniqueness] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
rets) Result
ses
      (Result
ctx_ses ++) (Result -> Result) -> m Result -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes)
-> Result -> [TypeBase ExtShape NoUniqueness] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
forall {m :: * -> *}.
MonadBuilder m =>
SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
reshapeResult Result
val_ses [TypeBase ExtShape NoUniqueness]
rets
    reshapeResult :: SubExpRes -> TypeBase ExtShape NoUniqueness -> m SubExpRes
reshapeResult (SubExpRes Certs
cs (Var VName
v)) t :: TypeBase ExtShape NoUniqueness
t@Array {} = do
      TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
      let newshape :: [SubExp]
newshape = TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape NoUniqueness
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials TypeBase ExtShape NoUniqueness
t TypeBase Shape NoUniqueness
v_t
      Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs
        (SubExp -> SubExpRes) -> m SubExp -> m SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
          then String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" ([SubExp] -> VName -> Exp (Rep m)
forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newshape VName
v)
          else SubExp -> m SubExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExpRes
se TypeBase ExtShape NoUniqueness
_ =
      SubExpRes -> m SubExpRes
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExpRes
se

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: (BuilderOps rep) => BottomUpRuleMatch rep
removeDeadBranchResult :: forall rep. BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([SubExp]
cond, [Case (Body rep)]
cases, Body rep
defbody, MatchDec [BranchType rep]
rettype MatchSort
ifsort)
  | -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Bool
keep ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) = do
      -- Remove the parts of the branch-results that correspond to dead
      -- return value bindings.  Note that this leaves dead code in the
      -- branch bodies, but that will be removed later.
      let pick :: [a] -> [a]
          pick :: forall a. [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
          pat' :: [PatElem (LetDec rep)]
pat' = [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a. [a] -> [a]
pick ([PatElem (LetDec rep)] -> [PatElem (LetDec rep)])
-> [PatElem (LetDec rep)] -> [PatElem (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
          rettype' :: [BranchType rep]
rettype' = [BranchType rep] -> [BranchType rep]
forall a. [a] -> [a]
pick [BranchType rep]
rettype
          -- We also need to adjust the existential references in the
          -- branch type.
          exts :: [Int]
exts = (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 [if Bool
b then Int
1 else Int
0 | Bool
b <- [Bool]
patused]
          adjust :: BranchType rep -> BranchType rep
adjust = (Int -> Int) -> BranchType rep -> BranchType rep
forall t. FixExt t => (Int -> Int) -> t -> t
mapExt ([Int]
exts !!)
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
        [Case (Body rep)]
cases' <- (Case (Body rep) -> RuleM rep (Case (Body rep)))
-> [Case (Body rep)] -> RuleM rep [Case (Body rep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body rep -> RuleM rep (Body rep))
-> Case (Body rep) -> RuleM rep (Case (Body rep))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse ((Body rep -> RuleM rep (Body rep))
 -> Case (Body rep) -> RuleM rep (Case (Body rep)))
-> (Body rep -> RuleM rep (Body rep))
-> Case (Body rep)
-> RuleM rep (Case (Body rep))
forall a b. (a -> b) -> a -> b
$ (Result -> Result)
-> Body (Rep (RuleM rep)) -> RuleM rep (Body (Rep (RuleM rep)))
forall {m :: * -> *}.
MonadBuilder m =>
(Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody Result -> Result
forall a. [a] -> [a]
pick) [Case (Body rep)]
cases
        Body rep
defbody' <- (Result -> Result)
-> Body (Rep (RuleM rep)) -> RuleM rep (Body (Rep (RuleM rep)))
forall {m :: * -> *}.
MonadBuilder m =>
(Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody Result -> Result
forall a. [a] -> [a]
pick Body rep
Body (Rep (RuleM rep))
defbody
        Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
pat') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body (Rep (RuleM rep)))]
-> Body (Rep (RuleM rep))
-> MatchDec (BranchType (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep))
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body rep)]
[Case (Body (Rep (RuleM rep)))]
cases' Body rep
Body (Rep (RuleM rep))
defbody' (MatchDec (BranchType (Rep (RuleM rep))) -> Exp (Rep (RuleM rep)))
-> MatchDec (BranchType (Rep (RuleM rep))) -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ [BranchType rep] -> MatchSort -> MatchDec (BranchType rep)
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec ((BranchType rep -> BranchType rep)
-> [BranchType rep] -> [BranchType rep]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> BranchType rep
adjust [BranchType rep]
rettype') MatchSort
ifsort
  | Bool
otherwise = Rule rep
forall rep. Rule rep
Skip
  where
    usedDirectly :: VName -> Bool
usedDirectly VName
v = VName
v VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used
    usedIndirectly :: VName -> Bool
usedIndirectly VName
v =
      (PatElem (LetDec rep) -> Bool) -> [PatElem (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
        (\PatElem (LetDec rep)
pe -> VName
v VName -> Names -> Bool
`nameIn` PatElem (LetDec rep) -> Names
forall a. FreeIn a => a -> Names
freeIn PatElem (LetDec rep)
pe Bool -> Bool -> Bool
&& VName -> Bool
usedDirectly (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe))
        (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
    keep :: VName -> Bool
keep VName
v = VName -> Bool
usedDirectly VName
v Bool -> Bool -> Bool
|| VName -> Bool
usedIndirectly VName
v

    onBody :: (Result -> Result) -> Body (Rep m) -> m (Body (Rep m))
onBody Result -> Result
pick (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res) = Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (Body (Rep m))) -> Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ Result -> Result
pick Result
res

topDownRules :: (BuilderOps rep) => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
  [ RuleMatch rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch RuleMatch rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleMatch rep
ruleMatch,
    RuleMatch rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch RuleMatch rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleMatch rep
hoistBranchInvariant
  ]

bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ RuleMatch rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleMatch rep a -> SimplificationRule rep a
RuleMatch RuleMatch rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleMatch rep
removeDeadBranchResult
  ]

matchRules :: (BuilderOps rep) => RuleBook rep
matchRules :: forall rep. BuilderOps rep => RuleBook rep
matchRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules