{-# LANGUAGE OverloadedStrings #-}

-- | Loop simplification rules.
module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where

import Control.Monad
import Data.List (partition)
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Transform.Rename

-- This next one is tricky - it's easy enough to determine that some
-- loop result is not used after the loop, but here, we must also make
-- sure that it does not affect any other values.
--
-- I do not claim that the current implementation of this rule is
-- perfect, but it should suffice for many cases, and should never
-- generate wrong code.
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
aux ([(Param (FParamInfo lore), SubExp)]
ctx, [(Param (FParamInfo lore), SubExp)]
val, LoopForm lore
form, BodyT lore
body)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo lore), SubExp) -> Bool)
-> [(Param (FParamInfo lore), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (FParamInfo lore) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop (Param (FParamInfo lore) -> Bool)
-> ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> (Param (FParamInfo lore), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo lore), SubExp)]
val,
    [(Param (FParamInfo lore), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param (FParamInfo lore), SubExp)]
ctx -- FIXME: things get tricky if we can remove all vals
    -- but some ctxs are still used.  We take the easy way
    -- out for now.
    =
    let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param (FParamInfo lore), SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param (FParamInfo lore), SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
        necessaryForReturned :: Names
necessaryForReturned =
          (Param (FParamInfo lore) -> Bool)
-> [(Param (FParamInfo lore), SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
            Param (FParamInfo lore) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoopOrInForm
            ([Param (FParamInfo lore)]
-> [SubExp] -> [(Param (FParamInfo lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst ([(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)])
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val) ([SubExp] -> [(Param (FParamInfo lore), SubExp)])
-> [SubExp] -> [(Param (FParamInfo lore), SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es)
            (BodyT lore -> Map VName Names
forall lore. ASTLore lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)

        resIsNecessary :: ((Param dec, b), b) -> Bool
resIsNecessary ((Param dec
v, b
_), b
_) =
          Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
v
            Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned
            Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInPat Param dec
v
            Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInForm Param dec
v

        ([((Param (FParamInfo lore), SubExp), SubExp)]
keep_ctx, [((Param (FParamInfo lore), SubExp), SubExp)]
discard_ctx) =
          (((Param (FParamInfo lore), SubExp), SubExp) -> Bool)
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([((Param (FParamInfo lore), SubExp), SubExp)],
    [((Param (FParamInfo lore), SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Param (FParamInfo lore), SubExp), SubExp) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary ([((Param (FParamInfo lore), SubExp), SubExp)]
 -> ([((Param (FParamInfo lore), SubExp), SubExp)],
     [((Param (FParamInfo lore), SubExp), SubExp)]))
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([((Param (FParamInfo lore), SubExp), SubExp)],
    [((Param (FParamInfo lore), SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
-> [SubExp] -> [((Param (FParamInfo lore), SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo lore), SubExp)]
ctx [SubExp]
ctx_es
        ([(PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))]
discard_valpart) =
          ((PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))
 -> Bool)
-> [(PatElemT (LetDec lore),
     ((Param (FParamInfo lore), SubExp), SubExp))]
-> ([(PatElemT (LetDec lore),
      ((Param (FParamInfo lore), SubExp), SubExp))],
    [(PatElemT (LetDec lore),
      ((Param (FParamInfo lore), SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((Param (FParamInfo lore), SubExp), SubExp) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary (((Param (FParamInfo lore), SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec lore),
     ((Param (FParamInfo lore), SubExp), SubExp))
    -> ((Param (FParamInfo lore), SubExp), SubExp))
-> (PatElemT (LetDec lore),
    ((Param (FParamInfo lore), SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec lore),
 ((Param (FParamInfo lore), SubExp), SubExp))
-> ((Param (FParamInfo lore), SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec lore),
   ((Param (FParamInfo lore), SubExp), SubExp))]
 -> ([(PatElemT (LetDec lore),
       ((Param (FParamInfo lore), SubExp), SubExp))],
     [(PatElemT (LetDec lore),
       ((Param (FParamInfo lore), SubExp), SubExp))]))
-> [(PatElemT (LetDec lore),
     ((Param (FParamInfo lore), SubExp), SubExp))]
-> ([(PatElemT (LetDec lore),
      ((Param (FParamInfo lore), SubExp), SubExp))],
    [(PatElemT (LetDec lore),
      ((Param (FParamInfo lore), SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
            [PatElemT (LetDec lore)]
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> [(PatElemT (LetDec lore),
     ((Param (FParamInfo lore), SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([((Param (FParamInfo lore), SubExp), SubExp)]
 -> [(PatElemT (LetDec lore),
      ((Param (FParamInfo lore), SubExp), SubExp))])
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> [(PatElemT (LetDec lore),
     ((Param (FParamInfo lore), SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
-> [SubExp] -> [((Param (FParamInfo lore), SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo lore), SubExp)]
val [SubExp]
val_es

        ([PatElemT (LetDec lore)]
keep_valpatelems, [((Param (FParamInfo lore), SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))]
-> ([PatElemT (LetDec lore)],
    [((Param (FParamInfo lore), SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))]
keep_valpart
        ([PatElemT (LetDec lore)]
_discard_valpatelems, [((Param (FParamInfo lore), SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))]
-> ([PatElemT (LetDec lore)],
    [((Param (FParamInfo lore), SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore),
  ((Param (FParamInfo lore), SubExp), SubExp))]
discard_valpart
        ([(Param (FParamInfo lore), SubExp)]
ctx', [SubExp]
ctx_es') = [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([(Param (FParamInfo lore), SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo lore), SubExp), SubExp)]
keep_ctx
        ([(Param (FParamInfo lore), SubExp)]
val', [SubExp]
val_es') = [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([(Param (FParamInfo lore), SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo lore), SubExp), SubExp)]
keep_val

        body' :: BodyT lore
body' = BodyT lore
body {bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es'}
        free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec lore)]
keep_valpatelems

        stillUsedContext :: PatElemT (LetDec lore) -> Bool
stillUsedContext PatElemT (LetDec lore)
pat_elem =
          PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pat_elem
            VName -> Names -> Bool
`nameIn` ( Names
free_in_keeps
                         Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec lore) -> PatElemT (LetDec lore) -> Bool
forall a. Eq a => a -> a -> Bool
/= PatElemT (LetDec lore)
pat_elem) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat)
                     )

        pat' :: Pattern lore
pat' =
          Pattern lore
pat
            { patternValueElements :: [PatElemT (LetDec lore)]
patternValueElements = [PatElemT (LetDec lore)]
keep_valpatelems,
              patternContextElements :: [PatElemT (LetDec lore)]
patternContextElements =
                (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec lore) -> Bool
stillUsedContext ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
            }
     in if [(Param (FParamInfo lore), SubExp)]
ctx' [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val' [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(Param (FParamInfo lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val
          then Rule lore
forall lore. Rule lore
Skip
          else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
            -- We can't just remove the bindings in 'discard', since the loop
            -- body may still use their names in (now-dead) expressions.
            -- Hence, we add them inside the loop, fully aware that dead-code
            -- removal will eventually get rid of them.  Some care is
            -- necessary to handle unique bindings.
            BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
              (([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo lore), SubExp), SubExp)]
-> [([VName], ExpT lore)]
forall {b} {lore}.
[((Param (FParamInfo lore), SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((Param (FParamInfo lore), SubExp), SubExp)]
discard_ctx
              (([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo lore), SubExp), SubExp)]
-> [([VName], ExpT lore)]
forall {b} {lore}.
[((Param (FParamInfo lore), SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((Param (FParamInfo lore), SubExp), SubExp)]
discard_val
              BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
            StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param (FParamInfo lore), SubExp)]
ctx' [(Param (FParamInfo lore), SubExp)]
val' LoopForm lore
form BodyT lore
body''
  where
    pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern lore
pat
    used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo lore), SubExp) -> VName)
-> [(Param (FParamInfo lore), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo lore) -> VName)
-> ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> (Param (FParamInfo lore), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo lore), SubExp)]
val) [Bool]
pat_used
    usedAfterLoop :: Param dec -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
    usedAfterLoopOrInForm :: Param dec -> Bool
usedAfterLoopOrInForm Param dec
p =
      Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
p Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
    patAnnotNames :: Names
patAnnotNames = [Param (FParamInfo lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([Param (FParamInfo lore)] -> Names)
-> [Param (FParamInfo lore)] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst ([(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)])
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val
    referencedInPat :: Param dec -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
    referencedInForm :: Param dec -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName

    dummyStms :: [((Param (FParamInfo lore), SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((Param (FParamInfo lore), SubExp), b) -> ([VName], ExpT lore))
-> [((Param (FParamInfo lore), SubExp), b)]
-> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param (FParamInfo lore), SubExp), b) -> ([VName], ExpT lore)
forall {dec} {b} {lore}.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm
    dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param dec
p, SubExp
e), b
_)
      | TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
        Var VName
v <- SubExp
e =
        ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
      | Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(Param (FParamInfo lore), SubExp)],
 [(Param (FParamInfo lore), SubExp)], LoopForm lore, BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

-- We may change the type of the loop if we hoist out a shape
-- annotation, in which case we also need to tweak the bound pattern.
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
  -- Figure out which of the elements of loopresult are
  -- loop-invariant, and hoist them out.
  case ((VName, (FParam lore, SubExp), SubExp)
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp])
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall {dec} {dec}.
(DeclTyped dec, Typed dec, FreeIn dec, Typed dec) =>
(VName, (Param dec, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec lore), VName)]
explpat, [], []) ([(VName, (FParam lore, SubExp), SubExp)]
 -> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
     [(FParam lore, SubExp)], [SubExp]))
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
    [(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
    [VName]
-> [(FParam lore, SubExp)]
-> [SubExp]
-> [(VName, (FParam lore, SubExp), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [(FParam lore, SubExp)]
merge [SubExp]
res of
    ([], [(PatElemT (LetDec lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
      -- Nothing is invariant.
      Rule lore
forall lore. Rule lore
Skip
    ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      -- We have moved something invariant out of the loop.
      let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody {bodyResult :: [SubExp]
bodyResult = [SubExp]
res'}
          invariantShape :: (a, VName) -> Bool
          invariantShape :: forall a. (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) =
            VName
shapemerge
              VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
          ([(PatElemT (LetDec lore), VName)]
implpat', [(PatElemT (LetDec lore), VName)]
implinvariant) = ((PatElemT (LetDec lore), VName) -> Bool)
-> [(PatElemT (LetDec lore), VName)]
-> ([(PatElemT (LetDec lore), VName)],
    [(PatElemT (LetDec lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec lore), VName)]
implpat
          implinvariant' :: [(Ident, SubExp)]
implinvariant' = [(PatElemT (LetDec lore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec lore)
p, VName
v) <- [(PatElemT (LetDec lore), VName)]
implinvariant]
          implpat'' :: [PatElemT (LetDec lore)]
implpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
implpat'
          explpat'' :: [PatElemT (LetDec lore)]
explpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
explpat'
          ([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
      [(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1, SubExp
v2) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
implpat'' [PatElemT (LetDec lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
  where
    merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
    res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody

    implpat :: [(PatElemT (LetDec lore), VName)]
implpat =
      [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
        ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
    explpat :: [(PatElemT (LetDec lore), VName)]
explpat =
      [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
        ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val

    namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val

    removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, b
mergeInit) [(PatElemT dec, VName)]
explpat' =
      case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
        ([(PatElemT dec
patelem, VName
_)], [(PatElemT dec, VName)]
rest) ->
          ((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
        ([(PatElemT dec, VName)]
_, [(PatElemT dec, VName)]
_) ->
          (Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')

    checkInvariance :: (VName, (Param dec, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
    [(Param dec, SubExp)], [SubExp])
checkInvariance
      (VName
pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExp
resExp)
      ([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExp]
resExps)
        | Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
mergeParam))
            Bool -> Bool -> Bool
|| TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
          Bool
isInvariant,
          -- Also do not remove the condition in a while-loop.
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
          let (Maybe (Ident, SubExp)
bnd, [(PatElemT dec, VName)]
explpat'') =
                (Param dec, SubExp)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, SubExp), [(PatElemT dec, VName)])
forall {dec} {dec} {b}.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, SubExp
mergeInit) [(PatElemT dec, VName)]
explpat'
           in ( ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
                [(PatElemT dec, VName)]
explpat'',
                [(Param dec, SubExp)]
merge',
                [SubExp]
resExps
              )
        where
          -- A non-unique merge variable is invariant if one of the
          -- following is true:
          --
          -- (0) The result is a variable of the same name as the
          -- parameter, where all existential parameters are already
          -- known to be invariant
          isInvariant :: Bool
isInvariant
            | Var VName
v2 <- SubExp
resExp,
              Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
              Names -> Param dec -> Bool
forall {dec}. FreeIn dec => Names -> Param dec -> Bool
allExistentialInvariant
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant)
                Param dec
mergeParam
            -- (1) The result is identical to the initial parameter value.
            | SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp = Bool
True
            -- (2) The initial parameter value is equal to an outer
            -- loop parameter 'P', where the initial value of 'P' is
            -- equal to 'resExp', AND 'resExp' ultimately becomes the
            -- new value of 'P'.  XXX: it's a bit clumsy that this
            -- only works for one level of nesting, and I think it
            -- would not be too hard to generalise.
            | Var VName
init_v <- SubExp
mergeInit,
              Just (SubExp
p_init, SubExp
p_res) <- VName -> TopDown lore -> Maybe (SubExp, SubExp)
forall lore. VName -> SymbolTable lore -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown lore
vtable,
              SubExp
p_init SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp,
              SubExp
p_res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
              Bool
True
            | Bool
otherwise = Bool
False
    checkInvariance
      (VName
_pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExp
resExp)
      ([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExp]
resExps) =
        ([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', (Param dec
mergeParam, SubExp
mergeInit) (Param dec, SubExp)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. a -> [a] -> [a]
: [(Param dec, SubExp)]
merge', SubExp
resExp SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
resExps)

    allExistentialInvariant :: Names -> Param dec -> Bool
allExistentialInvariant Names
namesOfInvariant Param dec
mergeParam =
      (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
        Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
          Param dec -> Names
forall a. FreeIn a => a -> Names
freeIn Param dec
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam)
    invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
      Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams)
        Bool -> Bool -> Bool
|| VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant

simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound [], BodyT lore
body) =
  RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(Param (LParamInfo lore), VName)]
loop_vars), BodyT lore
body)
  | [Maybe (RuleM lore IndexResult)]
simplifiable <- ((Param (LParamInfo lore), VName)
 -> Maybe (RuleM lore IndexResult))
-> [(Param (LParamInfo lore), VName)]
-> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (LParamInfo lore), VName) -> Maybe (RuleM lore IndexResult)
forall {dec}.
Typed dec =>
(Param dec, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(Param (LParamInfo lore), VName)]
loop_vars,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    -- Check if the simplifications throw away more information than
    -- we are comfortable with at this stage.
    ([Maybe (Param (LParamInfo lore), VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
      Scope lore
-> RuleM
     lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
-> RuleM
     lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
 -> RuleM
      lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore]))
-> RuleM
     lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
-> RuleM
     lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
        [(Maybe (Param (LParamInfo lore), VName), Stms lore)]
-> ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Param (LParamInfo lore), VName), Stms lore)]
 -> ([Maybe (Param (LParamInfo lore), VName)], [Stms lore]))
-> RuleM lore [(Maybe (Param (LParamInfo lore), VName), Stms lore)]
-> RuleM
     lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param (LParamInfo lore), VName)
 -> Maybe (RuleM lore IndexResult)
 -> RuleM lore (Maybe (Param (LParamInfo lore), VName), Stms lore))
-> [(Param (LParamInfo lore), VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (Param (LParamInfo lore), VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param (LParamInfo lore), VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (Param (LParamInfo lore), VName), Stms lore)
forall {m :: * -> *} {dec}.
MonadBinder m =>
(Param dec, VName)
-> Maybe (m IndexResult)
-> m (Maybe (Param dec, VName), Stms (Lore m))
onLoopVar [(Param (LParamInfo lore), VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
    if [Maybe (Param (LParamInfo lore), VName)]
maybe_loop_vars [Maybe (Param (LParamInfo lore), VName)]
-> [Maybe (Param (LParamInfo lore), VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((Param (LParamInfo lore), VName)
 -> Maybe (Param (LParamInfo lore), VName))
-> [(Param (LParamInfo lore), VName)]
-> [Maybe (Param (LParamInfo lore), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (LParamInfo lore), VName)
-> Maybe (Param (LParamInfo lore), VName)
forall a. a -> Maybe a
Just [(Param (LParamInfo lore), VName)]
loop_vars
      then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
      else do
        BodyT lore
body' <- RuleM lore [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m [SubExp] -> m (Body (Lore m))
buildBody_ (RuleM lore [SubExp] -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
          Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
          Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
        StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
            [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop
              [(FParam lore, SubExp)]
ctx
              [(FParam lore, SubExp)]
val
              (VName
-> IntType
-> SubExp
-> [(Param (LParamInfo lore), VName)]
-> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(Param (LParamInfo lore), VName)] -> LoopForm lore)
-> [(Param (LParamInfo lore), VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (Param (LParamInfo lore), VName)]
-> [(Param (LParamInfo lore), VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Param (LParamInfo lore), VName)]
maybe_loop_vars)
              BodyT lore
body'
  where
    seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
      | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
      | Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
    seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
 -> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
    consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body

    vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable

    checkIfSimplifiable :: (Param dec, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (Param dec
p, VName
arr) =
      SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing
        TopDown lore
SymbolTable (Lore (RuleM lore))
vtable'
        SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType
        VName
arr
        (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (Param dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
p) [])
        (Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body

    -- We only want this simplification if the result does not refer
    -- to 'i' at all, or does not contain accesses.
    onLoopVar :: (Param dec, VName)
-> Maybe (m IndexResult)
-> m (Maybe (Param dec, VName), Stms (Lore m))
onLoopVar (Param dec
p, VName
arr) Maybe (m IndexResult)
Nothing =
      (Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
arr), Stms (Lore m)
forall a. Monoid a => a
mempty)
    onLoopVar (Param dec
p, VName
arr) (Just m IndexResult
m) = do
      (IndexResult
x, Stms (Lore m)
x_stms) <- m IndexResult -> m (IndexResult, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms m IndexResult
m
      case IndexResult
x of
        IndexResult Certificates
cs VName
arr' Slice SubExp
slice
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Lore m) -> Bool) -> Stms (Lore m) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm (Lore m) -> Names) -> Stm (Lore m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Lore m) -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms (Lore m)
x_stms,
            DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
            VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
            Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore m)
x_stms
            SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
            VName
for_in_partial <-
              Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
                String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
                  BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                      SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
            (Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
for_in_partial), Stms (Lore m)
forall a. Monoid a => a
mempty)
        SubExpResult Certificates
cs SubExp
se
          | (Stm (Lore m) -> Bool) -> Stms (Lore m) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Lore m) -> Bool
forall {lore}. ExpT lore -> Bool
notIndex (Exp (Lore m) -> Bool)
-> (Stm (Lore m) -> Exp (Lore m)) -> Stm (Lore m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Lore m) -> Exp (Lore m)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Lore m)
x_stms -> do
            Stms (Lore m)
x_stms' <- m () -> m (Stms (Lore m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (m () -> m (Stms (Lore m))) -> m () -> m (Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
              Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
                Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore m)
x_stms
                [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
            (Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Param dec, VName)
forall a. Maybe a
Nothing, Stms (Lore m)
x_stms')
        IndexResult
_ -> (Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
arr), Stms (Lore m)
forall a. Monoid a => a
mempty)

    notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index {}) = Bool
False
    notIndex ExpT lore
_ = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

-- If a for-loop with no loop variables has a counter of type Int64,
-- and the bound is just a constant or sign-extended integer of
-- smaller type, then change the loop to iterate over the smaller type
-- instead.  We then move the sign extension inside the loop instead.
-- This addresses loops of the form @for i in x..<y@ in the source
-- language.
narrowLoopType :: (BinderOps lore) => TopDownRuleDoLoop lore
narrowLoopType :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
narrowLoopType TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
Int64 SubExp
n [], BodyT lore
body)
  | Just (SubExp
n', IntType
it', Certificates
cs) <- Maybe (SubExp, IntType, Certificates)
smallerType =
    RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
      VName
i' <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
i
      let form' :: LoopForm lore
form' = VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i' IntType
it' SubExp
n' []
      BodyT lore
body' <- RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopForm lore -> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm lore
forall {lore}. LoopForm lore
form' (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ do
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it' IntType
Int64) (VName -> SubExp
Var VName
i')
        BodyT lore -> RuleM lore (BodyT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyT lore
body
      StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
        Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
forall {lore}. LoopForm lore
form' BodyT lore
body'
  where
    smallerType :: Maybe (SubExp, IntType, Certificates)
smallerType
      | Var VName
n' <- SubExp
n,
        Just (ConvOp (SExt IntType
it' IntType
_) SubExp
n'', Certificates
cs) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
n' TopDown lore
vtable =
        (SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (SubExp
n'', IntType
it', Certificates
cs)
      | Constant (IntValue (Int64Value Int64
n')) <- SubExp
n,
        Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Int32
forall a. Bounded a => a
maxBound :: Int32) =
        (SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (IntType -> Integer -> SubExp
intConst IntType
Int32 (Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n'), IntType
Int32, Certificates
forall a. Monoid a => a
mempty)
      | Bool
otherwise =
        Maybe (SubExp, IntType, Certificates)
forall a. Maybe a
Nothing
narrowLoopType TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip

unroll ::
  BinderOps lore =>
  Integer ->
  [(FParam lore, SubExp)] ->
  (VName, IntType, Integer) ->
  [(LParam lore, VName)] ->
  Body lore ->
  RuleM lore [SubExp]
unroll :: forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge (VName
iv, IntType
it, Integer
i) [(LParam lore, VName)]
loop_vars Body lore
body
  | Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
    [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> RuleM lore [SubExp])
-> [SubExp] -> RuleM lore [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
merge
  | Bool
otherwise = do
    Body lore
iter_body <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
 -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
      [(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam lore, SubExp)]
merge (((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit

      [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
iv] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i

      [(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p, VName
arr) ->
        [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []

      -- Some of the sizes in the types here might be temporarily wrong
      -- until copy propagation fixes it up.
      Body lore -> RuleM lore (Body lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body lore
body

    Body lore
iter_body' <- Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
iter_body
    Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
iter_body'

    let merge' :: [(FParam lore, SubExp)]
merge' = [FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
merge) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
iter_body'
    Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge' (VName
iv, IntType
it, Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [(LParam lore, VName)]
loop_vars Body lore
body

simplifyKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
  | IntValue IntValue
n <- PrimValue
iters,
    IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
res <- Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> BodyT lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) ([(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val) (VName
i, IntType
it, Integer
0) [(LParam lore, VName)]
loop_vars BodyT lore
body
    [(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
      [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
 BodyT lore)
_ =
  Rule lore
forall lore. Rule lore
Skip

topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules =
  [ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables,
    RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
narrowLoopType
  ]

bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules =
  [ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables
  ]

-- | Standard loop simplification rules.
loopRules :: (BinderOps lore, Aliased lore) => RuleBook lore
loopRules :: forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
loopRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules