{-# LANGUAGE OverloadedStrings #-}
module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where
import Control.Monad
import Data.List (partition)
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Transform.Rename
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
aux ([(Param (FParamInfo lore), SubExp)]
ctx, [(Param (FParamInfo lore), SubExp)]
val, LoopForm lore
form, BodyT lore
body)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo lore), SubExp) -> Bool)
-> [(Param (FParamInfo lore), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Param (FParamInfo lore) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop (Param (FParamInfo lore) -> Bool)
-> ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> (Param (FParamInfo lore), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo lore), SubExp)]
val,
[(Param (FParamInfo lore), SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param (FParamInfo lore), SubExp)]
ctx
=
let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param (FParamInfo lore), SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Param (FParamInfo lore), SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
necessaryForReturned :: Names
necessaryForReturned =
(Param (FParamInfo lore) -> Bool)
-> [(Param (FParamInfo lore), SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
Param (FParamInfo lore) -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoopOrInForm
([Param (FParamInfo lore)]
-> [SubExp] -> [(Param (FParamInfo lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst ([(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)])
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val) ([SubExp] -> [(Param (FParamInfo lore), SubExp)])
-> [SubExp] -> [(Param (FParamInfo lore), SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es)
(BodyT lore -> Map VName Names
forall lore. ASTLore lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)
resIsNecessary :: ((Param dec, b), b) -> Bool
resIsNecessary ((Param dec
v, b
_), b
_) =
Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
v
Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned
Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInPat Param dec
v
Bool -> Bool -> Bool
|| Param dec -> Bool
forall {dec}. Param dec -> Bool
referencedInForm Param dec
v
([((Param (FParamInfo lore), SubExp), SubExp)]
keep_ctx, [((Param (FParamInfo lore), SubExp), SubExp)]
discard_ctx) =
(((Param (FParamInfo lore), SubExp), SubExp) -> Bool)
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([((Param (FParamInfo lore), SubExp), SubExp)],
[((Param (FParamInfo lore), SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Param (FParamInfo lore), SubExp), SubExp) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary ([((Param (FParamInfo lore), SubExp), SubExp)]
-> ([((Param (FParamInfo lore), SubExp), SubExp)],
[((Param (FParamInfo lore), SubExp), SubExp)]))
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([((Param (FParamInfo lore), SubExp), SubExp)],
[((Param (FParamInfo lore), SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
-> [SubExp] -> [((Param (FParamInfo lore), SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo lore), SubExp)]
ctx [SubExp]
ctx_es
([(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
discard_valpart) =
((PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))
-> Bool)
-> [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
-> ([(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))],
[(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((Param (FParamInfo lore), SubExp), SubExp) -> Bool
forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary (((Param (FParamInfo lore), SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))
-> ((Param (FParamInfo lore), SubExp), SubExp))
-> (PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))
-> ((Param (FParamInfo lore), SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
-> ([(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))],
[(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]))
-> [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
-> ([(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))],
[(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([((Param (FParamInfo lore), SubExp), SubExp)]
-> [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))])
-> [((Param (FParamInfo lore), SubExp), SubExp)]
-> [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
-> [SubExp] -> [((Param (FParamInfo lore), SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo lore), SubExp)]
val [SubExp]
val_es
([PatElemT (LetDec lore)]
keep_valpatelems, [((Param (FParamInfo lore), SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
-> ([PatElemT (LetDec lore)],
[((Param (FParamInfo lore), SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
keep_valpart
([PatElemT (LetDec lore)]
_discard_valpatelems, [((Param (FParamInfo lore), SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
-> ([PatElemT (LetDec lore)],
[((Param (FParamInfo lore), SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore),
((Param (FParamInfo lore), SubExp), SubExp))]
discard_valpart
([(Param (FParamInfo lore), SubExp)]
ctx', [SubExp]
ctx_es') = [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([(Param (FParamInfo lore), SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo lore), SubExp), SubExp)]
keep_ctx
([(Param (FParamInfo lore), SubExp)]
val', [SubExp]
val_es') = [((Param (FParamInfo lore), SubExp), SubExp)]
-> ([(Param (FParamInfo lore), SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo lore), SubExp), SubExp)]
keep_val
body' :: BodyT lore
body' = BodyT lore
body {bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es'}
free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec lore)]
keep_valpatelems
stillUsedContext :: PatElemT (LetDec lore) -> Bool
stillUsedContext PatElemT (LetDec lore)
pat_elem =
PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pat_elem
VName -> Names -> Bool
`nameIn` ( Names
free_in_keeps
Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec lore) -> PatElemT (LetDec lore) -> Bool
forall a. Eq a => a -> a -> Bool
/= PatElemT (LetDec lore)
pat_elem) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat)
)
pat' :: Pattern lore
pat' =
Pattern lore
pat
{ patternValueElements :: [PatElemT (LetDec lore)]
patternValueElements = [PatElemT (LetDec lore)]
keep_valpatelems,
patternContextElements :: [PatElemT (LetDec lore)]
patternContextElements =
(PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec lore) -> Bool
stillUsedContext ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
}
in if [(Param (FParamInfo lore), SubExp)]
ctx' [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val' [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(Param (FParamInfo lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val
then Rule lore
forall lore. Rule lore
Skip
else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
(([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo lore), SubExp), SubExp)]
-> [([VName], ExpT lore)]
forall {b} {lore}.
[((Param (FParamInfo lore), SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((Param (FParamInfo lore), SubExp), SubExp)]
discard_ctx
(([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((Param (FParamInfo lore), SubExp), SubExp)]
-> [([VName], ExpT lore)]
forall {b} {lore}.
[((Param (FParamInfo lore), SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((Param (FParamInfo lore), SubExp), SubExp)]
discard_val
BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(Param (FParamInfo lore), SubExp)]
ctx' [(Param (FParamInfo lore), SubExp)]
val' LoopForm lore
form BodyT lore
body''
where
pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern lore
pat
used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param (FParamInfo lore), SubExp) -> VName)
-> [(Param (FParamInfo lore), SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Param (FParamInfo lore) -> VName)
-> ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> (Param (FParamInfo lore), SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst) [(Param (FParamInfo lore), SubExp)]
val) [Bool]
pat_used
usedAfterLoop :: Param dec -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
usedAfterLoopOrInForm :: Param dec -> Bool
usedAfterLoopOrInForm Param dec
p =
Param dec -> Bool
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
p Bool -> Bool -> Bool
|| Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
patAnnotNames :: Names
patAnnotNames = [Param (FParamInfo lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([Param (FParamInfo lore)] -> Names)
-> [Param (FParamInfo lore)] -> Names
forall a b. (a -> b) -> a -> b
$ ((Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore))
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo lore), SubExp) -> Param (FParamInfo lore)
forall a b. (a, b) -> a
fst ([(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)])
-> [(Param (FParamInfo lore), SubExp)] -> [Param (FParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo lore), SubExp)]
ctx [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
-> [(Param (FParamInfo lore), SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param (FParamInfo lore), SubExp)]
val
referencedInPat :: Param dec -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
referencedInForm :: Param dec -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (Param dec -> VName) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName
dummyStms :: [((Param (FParamInfo lore), SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((Param (FParamInfo lore), SubExp), b) -> ([VName], ExpT lore))
-> [((Param (FParamInfo lore), SubExp), b)]
-> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((Param (FParamInfo lore), SubExp), b) -> ([VName], ExpT lore)
forall {dec} {b} {lore}.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm
dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param dec
p, SubExp
e), b
_)
| TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
Var VName
v <- SubExp
e =
([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
| Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(Param (FParamInfo lore), SubExp)],
[(Param (FParamInfo lore), SubExp)], LoopForm lore, BodyT lore)
_ =
Rule lore
forall lore. Rule lore
Skip
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
case ((VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VName, (FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall {dec} {dec}.
(DeclTyped dec, Typed dec, FreeIn dec, Typed dec) =>
(VName, (Param dec, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
[(Param dec, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
[(Param dec, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec lore), VName)]
explpat, [], []) ([(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp]))
-> [(VName, (FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
[VName]
-> [(FParam lore, SubExp)]
-> [SubExp]
-> [(VName, (FParam lore, SubExp), SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [(FParam lore, SubExp)]
merge [SubExp]
res of
([], [(PatElemT (LetDec lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
Rule lore
forall lore. Rule lore
Skip
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody {bodyResult :: [SubExp]
bodyResult = [SubExp]
res'}
invariantShape :: (a, VName) -> Bool
invariantShape :: forall a. (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) =
VName
shapemerge
VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
([(PatElemT (LetDec lore), VName)]
implpat', [(PatElemT (LetDec lore), VName)]
implinvariant) = ((PatElemT (LetDec lore), VName) -> Bool)
-> [(PatElemT (LetDec lore), VName)]
-> ([(PatElemT (LetDec lore), VName)],
[(PatElemT (LetDec lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec lore), VName)]
implpat
implinvariant' :: [(Ident, SubExp)]
implinvariant' = [(PatElemT (LetDec lore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec lore)
p, VName
v) <- [(PatElemT (LetDec lore), VName)]
implinvariant]
implpat'' :: [PatElemT (LetDec lore)]
implpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
implpat'
explpat'' :: [PatElemT (LetDec lore)]
explpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
explpat'
([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
[(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1, SubExp
v2) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
implpat'' [PatElemT (LetDec lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
where
merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody
implpat :: [(PatElemT (LetDec lore), VName)]
implpat =
[PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
explpat :: [(PatElemT (LetDec lore), VName)]
explpat =
[PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val
namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, b
mergeInit) [(PatElemT dec, VName)]
explpat' =
case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
([(PatElemT dec
patelem, VName
_)], [(PatElemT dec, VName)]
rest) ->
((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
([(PatElemT dec, VName)]
_, [(PatElemT dec, VName)]
_) ->
(Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')
checkInvariance :: (VName, (Param dec, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
[(Param dec, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT dec, VName)],
[(Param dec, SubExp)], [SubExp])
checkInvariance
(VName
pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExp
resExp)
([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExp]
resExps)
| Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
mergeParam))
Bool -> Bool -> Bool
|| TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
Bool
isInvariant,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
let (Maybe (Ident, SubExp)
bnd, [(PatElemT dec, VName)]
explpat'') =
(Param dec, SubExp)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, SubExp), [(PatElemT dec, VName)])
forall {dec} {dec} {b}.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam, SubExp
mergeInit) [(PatElemT dec, VName)]
explpat'
in ( ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param dec -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
[(PatElemT dec, VName)]
explpat'',
[(Param dec, SubExp)]
merge',
[SubExp]
resExps
)
where
isInvariant :: Bool
isInvariant
| Var VName
v2 <- SubExp
resExp,
Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
Names -> Param dec -> Bool
forall {dec}. FreeIn dec => Names -> Param dec -> Bool
allExistentialInvariant
([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant)
Param dec
mergeParam
| SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp = Bool
True
| Var VName
init_v <- SubExp
mergeInit,
Just (SubExp
p_init, SubExp
p_res) <- VName -> TopDown lore -> Maybe (SubExp, SubExp)
forall lore. VName -> SymbolTable lore -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown lore
vtable,
SubExp
p_init SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp,
SubExp
p_res SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
Bool
True
| Bool
otherwise = Bool
False
checkInvariance
(VName
_pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExp
resExp)
([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', [(Param dec, SubExp)]
merge', [SubExp]
resExps) =
([(Ident, SubExp)]
invariant, [(PatElemT dec, VName)]
explpat', (Param dec
mergeParam, SubExp
mergeInit) (Param dec, SubExp)
-> [(Param dec, SubExp)] -> [(Param dec, SubExp)]
forall a. a -> [a] -> [a]
: [(Param dec, SubExp)]
merge', SubExp
resExp SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
resExps)
allExistentialInvariant :: Names -> Param dec -> Bool
allExistentialInvariant Names
namesOfInvariant Param dec
mergeParam =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
Param dec -> Names
forall a. FreeIn a => a -> Names
freeIn Param dec
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam)
invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams)
Bool -> Bool -> Bool
|| VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant
simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound [], BodyT lore
body) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(Param (LParamInfo lore), VName)]
loop_vars), BodyT lore
body)
| [Maybe (RuleM lore IndexResult)]
simplifiable <- ((Param (LParamInfo lore), VName)
-> Maybe (RuleM lore IndexResult))
-> [(Param (LParamInfo lore), VName)]
-> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (LParamInfo lore), VName) -> Maybe (RuleM lore IndexResult)
forall {dec}.
Typed dec =>
(Param dec, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(Param (LParamInfo lore), VName)]
loop_vars,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
([Maybe (Param (LParamInfo lore), VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
Scope lore
-> RuleM
lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
-> RuleM
lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
-> RuleM
lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore]))
-> RuleM
lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
-> RuleM
lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
[(Maybe (Param (LParamInfo lore), VName), Stms lore)]
-> ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Param (LParamInfo lore), VName), Stms lore)]
-> ([Maybe (Param (LParamInfo lore), VName)], [Stms lore]))
-> RuleM lore [(Maybe (Param (LParamInfo lore), VName), Stms lore)]
-> RuleM
lore ([Maybe (Param (LParamInfo lore), VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param (LParamInfo lore), VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (Param (LParamInfo lore), VName), Stms lore))
-> [(Param (LParamInfo lore), VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (Param (LParamInfo lore), VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param (LParamInfo lore), VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (Param (LParamInfo lore), VName), Stms lore)
forall {m :: * -> *} {dec}.
MonadBinder m =>
(Param dec, VName)
-> Maybe (m IndexResult)
-> m (Maybe (Param dec, VName), Stms (Lore m))
onLoopVar [(Param (LParamInfo lore), VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
if [Maybe (Param (LParamInfo lore), VName)]
maybe_loop_vars [Maybe (Param (LParamInfo lore), VName)]
-> [Maybe (Param (LParamInfo lore), VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((Param (LParamInfo lore), VName)
-> Maybe (Param (LParamInfo lore), VName))
-> [(Param (LParamInfo lore), VName)]
-> [Maybe (Param (LParamInfo lore), VName)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (LParamInfo lore), VName)
-> Maybe (Param (LParamInfo lore), VName)
forall a. a -> Maybe a
Just [(Param (LParamInfo lore), VName)]
loop_vars
then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
else do
BodyT lore
body' <- RuleM lore [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m [SubExp] -> m (Body (Lore m))
buildBody_ (RuleM lore [SubExp] -> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop
[(FParam lore, SubExp)]
ctx
[(FParam lore, SubExp)]
val
(VName
-> IntType
-> SubExp
-> [(Param (LParamInfo lore), VName)]
-> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(Param (LParamInfo lore), VName)] -> LoopForm lore)
-> [(Param (LParamInfo lore), VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (Param (LParamInfo lore), VName)]
-> [(Param (LParamInfo lore), VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Param (LParamInfo lore), VName)]
maybe_loop_vars)
BodyT lore
body'
where
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
| VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
| Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body
vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable
checkIfSimplifiable :: (Param dec, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (Param dec
p, VName
arr) =
SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing
TopDown lore
SymbolTable (Lore (RuleM lore))
vtable'
SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType
VName
arr
(SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (Param dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
p) [])
(Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body
onLoopVar :: (Param dec, VName)
-> Maybe (m IndexResult)
-> m (Maybe (Param dec, VName), Stms (Lore m))
onLoopVar (Param dec
p, VName
arr) Maybe (m IndexResult)
Nothing =
(Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
arr), Stms (Lore m)
forall a. Monoid a => a
mempty)
onLoopVar (Param dec
p, VName
arr) (Just m IndexResult
m) = do
(IndexResult
x, Stms (Lore m)
x_stms) <- m IndexResult -> m (IndexResult, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms m IndexResult
m
case IndexResult
x of
IndexResult Certificates
cs VName
arr' Slice SubExp
slice
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm (Lore m) -> Bool) -> Stms (Lore m) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm (Lore m) -> Names) -> Stm (Lore m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Lore m) -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms (Lore m)
x_stms,
DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore m)
x_stms
SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
VName
for_in_partial <-
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
(Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
for_in_partial), Stms (Lore m)
forall a. Monoid a => a
mempty)
SubExpResult Certificates
cs SubExp
se
| (Stm (Lore m) -> Bool) -> Stms (Lore m) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Exp (Lore m) -> Bool
forall {lore}. ExpT lore -> Bool
notIndex (Exp (Lore m) -> Bool)
-> (Stm (Lore m) -> Exp (Lore m)) -> Stm (Lore m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Lore m) -> Exp (Lore m)
forall lore. Stm lore -> Exp lore
stmExp) Stms (Lore m)
x_stms -> do
Stms (Lore m)
x_stms' <- m () -> m (Stms (Lore m))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (m () -> m (Stms (Lore m))) -> m () -> m (Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
Certificates -> m () -> m ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore m)
x_stms
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
(Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Param dec, VName)
forall a. Maybe a
Nothing, Stms (Lore m)
x_stms')
IndexResult
_ -> (Maybe (Param dec, VName), Stms (Lore m))
-> m (Maybe (Param dec, VName), Stms (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Param dec, VName) -> Maybe (Param dec, VName)
forall a. a -> Maybe a
Just (Param dec
p, VName
arr), Stms (Lore m)
forall a. Monoid a => a
mempty)
notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index {}) = Bool
False
notIndex ExpT lore
_ = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
narrowLoopType :: (BinderOps lore) => TopDownRuleDoLoop lore
narrowLoopType :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
narrowLoopType TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
Int64 SubExp
n [], BodyT lore
body)
| Just (SubExp
n', IntType
it', Certificates
cs) <- Maybe (SubExp, IntType, Certificates)
smallerType =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
VName
i' <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> RuleM lore VName) -> String -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
i
let form' :: LoopForm lore
form' = VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i' IntType
it' SubExp
n' []
BodyT lore
body' <- RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopForm lore -> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf LoopForm lore
forall {lore}. LoopForm lore
form' (RuleM lore (BodyT lore) -> RuleM lore (BodyT lore))
-> RuleM lore (BodyT lore) -> RuleM lore (BodyT lore)
forall a b. (a -> b) -> a -> b
$ do
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it' IntType
Int64) (VName -> SubExp
Var VName
i')
BodyT lore -> RuleM lore (BodyT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyT lore
body
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
forall {lore}. LoopForm lore
form' BodyT lore
body'
where
smallerType :: Maybe (SubExp, IntType, Certificates)
smallerType
| Var VName
n' <- SubExp
n,
Just (ConvOp (SExt IntType
it' IntType
_) SubExp
n'', Certificates
cs) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
n' TopDown lore
vtable =
(SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (SubExp
n'', IntType
it', Certificates
cs)
| Constant (IntValue (Int64Value Int64
n')) <- SubExp
n,
Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Int32
forall a. Bounded a => a
maxBound :: Int32) =
(SubExp, IntType, Certificates)
-> Maybe (SubExp, IntType, Certificates)
forall a. a -> Maybe a
Just (IntType -> Integer -> SubExp
intConst IntType
Int32 (Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
n'), IntType
Int32, Certificates
forall a. Monoid a => a
mempty)
| Bool
otherwise =
Maybe (SubExp, IntType, Certificates)
forall a. Maybe a
Nothing
narrowLoopType TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
unroll ::
BinderOps lore =>
Integer ->
[(FParam lore, SubExp)] ->
(VName, IntType, Integer) ->
[(LParam lore, VName)] ->
Body lore ->
RuleM lore [SubExp]
unroll :: forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge (VName
iv, IntType
it, Integer
i) [(LParam lore, VName)]
loop_vars Body lore
body
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
[SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> RuleM lore [SubExp])
-> [SubExp] -> RuleM lore [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
merge
| Bool
otherwise = do
Body lore
iter_body <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
[(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam lore, SubExp)]
merge (((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
iv] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i
[(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p, VName
arr) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []
Body lore -> RuleM lore (Body lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body lore
body
Body lore
iter_body' <- Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
iter_body
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
iter_body'
let merge' :: [(FParam lore, SubExp)]
merge' = [FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
merge) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
iter_body'
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge' (VName
iv, IntType
it, Integer
i Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1) [(LParam lore, VName)]
loop_vars Body lore
body
simplifyKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop :: forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
| IntValue IntValue
n <- PrimValue
iters,
IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
res <- Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> BodyT lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) ([(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val) (VName
i, IntType
it, Integer
0) [(LParam lore, VName)]
loop_vars BodyT lore
body
[(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ =
Rule lore
forall lore. Rule lore
Skip
topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules =
[ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables,
RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
narrowLoopType
]
bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules =
[ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables
]
loopRules :: (BinderOps lore, Aliased lore) => RuleBook lore
loopRules :: forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
loopRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules