module Futhark.Optimise.Simplify.Rules.Loop (loopRules) where
import Control.Monad
import Data.Bifunctor (second)
import Data.List (partition)
import Data.Maybe
import Futhark.Analysis.DataDependencies
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.ClosedForm
import Futhark.Optimise.Simplify.Rules.Index
import Futhark.Transform.Rename
removeRedundantMergeVariables :: BuilderOps rep => BottomUpRuleDoLoop rep
removeRedundantMergeVariables :: forall rep. BuilderOps rep => BottomUpRuleDoLoop rep
removeRedundantMergeVariables (SymbolTable rep
_, UsageTable
used) Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(Param (FParamInfo rep), SubExp)]
merge, LoopForm rep
form, Body rep
body)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {dec}. Param dec -> Bool
usedAfterLoop forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
merge =
let necessaryForReturned :: Names
necessaryForReturned =
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned
forall {dec}. Param dec -> Bool
usedAfterLoopOrInForm
(forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param (FParamInfo rep), SubExp)]
merge) (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
body))
(forall {k} (rep :: k). ASTRep rep => Body rep -> Map VName Names
dataDependencies Body rep
body)
resIsNecessary :: ((Param dec, b), b) -> Bool
resIsNecessary ((Param dec
v, b
_), b
_) =
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
v
Bool -> Bool -> Bool
|| forall dec. Param dec -> VName
paramName Param dec
v
VName -> Names -> Bool
`nameIn` Names
necessaryForReturned
Bool -> Bool -> Bool
|| forall {dec}. Param dec -> Bool
referencedInPat Param dec
v
Bool -> Bool -> Bool
|| forall {dec}. Param dec -> Bool
referencedInForm Param dec
v
([(PatElem (LetDec rep),
((Param (FParamInfo rep), SubExp), SubExpRes))]
keep_valpart, [(PatElem (LetDec rep),
((Param (FParamInfo rep), SubExp), SubExpRes))]
discard_valpart) =
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall {dec} {b} {b}. ((Param dec, b), b) -> Bool
resIsNecessary forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [(Param (FParamInfo rep), SubExp)]
merge forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
body
([PatElem (LetDec rep)]
keep_valpatelems, [((Param (FParamInfo rep), SubExp), SubExpRes)]
keep_val) = forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep),
((Param (FParamInfo rep), SubExp), SubExpRes))]
keep_valpart
([PatElem (LetDec rep)]
_discard_valpatelems, [((Param (FParamInfo rep), SubExp), SubExpRes)]
discard_val) = forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (LetDec rep),
((Param (FParamInfo rep), SubExp), SubExpRes))]
discard_valpart
([(Param (FParamInfo rep), SubExp)]
merge', Result
val_es') = forall a b. [(a, b)] -> ([a], [b])
unzip [((Param (FParamInfo rep), SubExp), SubExpRes)]
keep_val
body' :: Body rep
body' = Body rep
body {bodyResult :: Result
bodyResult = Result
val_es'}
pat' :: Pat (LetDec rep)
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
keep_valpatelems
in if [(Param (FParamInfo rep), SubExp)]
merge' forall a. Eq a => a -> a -> Bool
== [(Param (FParamInfo rep), SubExp)]
merge
then forall {k} (rep :: k). Rule rep
Skip
else forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
Body rep
body'' <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames) forall a b. (a -> b) -> a -> b
$ forall {k} {b} {rep :: k}.
[((Param (FParamInfo rep), SubExp), b)] -> [([VName], Exp rep)]
dummyStms [((Param (FParamInfo rep), SubExp), SubExpRes)]
discard_val
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body'
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (FParamInfo rep), SubExp)]
merge' LoopForm rep
form Body rep
body''
where
pat_used :: [Bool]
pat_used = forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat
used_vals :: [VName]
used_vals = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param (FParamInfo rep), SubExp)]
merge) [Bool]
pat_used
usedAfterLoop :: Param dec -> Bool
usedAfterLoop = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName
usedAfterLoopOrInForm :: Param dec -> Bool
usedAfterLoopOrInForm Param dec
p =
forall {dec}. Param dec -> Bool
usedAfterLoop Param dec
p Bool -> Bool -> Bool
|| forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form
patAnnotNames :: Names
patAnnotNames = forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param (FParamInfo rep), SubExp)]
merge
referencedInPat :: Param dec -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName
referencedInForm :: Param dec -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName
dummyStms :: [((Param (FParamInfo rep), SubExp), b)] -> [([VName], Exp rep)]
dummyStms = forall a b. (a -> b) -> [a] -> [b]
map forall {k} {dec} {b} {rep :: k}.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], Exp rep)
dummyStm
dummyStm :: ((Param dec, SubExp), b) -> ([VName], Exp rep)
dummyStm ((Param dec
p, SubExp
e), b
_)
| forall shape. TypeBase shape Uniqueness -> Bool
unique (forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param dec
p),
Var VName
v <- SubExp
e =
([forall dec. Param dec -> VName
paramName Param dec
p], forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
| Bool
otherwise = ([forall dec. Param dec -> VName
paramName Param dec
p], forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(Param (FParamInfo rep), SubExp)], LoopForm rep, Body rep)
_ =
forall {k} (rep :: k). Rule rep
Skip
hoistLoopInvariantMergeVariables :: BuilderOps rep => TopDownRuleDoLoop rep
hoistLoopInvariantMergeVariables :: forall rep. BuilderOps rep => TopDownRuleDoLoop rep
hoistLoopInvariantMergeVariables TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, LoopForm rep
form, Body rep
loopbody) = do
let explpat :: [(PatElem (LetDec rep), VName)]
explpat = forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge
case forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall {dec} {dec}.
(Typed dec, FreeIn dec, Typed dec) =>
(VName, (Param dec, SubExp), SubExpRes)
-> ([(Ident, SubExp)], [(PatElem dec, VName)],
[(Param dec, SubExp)], Result)
-> ([(Ident, SubExp)], [(PatElem dec, VName)],
[(Param dec, SubExp)], Result)
checkInvariance ([], [(PatElem (LetDec rep), VName)]
explpat, [], []) forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) [(FParam rep, SubExp)]
merge Result
res of
([], [(PatElem (LetDec rep), VName)]
_, [(FParam rep, SubExp)]
_, Result
_) ->
forall {k} (rep :: k). Rule rep
Skip
([(Ident, SubExp)]
invariant, [(PatElem (LetDec rep), VName)]
explpat', [(FParam rep, SubExp)]
merge', Result
res') -> forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
let loopbody' :: Body rep
loopbody' = Body rep
loopbody {bodyResult :: Result
bodyResult = Result
res'}
explpat'' :: [PatElem (LetDec rep)]
explpat'' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(PatElem (LetDec rep), VName)]
explpat'
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Ident, SubExp)]
invariant forall a b. (a -> b) -> a -> b
$ \(Ident
v1, SubExp
v2) ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (LetDec rep)]
explpat'') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
merge' LoopForm rep
form Body rep
loopbody'
where
res :: Result
res = forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
loopbody
namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge
removeFromResult :: (Param dec, b)
-> [(PatElem dec, VName)]
-> (Maybe (Ident, b), [(PatElem dec, VName)])
removeFromResult (Param dec
mergeParam, b
mergeInit) [(PatElem dec, VName)]
explpat' =
case forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName Param dec
mergeParam) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(PatElem dec, VName)]
explpat' of
([(PatElem dec
patelem, VName
_)], [(PatElem dec, VName)]
rest) ->
(forall a. a -> Maybe a
Just (forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem dec
patelem, b
mergeInit), [(PatElem dec, VName)]
rest)
([(PatElem dec, VName)]
_, [(PatElem dec, VName)]
_) ->
(forall a. Maybe a
Nothing, [(PatElem dec, VName)]
explpat')
checkInvariance :: (VName, (Param dec, SubExp), SubExpRes)
-> ([(Ident, SubExp)], [(PatElem dec, VName)],
[(Param dec, SubExp)], Result)
-> ([(Ident, SubExp)], [(PatElem dec, VName)],
[(Param dec, SubExp)], Result)
checkInvariance
(VName
pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExpRes
resExp)
([(Ident, SubExp)]
invariant, [(PatElem dec, VName)]
explpat', [(Param dec, SubExp)]
merge', Result
resExps)
| Bool
isInvariant,
forall dec. Param dec -> VName
paramName Param dec
mergeParam VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn LoopForm rep
form =
let (Maybe (Ident, SubExp)
stm, [(PatElem dec, VName)]
explpat'') =
forall {dec} {dec} {b}.
Typed dec =>
(Param dec, b)
-> [(PatElem dec, VName)]
-> (Maybe (Ident, b), [(PatElem dec, VName)])
removeFromResult (Param dec
mergeParam, SubExp
mergeInit) [(PatElem dec, VName)]
explpat'
in ( forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. a -> a
id (:) Maybe (Ident, SubExp)
stm forall a b. (a -> b) -> a -> b
$ (forall dec. Typed dec => Param dec -> Ident
paramIdent Param dec
mergeParam, SubExp
mergeInit) forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
[(PatElem dec, VName)]
explpat'',
[(Param dec, SubExp)]
merge',
Result
resExps
)
where
isInvariant :: Bool
isInvariant
| Var VName
v2 <- SubExpRes -> SubExp
resSubExp SubExpRes
resExp,
forall dec. Param dec -> VName
paramName Param dec
mergeParam forall a. Eq a => a -> a -> Bool
== VName
v2 =
forall {dec}. FreeIn dec => Names -> Param dec -> Bool
allExistentialInvariant
([VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant)
Param dec
mergeParam
| SubExp
mergeInit forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
resExp = Bool
True
| Var VName
init_v <- SubExp
mergeInit,
Just (SubExp
p_init, SubExp
p_res) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (SubExp, SubExp)
ST.lookupLoopParam VName
init_v TopDown rep
vtable,
SubExp
p_init forall a. Eq a => a -> a -> Bool
== SubExpRes -> SubExp
resSubExp SubExpRes
resExp,
SubExp
p_res forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
pat_name =
Bool
True
| Bool
otherwise = Bool
False
checkInvariance
(VName
_pat_name, (Param dec
mergeParam, SubExp
mergeInit), SubExpRes
resExp)
([(Ident, SubExp)]
invariant, [(PatElem dec, VName)]
explpat', [(Param dec, SubExp)]
merge', Result
resExps) =
([(Ident, SubExp)]
invariant, [(PatElem dec, VName)]
explpat', (Param dec
mergeParam, SubExp
mergeInit) forall a. a -> [a] -> [a]
: [(Param dec, SubExp)]
merge', SubExpRes
resExp forall a. a -> [a] -> [a]
: Result
resExps)
allExistentialInvariant :: Names -> Param dec -> Bool
allExistentialInvariant Names
namesOfInvariant Param dec
mergeParam =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) forall a b. (a -> b) -> a -> b
$
Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$
forall a. FreeIn a => a -> Names
freeIn Param dec
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (forall dec. Param dec -> VName
paramName Param dec
mergeParam)
invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
(VName
name VName -> Names -> Bool
`notNameIn` Names
namesOfMergeParams)
Bool -> Bool -> Bool
|| (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant)
simplifyClosedFormLoop :: BuilderOps rep => TopDownRuleDoLoop rep
simplifyClosedFormLoop :: forall rep. BuilderOps rep => TopDownRuleDoLoop rep
simplifyClosedFormLoop TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)]
val, ForLoop VName
i IntType
it SubExp
bound [], Body rep
body) =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall rep.
BuilderOps rep =>
Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pat (LetDec rep)
pat [(FParam rep, SubExp)]
val (VName -> Names
oneName VName
i) IntType
it SubExp
bound Body rep
body
simplifyClosedFormLoop TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], LoopForm rep, Body rep)
_ = forall {k} (rep :: k). Rule rep
Skip
simplifyLoopVariables :: (BuilderOps rep, Aliased rep) => TopDownRuleDoLoop rep
simplifyLoopVariables :: forall rep. (BuilderOps rep, Aliased rep) => TopDownRuleDoLoop rep
simplifyLoopVariables TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, form :: LoopForm rep
form@(ForLoop VName
i IntType
it SubExp
num_iters [(Param (LParamInfo rep), VName)]
loop_vars), Body rep
body)
| [Maybe (RuleM rep IndexResult)]
simplifiable <- forall a b. (a -> b) -> [a] -> [b]
map forall {dec}.
Typed dec =>
(Param dec, VName) -> Maybe (RuleM rep IndexResult)
checkIfSimplifiable [(Param (LParamInfo rep), VName)]
loop_vars,
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall a. Maybe a -> Bool
isNothing [Maybe (RuleM rep IndexResult)]
simplifiable = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
([Maybe (Param (LParamInfo rep), VName)]
maybe_loop_vars, [Stms rep]
body_prefix_stms) <-
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form) forall a b. (a -> b) -> a -> b
$
forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *} {dec}.
MonadBuilder f =>
(Param dec, VName)
-> Maybe (f IndexResult)
-> f (Maybe (Param dec, VName), Stms (Rep f))
onLoopVar [(Param (LParamInfo rep), VName)]
loop_vars [Maybe (RuleM rep IndexResult)]
simplifiable
if [Maybe (Param (LParamInfo rep), VName)]
maybe_loop_vars forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just [(Param (LParamInfo rep), VName)]
loop_vars
then forall {k} (rep :: k) a. RuleM rep a
cannotSimplify
else do
Body rep
body' <- forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat [Stms rep]
body_prefix_stms
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
body
let form' :: LoopForm rep
form' = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
num_iters forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes [Maybe (Param (LParamInfo rep), VName)]
maybe_loop_vars
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form' Body rep
body'
where
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
| VName
v forall a. Eq a => a -> a -> Bool
== VName
i = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
| Bool
otherwise = forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown rep
vtable
seType (Constant PrimValue
v) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
consumed_in_body :: Names
consumed_in_body = forall {k} (rep :: k). Aliased rep => Body rep -> Names
consumedInBody Body rep
body
vtable' :: TopDown rep
vtable' = forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form) forall a. Semigroup a => a -> a -> a
<> TopDown rep
vtable
checkIfSimplifiable :: (Param dec, VName) -> Maybe (RuleM rep IndexResult)
checkIfSimplifiable (Param dec
p, VName
arr) =
forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing
TopDown rep
vtable'
SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType
VName
arr
(forall d. [DimIndex d] -> Slice d
Slice (forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) forall a. a -> [a] -> [a]
: forall d. Slice d -> [DimIndex d]
unSlice (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
p) [])))
forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param dec
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body
onLoopVar :: (Param dec, VName)
-> Maybe (f IndexResult)
-> f (Maybe (Param dec, VName), Stms (Rep f))
onLoopVar (Param dec
p, VName
arr) Maybe (f IndexResult)
Nothing =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just (Param dec
p, VName
arr), forall a. Monoid a => a
mempty)
onLoopVar (Param dec
p, VName
arr) (Just f IndexResult
m) = do
(IndexResult
x, Stms (Rep f)
x_stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms f IndexResult
m
case IndexResult
x of
IndexResult Certs
cs VName
arr' (Slice [DimIndex SubExp]
slice)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i `nameIn`) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn) Stms (Rep f)
x_stms,
DimFix (Var VName
j) : [DimIndex SubExp]
slice' <- [DimIndex SubExp]
slice,
VName
j forall a. Eq a => a -> a -> Bool
== VName
i,
VName
i VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn [DimIndex SubExp]
slice -> do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep f)
x_stms
SubExp
w <- forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
VName
for_in_partial <-
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"for_in_partial" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) forall a. a -> [a] -> [a]
: [DimIndex SubExp]
slice'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just (Param dec
p, VName
for_in_partial), forall a. Monoid a => a
mempty)
SubExpResult Certs
cs SubExp
se
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} {rep :: k}. Exp rep -> Bool
notIndex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) Stms (Rep f)
x_stms -> do
Stms (Rep f)
x_stms' <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep f)
x_stms
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
p] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Maybe a
Nothing, Stms (Rep f)
x_stms')
IndexResult
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just (Param dec
p, VName
arr), forall a. Monoid a => a
mempty)
notIndex :: Exp rep -> Bool
notIndex (BasicOp Index {}) = Bool
False
notIndex Exp rep
_ = Bool
True
simplifyLoopVariables TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], LoopForm rep, Body rep)
_ = forall {k} (rep :: k). Rule rep
Skip
narrowLoopType :: (BuilderOps rep) => TopDownRuleDoLoop rep
narrowLoopType :: forall rep. BuilderOps rep => TopDownRuleDoLoop rep
narrowLoopType TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, ForLoop VName
i IntType
Int64 SubExp
n [], Body rep
body)
| Just (SubExp
n', IntType
it', Certs
cs) <- Maybe (SubExp, IntType, Certs)
smallerType =
forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
VName
i' <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
i
let form' :: LoopForm rep
form' = forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i' IntType
it' SubExp
n' []
Body rep
body' <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf forall {k} {rep :: k}. LoopForm rep
form' forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
i] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it' IntType
Int64) (VName -> SubExp
Var VName
i')
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam rep, SubExp)]
merge forall {k} {rep :: k}. LoopForm rep
form' Body rep
body'
where
smallerType :: Maybe (SubExp, IntType, Certs)
smallerType
| Var VName
n' <- SubExp
n,
Just (ConvOp (SExt IntType
it' IntType
_) SubExp
n'', Certs
cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
n' TopDown rep
vtable =
forall a. a -> Maybe a
Just (SubExp
n'', IntType
it', Certs
cs)
| Constant (IntValue (Int64Value Int64
n')) <- SubExp
n,
forall a. Integral a => a -> Integer
toInteger Int64
n' forall a. Ord a => a -> a -> Bool
<= forall a. Integral a => a -> Integer
toInteger (forall a. Bounded a => a
maxBound :: Int32) =
forall a. a -> Maybe a
Just (IntType -> Integer -> SubExp
intConst IntType
Int32 (forall a. Integral a => a -> Integer
toInteger Int64
n'), IntType
Int32, forall a. Monoid a => a
mempty)
| Bool
otherwise =
forall a. Maybe a
Nothing
narrowLoopType TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], LoopForm rep, Body rep)
_ = forall {k} (rep :: k). Rule rep
Skip
unroll ::
BuilderOps rep =>
Integer ->
[(FParam rep, SubExpRes)] ->
(VName, IntType, Integer) ->
[(LParam rep, VName)] ->
Body rep ->
RuleM rep [SubExpRes]
unroll :: forall rep.
BuilderOps rep =>
Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep Result
unroll Integer
n [(FParam rep, SubExpRes)]
merge (VName
iv, IntType
it, Integer
i) [(LParam rep, VName)]
loop_vars Body rep
body
| Integer
i forall a. Ord a => a -> a -> Bool
>= Integer
n =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(FParam rep, SubExpRes)]
merge
| Bool
otherwise = do
Body rep
iter_body <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, SubExpRes)]
merge forall a b. (a -> b) -> a -> b
$ \(FParam rep
mergevar, SubExpRes Certs
cs SubExp
mergeinit) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName FParam rep
mergevar] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
iv] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam rep, VName)]
loop_vars forall a b. (a -> b) -> a -> b
$ \(LParam rep
p, VName
arr) ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName LParam rep
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
i) forall a. a -> [a] -> [a]
: forall d. Slice d -> [DimIndex d]
unSlice (TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice (forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam rep
p) [])
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body
Body rep
iter_body' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
iter_body
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
iter_body'
let merge' :: [(FParam rep, SubExpRes)]
merge' = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExpRes)]
merge) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
iter_body'
forall rep.
BuilderOps rep =>
Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep Result
unroll Integer
n [(FParam rep, SubExpRes)]
merge' (VName
iv, IntType
it, Integer
i forall a. Num a => a -> a -> a
+ Integer
1) [(LParam rep, VName)]
loop_vars Body rep
body
simplifyKnownIterationLoop :: BuilderOps rep => TopDownRuleDoLoop rep
simplifyKnownIterationLoop :: forall rep. BuilderOps rep => TopDownRuleDoLoop rep
simplifyKnownIterationLoop TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux ([(FParam rep, SubExp)]
merge, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam rep, VName)]
loop_vars, Body rep
body)
| IntValue IntValue
n <- PrimValue
iters,
IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
Result
res <- forall rep.
BuilderOps rep =>
Integer
-> [(FParam rep, SubExpRes)]
-> (VName, IntType, Integer)
-> [(LParam rep, VName)]
-> Body rep
-> RuleM rep Result
unroll (forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) (forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second SubExp -> SubExpRes
subExpRes) [(FParam rep, SubExp)]
merge) (VName
i, IntType
it, Integer
0) [(LParam rep, VName)]
loop_vars Body rep
body
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat) Result
res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ ([(FParam rep, SubExp)], LoopForm rep, Body rep)
_ =
forall {k} (rep :: k). Rule rep
Skip
topDownRules :: (BuilderOps rep, Aliased rep) => [TopDownRule rep]
topDownRules :: forall rep. (BuilderOps rep, Aliased rep) => [TopDownRule rep]
topDownRules =
[ forall {k} (rep :: k) a.
RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop forall rep. BuilderOps rep => TopDownRuleDoLoop rep
hoistLoopInvariantMergeVariables,
forall {k} (rep :: k) a.
RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop forall rep. BuilderOps rep => TopDownRuleDoLoop rep
simplifyClosedFormLoop,
forall {k} (rep :: k) a.
RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop forall rep. BuilderOps rep => TopDownRuleDoLoop rep
simplifyKnownIterationLoop,
forall {k} (rep :: k) a.
RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop forall rep. (BuilderOps rep, Aliased rep) => TopDownRuleDoLoop rep
simplifyLoopVariables,
forall {k} (rep :: k) a.
RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop forall rep. BuilderOps rep => TopDownRuleDoLoop rep
narrowLoopType
]
bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
[ forall {k} (rep :: k) a.
RuleDoLoop rep a -> SimplificationRule rep a
RuleDoLoop forall rep. BuilderOps rep => BottomUpRuleDoLoop rep
removeRedundantMergeVariables
]
loopRules :: (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules :: forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules = forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook forall rep. (BuilderOps rep, Aliased rep) => [TopDownRule rep]
topDownRules forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules