{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module defines a collection of simplification rules, as per
-- "Futhark.Optimise.Simplify.Rule".  They are used in the
-- simplifier.
--
-- For performance reasons, many sufficiently simple logically
-- separate rules are merged into single "super-rules", like ruleIf
-- and ruleBasicOp.  This is because it is relatively expensive to
-- activate a rule just to determine that it does not apply.  Thus, it
-- is more efficient to have a few very fat rules than a lot of small
-- rules.  This does not affect the compiler result in any way; it is
-- purely an optimisation to speed up compilation.
module Futhark.Optimise.Simplify.Rules
  ( standardRules,
    removeUnnecessaryCopy,
  )
where

import Control.Monad
import Data.Either
import Data.List (find, unzip4, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.BasicOp
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Util

topDownRules :: BinderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BinderOps rep => [TopDownRule rep]
topDownRules =
  [ RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun,
    RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleIf rep
ruleIf,
    RuleIf rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleIf rep
hoistBranchInvariant,
    RuleGeneric rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleGeneric rep
withAccTopDown
  ]

bottomUpRules :: BinderOps rep => [BottomUpRule rep]
bottomUpRules :: forall rep. BinderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ RuleIf rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleIf rep a -> SimplificationRule rep a
RuleIf RuleIf rep (BottomUp rep)
forall rep. BinderOps rep => BottomUpRuleIf rep
removeDeadBranchResult,
    RuleGeneric rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleGeneric rep a -> SimplificationRule rep a
RuleGeneric RuleGeneric rep (BottomUp rep)
forall rep. BinderOps rep => BottomUpRuleGeneric rep
withAccBottomUp,
    RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BinderOps rep => BottomUpRuleBasicOp rep
simplifyIndex
  ]

-- | A set of standard simplification rules.  These assume pure
-- functional semantics, and so probably should not be applied after
-- memory block merging.
standardRules :: (BinderOps rep, Aliased rep) => RuleBook rep
standardRules :: forall rep. (BinderOps rep, Aliased rep) => RuleBook rep
standardRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BinderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BinderOps rep => [BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BinderOps rep, Aliased rep) => RuleBook rep
loopRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BinderOps rep, Aliased rep) => RuleBook rep
basicOpRules

-- | Turn @copy(x)@ into @x@ iff @x@ is not used after this copy
-- statement and it can be consumed.
--
-- This simplistic rule is only valid before we introduce memory.
removeUnnecessaryCopy :: (BinderOps rep, Aliased rep) => BottomUpRuleBasicOp rep
removeUnnecessaryCopy :: forall rep. (BinderOps rep, Aliased rep) => BottomUpRuleBasicOp rep
removeUnnecessaryCopy (SymbolTable rep
vtable, UsageTable
used) (Pattern [] [PatElemT (LetDec rep)
d]) StmAux (ExpDec rep)
_ (Copy VName
v)
  | Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
    (Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
d] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  where
    -- We need to make sure we can even consume the original.  The big
    -- missing piece here is that we cannot do copy removal inside of
    -- 'map' and other SOACs, but that is handled by SOAC-specific rules.
    consumable :: Bool
consumable = Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False (Maybe Bool -> Bool) -> Maybe Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
      Entry rep
e <- VName -> SymbolTable rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
v SymbolTable rep
vtable
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Entry rep -> Int
forall rep. Entry rep -> Int
ST.entryDepth Entry rep
e Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== SymbolTable rep -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable rep
vtable
      Entry rep -> Maybe Bool
consumableStm Entry rep
e Maybe Bool -> Maybe Bool -> Maybe Bool
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Entry rep -> Maybe Bool
consumableFParam Entry rep
e
    consumableFParam :: Entry rep -> Maybe Bool
consumableFParam =
      Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool)
-> (Entry rep -> Bool) -> Entry rep -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> (FParamInfo rep -> Bool) -> Maybe (FParamInfo rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> (FParamInfo rep -> TypeBase Shape Uniqueness)
-> FParamInfo rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParamInfo rep -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf) (Maybe (FParamInfo rep) -> Bool)
-> (Entry rep -> Maybe (FParamInfo rep)) -> Entry rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry rep -> Maybe (FParamInfo rep)
forall rep. Entry rep -> Maybe (FParamInfo rep)
ST.entryFParam
    consumableStm :: Entry rep -> Maybe Bool
consumableStm Entry rep
e = do
      PatternT (LetDec rep)
pat <- Stm rep -> PatternT (LetDec rep)
forall rep. Stm rep -> Pattern rep
stmPattern (Stm rep -> PatternT (LetDec rep))
-> Maybe (Stm rep) -> Maybe (PatternT (LetDec rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Entry rep -> Maybe (Stm rep)
forall rep. Entry rep -> Maybe (Stm rep)
ST.entryStm Entry rep
e
      PatElemT (LetDec rep)
pe <- (PatElemT (LetDec rep) -> Bool)
-> [PatElemT (LetDec rep)] -> Maybe (PatElemT (LetDec rep))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (PatElemT (LetDec rep) -> VName)
-> PatElemT (LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName) (PatternT (LetDec rep) -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (LetDec rep)
pat)
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec rep) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElemT (LetDec rep)
pe Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty
      Bool -> Maybe Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
removeUnnecessaryCopy (SymbolTable rep, UsageTable)
_ PatternT (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

constantFoldPrimFun :: BinderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun :: forall rep. BinderOps rep => TopDownRuleGeneric rep
constantFoldPrimFun TopDown rep
_ (Let Pattern rep
pat (StmAux Certificates
cs Attrs
attrs ExpDec rep
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_))
  | Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
    Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
     String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
    Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
  where
    isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

simplifyIndex :: BinderOps rep => BottomUpRuleBasicOp rep
simplifyIndex :: forall rep. BinderOps rep => BottomUpRuleBasicOp rep
simplifyIndex (SymbolTable rep
vtable, UsageTable
used) pat :: Pattern rep
pat@(Pattern [] [PatElemT (LetDec rep)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec rep
_) (Index VName
idd Slice SubExp
inds)
  | Just RuleM rep IndexResult
m <- SymbolTable (Rep (RuleM rep))
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM rep IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable rep
SymbolTable (Rep (RuleM rep))
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    IndexResult
res <- RuleM rep IndexResult
m
    Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
      SubExpResult Certificates
cs' SubExp
se ->
        Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs') (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
      IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
        Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
extra_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
  where
    consumed :: Bool
consumed = PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
    seType :: TypeLookup
seType (Var VName
v) = VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable rep, UsageTable)
_ Pattern rep
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

ruleIf :: BinderOps rep => TopDownRuleIf rep
ruleIf :: forall rep. BinderOps rep => TopDownRuleIf rep
ruleIf TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (SubExp
e1, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
_ IfSort
ifsort)
  | Just BodyT rep
branch <- Maybe (BodyT rep)
checkBranch,
    IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
branch
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
branch
    [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec rep)
p, SubExp
se) <- [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat) Result
ses
      ]
  where
    checkBranch :: Maybe (BodyT rep)
checkBranch
      | SubExp -> Bool
isCt1 SubExp
e1 = BodyT rep -> Maybe (BodyT rep)
forall a. a -> Maybe a
Just BodyT rep
tb
      | SubExp -> Bool
isCt0 SubExp
e1 = BodyT rep -> Maybe (BodyT rep)
forall a. a -> Maybe a
Just BodyT rep
fb
      | Bool
otherwise = Maybe (BodyT rep)
forall a. Maybe a
Nothing

-- IMPROVE: the following two rules can be generalised to work in more
-- cases, especially when the branches have bindings, or return more
-- than one value.
--
-- if c then True else v == c || v
ruleIf
  TopDown rep
_
  Pattern rep
pat
  StmAux (ExpDec rep)
_
  ( SubExp
cond,
    Body BodyDec rep
_ Stms rep
tstms [Constant (BoolValue Bool
True)],
    Body BodyDec rep
_ Stms rep
fstms [SubExp
se],
    IfDec [BranchType rep]
ts IfSort
_
    )
    | Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
tstms,
      Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
fstms,
      [Prim PrimType
Bool] <- (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> [BranchType rep] -> [TypeBase ExtShape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ts =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
-- When type(x)==bool, if c then x else y == (c && x) || (!c && y)
ruleIf TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
ts IfSort
_)
  | Body BodyDec rep
_ Stms rep
tstms [SubExp
tres] <- BodyT rep
tb,
    Body BodyDec rep
_ Stms rep
fstms [SubExp
fres] <- BodyT rep
fb,
    (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ Stms rep
tstms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
fstms,
    (BranchType rep -> Bool) -> [BranchType rep] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((TypeBase ExtShape NoUniqueness
-> TypeBase ExtShape NoUniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> TypeBase ExtShape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (TypeBase ExtShape NoUniqueness -> Bool)
-> (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> BranchType rep
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf) [BranchType rep]
ts = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
tstms
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms Stms rep
Stms (Rep (RuleM rep))
fstms
    ExpT rep
e <-
      BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
        BinOp
LogOr
        (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
        ( BinOp
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
-> RuleM rep (Exp (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
            BinOp
LogAnd
            (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
            (ExpT rep -> RuleM rep (ExpT rep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT rep -> RuleM rep (ExpT rep))
-> ExpT rep -> RuleM rep (ExpT rep)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres)
        )
    Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat ExpT rep
Exp (Rep (RuleM rep))
e
ruleIf TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (SubExp
_, BodyT rep
tbranch, BodyT rep
_, IfDec [BranchType rep]
_ IfSort
IfFallback)
  | [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern rep
pat,
    (Stm rep -> Bool) -> Stms rep -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT rep -> Bool
forall rep. IsOp (Op rep) => Exp rep -> Bool
safeExp (ExpT rep -> Bool) -> (Stm rep -> ExpT rep) -> Stm rep -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp) (Stms rep -> Bool) -> Stms rep -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let ses :: Result
ses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tbranch
    Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBinder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tbranch
    [RuleM rep ()] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_
      [ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
p] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        | (PatElemT (LetDec rep)
p, SubExp
se) <- [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat) Result
ses
      ]
ruleIf TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec (BranchType rep)
_)
  | Body BodyDec rep
_ Stms rep
_ [Constant (IntValue IntValue
t)] <- BodyT rep
tb,
    Body BodyDec rep
_ Stms rep
_ [Constant (IntValue IntValue
f)] <- BodyT rep
fb =
    if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
      then
        RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
      else
        if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
          then RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
            SubExp
cond_neg <- String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
            Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
          else Rule rep
forall rep. Rule rep
Skip
ruleIf TopDown rep
_ Pattern rep
_ StmAux (ExpDec rep)
_ (SubExp, BodyT rep, BodyT rep, IfDec (BranchType rep))
_ = Rule rep
forall rep. Rule rep
Skip

-- | Move out results of a conditional expression whose computation is
-- either invariant to the branches (only done for results in the
-- context), or the same in both branches.
hoistBranchInvariant :: BinderOps rep => TopDownRuleIf rep
hoistBranchInvariant :: forall rep. BinderOps rep => TopDownRuleIf rep
hoistBranchInvariant TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (SubExp
cond, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
ret IfSort
ifsort) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  let tses :: Result
tses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tb
      fses :: Result
fses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
fb
  ([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec rep)]
pes, [Either Int (BranchType rep)]
ts, [(SubExp, SubExp)]
res)) <-
    ([Either
    (Maybe (Int, SubExp))
    (PatElemT (LetDec rep), Either Int (BranchType rep),
     (SubExp, SubExp))]
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
      [(SubExp, SubExp)])))
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
       [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec rep), Either Int (BranchType rep),
   (SubExp, SubExp))]
 -> ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
     [(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec rep), Either Int (BranchType rep),
      (SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
     [(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec rep), Either Int (BranchType rep),
  (SubExp, SubExp))]
-> ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
    [(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
  [(PatElemT (LetDec rep), Either Int (BranchType rep),
    (SubExp, SubExp))])
 -> ([Maybe (Int, SubExp)],
     ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
      [(SubExp, SubExp)])))
-> ([Either
       (Maybe (Int, SubExp))
       (PatElemT (LetDec rep), Either Int (BranchType rep),
        (SubExp, SubExp))]
    -> ([Maybe (Int, SubExp)],
        [(PatElemT (LetDec rep), Either Int (BranchType rep),
          (SubExp, SubExp))]))
-> [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec rep), Either Int (BranchType rep),
       (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
     [(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec rep), Either Int (BranchType rep),
    (SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
    [(PatElemT (LetDec rep), Either Int (BranchType rep),
      (SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
   rep
   [Either
      (Maybe (Int, SubExp))
      (PatElemT (LetDec rep), Either Int (BranchType rep),
       (SubExp, SubExp))]
 -> RuleM
      rep
      ([Maybe (Int, SubExp)],
       ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
        [(SubExp, SubExp)])))
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp))]
-> RuleM
     rep
     ([Maybe (Int, SubExp)],
      ([PatElemT (LetDec rep)], [Either Int (BranchType rep)],
       [(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$
      ((PatElemT (LetDec rep), Either Int (BranchType rep),
  (SubExp, SubExp))
 -> RuleM
      rep
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec rep), Either Int (BranchType rep),
          (SubExp, SubExp))))
-> [(PatElemT (LetDec rep), Either Int (BranchType rep),
     (SubExp, SubExp))]
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec rep), Either Int (BranchType rep),
 (SubExp, SubExp))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec rep), Either Int (BranchType rep),
   (SubExp, SubExp))]
 -> RuleM
      rep
      [Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec rep), Either Int (BranchType rep),
          (SubExp, SubExp))])
-> [(PatElemT (LetDec rep), Either Int (BranchType rep),
     (SubExp, SubExp))]
-> RuleM
     rep
     [Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
        [PatElemT (LetDec rep)]
-> [Either Int (BranchType rep)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec rep), Either Int (BranchType rep),
     (SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3
          (Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat)
          ((Int -> Either Int (BranchType rep))
-> [Int] -> [Either Int (BranchType rep)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType rep)
forall a b. a -> Either a b
Left [Int
0 .. Int
num_ctx Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType rep)]
-> [Either Int (BranchType rep)] -> [Either Int (BranchType rep)]
forall a. [a] -> [a] -> [a]
++ (BranchType rep -> Either Int (BranchType rep))
-> [BranchType rep] -> [Either Int (BranchType rep)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> Either Int (BranchType rep)
forall a b. b -> Either a b
Right [BranchType rep]
ret)
          (Result -> Result -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
tses Result
fses)
  let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
      (Result
tses', Result
fses') = [(SubExp, SubExp)] -> (Result, Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
      tb' :: BodyT rep
tb' = BodyT rep
tb {bodyResult :: Result
bodyResult = Result
tses'}
      fb' :: BodyT rep
fb' = BodyT rep
fb {bodyResult :: Result
bodyResult = Result
fses'}
      ret' :: [BranchType rep]
ret' = ((Int, SubExp) -> [BranchType rep] -> [BranchType rep])
-> [BranchType rep] -> [(Int, SubExp)] -> [BranchType rep]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType rep] -> [BranchType rep])
-> (Int, SubExp) -> [BranchType rep] -> [BranchType rep]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType rep] -> [BranchType rep]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType rep)] -> [BranchType rep]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType rep)]
ts) [(Int, SubExp)]
ctx_fixes
      ([PatElemT (LetDec rep)]
ctx_pes, [PatElemT (LetDec rep)]
val_pes) = Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType rep]
ret') [PatElemT (LetDec rep)]
pes
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings -- Was something hoisted?
    then do
      -- We may have to add some reshapes if we made the type
      -- less existential.
      BodyT rep
tb'' <- Body (Rep (RuleM rep))
-> [TypeBase ExtShape NoUniqueness]
-> RuleM rep (Body (Rep (RuleM rep)))
forall {m :: * -> *}.
MonadBinder m =>
Body (Rep m)
-> [TypeBase ExtShape NoUniqueness] -> m (Body (Rep m))
reshapeBodyResults BodyT rep
Body (Rep (RuleM rep))
tb' ([TypeBase ExtShape NoUniqueness]
 -> RuleM rep (Body (Rep (RuleM rep))))
-> [TypeBase ExtShape NoUniqueness]
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> [BranchType rep] -> [TypeBase ExtShape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret'
      BodyT rep
fb'' <- Body (Rep (RuleM rep))
-> [TypeBase ExtShape NoUniqueness]
-> RuleM rep (Body (Rep (RuleM rep)))
forall {m :: * -> *}.
MonadBinder m =>
Body (Rep m)
-> [TypeBase ExtShape NoUniqueness] -> m (Body (Rep m))
reshapeBodyResults BodyT rep
Body (Rep (RuleM rep))
fb' ([TypeBase ExtShape NoUniqueness]
 -> RuleM rep (Body (Rep (RuleM rep))))
-> [TypeBase ExtShape NoUniqueness]
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ (BranchType rep -> TypeBase ExtShape NoUniqueness)
-> [BranchType rep] -> [TypeBase ExtShape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map BranchType rep -> TypeBase ExtShape NoUniqueness
forall t. ExtTyped t => t -> TypeBase ExtShape NoUniqueness
extTypeOf [BranchType rep]
ret'
      Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec rep)]
ctx_pes [PatElemT (LetDec rep)]
val_pes) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond BodyT rep
tb'' BodyT rep
fb'' ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
ret' IfSort
ifsort)
    else RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify
  where
    num_ctx :: Int
num_ctx = [PatElemT (LetDec rep)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec rep)] -> Int) -> [PatElemT (LetDec rep)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern rep
pat
    bound_in_branches :: Names
bound_in_branches =
      [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
        (Stm rep -> [VName]) -> Stms rep -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern rep -> [VName])
-> (Stm rep -> Pattern rep) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pattern rep
forall rep. Stm rep -> Pattern rep
stmPattern) (Stms rep -> [VName]) -> Stms rep -> [VName]
forall a b. (a -> b) -> a -> b
$
          BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
tb Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms BodyT rep
fb
    mem_sizes :: Names
mem_sizes = [PatElemT (LetDec rep)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec rep)] -> Names)
-> [PatElemT (LetDec rep)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep) -> Bool)
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
forall {shape} {u}. TypeBase shape u -> Bool
isMem (Type -> Bool)
-> (PatElemT (LetDec rep) -> Type) -> PatElemT (LetDec rep) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType) ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)])
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches

    isMem :: TypeBase shape u -> Bool
isMem Mem {} = Bool
True
    isMem TypeBase shape u
_ = Bool
False
    sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes

    branchInvariant :: (PatElemT (LetDec rep), Either Int (BranchType rep),
 (SubExp, SubExp))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp)))
branchInvariant (PatElemT (LetDec rep)
pe, Either Int (BranchType rep)
t, (SubExp
tse, SubExp
fse))
      -- Do both branches return the same value?
      | SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
        PatElemT (LetDec rep)
-> Either Int (BranchType rep)
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp)))
forall {m :: * -> *} {dec} {a} {b} {b}.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec rep)
pe Either Int (BranchType rep)
t

      -- Do both branches return values that are free in the
      -- branch, and are we not the only pattern element?  The
      -- latter is to avoid infinite application of this rule.
      | SubExp -> Bool
invariant SubExp
tse,
        SubExp -> Bool
invariant SubExp
fse,
        Pattern rep -> Int
forall dec. PatternT dec -> Int
patternSize Pattern rep
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
        Prim PrimType
_ <- PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe = do
        [BranchType rep]
bt <- Pattern rep -> RuleM rep [BranchType rep]
forall rep (m :: * -> *).
(ASTRep rep, HasScope rep m, Monad m) =>
Pattern rep -> m [BranchType rep]
expTypesFromPattern (Pattern rep -> RuleM rep [BranchType rep])
-> Pattern rep -> RuleM rep [BranchType rep]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec rep)
pe]
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe]
          (ExpT rep -> RuleM rep ()) -> RuleM rep (ExpT rep) -> RuleM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
cond (BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (BodyT rep)
-> RuleM rep (BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Result -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Rep m))
resultBodyM [SubExp
tse]
                  RuleM rep (BodyT rep -> IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (BodyT rep)
-> RuleM rep (IfDec (BranchType rep) -> ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *). MonadBinder m => Result -> m (Body (Rep m))
resultBodyM [SubExp
fse]
                  RuleM rep (IfDec (BranchType rep) -> ExpT rep)
-> RuleM rep (IfDec (BranchType rep)) -> RuleM rep (ExpT rep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType rep) -> RuleM rep (IfDec (BranchType rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
bt IfSort
ifsort)
              )
        PatElemT (LetDec rep)
-> Either Int (BranchType rep)
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp)))
forall {m :: * -> *} {dec} {a} {b} {b}.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec rep)
pe Either Int (BranchType rep)
t
      | Bool
otherwise =
        Either
  (Maybe (Int, SubExp))
  (PatElemT (LetDec rep), Either Int (BranchType rep),
   (SubExp, SubExp))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (Maybe (Int, SubExp))
   (PatElemT (LetDec rep), Either Int (BranchType rep),
    (SubExp, SubExp))
 -> RuleM
      rep
      (Either
         (Maybe (Int, SubExp))
         (PatElemT (LetDec rep), Either Int (BranchType rep),
          (SubExp, SubExp))))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec rep), Either Int (BranchType rep),
      (SubExp, SubExp))
-> RuleM
     rep
     (Either
        (Maybe (Int, SubExp))
        (PatElemT (LetDec rep), Either Int (BranchType rep),
         (SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep), Either Int (BranchType rep),
 (SubExp, SubExp))
-> Either
     (Maybe (Int, SubExp))
     (PatElemT (LetDec rep), Either Int (BranchType rep),
      (SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec rep)
pe, Either Int (BranchType rep)
t, (SubExp
tse, SubExp
fse))

    hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
    hoisted PatElemT dec
_ Right {} = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing

    reshapeBodyResults :: Body (Rep m)
-> [TypeBase ExtShape NoUniqueness] -> m (Body (Rep m))
reshapeBodyResults Body (Rep m)
body [TypeBase ExtShape NoUniqueness]
rets = m Result -> m (Body (Rep m))
forall (m :: * -> *). MonadBinder m => m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (Body (Rep m))) -> m Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
      Result
ses <- Body (Rep m) -> m Result
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
body
      let (Result
ctx_ses, Result
val_ses) = Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([TypeBase ExtShape NoUniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase ExtShape NoUniqueness]
rets) Result
ses
      (Result
ctx_ses Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++) (Result -> Result) -> m Result -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> TypeBase ExtShape NoUniqueness -> m SubExp)
-> Result -> [TypeBase ExtShape NoUniqueness] -> m Result
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> TypeBase ExtShape NoUniqueness -> m SubExp
forall {m :: * -> *}.
MonadBinder m =>
SubExp -> TypeBase ExtShape NoUniqueness -> m SubExp
reshapeResult Result
val_ses [TypeBase ExtShape NoUniqueness]
rets
    reshapeResult :: SubExp -> TypeBase ExtShape NoUniqueness -> m SubExp
reshapeResult (Var VName
v) t :: TypeBase ExtShape NoUniqueness
t@Array {} = do
      Type
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
      let newshape :: Result
newshape = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims (Type -> Result) -> Type -> Result
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape NoUniqueness -> Type -> Type
removeExistentials TypeBase ExtShape NoUniqueness
t Type
v_t
      if Result
newshape Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
v_t
        then String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Result -> VName -> Exp (Rep m)
forall rep. Result -> VName -> Exp rep
shapeCoerce Result
newshape VName
v
        else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
    reshapeResult SubExp
se TypeBase ExtShape NoUniqueness
_ =
      SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

-- | Remove the return values of a branch, that are not actually used
-- after a branch.  Standard dead code removal can remove the branch
-- if *none* of the return values are used, but this rule is more
-- precise.
removeDeadBranchResult :: BinderOps rep => BottomUpRuleIf rep
removeDeadBranchResult :: forall rep. BinderOps rep => BottomUpRuleIf rep
removeDeadBranchResult (SymbolTable rep
_, UsageTable
used) Pattern rep
pat StmAux (ExpDec rep)
_ (SubExp
e1, BodyT rep
tb, BodyT rep
fb, IfDec [BranchType rep]
rettype IfSort
ifsort)
  | -- Only if there is no existential context...
    Pattern rep -> Int
forall dec. PatternT dec -> Int
patternSize Pattern rep
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType rep]
rettype,
    -- Figure out which of the names in 'pat' are used...
    [Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
    -- Remove the parts of the branch-results that correspond to dead
    -- return value bindings.  Note that this leaves dead code in the
    -- branch bodies, but that will be removed later.
    let tses :: Result
tses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
tb
        fses :: Result
fses = BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT rep
fb
        pick :: [a] -> [a]
        pick :: forall a. [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
        tb' :: BodyT rep
tb' = BodyT rep
tb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
tses}
        fb' :: BodyT rep
fb' = BodyT rep
fb {bodyResult :: Result
bodyResult = Result -> Result
forall a. [a] -> [a]
pick Result
fses}
        pat' :: [PatElemT (LetDec rep)]
pat' = [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)])
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat
        rettype' :: [BranchType rep]
rettype' = [BranchType rep] -> [BranchType rep]
forall a. [a] -> [a]
pick [BranchType rep]
rettype
     in RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec rep)]
pat') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
e1 BodyT rep
tb' BodyT rep
fb' (IfDec (BranchType rep) -> ExpT rep)
-> IfDec (BranchType rep) -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [BranchType rep] -> IfSort -> IfDec (BranchType rep)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType rep]
rettype' IfSort
ifsort
  | Bool
otherwise = Rule rep
forall rep. Rule rep
Skip

withAccTopDown :: BinderOps rep => TopDownRuleGeneric rep
-- A WithAcc with no accumulators is sent to Valhalla.
withAccTopDown :: forall rep. BinderOps rep => TopDownRuleGeneric rep
withAccTopDown TopDown rep
_ (Let Pattern rep
pat StmAux (ExpDec rep)
aux (WithAcc [] Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  Result
lam_res <- Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
  [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat) Result
lam_res) (((VName, SubExp) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExp) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
    [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
-- Identify those results in 'lam' that are free and move them out.
withAccTopDown TopDown rep
vtable (Let Pattern rep
pat StmAux (ExpDec rep)
aux (WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs Lambda rep
lam)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  let ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
        Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Shape, [VName], Maybe (Lambda rep, Result))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
      (Result
acc_res, Result
nonacc_res) =
        Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
      ([PatElemT (LetDec rep)]
acc_pes, [PatElemT (LetDec rep)]
nonacc_pes) =
        Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElemT (LetDec rep)]
 -> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)]))
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat

  -- Look at accumulator results.
  ([[PatElemT (LetDec rep)]]
acc_pes', [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
params', Result
acc_res') <-
    ([Maybe
    ([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, Result))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> RuleM
     rep
     [Maybe
        ([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, Result)),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> RuleM
     rep
     ([[PatElemT (LetDec rep)]],
      [(Shape, [VName], Maybe (Lambda rep, Result))],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, Result)),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, Result))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, Result)),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, Result))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([Maybe
       ([PatElemT (LetDec rep)],
        (Shape, [VName], Maybe (Lambda rep, Result)),
        (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
    -> [([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, Result)),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)])
-> [Maybe
      ([PatElemT (LetDec rep)],
       (Shape, [VName], Maybe (Lambda rep, Result)),
       (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, Result))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe
   ([PatElemT (LetDec rep)],
    (Shape, [VName], Maybe (Lambda rep, Result)),
    (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
forall a. [Maybe a] -> [a]
catMaybes) (RuleM
   rep
   [Maybe
      ([PatElemT (LetDec rep)],
       (Shape, [VName], Maybe (Lambda rep, Result)),
       (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
 -> RuleM
      rep
      ([[PatElemT (LetDec rep)]],
       [(Shape, [VName], Maybe (Lambda rep, Result))],
       [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElemT (LetDec rep)],
      (Shape, [VName], Maybe (Lambda rep, Result)),
      (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
    -> RuleM
         rep
         [Maybe
            ([PatElemT (LetDec rep)],
             (Shape, [VName], Maybe (Lambda rep, Result)),
             (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)])
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> RuleM
     rep
     ([[PatElemT (LetDec rep)]],
      [(Shape, [VName], Maybe (Lambda rep, Result))],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, Result)),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)
 -> RuleM
      rep
      (Maybe
         ([PatElemT (LetDec rep)],
          (Shape, [VName], Maybe (Lambda rep, Result)),
          (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)))
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> RuleM
     rep
     [Maybe
        ([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, Result)),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([PatElemT (LetDec rep)],
 (Shape, [VName], Maybe (Lambda rep, Result)),
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)
-> RuleM
     rep
     (Maybe
        ([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, Result)),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp))
forall {m :: * -> *} {dec} {a} {c} {a} {dec}.
MonadBinder m =>
([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp))
tryMoveAcc ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, Result)),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
 -> RuleM
      rep
      ([[PatElemT (LetDec rep)]],
       [(Shape, [VName], Maybe (Lambda rep, Result))],
       [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> RuleM
     rep
     ([[PatElemT (LetDec rep)]],
      [(Shape, [VName], Maybe (Lambda rep, Result))],
      [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
      [[PatElemT (LetDec rep)]]
-> [(Shape, [VName], Maybe (Lambda rep, Result))]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
        ([Int] -> [PatElemT (LetDec rep)] -> [[PatElemT (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks (((Shape, [VName], Maybe (Lambda rep, Result)) -> Int)
-> [(Shape, [VName], Maybe (Lambda rep, Result))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda rep, Result)) -> Int
forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs) [PatElemT (LetDec rep)]
acc_pes)
        [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs
        ([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
        Result
acc_res
  let ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
params'

  -- Look at non-accumulator results.
  ([PatElemT (LetDec rep)]
nonacc_pes', Result
nonacc_res') <-
    [(PatElemT (LetDec rep), SubExp)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT (LetDec rep), SubExp)]
 -> ([PatElemT (LetDec rep)], Result))
-> ([Maybe (PatElemT (LetDec rep), SubExp)]
    -> [(PatElemT (LetDec rep), SubExp)])
-> [Maybe (PatElemT (LetDec rep), SubExp)]
-> ([PatElemT (LetDec rep)], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PatElemT (LetDec rep), SubExp)]
-> [(PatElemT (LetDec rep), SubExp)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PatElemT (LetDec rep), SubExp)]
 -> ([PatElemT (LetDec rep)], Result))
-> RuleM rep [Maybe (PatElemT (LetDec rep), SubExp)]
-> RuleM rep ([PatElemT (LetDec rep)], Result)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((PatElemT (LetDec rep), SubExp)
 -> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp)))
-> [(PatElemT (LetDec rep), SubExp)]
-> RuleM rep [Maybe (PatElemT (LetDec rep), SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec rep), SubExp)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp))
tryMoveNonAcc ([PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (LetDec rep)]
nonacc_pes Result
nonacc_res)

  Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElemT (LetDec rep)]
nonacc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

  Lambda rep
lam' <-
    [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$
      Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
acc_res' Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
nonacc_res'}

  Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)]
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec rep)]
nonacc_pes')) (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs' Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(Shape, [VName], Maybe (Lambda rep, Result))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs

    tryMoveAcc :: ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp))
tryMoveAcc ([PatElemT dec]
pes, (a
_, [VName]
arrs, c
_), (a
_, Param dec
acc_p), Var VName
v)
      | Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
acc_p VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v = do
        [(PatElemT dec, VName)] -> ((PatElemT dec, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT dec] -> [VName] -> [(PatElemT dec, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT dec]
pes [VName]
arrs) (((PatElemT dec, VName) -> m ()) -> m ())
-> ((PatElemT dec, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT dec
pe, VName
arr) ->
          [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
        Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
forall a. Maybe a
Nothing
    tryMoveAcc ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
x =
      Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
 -> m (Maybe
         ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)))
-> Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
-> m (Maybe
        ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp))
forall a b. (a -> b) -> a -> b
$ ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
-> Maybe ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
forall a. a -> Maybe a
Just ([PatElemT dec], (a, [VName], c), (a, Param dec), SubExp)
x

    tryMoveNonAcc :: (PatElemT (LetDec rep), SubExp)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp))
tryMoveNonAcc (PatElemT (LetDec rep)
pe, Var VName
v)
      | VName
v VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable = do
        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        Maybe (PatElemT (LetDec rep), SubExp)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElemT (LetDec rep), SubExp)
forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElemT (LetDec rep)
pe, Constant PrimValue
v) = do
      [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
      Maybe (PatElemT (LetDec rep), SubExp)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PatElemT (LetDec rep), SubExp)
forall a. Maybe a
Nothing
    tryMoveNonAcc (PatElemT (LetDec rep), SubExp)
x =
      Maybe (PatElemT (LetDec rep), SubExp)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PatElemT (LetDec rep), SubExp)
 -> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp)))
-> Maybe (PatElemT (LetDec rep), SubExp)
-> RuleM rep (Maybe (PatElemT (LetDec rep), SubExp))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec rep), SubExp)
-> Maybe (PatElemT (LetDec rep), SubExp)
forall a. a -> Maybe a
Just (PatElemT (LetDec rep), SubExp)
x
withAccTopDown TopDown rep
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

withAccBottomUp :: BinderOps rep => BottomUpRuleGeneric rep
-- Eliminate dead results.
withAccBottomUp :: forall rep. BinderOps rep => BottomUpRuleGeneric rep
withAccBottomUp (SymbolTable rep
_, UsageTable
utable) (Let Pattern rep
pat StmAux (ExpDec rep)
aux (WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs Lambda rep
lam))
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
pat = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let (Result
acc_res, Result
nonacc_res) =
          Int -> Result -> (Result, Result)
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs (Result -> (Result, Result)) -> Result -> (Result, Result)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
        ([PatElemT (LetDec rep)]
acc_pes, [PatElemT (LetDec rep)]
nonacc_pes) =
          Int
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
num_nonaccs ([PatElemT (LetDec rep)]
 -> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)]))
-> [PatElemT (LetDec rep)]
-> ([PatElemT (LetDec rep)], [PatElemT (LetDec rep)])
forall a b. (a -> b) -> a -> b
$ Pattern rep -> [PatElemT (LetDec rep)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
pat
        ([Param (LParamInfo rep)]
cert_params, [Param (LParamInfo rep)]
acc_params) =
          Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Shape, [VName], Maybe (Lambda rep, Result))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam

    -- Eliminate unused accumulator results
    let ([[PatElemT (LetDec rep)]]
acc_pes', [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs', [(Param (LParamInfo rep), Param (LParamInfo rep))]
param_pairs, Result
acc_res') =
          [([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, Result)),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, Result))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, Result)),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, Result))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> ([([PatElemT (LetDec rep)],
      (Shape, [VName], Maybe (Lambda rep, Result)),
      (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
    -> [([PatElemT (LetDec rep)],
         (Shape, [VName], Maybe (Lambda rep, Result)),
         (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)])
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, Result))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([PatElemT (LetDec rep)],
  (Shape, [VName], Maybe (Lambda rep, Result)),
  (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)
 -> Bool)
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ([PatElemT (LetDec rep)],
 (Shape, [VName], Maybe (Lambda rep, Result)),
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)
-> Bool
keepAccRes ([([PatElemT (LetDec rep)],
   (Shape, [VName], Maybe (Lambda rep, Result)),
   (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
 -> ([[PatElemT (LetDec rep)]],
     [(Shape, [VName], Maybe (Lambda rep, Result))],
     [(Param (LParamInfo rep), Param (LParamInfo rep))], Result))
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
-> ([[PatElemT (LetDec rep)]],
    [(Shape, [VName], Maybe (Lambda rep, Result))],
    [(Param (LParamInfo rep), Param (LParamInfo rep))], Result)
forall a b. (a -> b) -> a -> b
$
            [[PatElemT (LetDec rep)]]
-> [(Shape, [VName], Maybe (Lambda rep, Result))]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> Result
-> [([PatElemT (LetDec rep)],
     (Shape, [VName], Maybe (Lambda rep, Result)),
     (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
              ([Int] -> [PatElemT (LetDec rep)] -> [[PatElemT (LetDec rep)]]
forall a. [Int] -> [a] -> [[a]]
chunks (((Shape, [VName], Maybe (Lambda rep, Result)) -> Int)
-> [(Shape, [VName], Maybe (Lambda rep, Result))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda rep, Result)) -> Int
forall {t :: * -> *} {a} {a} {c}. Foldable t => (a, t a, c) -> Int
inputArrs [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs) [PatElemT (LetDec rep)]
acc_pes)
              [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs
              ([Param (LParamInfo rep)]
-> [Param (LParamInfo rep)]
-> [(Param (LParamInfo rep), Param (LParamInfo rep))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (LParamInfo rep)]
cert_params [Param (LParamInfo rep)]
acc_params)
              Result
acc_res
        ([Param (LParamInfo rep)]
cert_params', [Param (LParamInfo rep)]
acc_params') = [(Param (LParamInfo rep), Param (LParamInfo rep))]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param (LParamInfo rep), Param (LParamInfo rep))]
param_pairs

    -- Eliminate unused non-accumulator results
    let ([PatElemT (LetDec rep)]
nonacc_pes', Result
nonacc_res') =
          [(PatElemT (LetDec rep), SubExp)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT (LetDec rep), SubExp)]
 -> ([PatElemT (LetDec rep)], Result))
-> [(PatElemT (LetDec rep), SubExp)]
-> ([PatElemT (LetDec rep)], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetDec rep), SubExp) -> Bool)
-> [(PatElemT (LetDec rep), SubExp)]
-> [(PatElemT (LetDec rep), SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec rep), SubExp) -> Bool
keepNonAccRes ([(PatElemT (LetDec rep), SubExp)]
 -> [(PatElemT (LetDec rep), SubExp)])
-> [(PatElemT (LetDec rep), SubExp)]
-> [(PatElemT (LetDec rep), SubExp)]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec rep)]
-> Result -> [(PatElemT (LetDec rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (LetDec rep)]
nonacc_pes Result
nonacc_res

    Bool -> RuleM rep () -> RuleM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
acc_pes Bool -> Bool -> Bool
&& [PatElemT (LetDec rep)]
nonacc_pes' [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Bool
forall a. Eq a => a -> a -> Bool
== [PatElemT (LetDec rep)]
nonacc_pes) RuleM rep ()
forall rep a. RuleM rep a
cannotSimplify

    let pes' :: [PatElemT (LetDec rep)]
pes' = [[PatElemT (LetDec rep)]] -> [PatElemT (LetDec rep)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[PatElemT (LetDec rep)]]
acc_pes' [PatElemT (LetDec rep)]
-> [PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)]
forall a. [a] -> [a] -> [a]
++ [PatElemT (LetDec rep)]
nonacc_pes'

    Lambda rep
lam' <- [LParam (Rep (RuleM rep))]
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBinder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param (LParamInfo rep)]
cert_params' [Param (LParamInfo rep)]
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. [a] -> [a] -> [a]
++ [Param (LParamInfo rep)]
acc_params') (RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep))))
-> RuleM rep Result -> RuleM rep (Lambda (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
      RuleM rep Result -> RuleM rep ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (RuleM rep Result -> RuleM rep ())
-> RuleM rep Result -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body (Rep (RuleM rep)) -> RuleM rep Result
forall (m :: * -> *). MonadBinder m => Body (Rep m) -> m Result
bodyBind (Body (Rep (RuleM rep)) -> RuleM rep Result)
-> Body (Rep (RuleM rep)) -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
      Result -> RuleM rep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> RuleM rep Result) -> Result -> RuleM rep Result
forall a b. (a -> b) -> a -> b
$ Result
acc_res' Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
nonacc_res'

    StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind ([PatElemT (LetDec rep)] -> [PatElemT (LetDec rep)] -> Pattern rep
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec rep)]
pes') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
forall rep.
[(Shape, [VName], Maybe (Lambda rep, Result))]
-> Lambda rep -> ExpT rep
WithAcc [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs' Lambda rep
lam'
  where
    num_nonaccs :: Int
num_nonaccs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam) Int -> Int -> Int
forall a. Num a => a -> a -> a
- [(Shape, [VName], Maybe (Lambda rep, Result))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, Result))]
inputs
    inputArrs :: (a, t a, c) -> Int
inputArrs (a
_, t a
arrs, c
_) = t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length t a
arrs
    keepAccRes :: ([PatElemT (LetDec rep)],
 (Shape, [VName], Maybe (Lambda rep, Result)),
 (Param (LParamInfo rep), Param (LParamInfo rep)), SubExp)
-> Bool
keepAccRes ([PatElemT (LetDec rep)]
pes, (Shape, [VName], Maybe (Lambda rep, Result))
_, (Param (LParamInfo rep), Param (LParamInfo rep))
_, SubExp
_) = (PatElemT (LetDec rep) -> Bool) -> [PatElemT (LetDec rep)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> UsageTable -> Bool
`UT.used` UsageTable
utable) (VName -> Bool)
-> (PatElemT (LetDec rep) -> VName)
-> PatElemT (LetDec rep)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT (LetDec rep)]
pes
    keepNonAccRes :: (PatElemT (LetDec rep), SubExp) -> Bool
keepNonAccRes (PatElemT (LetDec rep)
pe, SubExp
_) = PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> UsageTable -> Bool
`UT.used` UsageTable
utable
withAccBottomUp (SymbolTable rep, UsageTable)
_ Stm rep
_ = Rule rep
forall rep. Rule rep
Skip

-- Some helper functions

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False