-- | Loop simplification rules.
module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where

import Control.Monad
import Data.Bifunctor (second)
import Data.List (partition)
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Transform.Rename

-- This next one is tricky - it's easy enough to determine that some
-- loop result is not used after the loop, but here, we must also make
-- sure that it does not affect any other values.
--
-- I do not claim that the current implementation of this rule is
-- perfect, but it should suffice for many cases, and should never
-- generate wrong code.
removeRedundantMergeVariables :: (BuilderOps rep) => BottomUpRuleLoop rep
removeRedundantMergeVariables :: forall rep. BuilderOps rep => BottomUpRuleLoop rep
removeRedundantMergeVariables (SymbolTable rep
_, UsageTable
used) Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(Param (FParamInfo rep), SubExp)]
merge, LoopForm
form, Body rep
body)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), SubExp) -> Bool)
-> [(Param (FParamInfo rep), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (FParamInfo rep) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop (Param (FParamInfo rep) -> Bool)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
merge =
      let necessaryForReturned :: Names
necessaryForReturned =
            (Param (FParamInfo rep) -> Bool)
-> [(Param (FParamInfo rep), SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
              Param (FParamInfo rep) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoopOrInForm
              ([Param (FParamInfo rep)]
-> [SubExp] -> [(Param (FParamInfo rep), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> [(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst [(Param (FParamInfo rep), SubExp)]
merge) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
body))
              (Body rep -> Map VName Names
forall rep. ASTRep rep => Body rep -> Map VName Names
dataDependencies Body rep
body)

          resIsNecessary :: ((Param dec, b), b) -> Bool
resIsNecessary ((Param dec
v, b
_), b
_) =
            Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
v
              Bool -> Bool -> Bool
|| (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned)
              Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInPat Param dec
v
              Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInForm Param dec
v

          ([(PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))]
keep_valpart, [(PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))]
discard_valpart) =
            ((PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))
 -> Bool)
-> [(PatElem (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExpRes))]
-> ([(PatElem (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExpRes))],
    [(PatElem (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExpRes))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((Param (FParamInfo rep), SubExp), SubExpRes) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary (((Param (FParamInfo rep), SubExp), SubExpRes) -> Bool)
-> ((PatElem (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExpRes))
    -> ((Param (FParamInfo rep), SubExp), SubExpRes))
-> (PatElem (LetDec rep),
    ((Param (FParamInfo rep), SubExp), SubExpRes))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (LetDec rep),
 ((Param (FParamInfo rep), SubExp), SubExpRes))
-> ((Param (FParamInfo rep), SubExp), SubExpRes)
forall a b. (a, b) -> b
snd) ([(PatElem (LetDec rep),
   ((Param (FParamInfo rep), SubExp), SubExpRes))]
 -> ([(PatElem (LetDec rep),
       ((Param (FParamInfo rep), SubExp), SubExpRes))],
     [(PatElem (LetDec rep),
       ((Param (FParamInfo rep), SubExp), SubExpRes))]))
-> [(PatElem (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExpRes))]
-> ([(PatElem (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExpRes))],
    [(PatElem (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExpRes))])
forall a b. (a -> b) -> a -> b
$
              [PatElem (LetDec rep)]
-> [((Param (FParamInfo rep), SubExp), SubExpRes)]
-> [(PatElem (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExpRes))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) ([((Param (FParamInfo rep), SubExp), SubExpRes)]
 -> [(PatElem (LetDec rep),
      ((Param (FParamInfo rep), SubExp), SubExpRes))])
-> [((Param (FParamInfo rep), SubExp), SubExpRes)]
-> [(PatElem (LetDec rep),
     ((Param (FParamInfo rep), SubExp), SubExpRes))]
forall a b. (a -> b) -> a -> b
$
                [(Param (FParamInfo rep), SubExp)]
-> [SubExpRes] -> [((Param (FParamInfo rep), SubExp), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo rep), SubExp)]
merge ([SubExpRes] -> [((Param (FParamInfo rep), SubExp), SubExpRes)])
-> [SubExpRes] -> [((Param (FParamInfo rep), SubExp), SubExpRes)]
forall a b. (a -> b) -> a -> b
$
                  Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
body

          ([PatElem (LetDec rep)]
keep_valpatelems, [((Param (FParamInfo rep), SubExp), SubExpRes)]
keep_val) = [(PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))]
-> ([PatElem (LetDec rep)],
    [((Param (FParamInfo rep), SubExp), SubExpRes)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))]
keep_valpart
          ([PatElem (LetDec rep)]
_discard_valpatelems, [((Param (FParamInfo rep), SubExp), SubExpRes)]
discard_val) = [(PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))]
-> ([PatElem (LetDec rep)],
    [((Param (FParamInfo rep), SubExp), SubExpRes)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep),
  ((Param (FParamInfo rep), SubExp), SubExpRes))]
discard_valpart
          ([(Param (FParamInfo rep), SubExp)]
merge', [SubExpRes]
val_es') = [((Param (FParamInfo rep), SubExp), SubExpRes)]
-> ([(Param (FParamInfo rep), SubExp)], [SubExpRes])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo rep), SubExp), SubExpRes)]
keep_val

          body' :: Body rep
body' = Body rep
body {bodyResult :: [SubExpRes]
bodyResult = [SubExpRes]
val_es'}

          pat' :: Pat (LetDec rep)
pat' = [PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
keep_valpatelems
       in if [(Param (FParamInfo rep), SubExp)]
merge' [(Param (FParamInfo rep), SubExp)]
-> [(Param (FParamInfo rep), SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(Param (FParamInfo rep), SubExp)]
merge
            then Rule rep
forall rep. Rule rep
Skip
            else RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
              -- We can't just remove the bindings in 'discard', since the loop
              -- body may still use their names in (now-dead) expressions.
              -- Hence, we add them inside the loop, fully aware that dead-code
              -- removal will eventually get rid of them.  Some care is
              -- necessary to handle unique bindings.
              Body rep
body'' <- RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (RuleM rep (Body (Rep (RuleM rep)))
 -> RuleM rep (Body (Rep (RuleM rep))))
-> RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
                (([VName], Exp rep) -> RuleM rep ())
-> [([VName], Exp rep)] -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> Exp rep -> RuleM rep ())
-> ([VName], Exp rep) -> RuleM rep ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> Exp rep -> RuleM rep ()
[VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames) ([([VName], Exp rep)] -> RuleM rep ())
-> [([VName], Exp rep)] -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo rep), SubExp), SubExpRes)]
-> [([VName], Exp rep)]
forall {b} {rep}.
[((Param (FParamInfo rep), SubExp), b)] -> [([VName], Exp rep)]
dummyStms [((Param (FParamInfo rep), SubExp), SubExpRes)]
discard_val
                Body rep -> RuleM rep (Body rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body'
              StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat' (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo rep), SubExp)]
-> LoopForm -> Body rep -> Exp rep
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param (FParamInfo rep), SubExp)]
merge' LoopForm
form Body rep
body''
  where
    pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
    used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo rep), SubExp) -> VName)
-> [(Param (FParamInfo rep), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo rep) -> VName)
-> ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> (Param (FParamInfo rep), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
merge) [Bool]
pat_used
    usedAfterLoop :: Param dec -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
    usedAfterLoopOrInForm :: Param dec -> Bool
usedAfterLoopOrInForm Param dec
p =
      Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
p Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` LoopForm -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm
form
    patAnnotNames :: Names
patAnnotNames = [Param (FParamInfo rep)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([Param (FParamInfo rep)] -> Names)
-> [Param (FParamInfo rep)] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep))
-> [(Param (FParamInfo rep), SubExp)] -> [Param (FParamInfo rep)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo rep), SubExp) -> Param (FParamInfo rep)
forall a b. (a, b) -> a
fst [(Param (FParamInfo rep), SubExp)]
merge
    referencedInPat :: Param dec -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
    referencedInForm :: Param dec -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm
form) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName

    dummyStms :: [((Param (FParamInfo rep), SubExp), b)] -> [([VName], Exp rep)]
dummyStms = (((Param (FParamInfo rep), SubExp), b) -> ([VName], Exp rep))
-> [((Param (FParamInfo rep), SubExp), b)] -> [([VName], Exp rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param (FParamInfo rep), SubExp), b) -> ([VName], Exp rep)
forall {dec} {b} {rep}.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], Exp rep)
dummyStm
    dummyStm :: ((Param dec, SubExp), b) -> ([VName], Exp rep)
dummyStm ((Param dec
p, SubExp
e), b
_)
      | TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
        Var VName
v <- SubExp
e =
          ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v)
      | Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(Param (FParamInfo rep), SubExp)], LoopForm, Body rep)
_ =
  Rule rep
forall rep. Rule rep
Skip

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: (BuilderOps rep) => TopDownRuleLoop rep
hoistLoopInvariantMergeVariables :: forall rep. BuilderOps rep => TopDownRuleLoop rep
hoistLoopInvariantMergeVariables TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, LoopForm
form, Body rep
loopbody) = do
  -- Figure out which of the elements of loopresult are
  -- loop-invariant, and hoist them out.
  let explpat :: [(PatElem (LetDec rep), VName)]
explpat = [PatElem (LetDec rep)]
-> [VName] -> [(PatElem (LetDec rep), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) ([VName] -> [(PatElem (LetDec rep), VName)])
-> [VName] -> [(PatElem (LetDec rep), VName)]
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge
  case ((VName, (FParam rep, SubExp), SubExpRes)
 -> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
     [(FParam rep, SubExp)], [SubExpRes])
 -> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
     [(FParam rep, SubExp)], [SubExpRes]))
-> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExpRes])
-> [(VName, (FParam rep, SubExp), SubExpRes)]
-> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExpRes])
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VName, (FParam rep, SubExp), SubExpRes)
-> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExpRes])
-> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExpRes])
forall {dec} {dec}.
(FreeIn dec, Typed dec, Typed dec) =>
(VName, (Param dec, SubExp), SubExpRes)
-> ([(Ident, (SubExp, Certs))], [(PatElem dec, VName)],
    [(Param dec, SubExp)], [SubExpRes])
-> ([(Ident, (SubExp, Certs))], [(PatElem dec, VName)],
    [(Param dec, SubExp)], [SubExpRes])
checkInvariance ([], [(PatElem (LetDec rep), VName)]
explpat, [], []) ([(VName, (FParam rep, SubExp), SubExpRes)]
 -> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
     [(FParam rep, SubExp)], [SubExpRes]))
-> [(VName, (FParam rep, SubExp), SubExpRes)]
-> ([(Ident, (SubExp, Certs))], [(PatElem (LetDec rep), VName)],
    [(FParam rep, SubExp)], [SubExpRes])
forall a b. (a -> b) -> a -> b
$
    [VName]
-> [(FParam rep, SubExp)]
-> [SubExpRes]
-> [(VName, (FParam rep, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [(FParam rep, SubExp)]
merge [SubExpRes]
res of
    ([], [(PatElem (LetDec rep), VName)]
_, [(FParam rep, SubExp)]
_, [SubExpRes]
_) ->
      -- Nothing is invariant.
      Rule rep
forall rep. Rule rep
Skip
    ([(Ident, (SubExp, Certs))]
invariant, [(PatElem (LetDec rep), VName)]
explpat', [(FParam rep, SubExp)]
merge', [SubExpRes]
res') -> RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      -- We have moved something invariant out of the loop.
      let loopbody' :: Body rep
loopbody' = Body rep
loopbody {bodyResult :: [SubExpRes]
bodyResult = [SubExpRes]
res'}
          explpat'' :: [PatElem (LetDec rep)]
explpat'' = ((PatElem (LetDec rep), VName) -> PatElem (LetDec rep))
-> [(PatElem (LetDec rep), VName)] -> [PatElem (LetDec rep)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElem (LetDec rep), VName) -> PatElem (LetDec rep)
forall a b. (a, b) -> a
fst [(PatElem (LetDec rep), VName)]
explpat'
      [(Ident, (SubExp, Certs))]
-> ((Ident, (SubExp, Certs)) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Ident, (SubExp, Certs))]
invariant (((Ident, (SubExp, Certs)) -> RuleM rep ()) -> RuleM rep ())
-> ((Ident, (SubExp, Certs)) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1, (SubExp
v2, Certs
cs)) ->
        Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
      Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind ([PatElem (LetDec rep)] -> Pat (LetDec rep)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
explpat'') (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam rep, SubExp)]
merge' LoopForm
form Body rep
loopbody'
  where
    res :: [SubExpRes]
res = Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
loopbody

    namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExp) -> VName)
-> [(FParam rep, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep -> VName
forall dec. Param dec -> VName
paramName (FParam rep -> VName)
-> ((FParam rep, SubExp) -> FParam rep)
-> (FParam rep, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge

    removeFromResult :: b
-> (Param dec, a)
-> [(PatElem dec, VName)]
-> (Maybe (Ident, (a, b)), [(PatElem dec, VName)])
removeFromResult b
cs (Param dec
mergeParam, a
mergeInit) [(PatElem dec, VName)]
explpat' =
      case ((PatElem dec, VName) -> Bool)
-> [(PatElem dec, VName)]
-> ([(PatElem dec, VName)], [(PatElem dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElem dec, VName) -> VName) -> (PatElem dec, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElem dec, VName)]
explpat' of
        ([(PatElem dec
patelem, VName
_)], [(PatElem dec, VName)]
rest) ->
          ((Ident, (a, b)) -> Maybe (Ident, (a, b))
forall a. a -> Maybe a
Just (PatElem dec -> Ident
forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem dec
patelem, (a
mergeInit, b
cs)), [(PatElem dec, VName)]
rest)
        ([(PatElem dec, VName)]
_, [(PatElem dec, VName)]
_) ->
          (Maybe (Ident, (a, b))
forall a. Maybe a
Nothing, [(PatElem dec, VName)]
explpat')

    checkInvariance :: (VName, (Param dec, SubExp), SubExpRes)
-> ([(Ident, (SubExp, Certs))], [(PatElem dec, VName)],
    [(Param dec, SubExp)], [SubExpRes])
-> ([(Ident, (SubExp, Certs))], [(PatElem dec, VName)],
    [(Param dec, SubExp)], [SubExpRes])
checkInvariance
      (VName
pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExpRes
resExp)
      ([(Ident, (SubExp, Certs))]
invariant, [(PatElem dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExpRes]
resExps)
        | Bool
isInvariant,
          -- Certificates must be available.
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` TopDown rep
vtable) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ SubExpRes -> Certs
resCerts SubExpRes
resExp =
            let (Maybe (Ident, (SubExp, Certs))
stm, [(PatElem dec, VName)]
explpat'') =
                  Certs
-> (Param dec, SubExp)
-> [(PatElem dec, VName)]
-> (Maybe (Ident, (SubExp, Certs)), [(PatElem dec, VName)])
forall {dec} {b} {dec} {a}.
Typed dec =>
b
-> (Param dec, a)
-> [(PatElem dec, VName)]
-> (Maybe (Ident, (a, b)), [(PatElem dec, VName)])
removeFromResult
                    (SubExpRes -> Certs
resCerts SubExpRes
resExp)
                    (Param dec
mergeParam, SubExp
mergeInit)
                    [(PatElem dec, VName)]
explpat'
             in ( ([(Ident, (SubExp, Certs))] -> [(Ident, (SubExp, Certs))])
-> ((Ident, (SubExp, Certs))
    -> [(Ident, (SubExp, Certs))] -> [(Ident, (SubExp, Certs))])
-> Maybe (Ident, (SubExp, Certs))
-> [(Ident, (SubExp, Certs))]
-> [(Ident, (SubExp, Certs))]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, (SubExp, Certs))] -> [(Ident, (SubExp, Certs))]
forall a. a -> a
id (:) Maybe (Ident, (SubExp, Certs))
stm ([(Ident, (SubExp, Certs))] -> [(Ident, (SubExp, Certs))])
-> [(Ident, (SubExp, Certs))] -> [(Ident, (SubExp, Certs))]
forall a b. (a -> b) -> a -> b
$
                    (Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
mergeParam, (SubExp
mergeInit, SubExpRes -> Certs
resCerts SubExpRes
resExp)) (Ident, (SubExp, Certs))
-> [(Ident, (SubExp, Certs))] -> [(Ident, (SubExp, Certs))]
forall a. a -> [a] -> [a]
: [(Ident, (SubExp, Certs))]
invariant,
                  [(PatElem dec, VName)]
explpat'',
                  [(Param dec, SubExp)]
merge',
                  [SubExpRes]
resExps
                )
        where
          -- A non-unique merge variable is invariant if one of the
          -- following is true:
          isInvariant :: Bool
isInvariant
            -- (0) The result is a variable of the same name as the
            -- parameter, where all existential parameters are already
            -- known to be invariant
            | Var VName
v2 <- SubExpRes -> SubExp
resSubExp SubExpRes
resExp,
              Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
                Names -> Param dec -> Bool
forall {dec}. FreeIn dec => Names -> Param dec -> Bool
allExistentialInvariant
                  ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, (SubExp, Certs)) -> VName)
-> [(Ident, (SubExp, Certs))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, (SubExp, Certs)) -> Ident)
-> (Ident, (SubExp, Certs))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, (SubExp, Certs)) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, (SubExp, Certs))]
invariant)
                  Param dec
mergeParam
            -- (1) The result is identical to the initial parameter value.
            | SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
resExp = Bool
True
            -- (2) The initial parameter value is equal to an outer
            -- loop parameter 'P', where the initial value of 'P' is
            -- equal to 'resExp', AND 'resExp' ultimately becomes the
            -- new value of 'P'.  XXX: it's a bit clumsy that this
            -- only works for one level of nesting, and I think it
            -- would not be too hard to generalise.
            | Var VName
init_v <- SubExp
mergeInit,
              Just (SubExp
p_init, SubExp
p_res) <- VName -> TopDown rep -> Maybe (SubExp, SubExp)
forall rep. VName -> SymbolTable rep -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown rep
vtable,
              SubExp
p_init SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
resExp,
              SubExp
p_res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
                Bool
True
            -- (3) It is a statically empty array.
            | Maybe (PrimType, Shape) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (PrimType, Shape) -> Bool)
-> Maybe (PrimType, Shape) -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Maybe (PrimType, Shape)
isEmptyArray (Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param dec
mergeParam) = Bool
True
            | Bool
otherwise = Bool
False
    checkInvariance
      (VName
_pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExpRes
resExp)
      ([(Ident, (SubExp, Certs))]
invariant, [(PatElem dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExpRes]
resExps) =
        ([(Ident, (SubExp, Certs))]
invariant, [(PatElem dec, VName)]
explpat', (Param dec
mergeParam, SubExp
mergeInit) (Param dec, SubExp)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. a -> [a] -> [a]
: [(Param dec, SubExp)]
merge', SubExpRes
resExp SubExpRes -> [SubExpRes] -> [SubExpRes]
forall a. a -> [a] -> [a]
: [SubExpRes]
resExps)

    allExistentialInvariant :: Names -> Param dec -> Bool
allExistentialInvariant Names
namesOfInvariant Param dec
mergeParam =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
        Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          Param dec -> Names
forall a. FreeIn a => a -> Names
freeIn Param dec
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam)
    invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
      (VName
name VName -> Names -> Bool
`notNameIn` Names
namesOfMergeParams)
        Bool -> Bool -> Bool
|| (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant)

simplifyClosedFormLoop :: (BuilderOps rep) => TopDownRuleLoop rep
simplifyClosedFormLoop :: forall rep. BuilderOps rep => TopDownRuleLoop rep
simplifyClosedFormLoop TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound, Body rep
body) =
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
forall rep.
BuilderOps rep =>
Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pat (LetDec rep)
pat [(FParam rep, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound Body rep
body
simplifyClosedFormLoop TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], LoopForm, Body rep)
_ = Rule rep
forall rep. Rule rep
Skip

unroll ::
  (BuilderOps rep) =>
  Integer ->
  [(FParam rep, SubExpRes)] ->
  (VName, IntType, Integer) ->
  Body rep ->
  RuleM rep [SubExpRes]
unroll :: forall rep.
BuilderOps rep =>
Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> Body rep
-> RuleM rep [SubExpRes]
unroll Integer
n [(FParam rep, SubExpRes)]
merge (VName
iv, IntType
it, Integer
i) Body rep
body
  | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
      [SubExpRes] -> RuleM rep [SubExpRes]
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> RuleM rep [SubExpRes])
-> [SubExpRes] -> RuleM rep [SubExpRes]
forall a b. (a -> b) -> a -> b
$ ((FParam rep, SubExpRes) -> SubExpRes)
-> [(FParam rep, SubExpRes)] -> [SubExpRes]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd [(FParam rep, SubExpRes)]
merge
  | Bool
otherwise = do
      Body rep
iter_body <- RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (RuleM rep (Body (Rep (RuleM rep)))
 -> RuleM rep (Body (Rep (RuleM rep))))
-> RuleM rep (Body (Rep (RuleM rep)))
-> RuleM rep (Body (Rep (RuleM rep)))
forall a b. (a -> b) -> a -> b
$ do
        [(FParam rep, SubExpRes)]
-> ((FParam rep, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, SubExpRes)]
merge (((FParam rep, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((FParam rep, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(FParam rep
mergevar, SubExpRes Certs
cs SubExp
mergeinit) ->
          Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [FParam rep -> VName
forall dec. Param dec -> VName
paramName FParam rep
mergevar] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit

        [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
iv] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i

        -- Some of the sizes in the types here might be temporarily wrong
        -- until copy propagation fixes it up.
        Body rep -> RuleM rep (Body rep)
forall a. a -> RuleM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body

      Body rep
iter_body' <- Body rep -> RuleM rep (Body rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
iter_body
      Stms (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (RuleM rep)) -> RuleM rep ())
-> Stms (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
iter_body'

      let merge' :: [(FParam rep, SubExpRes)]
merge' = [FParam rep] -> [SubExpRes] -> [(FParam rep, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam rep, SubExpRes) -> FParam rep)
-> [(FParam rep, SubExpRes)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExpRes) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExpRes)]
merge) ([SubExpRes] -> [(FParam rep, SubExpRes)])
-> [SubExpRes] -> [(FParam rep, SubExpRes)]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
iter_body'
      Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> Body rep
-> RuleM rep [SubExpRes]
forall rep.
BuilderOps rep =>
Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> Body rep
-> RuleM rep [SubExpRes]
unroll Integer
n [(FParam rep, SubExpRes)]
merge' (VName
iv, IntType
it, Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) Body rep
body

simplifyKnownIterationLoop :: (BuilderOps rep) => TopDownRuleLoop rep
simplifyKnownIterationLoop :: forall rep. BuilderOps rep => TopDownRuleLoop rep
simplifyKnownIterationLoop TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, ForLoop VName
i IntType
it (Constant PrimValue
iters), Body rep
body)
  | IntValue IntValue
n <- PrimValue
iters,
    IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      [SubExpRes]
res <- Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> Body rep
-> RuleM rep [SubExpRes]
forall rep.
BuilderOps rep =>
Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> Body rep
-> RuleM rep [SubExpRes]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) (((FParam rep, SubExp) -> (FParam rep, SubExpRes))
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExpRes)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> SubExpRes)
-> (FParam rep, SubExp) -> (FParam rep, SubExpRes)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second SubExp -> SubExpRes
subExpRes) [(FParam rep, SubExp)]
merge) (VName
i, IntType
it, Integer
0) Body rep
body
      [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExpRes] -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [SubExpRes]
res) (((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ())
-> ((VName, SubExpRes) -> RuleM rep ()) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
        Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], LoopForm, Body rep)
_ =
  Rule rep
forall rep. Rule rep
Skip

topDownRules :: (BuilderOps rep) => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
  [ RuleLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleLoop rep a -> SimplificationRule rep a
RuleLoop RuleLoop rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleLoop rep
hoistLoopInvariantMergeVariables,
    RuleLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleLoop rep a -> SimplificationRule rep a
RuleLoop RuleLoop rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleLoop rep
simplifyClosedFormLoop,
    RuleLoop rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleLoop rep a -> SimplificationRule rep a
RuleLoop RuleLoop rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleLoop rep
simplifyKnownIterationLoop
  ]

bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ RuleLoop rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleLoop rep a -> SimplificationRule rep a
RuleLoop RuleLoop rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleLoop rep
removeRedundantMergeVariables
  ]

-- | Standard loop simplification rules.
loopRules :: (BuilderOps rep) => RuleBook rep
loopRules :: forall rep. BuilderOps rep => RuleBook rep
loopRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules