{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Safe #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
module Futhark.Optimise.Simplify.Rules
( standardRules
, removeUnnecessaryCopy
)
where
import Control.Monad
import Data.Either
import Data.List (find, isSuffixOf, partition, sort)
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Analysis.DataDependencies
import Futhark.Optimise.Simplify.ClosedForm
import Futhark.Optimise.Simplify.Rule
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Transform.Rename
import Futhark.Construct
import Futhark.Util
topDownRules :: (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules :: [TopDownRule lore]
topDownRules = [ RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables
, RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop
, RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop
, RuleDoLoop lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (TopDown lore)
forall lore.
(BinderOps lore, Aliased lore) =>
TopDownRuleDoLoop lore
simplifyLoopVariables
, RuleGeneric lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleGeneric lore a -> SimplificationRule lore a
RuleGeneric RuleGeneric lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun
, RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
ruleIf
, RuleIf lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant
, RuleBasicOp lore (TopDown lore) -> TopDownRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (TopDown lore)
forall lore. BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp
]
bottomUpRules :: BinderOps lore => [BottomUpRule lore]
bottomUpRules :: [BottomUpRule lore]
bottomUpRules = [ RuleDoLoop lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleDoLoop lore a -> SimplificationRule lore a
RuleDoLoop RuleDoLoop lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables
, RuleIf lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleIf lore a -> SimplificationRule lore a
RuleIf RuleIf lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult
, RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex
, RuleBasicOp lore (BottomUp lore) -> BottomUpRule lore
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp lore (BottomUp lore)
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat
]
asInt32PrimExp :: PrimExp v -> PrimExp v
asInt32PrimExp :: PrimExp v -> PrimExp v
asInt32PrimExp PrimExp v
pe
| IntType IntType
it <- PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp v
pe, IntType
it IntType -> IntType -> Bool
forall a. Eq a => a -> a -> Bool
/= IntType
Int32 =
IntType -> PrimExp v -> PrimExp v
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int32 PrimExp v
pe
| Bool
otherwise =
PrimExp v
pe
standardRules :: (BinderOps lore, Aliased lore) => RuleBook lore
standardRules :: RuleBook lore
standardRules = [TopDownRule lore] -> [BottomUpRule lore] -> RuleBook lore
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule lore]
forall lore. (BinderOps lore, Aliased lore) => [TopDownRule lore]
topDownRules [BottomUpRule lore]
forall lore. BinderOps lore => [BottomUpRule lore]
bottomUpRules
removeRedundantMergeVariables :: BinderOps lore => BottomUpRuleDoLoop lore
removeRedundantMergeVariables :: BottomUpRuleDoLoop lore
removeRedundantMergeVariables (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
body)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> Bool) -> [(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (FParam lore -> Bool
usedAfterLoop (FParam lore -> Bool)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val,
[(FParam lore, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(FParam lore, SubExp)]
ctx =
let ([SubExp]
ctx_es, [SubExp]
val_es) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(FParam lore, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
ctx) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
body
necessaryForReturned :: Names
necessaryForReturned =
(FParam lore -> Bool)
-> [(FParam lore, SubExp)] -> Map VName Names -> Names
forall dec.
(Param dec -> Bool)
-> [(Param dec, SubExp)] -> Map VName Names -> Names
findNecessaryForReturned FParam lore -> Bool
usedAfterLoopOrInForm
([FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_es[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[SubExp]
val_es) (BodyT lore -> Map VName Names
forall lore. ASTLore lore => Body lore -> Map VName Names
dataDependencies BodyT lore
body)
resIsNecessary :: ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ((FParam lore
v,SubExp
_), SubExp
_) =
FParam lore -> Bool
usedAfterLoop FParam lore
v Bool -> Bool -> Bool
||
FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
v VName -> Names -> Bool
`nameIn` Names
necessaryForReturned Bool -> Bool -> Bool
||
FParam lore -> Bool
referencedInPat FParam lore
v Bool -> Bool -> Bool
||
FParam lore -> Bool
referencedInForm FParam lore
v
([((FParam lore, SubExp), SubExp)]
keep_ctx, [((FParam lore, SubExp), SubExp)]
discard_ctx) =
(((FParam lore, SubExp), SubExp) -> Bool)
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
[((FParam lore, SubExp), SubExp)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary ([((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
[((FParam lore, SubExp), SubExp)]))
-> [((FParam lore, SubExp), SubExp)]
-> ([((FParam lore, SubExp), SubExp)],
[((FParam lore, SubExp), SubExp)])
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
ctx [SubExp]
ctx_es
([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart, [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart) =
((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp)) -> Bool)
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
[(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (((FParam lore, SubExp), SubExp) -> Bool
resIsNecessary (((FParam lore, SubExp), SubExp) -> Bool)
-> ((PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp))
-> (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))
-> ((FParam lore, SubExp), SubExp)
forall a b. (a, b) -> b
snd) ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
[(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]))
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))],
[(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))])
-> [((FParam lore, SubExp), SubExp)]
-> [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
val [SubExp]
val_es
([PatElemT (LetDec lore)]
keep_valpatelems, [((FParam lore, SubExp), SubExp)]
keep_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
keep_valpart
([PatElemT (LetDec lore)]
_discard_valpatelems, [((FParam lore, SubExp), SubExp)]
discard_val) = [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
-> ([PatElemT (LetDec lore)], [((FParam lore, SubExp), SubExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT (LetDec lore), ((FParam lore, SubExp), SubExp))]
discard_valpart
([(FParam lore, SubExp)]
ctx', [SubExp]
ctx_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_ctx
([(FParam lore, SubExp)]
val', [SubExp]
val_es') = [((FParam lore, SubExp), SubExp)]
-> ([(FParam lore, SubExp)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [((FParam lore, SubExp), SubExp)]
keep_val
body' :: BodyT lore
body' = BodyT lore
body { bodyResult :: [SubExp]
bodyResult = [SubExp]
ctx_es' [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_es' }
free_in_keeps :: Names
free_in_keeps = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn [PatElemT (LetDec lore)]
keep_valpatelems
stillUsedContext :: PatElemT (LetDec lore) -> Bool
stillUsedContext PatElemT (LetDec lore)
pat_elem =
PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pat_elem VName -> Names -> Bool
`nameIn`
(Names
free_in_keeps Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
[PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ((PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT (LetDec lore) -> PatElemT (LetDec lore) -> Bool
forall a. Eq a => a -> a -> Bool
/=PatElemT (LetDec lore)
pat_elem) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat))
pat' :: Pattern lore
pat' = Pattern lore
pat { patternValueElements :: [PatElemT (LetDec lore)]
patternValueElements = [PatElemT (LetDec lore)]
keep_valpatelems
, patternContextElements :: [PatElemT (LetDec lore)]
patternContextElements =
(PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter PatElemT (LetDec lore) -> Bool
stillUsedContext ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat }
in if [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val' [(FParam lore, SubExp)] -> [(FParam lore, SubExp)] -> Bool
forall a. Eq a => a -> a -> Bool
== [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
then Rule lore
forall lore. Rule lore
Skip
else RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
BodyT lore
body'' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
(([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_ctx
(([VName], ExpT lore) -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (([VName] -> ExpT lore -> RuleM lore ())
-> ([VName], ExpT lore) -> RuleM lore ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [VName] -> ExpT lore -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames) ([([VName], ExpT lore)] -> RuleM lore ())
-> [([VName], ExpT lore)] -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [((FParam lore, SubExp), SubExp)] -> [([VName], ExpT lore)]
forall b lore.
[((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms [((FParam lore, SubExp), SubExp)]
discard_val
BodyT lore -> RuleM lore (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat' (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
body''
where pat_used :: [Bool]
pat_used = (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames Pattern lore
pat
used_vals :: [VName]
used_vals = ((VName, Bool) -> VName) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Bool) -> VName
forall a b. (a, b) -> a
fst ([(VName, Bool)] -> [VName]) -> [(VName, Bool)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Bool) -> Bool) -> [(VName, Bool)] -> [(VName, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName, Bool) -> Bool
forall a b. (a, b) -> b
snd ([(VName, Bool)] -> [(VName, Bool)])
-> [(VName, Bool)] -> [(VName, Bool)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Bool] -> [(VName, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val) [Bool]
pat_used
usedAfterLoop :: FParam lore -> Bool
usedAfterLoop = (VName -> [VName] -> Bool) -> [VName] -> VName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [VName]
used_vals (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
usedAfterLoopOrInForm :: FParam lore -> Bool
usedAfterLoopOrInForm FParam lore
p =
FParam lore -> Bool
usedAfterLoop FParam lore
p Bool -> Bool -> Bool
|| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
p VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form
patAnnotNames :: Names
patAnnotNames = [FParam lore] -> Names
forall a. FreeIn a => a -> Names
freeIn ([FParam lore] -> Names) -> [FParam lore] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst ([(FParam lore, SubExp)] -> [FParam lore])
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val
referencedInPat :: FParam lore -> Bool
referencedInPat = (VName -> Names -> Bool
`nameIn` Names
patAnnotNames) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
referencedInForm :: FParam lore -> Bool
referencedInForm = (VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form) (VName -> Bool) -> (FParam lore -> VName) -> FParam lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> VName
forall dec. Param dec -> VName
paramName
dummyStms :: [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
dummyStms = (((FParam lore, SubExp), b) -> ([VName], ExpT lore))
-> [((FParam lore, SubExp), b)] -> [([VName], ExpT lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((FParam lore, SubExp), b) -> ([VName], ExpT lore)
forall dec b lore.
DeclTyped dec =>
((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm
dummyStm :: ((Param dec, SubExp), b) -> ([VName], ExpT lore)
dummyStm ((Param dec
p,SubExp
e), b
_)
| TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (Param dec -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType Param dec
p),
Var VName
v <- SubExp
e = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
| Bool
otherwise = ([Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p], BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e)
removeRedundantMergeVariables (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ =
Rule lore
forall lore. Rule lore
Skip
hoistLoopInvariantMergeVariables :: BinderOps lore => TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables :: TopDownRuleDoLoop lore
hoistLoopInvariantMergeVariables TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, LoopForm lore
form, BodyT lore
loopbody) =
case (((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp]))
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> [((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
checkInvariance ([], [(PatElemT (LetDec lore), VName)]
explpat, [], []) ([((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp]))
-> [((FParam lore, SubExp), SubExp)]
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
forall a b. (a -> b) -> a -> b
$
[(FParam lore, SubExp)]
-> [SubExp] -> [((FParam lore, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam lore, SubExp)]
merge [SubExp]
res of
([], [(PatElemT (LetDec lore), VName)]
_, [(FParam lore, SubExp)]
_, [SubExp]
_) ->
Rule lore
forall lore. Rule lore
Skip
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
res') -> RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let loopbody' :: BodyT lore
loopbody' = BodyT lore
loopbody { bodyResult :: [SubExp]
bodyResult = [SubExp]
res' }
invariantShape :: (a, VName) -> Bool
invariantShape :: (a, VName) -> Bool
invariantShape (a
_, VName
shapemerge) = VName
shapemerge VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem`
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge'
([(PatElemT (LetDec lore), VName)]
implpat',[(PatElemT (LetDec lore), VName)]
implinvariant) = ((PatElemT (LetDec lore), VName) -> Bool)
-> [(PatElemT (LetDec lore), VName)]
-> ([(PatElemT (LetDec lore), VName)],
[(PatElemT (LetDec lore), VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (PatElemT (LetDec lore), VName) -> Bool
forall a. (a, VName) -> Bool
invariantShape [(PatElemT (LetDec lore), VName)]
implpat
implinvariant' :: [(Ident, SubExp)]
implinvariant' = [ (PatElemT (LetDec lore) -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT (LetDec lore)
p, VName -> SubExp
Var VName
v) | (PatElemT (LetDec lore)
p,VName
v) <- [(PatElemT (LetDec lore), VName)]
implinvariant ]
implpat'' :: [PatElemT (LetDec lore)]
implpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
implpat'
explpat'' :: [PatElemT (LetDec lore)]
explpat'' = ((PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore))
-> [(PatElemT (LetDec lore), VName)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> [a] -> [b]
map (PatElemT (LetDec lore), VName) -> PatElemT (LetDec lore)
forall a b. (a, b) -> a
fst [(PatElemT (LetDec lore), VName)]
explpat'
([(FParam lore, SubExp)]
ctx', [(FParam lore, SubExp)]
val') = Int
-> [(FParam lore, SubExp)]
-> ([(FParam lore, SubExp)], [(FParam lore, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(PatElemT (LetDec lore), VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElemT (LetDec lore), VName)]
implpat') [(FParam lore, SubExp)]
merge'
[(Ident, SubExp)]
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Ident, SubExp)]
invariant [(Ident, SubExp)] -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Ident, SubExp)]
implinvariant') (((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((Ident, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(Ident
v1,SubExp
v2) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Ident -> VName
identName Ident
v1] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v2
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
implpat'' [PatElemT (LetDec lore)]
explpat'') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx' [(FParam lore, SubExp)]
val' LoopForm lore
form BodyT lore
loopbody'
where merge :: [(FParam lore, SubExp)]
merge = [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
val
res :: [SubExp]
res = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
loopbody
implpat :: [(PatElemT (LetDec lore), VName)]
implpat = [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
ctx
explpat :: [(PatElemT (LetDec lore), VName)]
explpat = [PatElemT (LetDec lore)]
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements Pattern lore
pat) ([VName] -> [(PatElemT (LetDec lore), VName)])
-> [VName] -> [(PatElemT (LetDec lore), VName)]
forall a b. (a -> b) -> a -> b
$
((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
val
namesOfMergeParams :: Names
namesOfMergeParams = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall dec. Param dec -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val
removeFromResult :: (Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (Param dec
mergeParam,b
mergeInit) [(PatElemT dec, VName)]
explpat' =
case ((PatElemT dec, VName) -> Bool)
-> [(PatElemT dec, VName)]
-> ([(PatElemT dec, VName)], [(PatElemT dec, VName)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
mergeParam) (VName -> Bool)
-> ((PatElemT dec, VName) -> VName)
-> (PatElemT dec, VName)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, VName) -> VName
forall a b. (a, b) -> b
snd) [(PatElemT dec, VName)]
explpat' of
([(PatElemT dec
patelem,VName
_)], [(PatElemT dec, VName)]
rest) ->
((Ident, b) -> Maybe (Ident, b)
forall a. a -> Maybe a
Just (PatElemT dec -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT dec
patelem, b
mergeInit), [(PatElemT dec, VName)]
rest)
([(PatElemT dec, VName)]
_, [(PatElemT dec, VName)]
_) ->
(Maybe (Ident, b)
forall a. Maybe a
Nothing, [(PatElemT dec, VName)]
explpat')
checkInvariance :: ((FParam lore, SubExp), SubExp)
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
-> ([(Ident, SubExp)], [(PatElemT (LetDec lore), VName)],
[(FParam lore, SubExp)], [SubExp])
checkInvariance
((FParam lore
mergeParam,SubExp
mergeInit), SubExp
resExp)
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
| Bool -> Bool
not (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam)) Bool -> Bool -> Bool
||
TypeBase Shape Uniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (FParam lore -> TypeBase Shape Uniqueness
forall dec. DeclTyped dec => Param dec -> TypeBase Shape Uniqueness
paramDeclType FParam lore
mergeParam) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
SubExp -> Bool
isInvariant SubExp
resExp,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> Names -> Bool
`nameIn` LoopForm lore -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm lore
form =
let (Maybe (Ident, SubExp)
bnd, [(PatElemT (LetDec lore), VName)]
explpat'') =
(FParam lore, SubExp)
-> [(PatElemT (LetDec lore), VName)]
-> (Maybe (Ident, SubExp), [(PatElemT (LetDec lore), VName)])
forall dec dec b.
Typed dec =>
(Param dec, b)
-> [(PatElemT dec, VName)]
-> (Maybe (Ident, b), [(PatElemT dec, VName)])
removeFromResult (FParam lore
mergeParam,SubExp
mergeInit) [(PatElemT (LetDec lore), VName)]
explpat'
in (([(Ident, SubExp)] -> [(Ident, SubExp)])
-> ((Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)])
-> Maybe (Ident, SubExp)
-> [(Ident, SubExp)]
-> [(Ident, SubExp)]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> a
id (:) Maybe (Ident, SubExp)
bnd ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ (FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent FParam lore
mergeParam, SubExp
mergeInit) (Ident, SubExp) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a. a -> [a] -> [a]
: [(Ident, SubExp)]
invariant,
[(PatElemT (LetDec lore), VName)]
explpat'', [(FParam lore, SubExp)]
merge', [SubExp]
resExps)
where
isInvariant :: SubExp -> Bool
isInvariant (Var VName
v2)
| FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v2 =
Names -> FParam lore -> Bool
allExistentialInvariant
([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Ident, SubExp) -> VName) -> [(Ident, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> VName
identName (Ident -> VName)
-> ((Ident, SubExp) -> Ident) -> (Ident, SubExp) -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst) [(Ident, SubExp)]
invariant) FParam lore
mergeParam
isInvariant SubExp
_ = SubExp
mergeInit SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
resExp
checkInvariance ((FParam lore
mergeParam,SubExp
mergeInit), SubExp
resExp) ([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', [(FParam lore, SubExp)]
merge', [SubExp]
resExps) =
([(Ident, SubExp)]
invariant, [(PatElemT (LetDec lore), VName)]
explpat', (FParam lore
mergeParam,SubExp
mergeInit)(FParam lore, SubExp)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. a -> [a] -> [a]
:[(FParam lore, SubExp)]
merge', SubExp
resExpSubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:[SubExp]
resExps)
allExistentialInvariant :: Names -> FParam lore -> Bool
allExistentialInvariant Names
namesOfInvariant FParam lore
mergeParam =
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
FParam lore -> Names
forall a. FreeIn a => a -> Names
freeIn FParam lore
mergeParam Names -> Names -> Names
`namesSubtract` VName -> Names
oneName (FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergeParam)
invariantOrNotMergeParam :: Names -> VName -> Bool
invariantOrNotMergeParam Names
namesOfInvariant VName
name =
Bool -> Bool
not (VName
name VName -> Names -> Bool
`nameIn` Names
namesOfMergeParams) Bool -> Bool -> Bool
||
VName
name VName -> Names -> Bool
`nameIn` Names
namesOfInvariant
type TypeLookup = SubExp -> Maybe Type
type SimpleRule lore = VarLookup lore -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)
simpleRules :: [SimpleRule lore]
simpleRules :: [SimpleRule lore]
simpleRules = [ SimpleRule lore
forall lore. SimpleRule lore
simplifyBinOp
, SimpleRule lore
forall lore. SimpleRule lore
simplifyCmpOp
, SimpleRule lore
forall lore. SimpleRule lore
simplifyUnOp
, SimpleRule lore
forall lore. SimpleRule lore
simplifyConvOp
, SimpleRule lore
forall lore. SimpleRule lore
simplifyAssert
, SimpleRule lore
forall lore. SimpleRule lore
copyScratchToScratch
, SimpleRule lore
forall lore. SimpleRule lore
simplifyIdentityReshape
, SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReshape
, SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeScratch
, SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReplicate
, SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIota
, SimpleRule lore
forall lore. SimpleRule lore
improveReshape ]
simplifyClosedFormLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyClosedFormLoop :: TopDownRuleDoLoop lore
simplifyClosedFormLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ ([], [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
_ SubExp
bound [], BodyT lore
body) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> BodyT lore
-> RuleM lore ()
forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
val (VName -> Names
oneName VName
i) SubExp
bound BodyT lore
body
simplifyClosedFormLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
simplifyLoopVariables :: (BinderOps lore, Aliased lore) => TopDownRuleDoLoop lore
simplifyLoopVariables :: TopDownRuleDoLoop lore
simplifyLoopVariables TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, form :: LoopForm lore
form@(ForLoop VName
i IntType
it SubExp
num_iters [(LParam lore, VName)]
loop_vars), BodyT lore
body)
| [Maybe (RuleM lore IndexResult)]
simplifiable <- ((LParam lore, VName) -> Maybe (RuleM lore IndexResult))
-> [(LParam lore, VName)] -> [Maybe (RuleM lore IndexResult)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable [(LParam lore, VName)]
loop_vars,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Maybe (RuleM lore IndexResult) -> Bool)
-> [Maybe (RuleM lore IndexResult)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Maybe (RuleM lore IndexResult) -> Bool
forall a. Maybe a -> Bool
isNothing [Maybe (RuleM lore IndexResult)]
simplifiable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
([Maybe (LParam lore, VName)]
maybe_loop_vars, [Stms lore]
body_prefix_stms) <-
Scope lore
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) (RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. (a -> b) -> a -> b
$
[(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (LParam lore, VName), Stms lore)]
-> ([Maybe (LParam lore, VName)], [Stms lore]))
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
-> RuleM lore ([Maybe (LParam lore, VName)], [Stms lore])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore))
-> [(LParam lore, VName)]
-> [Maybe (RuleM lore IndexResult)]
-> RuleM lore [(Maybe (LParam lore, VName), Stms lore)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar [(LParam lore, VName)]
loop_vars [Maybe (RuleM lore IndexResult)]
simplifiable
if [Maybe (LParam lore, VName)]
maybe_loop_vars [Maybe (LParam lore, VName)]
-> [Maybe (LParam lore, VName)] -> Bool
forall a. Eq a => a -> a -> Bool
== ((LParam lore, VName) -> Maybe (LParam lore, VName))
-> [(LParam lore, VName)] -> [Maybe (LParam lore, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just [(LParam lore, VName)]
loop_vars
then RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
else do BodyT lore
body' <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat [Stms lore]
body_prefix_stms
[SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> RuleM lore (BodyT lore))
-> RuleM lore [SubExp] -> RuleM lore (BodyT lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body (Lore (RuleM lore)) -> RuleM lore [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT lore
Body (Lore (RuleM lore))
body
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val
(VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
num_iters ([(LParam lore, VName)] -> LoopForm lore)
-> [(LParam lore, VName)] -> LoopForm lore
forall a b. (a -> b) -> a -> b
$ [Maybe (LParam lore, VName)] -> [(LParam lore, VName)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (LParam lore, VName)]
maybe_loop_vars) BodyT lore
body'
where seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v)
| VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
| Bool
otherwise = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
consumed_in_body :: Names
consumed_in_body = BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
body
vtable' :: TopDown lore
vtable' = Scope lore -> TopDown lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (LoopForm lore -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm lore
form) TopDown lore -> TopDown lore -> TopDown lore
forall a. Semigroup a => a -> a -> a
<> TopDown lore
vtable
checkIfSimplifiable :: (LParam lore, VName) -> Maybe (RuleM lore IndexResult)
checkIfSimplifiable (LParam lore
p,VName
arr) =
SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing TopDown lore
SymbolTable (Lore (RuleM lore))
vtable' SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
arr
(SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []) (Bool -> Maybe (RuleM lore IndexResult))
-> Bool -> Maybe (RuleM lore IndexResult)
forall a b. (a -> b) -> a -> b
$
LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p VName -> Names -> Bool
`nameIn` Names
consumed_in_body
onLoopVar :: (LParam lore, VName)
-> Maybe (RuleM lore IndexResult)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
onLoopVar (LParam lore
p,VName
arr) Maybe (RuleM lore IndexResult)
Nothing =
(Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p,VName
arr), Stms lore
forall a. Monoid a => a
mempty)
onLoopVar (LParam lore
p,VName
arr) (Just RuleM lore IndexResult
m) = do
(IndexResult
x,Stms lore
x_stms) <- RuleM lore IndexResult
-> RuleM lore (IndexResult, Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms RuleM lore IndexResult
m
case IndexResult
x of
IndexResult Certificates
cs VName
arr' Slice SubExp
slice
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName
i VName -> Names -> Bool
`nameIn`) (Names -> Bool) -> (Stm lore -> Names) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn) Stms lore
x_stms,
DimFix (Var VName
j) : Slice SubExp
slice' <- Slice SubExp
slice,
VName
j VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
i, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
i VName -> Names -> Bool
`nameIn` Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
slice -> do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
SubExp
w <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> SubExp)
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr'
VName
for_in_partial <-
Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"for_in_partial" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
slice'
(Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p, VName
for_in_partial), Stms lore
forall a. Monoid a => a
mempty)
SubExpResult Certificates
cs SubExp
se
| (Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. ExpT lore -> Bool
notIndex (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) Stms lore
x_stms -> do
Stms lore
x_stms' <- RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore (Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
x_stms
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
(Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (LParam lore, VName)
forall a. Maybe a
Nothing, Stms lore
x_stms')
IndexResult
_ -> (Maybe (LParam lore, VName), Stms lore)
-> RuleM lore (Maybe (LParam lore, VName), Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ((LParam lore, VName) -> Maybe (LParam lore, VName)
forall a. a -> Maybe a
Just (LParam lore
p,VName
arr), Stms lore
forall a. Monoid a => a
mempty)
notIndex :: ExpT lore -> Bool
notIndex (BasicOp Index{}) = Bool
False
notIndex ExpT lore
_ = Bool
True
simplifyLoopVariables TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ = Rule lore
forall lore. Rule lore
Skip
unroll :: BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll :: Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge (VName
iv, IntType
it, Integer
i) [(LParam lore, VName)]
loop_vars Body lore
body
| Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
n =
[SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return ([SubExp] -> RuleM lore [SubExp])
-> [SubExp] -> RuleM lore [SubExp]
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> SubExp)
-> [(FParam lore, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam lore, SubExp)]
merge
| Bool
otherwise = do
Body lore
iter_body <- RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore))))
-> RuleM lore (Body (Lore (RuleM lore)))
-> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ do
[(FParam lore, SubExp)]
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam lore, SubExp)]
merge (((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((FParam lore, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(FParam lore
mergevar, SubExp
mergeinit) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [FParam lore -> VName
forall dec. Param dec -> VName
paramName FParam lore
mergevar] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
mergeinit
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
iv] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
it Integer
i
[(LParam lore, VName)]
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(LParam lore, VName)]
loop_vars (((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ())
-> ((LParam lore, VName) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(LParam lore
p,VName
arr) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [LParam lore -> VName
forall dec. Param dec -> VName
paramName LParam lore
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice (LParam lore -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType LParam lore
p) []
Body lore -> RuleM lore (Body lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body lore
body
Body lore
iter_body' <- Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
iter_body
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Body lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms Body lore
iter_body'
let merge' :: [(FParam lore, SubExp)]
merge' = [FParam lore] -> [SubExp] -> [(FParam lore, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((FParam lore, SubExp) -> FParam lore)
-> [(FParam lore, SubExp)] -> [FParam lore]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst [(FParam lore, SubExp)]
merge) ([SubExp] -> [(FParam lore, SubExp)])
-> [SubExp] -> [(FParam lore, SubExp)]
forall a b. (a -> b) -> a -> b
$ Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
iter_body'
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll Integer
n [(FParam lore, SubExp)]
merge' (VName
iv, IntType
it, Integer
iInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1) [(LParam lore, VName)]
loop_vars Body lore
body
simplifyKnownIterationLoop :: BinderOps lore => TopDownRuleDoLoop lore
simplifyKnownIterationLoop :: TopDownRuleDoLoop lore
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
aux ([(FParam lore, SubExp)]
ctx, [(FParam lore, SubExp)]
val, ForLoop VName
i IntType
it (Constant PrimValue
iters) [(LParam lore, VName)]
loop_vars, BodyT lore
body)
| IntValue IntValue
n <- PrimValue
iters,
IntValue -> Bool
zeroIshInt IntValue
n Bool -> Bool -> Bool
|| IntValue -> Bool
oneIshInt IntValue
n Bool -> Bool -> Bool
|| Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
res <- Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> BodyT lore
-> RuleM lore [SubExp]
forall lore.
BinderOps lore =>
Integer
-> [(FParam lore, SubExp)]
-> (VName, IntType, Integer)
-> [(LParam lore, VName)]
-> Body lore
-> RuleM lore [SubExp]
unroll (IntValue -> Integer
forall int. Integral int => IntValue -> int
valueIntegral IntValue
n) ([(FParam lore, SubExp)]
ctx[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++[(FParam lore, SubExp)]
val) (VName
i, IntType
it, Integer
0) [(LParam lore, VName)]
loop_vars BodyT lore
body
[(VName, SubExp)]
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) [SubExp]
res) (((VName, SubExp) -> RuleM lore ()) -> RuleM lore ())
-> ((VName, SubExp) -> RuleM lore ()) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
simplifyKnownIterationLoop TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ ([(FParam lore, SubExp)], [(FParam lore, SubExp)], LoopForm lore,
BodyT lore)
_ =
Rule lore
forall lore. Rule lore
Skip
removeUnnecessaryCopy :: BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy :: BottomUpRuleBasicOp lore
removeUnnecessaryCopy (SymbolTable lore
vtable,UsageTable
used) (Pattern [] [PatElemT (LetDec lore)
d]) StmAux (ExpDec lore)
_ (Copy VName
v)
| Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used),
(Bool -> Bool
not (VName
v VName -> UsageTable -> Bool
`UT.used` UsageTable
used) Bool -> Bool -> Bool
&& Bool
consumable) Bool -> Bool -> Bool
|| Bool -> Bool
not (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
d] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
where
consumable :: Bool
consumable = case VName -> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName (NameInfo lore) -> Maybe (NameInfo lore))
-> Map VName (NameInfo lore) -> Maybe (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (NameInfo lore)
forall lore. SymbolTable lore -> Scope lore
ST.toScope SymbolTable lore
vtable of
Just (FParamName FParamInfo lore
info) -> TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> TypeBase Shape Uniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ FParamInfo lore -> TypeBase Shape Uniqueness
forall t. DeclTyped t => t -> TypeBase Shape Uniqueness
declTypeOf FParamInfo lore
info
Maybe (NameInfo lore)
_ -> Bool
False
removeUnnecessaryCopy (SymbolTable lore, UsageTable)
_ PatternT (LetDec lore)
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp :: SimpleRule lore
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
case CmpOp
cmp of CmpEq{} -> Bool
True
CmpSlt{} -> Bool
False
CmpUlt{} -> Bool
False
CmpSle{} -> Bool
True
CmpUle{} -> Bool
True
FCmpLt{} -> Bool
False
FCmpLe{} -> Bool
True
CmpOp
CmpLlt -> Bool
False
CmpOp
CmpLle -> Bool
True
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> (Bool -> PrimValue) -> Bool -> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue (Bool -> Maybe (BasicOp, Certificates))
-> Maybe Bool -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2
simplifyCmpOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (CmpOp CmpEq{} (Constant (IntValue IntValue
x)) (Var VName
v))
| Just (BasicOp (ConvOp BToI{} SubExp
b), Certificates
cs) <- VarLookup lore
look VName
v =
case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
Int
1 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certificates
cs)
Int
0 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certificates
cs)
Int
_ -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certificates
cs)
simplifyCmpOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp :: SimpleRule lore
simplifyBinOp :: SimpleRule lore
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
| Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
res
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Add{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp Sub{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FAdd{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Sub{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp Add{} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certificates
cs)
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp Add{} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certificates
cs)
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp Add{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_b, Certificates
cs)
| Var VName
v2 <- SubExp
e1,
Just (BasicOp (BinOp Add{} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FSub{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Mul{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FMul{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp SMod{} SubExp
_ SubExp
e4), Certificates
v1_cs) <- VarLookup lore
look VName
v1,
SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certificates
v1_cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDiv{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SDivUp{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp FDiv{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp SQuot{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
| SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp AShr{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp Or{} SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
| SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
| SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v <- SubExp
e1,
Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
| Var VName
v <- SubExp
e2,
Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
| SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
| SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
| SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
| Var VName
v <- SubExp
e1,
Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
| Var VName
v <- SubExp
e2,
Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certificates
v1_cs)
| Var VName
v1 <- SubExp
e1,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certificates
v1_cs)
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certificates
v2_cs)
| Var VName
v2 <- SubExp
e2,
Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certificates
v2_cs)
simplifyBinOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (PrimValue -> (BasicOp, Certificates))
-> PrimValue
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (PrimValue -> BasicOp) -> PrimValue -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (SubExp -> (BasicOp, Certificates))
-> SubExp
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (SubExp -> BasicOp) -> SubExp -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
simplifyUnOp :: SimpleRule lore
simplifyUnOp :: SimpleRule lore
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
op (Constant PrimValue
v)) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (UnOp UnOp
Not (Var VName
v))
| Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certificates
v_cs)
simplifyUnOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyConvOp :: SimpleRule lore
simplifyConvOp :: SimpleRule lore
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp ConvOp
op SubExp
se)
| (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op, PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
se
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
| Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyAssert :: SimpleRule lore
simplifyAssert :: SimpleRule lore
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
Checked
simplifyAssert VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
constantFoldPrimFun :: BinderOps lore => TopDownRuleGeneric lore
constantFoldPrimFun :: TopDownRuleGeneric lore
constantFoldPrimFun TopDown lore
_ (Let Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_))
| Just [PrimValue]
args' <- ((SubExp, Diet) -> Maybe PrimValue)
-> [(SubExp, Diet)] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Maybe PrimValue
isConst (SubExp -> Maybe PrimValue)
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> Maybe PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args,
Just ([PrimType]
_, PrimType
_, [PrimValue] -> Maybe PrimValue
fun) <- String
-> Map
String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> Maybe ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Name -> String
nameToString Name
fname) Map String ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
Just PrimValue
result <- [PrimValue] -> Maybe PrimValue
fun [PrimValue]
args' =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
result
where isConst :: SubExp -> Maybe PrimValue
isConst (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isConst SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
constantFoldPrimFun TopDown lore
_ Stm lore
_ = Rule lore
forall lore. Rule lore
Skip
simplifyIndex :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyIndex :: BottomUpRuleBasicOp lore
simplifyIndex (SymbolTable lore
vtable, UsageTable
used) pat :: Pattern lore
pat@(Pattern [] [PatElemT (LetDec lore)
pe]) (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Index VName
idd Slice SubExp
inds)
| Just RuleM lore IndexResult
m <- SymbolTable (Lore (RuleM lore))
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (RuleM lore IndexResult)
forall (m :: * -> *).
MonadBinder m =>
SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable lore
SymbolTable (Lore (RuleM lore))
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consumed = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
IndexResult
res <- RuleM lore IndexResult
m
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ case IndexResult
res of
SubExpResult Certificates
cs' SubExp
se ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
cs') (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
IndexResult Certificates
extra_cs VName
idd' Slice SubExp
inds' ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
extra_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd' Slice SubExp
inds'
where consumed :: Bool
consumed = PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> UsageTable -> Bool
`UT.isConsumed` UsageTable
used
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v SymbolTable lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
simplifyIndex (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
data IndexResult = IndexResult Certificates VName (Slice SubExp)
| SubExpResult Certificates SubExp
simplifyIndexing :: MonadBinder m =>
ST.SymbolTable (Lore m) -> TypeLookup
-> VName -> Slice SubExp -> Bool
-> Maybe (m IndexResult)
simplifyIndexing :: SymbolTable (Lore m)
-> (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Lore m)
vtable SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType VName
idd Slice SubExp
inds Bool
consuming =
case VName -> Maybe (BasicOp, Certificates)
defOf VName
idd of
Maybe (BasicOp, Certificates)
_ | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
idd),
Slice SubExp
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
t [] ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd
| Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining PrimExp VName
e,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e
| Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
Just (ST.IndexedArray Certificates
cs VName
arr [PrimExp VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
(PrimExp VName -> Bool) -> [PrimExp VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining [PrimExp VName]
inds'',
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix ([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
(PrimExp VName -> m SubExp) -> [PrimExp VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [PrimExp VName]
inds''
Maybe (BasicOp, Certificates)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
Just (SubExp (Var VName
v), Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v Slice SubExp
inds
Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certificates
cs)
| [DimFix SubExp
ii] <- Slice SubExp
inds,
Just (Prim (IntType IntType
from_it)) <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
ii ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
(SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$ String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
| [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- Slice SubExp
inds ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
SubExp
i_offset'' <- String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+
PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
*
PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
SubExp
i_stride'' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
(SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_iota" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it
Just (Rotate [SubExp]
offsets VName
a, Certificates
cs)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp -> Bool)
-> [SubExp] -> Slice SubExp -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
forall d. SubExp -> DimIndex d -> Bool
rotateAndSlice [SubExp]
offsets Slice SubExp
inds -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> m (TypeBase Shape NoUniqueness) -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
a
let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
SubExp
i_p_o <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i_p_o" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowWrap) SubExp
i SubExp
o
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rot_i" (BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
Int32 Safety
Unsafe) SubExp
i_p_o SubExp
d)
adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (SubExp -> SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d f (SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n f (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
a (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp))
-> [(DimIndex SubExp, SubExp, SubExp)] -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp)
forall (f :: * -> *).
MonadBinder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (Slice SubExp
-> [SubExp] -> [SubExp] -> [(DimIndex SubExp, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Slice SubExp
inds [SubExp]
offsets [SubExp]
dims)
where rotateAndSlice :: SubExp -> DimIndex d -> Bool
rotateAndSlice SubExp
r DimSlice{} = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
r
rotateAndSlice SubExp
_ DimIndex d
_ = Bool
False
Just (Index VName
aa Slice SubExp
ais, Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
aa (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
Slice (PrimExp VName) -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (PrimExp VName) -> m (Slice SubExp)
subExpSlice (Slice (PrimExp VName)
-> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (PrimExp VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (PrimExp VName)
primExpSlice Slice SubExp
inds))
Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certificates
cs)
| [DimFix{}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
| DimFix{}:Slice SubExp
is' <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
vv Slice SubExp
is'
Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certificates
cs)
| [DimFix{}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
val
Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certificates
cs)
| (Slice SubExp
ds_inds, Slice SubExp
rest_inds) <- Int -> Slice SubExp -> (Slice SubExp, Slice SubExp)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Slice SubExp
inds,
([SubExp]
ds', Slice SubExp
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index Slice SubExp
ds_inds,
[SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
VName
arr <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"smaller_replicate" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
IndexResult -> m IndexResult
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ds_inds' Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
rest_inds
where index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix{} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) SubExp
n SubExp
s)
Just (Rearrange [Int]
perm VName
src, Certificates
cs)
| [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> Slice SubExp -> Slice SubExp
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isIndex Slice SubExp
inds) ->
let inds' :: Slice SubExp
inds' = [Int] -> Slice SubExp -> Slice SubExp
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) Slice SubExp
inds
in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds'
where isIndex :: DimIndex d -> Bool
isIndex DimFix{} = Bool
True
isIndex DimIndex d
_ = Bool
False
Just (Copy VName
src, Certificates
cs)
| Just [SubExp]
dims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
Bool -> Bool
not Bool
consuming, VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
src SymbolTable (Lore m)
vtable ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
Just (Reshape ShapeChange SubExp
newshape VName
src, Certificates
cs)
| Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
[Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [SubExp]
newdims [SubExp]
olddims,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds) [Bool]
changed_dims ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
| Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
Just [SubExp]
olddims <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
src),
ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds,
[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
newdims ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
Just (Reshape [DimChange SubExp
_] VName
v2, Certificates
cs)
| Just [SubExp
_] <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (VName -> SubExp
Var VName
v2) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds
Just (Concat Int
d VName
x [VName]
xs SubExp
_, Certificates
cs)
| Just (Slice SubExp
ibef, DimFix SubExp
i, Slice SubExp
iaft) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d Slice SubExp
inds,
Just (Prim PrimType
res_t) <- (TypeBase Shape NoUniqueness
-> [SubExp] -> TypeBase Shape NoUniqueness
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
inds) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
-> Maybe (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
VName
-> SymbolTable (Lore m) -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable (Lore m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
SubExp
x_len <- Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d (TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
[SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((TypeBase Shape NoUniqueness -> SubExp)
-> m (TypeBase Shape NoUniqueness) -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) (m (TypeBase Shape NoUniqueness) -> m SubExp)
-> (VName -> m (TypeBase Shape NoUniqueness)) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType) [VName]
xs
let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
SubExp
added <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowWrap) SubExp
n SubExp
m
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
added, SubExp
n)
(SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts
let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
mkBranch ((VName
x', SubExp
start):[(VName, SubExp)]
xs_and_starts') = do
SubExp
cmp <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_cmp" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int32) SubExp
start SubExp
i
(SubExp
thisres, Stms (Lore m)
thisbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
SubExp
i' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int32 Overflow
OverflowWrap) SubExp
i SubExp
start
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
BodyT (Lore m)
thisbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
thisbnds [SubExp
thisres]
(SubExp
altres, Stms (Lore m)
altbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
BodyT (Lore m)
altbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
altbnds [SubExp
altres]
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_branch" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cmp BodyT (Lore m)
thisbody BodyT (Lore m)
altbody (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType (Lore m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] IfSort
IfNormal
Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts
Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
cs)
| DimFix (Constant (IntValue (Int32Value Int32
i))) : Slice SubExp
inds' <- Slice SubExp
inds,
Just SubExp
se <- Int32 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int32
i [SubExp]
ses ->
case Slice SubExp
inds' of
[] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
se
Slice SubExp
_ | Var VName
v2 <- SubExp
se -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds'
Slice SubExp
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
Maybe (BasicOp, Certificates)
_ | Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd, SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 TypeBase Shape NoUniqueness
t,
DimFix SubExp
i : Slice SubExp
inds' <- Slice SubExp
inds, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
inds'
Maybe (BasicOp, Certificates)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
where defOf :: VName -> Maybe (BasicOp, Certificates)
defOf VName
v = do (BasicOp BasicOp
op, Certificates
def_cs) <- VName -> SymbolTable (Lore m) -> Maybe (Exp (Lore m), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Lore m)
vtable
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp
op, Certificates
def_cs)
worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
| Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False
| Bool
otherwise = PrimExp v -> Bool
forall v. PrimExp v -> Bool
worthInlining' PrimExp v
e
worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow{} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp FPow{} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' FunExp{} = Bool
False
worthInlining' PrimExp v
_ = Bool
True
simplifyConcat :: BinderOps lore => BottomUpRuleBasicOp lore
simplifyConcat :: BottomUpRuleBasicOp lore
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
| Just Int
r <- TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (TypeBase Shape NoUniqueness -> Int)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
x SymbolTable lore
vtable,
let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
Just (VName
x',Certificates
x_cs) <- [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm VName
x,
Just ([VName]
xs',[Certificates]
xs_cs) <- [(VName, Certificates)] -> ([VName], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certificates)] -> ([VName], [Certificates]))
-> Maybe [(VName, Certificates)] -> Maybe ([VName], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certificates))
-> [VName] -> Maybe [(VName, Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm) [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
VName
concat_rearrange <-
Certificates -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
x_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>[Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"concat_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
where transposedBy :: [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm1 VName
v =
case VName -> SymbolTable lore -> Maybe (ExpT lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable lore
vtable of
Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certificates
vcs)
| [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certificates) -> Maybe (VName, Certificates)
forall a. a -> Maybe a
Just (VName
v', Certificates
vcs)
Maybe (ExpT lore, Certificates)
_ -> Maybe (VName, Certificates)
forall a. Maybe a
Nothing
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat (StmAux Certificates
cs Attrs
attrs ExpDec lore
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
| VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
x_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>[Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
i VName
x' ([VName]
zs[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++[[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
where (VName
x':[VName]
zs, Certificates
x_cs) = VName -> ([VName], Certificates)
isConcat VName
x
([[VName]]
xs', [Certificates]
xs_cs) = [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certificates)] -> ([[VName]], [Certificates]))
-> [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certificates))
-> [VName] -> [([VName], Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certificates)
isConcat [VName]
xs
isConcat :: VName -> ([VName], Certificates)
isConcat VName
v = case VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable of
Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certificates
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
v_cs)
Maybe (BasicOp, Certificates)
_ -> ([VName
v], Certificates
forall a. Monoid a => a
mempty)
simplifyConcat (SymbolTable lore
vtable, UsageTable
_) Pattern lore
pat StmAux (ExpDec lore)
aux (Concat Int
0 VName
x [VName]
xs SubExp
_)
| Just ([[SubExp]]
vs, [Certificates]
vcs) <- [([SubExp], Certificates)] -> ([[SubExp]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], Certificates)] -> ([[SubExp]], [Certificates]))
-> Maybe [([SubExp], Certificates)]
-> Maybe ([[SubExp]], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe ([SubExp], Certificates))
-> [VName] -> Maybe [([SubExp], Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Maybe ([SubExp], Certificates)
isArrayLit (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
xs) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
TypeBase Shape NoUniqueness
rt <- TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> RuleM lore (TypeBase Shape NoUniqueness)
-> RuleM lore (TypeBase Shape NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
x
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying ([Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
vcs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
vs) TypeBase Shape NoUniqueness
rt
where isArrayLit :: VName -> Maybe ([SubExp], Certificates)
isArrayLit VName
v
| Just (Replicate Shape
shape SubExp
se, Certificates
vcs) <- VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable,
Shape -> Bool
unitShape Shape
shape = ([SubExp], Certificates) -> Maybe ([SubExp], Certificates)
forall a. a -> Maybe a
Just ([SubExp
se], Certificates
vcs)
| Just (ArrayLit [SubExp]
ses TypeBase Shape NoUniqueness
_, Certificates
vcs) <- VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable lore
vtable =
([SubExp], Certificates) -> Maybe ([SubExp], Certificates)
forall a. a -> Maybe a
Just ([SubExp]
ses, Certificates
vcs)
| Bool
otherwise =
Maybe ([SubExp], Certificates)
forall a. Maybe a
Nothing
unitShape :: Shape -> Bool
unitShape = (Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
==[SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
1])
simplifyConcat (SymbolTable lore, UsageTable)
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ = Rule lore
forall lore. Rule lore
Skip
ruleIf :: BinderOps lore => TopDownRuleIf lore
ruleIf :: TopDownRuleIf lore
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
_ IfSort
ifsort)
| Just BodyT lore
branch <- Maybe (BodyT lore)
checkBranch,
IfSort
ifsort IfSort -> IfSort -> Bool
forall a. Eq a => a -> a -> Bool
/= IfSort
IfFallback Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
branch
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
branch
[RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec lore)
p,SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses]
where checkBranch :: Maybe (BodyT lore)
checkBranch
| SubExp -> Bool
isCt1 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
tb
| SubExp -> Bool
isCt0 SubExp
e1 = BodyT lore -> Maybe (BodyT lore)
forall a. a -> Maybe a
Just BodyT lore
fb
| Bool
otherwise = Maybe (BodyT lore)
forall a. Maybe a
Nothing
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_
(SubExp
cond, Body BodyDec lore
_ Stms lore
tstms [Constant (BoolValue Bool
True)],
Body BodyDec lore
_ Stms lore
fstms [SubExp
se], IfDec [BranchType lore]
ts IfSort
_)
| Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
tstms, Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
fstms, [Prim PrimType
Bool] <- (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ts =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
cond SubExp
se
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ts IfSort
_)
| Body BodyDec lore
_ Stms lore
tstms [SubExp
tres] <- BodyT lore
tb,
Body BodyDec lore
_ Stms lore
fstms [SubExp
fres] <- BodyT lore
fb,
(Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ Stms lore
tstms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
fstms,
(BranchType lore -> Bool) -> [BranchType lore] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((ExtType -> ExtType -> Bool
forall a. Eq a => a -> a -> Bool
==PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool) (ExtType -> Bool)
-> (BranchType lore -> ExtType) -> BranchType lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf) [BranchType lore]
ts = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
tstms
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (RuleM lore))
fstms
ExpT lore
e <- BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
LogOr (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
cond SubExp
tres)
(BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
LogAnd (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond)
(ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
fres))
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat ExpT lore
Exp (Lore (RuleM lore))
e
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
_, BodyT lore
tbranch, BodyT lore
_, IfDec [BranchType lore]
_ IfSort
IfFallback)
| [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames Pattern lore
pat,
(Stm lore -> Bool) -> Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ExpT lore -> Bool
forall lore. IsOp (Op lore) => Exp lore -> Bool
safeExp (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Stms lore -> Bool) -> Stms lore -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let ses :: [SubExp]
ses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch
Stms (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (RuleM lore)) -> RuleM lore ())
-> Stms (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tbranch
[RuleM lore ()] -> RuleM lore ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
| (PatElemT (LetDec lore)
p,SubExp
se) <- [PatElemT (LetDec lore)]
-> [SubExp] -> [(PatElemT (LetDec lore), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat) [SubExp]
ses]
ruleIf TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec (BranchType lore)
_)
| Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
t)] <- BodyT lore
tb,
Body BodyDec lore
_ Stms lore
_ [Constant (IntValue IntValue
f)] <- BodyT lore
fb =
if IntValue -> Bool
oneIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
zeroIshInt IntValue
f
then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond
else if IntValue -> Bool
zeroIshInt IntValue
t Bool -> Bool -> Bool
&& IntValue -> Bool
oneIshInt IntValue
f
then RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
SubExp
cond_neg <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond_neg" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
cond
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI (IntValue -> IntType
intValueType IntValue
t)) SubExp
cond_neg
else Rule lore
forall lore. Rule lore
Skip
ruleIf TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ (SubExp, BodyT lore, BodyT lore, IfDec (BranchType lore))
_ = Rule lore
forall lore. Rule lore
Skip
hoistBranchInvariant :: BinderOps lore => TopDownRuleIf lore
hoistBranchInvariant :: TopDownRuleIf lore
hoistBranchInvariant TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
cond, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
ret IfSort
ifsort) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
([Maybe (Int, SubExp)]
hoistings, ([PatElemT (LetDec lore)]
pes, [Either Int (BranchType lore)]
ts, [(SubExp, SubExp)]
res)) <-
([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> ([Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]))
-> [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> ([Maybe (Int, SubExp)],
[(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
forall a b. [Either a b] -> ([a], [b])
partitionEithers) (RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)])))
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
([Maybe (Int, SubExp)],
([PatElemT (LetDec lore)], [Either Int (BranchType lore)],
[(SubExp, SubExp)]))
forall a b. (a -> b) -> a -> b
$ ((PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))))
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
branchInvariant ([(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))])
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
-> RuleM
lore
[Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
[PatElemT (LetDec lore)]
-> [Either Int (BranchType lore)]
-> [(SubExp, SubExp)]
-> [(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat)
((Int -> Either Int (BranchType lore))
-> [Int] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Either Int (BranchType lore)
forall a b. a -> Either a b
Left [Int
0..Int
num_ctxInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Either Int (BranchType lore)]
-> [Either Int (BranchType lore)] -> [Either Int (BranchType lore)]
forall a. [a] -> [a] -> [a]
++ (BranchType lore -> Either Int (BranchType lore))
-> [BranchType lore] -> [Either Int (BranchType lore)]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> Either Int (BranchType lore)
forall a b. b -> Either a b
Right [BranchType lore]
ret)
([SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
tses [SubExp]
fses)
let ctx_fixes :: [(Int, SubExp)]
ctx_fixes = [Maybe (Int, SubExp)] -> [(Int, SubExp)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Int, SubExp)]
hoistings
([SubExp]
tses', [SubExp]
fses') = [(SubExp, SubExp)] -> ([SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, SubExp)]
res
tb' :: BodyT lore
tb' = BodyT lore
tb { bodyResult :: [SubExp]
bodyResult = [SubExp]
tses' }
fb' :: BodyT lore
fb' = BodyT lore
fb { bodyResult :: [SubExp]
bodyResult = [SubExp]
fses' }
ret' :: [BranchType lore]
ret' = ((Int, SubExp) -> [BranchType lore] -> [BranchType lore])
-> [BranchType lore] -> [(Int, SubExp)] -> [BranchType lore]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((Int -> SubExp -> [BranchType lore] -> [BranchType lore])
-> (Int, SubExp) -> [BranchType lore] -> [BranchType lore]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> SubExp -> [BranchType lore] -> [BranchType lore]
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) ([Either Int (BranchType lore)] -> [BranchType lore]
forall a b. [Either a b] -> [b]
rights [Either Int (BranchType lore)]
ts) [(Int, SubExp)]
ctx_fixes
([PatElemT (LetDec lore)]
ctx_pes, [PatElemT (LetDec lore)]
val_pes) = Int
-> [PatElemT (LetDec lore)]
-> ([PatElemT (LetDec lore)], [PatElemT (LetDec lore)])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ret') [PatElemT (LetDec lore)]
pes
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Maybe (Int, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Maybe (Int, SubExp)]
hoistings
then do
BodyT lore
tb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
tb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
BodyT lore
fb'' <- BodyT (Lore (RuleM lore))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT lore
BodyT (Lore (RuleM lore))
fb' ([ExtType] -> RuleM lore (BodyT (Lore (RuleM lore))))
-> [ExtType] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (BranchType lore -> ExtType) -> [BranchType lore] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map BranchType lore -> ExtType
forall t. ExtTyped t => t -> ExtType
extTypeOf [BranchType lore]
ret'
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT (LetDec lore)]
ctx_pes [PatElemT (LetDec lore)]
val_pes) (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tb'' BodyT lore
fb'' ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
ret' IfSort
ifsort)
else RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify
where num_ctx :: Int
num_ctx = [PatElemT (LetDec lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PatElemT (LetDec lore)] -> Int)
-> [PatElemT (LetDec lore)] -> Int
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternContextElements Pattern lore
pat
bound_in_branches :: Names
bound_in_branches = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Stm lore -> [VName]) -> Seq (Stm lore) -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Pattern lore -> [VName])
-> (Stm lore -> Pattern lore) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern) (Seq (Stm lore) -> [VName]) -> Seq (Stm lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
tb Seq (Stm lore) -> Seq (Stm lore) -> Seq (Stm lore)
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
fb
mem_sizes :: Names
mem_sizes = [PatElemT (LetDec lore)] -> Names
forall a. FreeIn a => a -> Names
freeIn ([PatElemT (LetDec lore)] -> Names)
-> [PatElemT (LetDec lore)] -> Names
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore) -> Bool)
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
isMem (TypeBase Shape NoUniqueness -> Bool)
-> (PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness)
-> PatElemT (LetDec lore)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType) ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
invariant :: SubExp -> Bool
invariant Constant{} = Bool
True
invariant (Var VName
v) = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_branches
isMem :: TypeBase shape u -> Bool
isMem Mem{} = Bool
True
isMem TypeBase shape u
_ = Bool
False
sizeOfMem :: VName -> Bool
sizeOfMem VName
v = VName
v VName -> Names -> Bool
`nameIn` Names
mem_sizes
branchInvariant :: (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
branchInvariant (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse, SubExp
fse))
| SubExp
tse SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
fse = do
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
tse
PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
| SubExp -> Bool
invariant SubExp
tse, SubExp -> Bool
invariant SubExp
fse, Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1,
Prim PrimType
_ <- PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetDec lore)
pe, Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
sizeOfMem (VName -> Bool) -> VName -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe = do
[BranchType lore]
bt <- Pattern lore -> RuleM lore [BranchType lore]
forall lore (m :: * -> *).
(ASTLore lore, HasScope lore m, Monad m) =>
Pattern lore -> m [BranchType lore]
expTypesFromPattern (Pattern lore -> RuleM lore [BranchType lore])
-> Pattern lore -> RuleM lore [BranchType lore]
forall a b. (a -> b) -> a -> b
$ [PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)
pe]
[VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe] (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
(SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
tse]
RuleM lore (BodyT lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (BodyT lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> RuleM lore (BodyT (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp
fse]
RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
bt IfSort
ifsort))
PatElemT (LetDec lore)
-> Either Int (BranchType lore)
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) dec a b b.
Monad m =>
PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT (LetDec lore)
pe Either Int (BranchType lore)
t
| Bool
otherwise =
Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> RuleM
lore
(Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp)))
forall a b. (a -> b) -> a -> b
$ (PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
-> Either
(Maybe (Int, SubExp))
(PatElemT (LetDec lore), Either Int (BranchType lore),
(SubExp, SubExp))
forall a b. b -> Either a b
Right (PatElemT (LetDec lore)
pe, Either Int (BranchType lore)
t, (SubExp
tse,SubExp
fse))
hoisted :: PatElemT dec -> Either a b -> m (Either (Maybe (a, SubExp)) b)
hoisted PatElemT dec
pe (Left a
i) = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left (Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b)
-> Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. (a -> b) -> a -> b
$ (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
i, VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe)
hoisted PatElemT dec
_ Right{} = Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b))
-> Either (Maybe (a, SubExp)) b -> m (Either (Maybe (a, SubExp)) b)
forall a b. (a -> b) -> a -> b
$ Maybe (a, SubExp) -> Either (Maybe (a, SubExp)) b
forall a b. a -> Either a b
Left Maybe (a, SubExp)
forall a. Maybe a
Nothing
reshapeBodyResults :: BodyT (Lore m) -> [ExtType] -> m (BodyT (Lore m))
reshapeBodyResults BodyT (Lore m)
body [ExtType]
rets = m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
ses <- BodyT (Lore m) -> m [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind BodyT (Lore m)
body
let ([SubExp]
ctx_ses, [SubExp]
val_ses) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
rets) [SubExp]
ses
[SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM ([SubExp] -> m (BodyT (Lore m)))
-> ([SubExp] -> [SubExp]) -> [SubExp] -> m (BodyT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp]
ctx_ses[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++) ([SubExp] -> m (BodyT (Lore m)))
-> m [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> ExtType -> m SubExp)
-> [SubExp] -> [ExtType] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> ExtType -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> ExtType -> m SubExp
reshapeResult [SubExp]
val_ses [ExtType]
rets
reshapeResult :: SubExp -> ExtType -> m SubExp
reshapeResult (Var VName
v) t :: ExtType
t@Array{} = do
TypeBase Shape NoUniqueness
v_t <- VName -> m (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
let newshape :: [SubExp]
newshape = TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType
-> TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
removeExistentials ExtType
t TypeBase Shape NoUniqueness
v_t
if [SubExp]
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
v_t
then String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"branch_ctx_reshaped" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> Exp (Lore m)
forall lore. [SubExp] -> VName -> Exp lore
shapeCoerce [SubExp]
newshape VName
v
else SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
reshapeResult SubExp
se ExtType
_ =
SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
| Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t =
SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Reshape ShapeChange SubExp
oldshape VName
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape (ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange SubExp
oldshape ShapeChange SubExp
newshape) VName
v2, Certificates
v_cs)
simplifyReshapeReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape, Certificates
v_cs)
simplifyReshapeScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Replicate Shape
_ SubExp
se), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
Just Shape
oldshape <- TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase Shape NoUniqueness -> Shape)
-> Maybe (TypeBase Shape NoUniqueness) -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType SubExp
se,
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
let new :: [SubExp]
new = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape
in (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certificates
v_cs)
simplifyReshapeReplicate VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ (Reshape ShapeChange SubExp
newshape VName
v)
| Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
[SubExp
n] <- ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certificates
v_cs)
simplifyReshapeIota VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
improveReshape :: SimpleRule lore
improveReshape :: SimpleRule lore
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Reshape ShapeChange SubExp
newshape VName
v)
| Just TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
ShapeChange SubExp
newshape' <- [SubExp] -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t) ShapeChange SubExp
newshape,
ShapeChange SubExp
newshape' ShapeChange SubExp -> ShapeChange SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeChange SubExp
newshape =
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
newshape' VName
v, Certificates
forall a. Monoid a => a
mempty)
improveReshape VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch VarLookup lore
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Copy VName
src) = do
TypeBase Shape NoUniqueness
t <- SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (SubExp -> Maybe (TypeBase Shape NoUniqueness))
-> SubExp -> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
if VName -> Bool
isActuallyScratch VName
src then
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t), Certificates
forall a. Monoid a => a
mempty)
else Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
where isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
case Exp lore -> Maybe BasicOp
forall lore. Exp lore -> Maybe BasicOp
asBasicOp (Exp lore -> Maybe BasicOp)
-> ((Exp lore, Certificates) -> Exp lore)
-> (Exp lore, Certificates)
-> Maybe BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp lore, Certificates) -> Exp lore
forall a b. (a, b) -> a
fst ((Exp lore, Certificates) -> Maybe BasicOp)
-> Maybe (Exp lore, Certificates) -> Maybe BasicOp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup lore
defOf VName
v of
Just Scratch{} -> Bool
True
Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
Just (Reshape ShapeChange SubExp
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
Maybe BasicOp
_ -> Bool
False
copyScratchToScratch VarLookup lore
_ SubExp -> Maybe (TypeBase Shape NoUniqueness)
_ BasicOp
_ =
Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
ruleBasicOp :: BinderOps lore => TopDownRuleBasicOp lore
ruleBasicOp :: TopDownRuleBasicOp lore
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux BasicOp
op
| Just (BasicOp
op', Certificates
cs) <- [Maybe (BasicOp, Certificates)] -> Maybe (BasicOp, Certificates)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [ SimpleRule lore
rule VName -> Maybe (Exp lore, Certificates)
defOf SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType BasicOp
op | SimpleRule lore
rule <- [SimpleRule lore]
forall lore. [SimpleRule lore]
simpleRules ] =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpDec lore)
aux) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
op'
where defOf :: VName -> Maybe (Exp lore, Certificates)
defOf = (VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
`ST.lookupExp` TopDown lore
vtable)
seType :: SubExp -> Maybe (TypeBase Shape NoUniqueness)
seType (Var VName
v) = VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
v TopDown lore
vtable
seType (Constant PrimValue
v) = TypeBase Shape NoUniqueness -> Maybe (TypeBase Shape NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness))
-> TypeBase Shape NoUniqueness
-> Maybe (TypeBase Shape NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase Shape NoUniqueness)
-> PrimType -> TypeBase Shape NoUniqueness
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
src Slice SubExp
_ (Var VName
v))
| Just (BasicOp Scratch{}, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
src [DimSlice SubExp
i SubExp
n SubExp
s] (Var VName
v))
| SubExp -> Bool
isCt1 SubExp
n, SubExp -> Bool
isCt1 SubExp
s,
Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown lore -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0] TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
SubExp
e' <- String -> PrimExp VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
src [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i] SubExp
e'
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
destis (Var VName
v))
| Just (Exp lore
e, Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
Exp lore -> Bool
arrayFrom Exp lore
e =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
where arrayFrom :: Exp lore -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
| Just (Exp lore
e',Certificates
_) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
copy_v TopDown lore
vtable =
Exp lore -> Bool
arrayFrom Exp lore
e'
arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
| Just (Replicate Shape
dest_shape SubExp
dest_se, Certificates
_) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest TopDown lore
vtable,
SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
Bool
True
arrayFrom Exp lore
_ =
Bool
False
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Update VName
dest Slice SubExp
is SubExp
se)
| Just TypeBase Shape NoUniqueness
dest_t <- VName -> TopDown lore -> Maybe (TypeBase Shape NoUniqueness)
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe (TypeBase Shape NoUniqueness)
ST.lookupType VName
dest TopDown lore
vtable,
Shape -> Slice SubExp -> Bool
isFullSlice (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
dest_t) Slice SubExp
is = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
case SubExp
se of
Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
VName
v_reshaped <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
dest_t) VName
v
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_reshaped
SubExp
_ -> Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [SubExp
se] (TypeBase Shape NoUniqueness -> BasicOp)
-> TypeBase Shape NoUniqueness -> BasicOp
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
dest_t
ruleBasicOp TopDown lore
vtable Pattern lore
pat (StmAux Certificates
cs1 Attrs
attrs ExpDec lore
_) (Update VName
dest1 Slice SubExp
is1 (Var VName
v1))
| Just (Update VName
dest2 Slice SubExp
is2 SubExp
se2, Certificates
cs2) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v1 TopDown lore
vtable,
Just (Copy VName
v3, Certificates
cs3) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest2 TopDown lore
vtable,
Just (Index VName
v4 Slice SubExp
is4, Certificates
cs4) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v3 TopDown lore
vtable,
Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1, VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs2 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs3 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs4) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
Slice SubExp
is5 <- Slice (PrimExp VName) -> RuleM lore (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (PrimExp VName) -> m (Slice SubExp)
subExpSlice (Slice (PrimExp VName) -> RuleM lore (Slice SubExp))
-> Slice (PrimExp VName) -> RuleM lore (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice (PrimExp VName)
-> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (PrimExp VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (PrimExp VName)
primExpSlice Slice SubExp
is2)
Attrs -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
| Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
| Just RuleM lore ()
m <- SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify RuleM lore ()
m
where simplifyWith :: SubExp -> SubExp -> Maybe (RuleM lore ())
simplifyWith (Var VName
v) SubExp
x
| Just Stm lore
bnd <- VName -> TopDown lore -> Maybe (Stm lore)
forall lore. VName -> SymbolTable lore -> Maybe (Stm lore)
ST.lookupStm VName
v TopDown lore
vtable,
If SubExp
p BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_ <- Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd,
Just (SubExp
y, SubExp
z) <-
VName
-> Pattern lore
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
forall dec lore lore.
VName
-> PatternT dec
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v (Stm lore -> Pattern lore
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
bnd) BodyT lore
tbranch BodyT lore
fbranch,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM lore () -> Maybe (RuleM lore ())
forall a. a -> Maybe a
Just (RuleM lore () -> Maybe (RuleM lore ()))
-> RuleM lore () -> Maybe (RuleM lore ())
forall a b. (a -> b) -> a -> b
$ do
SubExp
eq_x_y <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
SubExp
eq_x_z <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
SubExp
p_and_eq_x_y <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
SubExp
not_p <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"not_p" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
SubExp
not_p_and_eq_x_z <-
String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
simplifyWith SubExp
_ SubExp
_ =
Maybe (RuleM lore ())
forall a. Maybe a
Nothing
returns :: VName
-> PatternT dec
-> BodyT lore
-> BodyT lore
-> Maybe (SubExp, SubExp)
returns VName
v PatternT dec
ifpat BodyT lore
tbranch BodyT lore
fbranch =
((PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
((PatElemT dec, (SubExp, SubExp)) -> Bool)
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> ((PatElemT dec, (SubExp, SubExp)) -> VName)
-> (PatElemT dec, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT dec -> VName)
-> ((PatElemT dec, (SubExp, SubExp)) -> PatElemT dec)
-> (PatElemT dec, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, (SubExp, SubExp)) -> PatElemT dec
forall a b. (a, b) -> a
fst) ([(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp)))
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall a b. (a -> b) -> a -> b
$
[PatElemT dec]
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
ifpat) ([(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tbranch) (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fbranch)
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate (Shape []) se :: SubExp
se@Constant{}) =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate (Shape []) (Var VName
v)) = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
TypeBase Shape NoUniqueness
v_t <- VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType TypeBase Shape NoUniqueness
v_t
then SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
else VName -> BasicOp
Copy VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
_ (Replicate Shape
shape (Var VName
v))
| Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certificates
cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Shape
shapeShape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<>Shape
shape2) SubExp
se
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (ArrayLit (SubExp
se:[SubExp]
ses) TypeBase Shape NoUniqueness
_)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
==SubExp
se) [SubExp]
ses =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ let n :: SubExp
n = Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
+ Int32
1 :: Int32)
in Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Index VName
idd Slice SubExp
slice)
| Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
idd2), Certificates
idd_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
idd TopDown lore
vtable,
ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$
case ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape of
Just [SubExp]
_ ->
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
Maybe [SubExp]
Nothing -> do
[SubExp]
oldshape <- TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> RuleM lore (TypeBase Shape NoUniqueness) -> RuleM lore [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
idd2
let new_inds :: [PrimExp VName]
new_inds =
[PrimExp VName]
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex ((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp]
oldshape)
((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> [PrimExp VName]) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp]
inds)
[SubExp]
new_inds' <-
(PrimExp VName -> RuleM lore SubExp)
-> [PrimExp VName] -> RuleM lore [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimExp VName -> RuleM lore SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index" (PrimExp VName -> RuleM lore SubExp)
-> (PrimExp VName -> PrimExp VName)
-> PrimExp VName
-> RuleM lore SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v
asInt32PrimExp) [PrimExp VName]
new_inds
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
| SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Rearrange [Int]
perm VName
v)
| [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v)
| Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v)
| Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let offsets' :: [SubExp]
offsets' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
VName
rearrange_rotate <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rearrange_rotate" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
v3
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_csCertificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rearrange [Int]
perm VName
v1)
| Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certificates
v1_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v1 TopDown lore
vtable,
Int
num_dims <- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
[Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0..[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_permInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v1_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ do
SubExp
v <- String -> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Lore (RuleM lore)) -> RuleM lore SubExp)
-> Exp (Lore (RuleM lore)) -> RuleM lore SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
dims SubExp
v
ruleBasicOp TopDown lore
_ Pattern lore
pat StmAux (ExpDec lore)
_ (Rotate [SubExp]
offsets VName
v)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rotate [SubExp]
offsets VName
v)
| Just (BasicOp (Rearrange [Int]
perm VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable,
Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certificates
v2_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v2 TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let offsets2' :: [SubExp]
offsets2' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"summed_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowWrap) SubExp
x SubExp
y
[SubExp]
offsets' <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
VName
rotate_rearrange <-
StmAux (ExpDec lore) -> RuleM lore VName -> RuleM lore VName
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore VName -> RuleM lore VName)
-> RuleM lore VName -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"rotate_rearrange" (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v3
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
rotate_rearrange
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Rotate [SubExp]
offsets1 VName
v)
| Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certificates
v_cs) <- VName -> TopDown lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v TopDown lore
vtable = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
offsets <- (SubExp -> SubExp -> RuleM lore SubExp)
-> [SubExp] -> [SubExp] -> RuleM lore [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM lore SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets VName
v2
where add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowWrap) SubExp
x SubExp
y
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (Update VName
arr_x Slice SubExp
slice_x (Var VName
v))
| Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice_x,
Just (Index VName
arr_y Slice SubExp
slice_y, Certificates
cs_y) <- VName -> TopDown lore -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v TopDown lore
vtable,
VName -> TopDown lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
arr_y TopDown lore
vtable,
VName
arr_y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
arr_x,
Just (Slice SubExp
slice_x_bef, DimFix SubExp
i, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_x,
Just (Slice SubExp
slice_y_bef, DimFix SubExp
j, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_y = RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ do
let slice_x' :: Slice SubExp
slice_x' = Slice SubExp
slice_x_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]
slice_y' :: Slice SubExp
slice_y' = Slice SubExp
slice_y_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
1)]
VName
v' <- String -> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_slice") (Exp (Lore (RuleM lore)) -> RuleM lore VName)
-> Exp (Lore (RuleM lore)) -> RuleM lore VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
Certificates -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs_y (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSle{} SubExp
x SubExp
y)
| Constant (IntValue (Int32Value Int32
0)) <- SubExp
x,
Var VName
v <- SubExp
y,
Just SubExp
_ <- VName -> TopDown lore -> Maybe SubExp
forall lore. VName -> SymbolTable lore -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSlt{} SubExp
x SubExp
y)
| Var VName
v <- SubExp
x,
Just SubExp
n <- VName -> TopDown lore -> Maybe SubExp
forall lore. VName -> SymbolTable lore -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown lore
vtable,
SubExp
n SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
ruleBasicOp TopDown lore
vtable Pattern lore
pat StmAux (ExpDec lore)
aux (CmpOp CmpSlt{} (Var VName
x) SubExp
y)
| SubExp -> Bool
isCt0 SubExp
y,
Bool -> (Entry lore -> Bool) -> Maybe (Entry lore) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry lore -> Bool
forall lore. Entry lore -> Bool
ST.entryIsSize (Maybe (Entry lore) -> Bool) -> Maybe (Entry lore) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
ST.lookup VName
x TopDown lore
vtable =
RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec lore) -> RuleM lore () -> RuleM lore ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux (ExpDec lore)
aux (RuleM lore () -> RuleM lore ()) -> RuleM lore () -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False
ruleBasicOp TopDown lore
_ Pattern lore
_ StmAux (ExpDec lore)
_ BasicOp
_ =
Rule lore
forall lore. Rule lore
Skip
removeDeadBranchResult :: BinderOps lore => BottomUpRuleIf lore
removeDeadBranchResult :: BottomUpRuleIf lore
removeDeadBranchResult (SymbolTable lore
_, UsageTable
used) Pattern lore
pat StmAux (ExpDec lore)
_ (SubExp
e1, BodyT lore
tb, BodyT lore
fb, IfDec [BranchType lore]
rettype IfSort
ifsort)
|
Pattern lore -> Int
forall dec. PatternT dec -> Int
patternSize Pattern lore
pat Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
rettype,
[Bool]
patused <- (VName -> Bool) -> [VName] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) ([VName] -> [Bool]) -> [VName] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat,
Bool -> Bool
not ([Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
patused) =
let tses :: [SubExp]
tses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
tb
fses :: [SubExp]
fses = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT lore
fb
pick :: [a] -> [a]
pick :: [a] -> [a]
pick = ((Bool, a) -> a) -> [(Bool, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, a) -> a
forall a b. (a, b) -> b
snd ([(Bool, a)] -> [a]) -> ([a] -> [(Bool, a)]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, a) -> Bool) -> [(Bool, a)] -> [(Bool, a)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, a) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, a)] -> [(Bool, a)])
-> ([a] -> [(Bool, a)]) -> [a] -> [(Bool, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Bool] -> [a] -> [(Bool, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
patused
tb' :: BodyT lore
tb' = BodyT lore
tb { bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
tses }
fb' :: BodyT lore
fb' = BodyT lore
fb { bodyResult :: [SubExp]
bodyResult = [SubExp] -> [SubExp]
forall a. [a] -> [a]
pick [SubExp]
fses }
pat' :: [PatElemT (LetDec lore)]
pat' = [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a. [a] -> [a]
pick ([PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)])
-> [PatElemT (LetDec lore)] -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
rettype' :: [BranchType lore]
rettype' = [BranchType lore] -> [BranchType lore]
forall a. [a] -> [a]
pick [BranchType lore]
rettype
in RuleM lore () -> Rule lore
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM lore () -> Rule lore) -> RuleM lore () -> Rule lore
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT (LetDec lore)]
-> [PatElemT (LetDec lore)] -> Pattern lore
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT (LetDec lore)]
pat') (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
e1 BodyT lore
tb' BodyT lore
fb' (IfDec (BranchType lore) -> ExpT lore)
-> IfDec (BranchType lore) -> ExpT lore
forall a b. (a -> b) -> a -> b
$ [BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [BranchType lore]
rettype' IfSort
ifsort
| Bool
otherwise = Rule lore
forall lore. Rule lore
Skip
isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False
isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False