{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.InPlaceLowering.SubstituteIndices
(
substituteIndices
, IndexSubstitution
, IndexSubstitutions
) where
import Control.Monad
import qualified Data.Map.Strict as M
import Futhark.Representation.AST.Attributes.Aliases
import Futhark.Representation.AST
import Futhark.Construct
import Futhark.Util
type IndexSubstitution attr = (Certificates, VName, attr, Slice SubExp)
type IndexSubstitutions attr = [(VName, IndexSubstitution attr)]
typeEnvFromSubstitutions :: LetAttr lore ~ attr =>
IndexSubstitutions attr -> Scope lore
typeEnvFromSubstitutions :: IndexSubstitutions attr -> Scope lore
typeEnvFromSubstitutions = [(VName, NameInfo lore)] -> Scope lore
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Scope lore)
-> (IndexSubstitutions attr -> [(VName, NameInfo lore)])
-> IndexSubstitutions attr
-> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, IndexSubstitution attr) -> (VName, NameInfo lore))
-> IndexSubstitutions attr -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> [a] -> [b]
map (IndexSubstitution attr -> (VName, NameInfo lore)
forall a a lore d. (a, a, LetAttr lore, d) -> (a, NameInfo lore)
fromSubstitution (IndexSubstitution attr -> (VName, NameInfo lore))
-> ((VName, IndexSubstitution attr) -> IndexSubstitution attr)
-> (VName, IndexSubstitution attr)
-> (VName, NameInfo lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, IndexSubstitution attr) -> IndexSubstitution attr
forall a b. (a, b) -> b
snd)
where fromSubstitution :: (a, a, LetAttr lore, d) -> (a, NameInfo lore)
fromSubstitution (a
_, a
name, LetAttr lore
t, d
_) =
(a
name, LetAttr lore -> NameInfo lore
forall lore. LetAttr lore -> NameInfo lore
LetInfo LetAttr lore
t)
substituteIndices :: (MonadFreshNames m, BinderOps lore, Bindable lore,
Aliased lore, LetAttr lore ~ attr) =>
IndexSubstitutions attr -> Stms lore
-> m (IndexSubstitutions attr, Stms lore)
substituteIndices :: IndexSubstitutions attr
-> Stms lore -> m (IndexSubstitutions attr, Stms lore)
substituteIndices IndexSubstitutions attr
substs Stms lore
bnds =
BinderT lore m (IndexSubstitutions attr)
-> Scope lore -> m (IndexSubstitutions attr, Stms lore)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (IndexSubstitutions (LetAttr (Lore (BinderT lore m)))
-> Stms (Lore (BinderT lore m))
-> BinderT
lore m (IndexSubstitutions (LetAttr (Lore (BinderT lore m))))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStms IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore (BinderT lore m)))
substs Stms lore
Stms (Lore (BinderT lore m))
bnds) Scope lore
types
where types :: Scope lore
types = IndexSubstitutions attr -> Scope lore
forall lore attr.
(LetAttr lore ~ attr) =>
IndexSubstitutions attr -> Scope lore
typeEnvFromSubstitutions IndexSubstitutions attr
substs
substituteIndicesInStms :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Stms (Lore m)
-> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStms :: IndexSubstitutions (LetAttr (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStms = (IndexSubstitutions (LetAttr (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m))))
-> IndexSubstitutions (LetAttr (Lore m))
-> Stms (Lore m)
-> m (IndexSubstitutions (LetAttr (Lore m)))
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions (LetAttr (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStm
substituteIndicesInStm :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Stm (Lore m)
-> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStm :: IndexSubstitutions (LetAttr (Lore m))
-> Stm (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStm IndexSubstitutions (LetAttr (Lore m))
substs (Let Pattern (Lore m)
pat StmAux (ExpAttr (Lore m))
lore Exp (Lore m)
e) = do
Exp (Lore m)
e' <- IndexSubstitutions (LetAttr (Lore m))
-> Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) attr.
(MonadBinder m, Bindable (Lore m), Aliased (Lore m),
LetAttr (Lore m) ~ attr) =>
IndexSubstitutions (LetAttr (Lore m))
-> Exp (Lore m) -> m (Exp (Lore m))
substituteIndicesInExp IndexSubstitutions (LetAttr (Lore m))
substs Exp (Lore m)
e
(IndexSubstitutions (LetAttr (Lore m))
substs', Pattern (Lore m)
pat') <- IndexSubstitutions (LetAttr (Lore m))
-> Pattern (Lore m)
-> m (IndexSubstitutions (LetAttr (Lore m)), Pattern (Lore m))
forall (m :: * -> *) attr.
(MonadBinder m, LetAttr (Lore m) ~ attr) =>
IndexSubstitutions (LetAttr (Lore m))
-> PatternT attr
-> m (IndexSubstitutions (LetAttr (Lore m)), PatternT attr)
substituteIndicesInPattern IndexSubstitutions (LetAttr (Lore m))
substs Pattern (Lore m)
pat
Stm (Lore m) -> m ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore m) -> m ()) -> Stm (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore m)
-> StmAux (ExpAttr (Lore m)) -> Exp (Lore m) -> Stm (Lore m)
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern (Lore m)
pat' StmAux (ExpAttr (Lore m))
lore Exp (Lore m)
e'
IndexSubstitutions (LetAttr (Lore m))
-> m (IndexSubstitutions (LetAttr (Lore m)))
forall (m :: * -> *) a. Monad m => a -> m a
return IndexSubstitutions (LetAttr (Lore m))
substs'
substituteIndicesInPattern :: (MonadBinder m, LetAttr (Lore m) ~ attr) =>
IndexSubstitutions (LetAttr (Lore m))
-> PatternT attr
-> m (IndexSubstitutions (LetAttr (Lore m)), PatternT attr)
substituteIndicesInPattern :: IndexSubstitutions (LetAttr (Lore m))
-> PatternT attr
-> m (IndexSubstitutions (LetAttr (Lore m)), PatternT attr)
substituteIndicesInPattern IndexSubstitutions (LetAttr (Lore m))
substs PatternT attr
pat = do
(IndexSubstitutions attr
substs', [PatElemT attr]
context) <- (IndexSubstitutions attr
-> PatElemT attr -> m (IndexSubstitutions attr, PatElemT attr))
-> IndexSubstitutions attr
-> [PatElemT attr]
-> m (IndexSubstitutions attr, [PatElemT attr])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM IndexSubstitutions attr
-> PatElemT attr -> m (IndexSubstitutions attr, PatElemT attr)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
sub IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore m))
substs ([PatElemT attr] -> m (IndexSubstitutions attr, [PatElemT attr]))
-> [PatElemT attr] -> m (IndexSubstitutions attr, [PatElemT attr])
forall a b. (a -> b) -> a -> b
$ PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternContextElements PatternT attr
pat
(IndexSubstitutions attr
substs'', [PatElemT attr]
values) <- (IndexSubstitutions attr
-> PatElemT attr -> m (IndexSubstitutions attr, PatElemT attr))
-> IndexSubstitutions attr
-> [PatElemT attr]
-> m (IndexSubstitutions attr, [PatElemT attr])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM IndexSubstitutions attr
-> PatElemT attr -> m (IndexSubstitutions attr, PatElemT attr)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
sub IndexSubstitutions attr
substs' ([PatElemT attr] -> m (IndexSubstitutions attr, [PatElemT attr]))
-> [PatElemT attr] -> m (IndexSubstitutions attr, [PatElemT attr])
forall a b. (a -> b) -> a -> b
$ PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternValueElements PatternT attr
pat
(IndexSubstitutions attr, PatternT attr)
-> m (IndexSubstitutions attr, PatternT attr)
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexSubstitutions attr
substs'', [PatElemT attr] -> [PatElemT attr] -> PatternT attr
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [PatElemT attr]
context [PatElemT attr]
values)
where sub :: a -> b -> m (a, b)
sub a
substs' b
patElem = (a, b) -> m (a, b)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
substs', b
patElem)
substituteIndicesInExp :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m),
LetAttr (Lore m) ~ attr) =>
IndexSubstitutions (LetAttr (Lore m))
-> Exp (Lore m)
-> m (Exp (Lore m))
substituteIndicesInExp :: IndexSubstitutions (LetAttr (Lore m))
-> Exp (Lore m) -> m (Exp (Lore m))
substituteIndicesInExp IndexSubstitutions (LetAttr (Lore m))
substs Exp (Lore m)
e = do
IndexSubstitutions attr
substs' <- Exp (Lore m) -> m (IndexSubstitutions attr)
copyAnyConsumed Exp (Lore m)
e
let substitute :: Mapper (Lore m) (Lore m) m
substitute = Mapper (Lore m) (Lore m) m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnSubExp :: SubExp -> m SubExp
mapOnSubExp = IndexSubstitutions (LetAttr (Lore m)) -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetAttr (Lore m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore m))
substs'
, mapOnVName :: VName -> m VName
mapOnVName = IndexSubstitutions (LetAttr (Lore m)) -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetAttr (Lore m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore m))
substs'
, mapOnBody :: Scope (Lore m) -> Body (Lore m) -> m (Body (Lore m))
mapOnBody = (Body (Lore m) -> m (Body (Lore m)))
-> Scope (Lore m) -> Body (Lore m) -> m (Body (Lore m))
forall a b. a -> b -> a
const ((Body (Lore m) -> m (Body (Lore m)))
-> Scope (Lore m) -> Body (Lore m) -> m (Body (Lore m)))
-> (Body (Lore m) -> m (Body (Lore m)))
-> Scope (Lore m)
-> Body (Lore m)
-> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions (LetAttr (Lore m))
-> Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Body (Lore m) -> m (Body (Lore m))
substituteIndicesInBody IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore m))
substs'
}
Mapper (Lore m) (Lore m) m -> Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper (Lore m) (Lore m) m
substitute Exp (Lore m)
e
where copyAnyConsumed :: Exp (Lore m) -> m (IndexSubstitutions attr)
copyAnyConsumed =
let consumingSubst :: IndexSubstitutions attr -> VName -> m (IndexSubstitutions attr)
consumingSubst IndexSubstitutions attr
substs' VName
v
| Just (Certificates
cs2, VName
src2, attr
src2attr, [DimIndex SubExp]
is2) <- VName
-> IndexSubstitutions attr
-> Maybe (Certificates, VName, attr, [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore m))
substs = do
VName
row <- Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp (Lore m)
forall lore. VName -> [DimIndex SubExp] -> BasicOp lore
Index VName
src2 ([DimIndex SubExp] -> BasicOp (Lore m))
-> [DimIndex SubExp] -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (attr -> Type
forall t. Typed t => t -> Type
typeOf attr
src2attr) [DimIndex SubExp]
is2
VName
row_copy <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_row_copy") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp (Lore m)
forall lore. VName -> BasicOp lore
Copy VName
row
IndexSubstitutions attr -> m (IndexSubstitutions attr)
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexSubstitutions attr -> m (IndexSubstitutions attr))
-> IndexSubstitutions attr -> m (IndexSubstitutions attr)
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> (Certificates, VName, attr, [DimIndex SubExp])
-> IndexSubstitutions attr
-> IndexSubstitutions attr
forall attr.
VName
-> VName
-> IndexSubstitution attr
-> IndexSubstitutions attr
-> IndexSubstitutions attr
update VName
v VName
v (Certificates
forall a. Monoid a => a
mempty,
VName
row_copy,
attr
src2attr attr -> Type -> attr
forall a. SetType a => a -> Type -> a
`setType`
Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
is2) (attr -> Type
forall t. Typed t => t -> Type
typeOf attr
src2attr),
[]) IndexSubstitutions attr
substs'
consumingSubst IndexSubstitutions attr
substs' VName
_ =
IndexSubstitutions attr -> m (IndexSubstitutions attr)
forall (m :: * -> *) a. Monad m => a -> m a
return IndexSubstitutions attr
substs'
in (IndexSubstitutions attr -> VName -> m (IndexSubstitutions attr))
-> IndexSubstitutions attr
-> [VName]
-> m (IndexSubstitutions attr)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM IndexSubstitutions attr -> VName -> m (IndexSubstitutions attr)
consumingSubst IndexSubstitutions attr
IndexSubstitutions (LetAttr (Lore m))
substs ([VName] -> m (IndexSubstitutions attr))
-> (Exp (Lore m) -> [VName])
-> Exp (Lore m)
-> m (IndexSubstitutions attr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName])
-> (Exp (Lore m) -> Names) -> Exp (Lore m) -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Lore m) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp
substituteIndicesInSubExp :: MonadBinder m =>
IndexSubstitutions (LetAttr (Lore m))
-> SubExp
-> m SubExp
substituteIndicesInSubExp :: IndexSubstitutions (LetAttr (Lore m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions (LetAttr (Lore m))
substs (Var VName
v) = VName -> SubExp
Var (VName -> SubExp) -> m VName -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IndexSubstitutions (LetAttr (Lore m)) -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetAttr (Lore m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions (LetAttr (Lore m))
substs VName
v
substituteIndicesInSubExp IndexSubstitutions (LetAttr (Lore m))
_ SubExp
se = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
substituteIndicesInVar :: MonadBinder m =>
IndexSubstitutions (LetAttr (Lore m))
-> VName
-> m VName
substituteIndicesInVar :: IndexSubstitutions (LetAttr (Lore m)) -> VName -> m VName
substituteIndicesInVar IndexSubstitutions (LetAttr (Lore m))
substs VName
v
| Just (Certificates
cs2, VName
src2, LetAttr (Lore m)
_, []) <- VName
-> IndexSubstitutions (LetAttr (Lore m))
-> Maybe (Certificates, VName, LetAttr (Lore m), [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions (LetAttr (Lore m))
substs =
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
src2) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp (Lore m)
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp (Lore m)) -> SubExp -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src2
| Just (Certificates
cs2, VName
src2, LetAttr (Lore m)
src2_attr, [DimIndex SubExp]
is2) <- VName
-> IndexSubstitutions (LetAttr (Lore m))
-> Maybe (Certificates, VName, LetAttr (Lore m), [DimIndex SubExp])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v IndexSubstitutions (LetAttr (Lore m))
substs =
Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs2 (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"idx" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp (Lore m) -> Exp (Lore m)
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp (Lore m) -> Exp (Lore m))
-> BasicOp (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> [DimIndex SubExp] -> BasicOp (Lore m)
forall lore. VName -> [DimIndex SubExp] -> BasicOp lore
Index VName
src2 ([DimIndex SubExp] -> BasicOp (Lore m))
-> [DimIndex SubExp] -> BasicOp (Lore m)
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> [DimIndex SubExp]
fullSlice (LetAttr (Lore m) -> Type
forall t. Typed t => t -> Type
typeOf LetAttr (Lore m)
src2_attr) [DimIndex SubExp]
is2
| Bool
otherwise =
VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
substituteIndicesInBody :: (MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Body (Lore m)
-> m (Body (Lore m))
substituteIndicesInBody :: IndexSubstitutions (LetAttr (Lore m))
-> Body (Lore m) -> m (Body (Lore m))
substituteIndicesInBody IndexSubstitutions (LetAttr (Lore m))
substs Body (Lore m)
body = do
(IndexSubstitutions (LetAttr (Lore m))
substs', Stms (Lore m)
bnds') <- Stms (Lore m)
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms (Lore m)
bnds (m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m)))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
m (IndexSubstitutions (LetAttr (Lore m)))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m (IndexSubstitutions (LetAttr (Lore m)))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m)))
-> m (IndexSubstitutions (LetAttr (Lore m)))
-> m (IndexSubstitutions (LetAttr (Lore m)), Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ IndexSubstitutions (LetAttr (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m), Aliased (Lore m)) =>
IndexSubstitutions (LetAttr (Lore m))
-> Stms (Lore m) -> m (IndexSubstitutions (LetAttr (Lore m)))
substituteIndicesInStms IndexSubstitutions (LetAttr (Lore m))
substs Stms (Lore m)
bnds
(Result
ses, Stms (Lore m)
ses_bnds) <- Stms (Lore m)
-> m (Result, Stms (Lore m)) -> m (Result, Stms (Lore m))
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms (Lore m)
bnds' (m (Result, Stms (Lore m)) -> m (Result, Stms (Lore m)))
-> m (Result, Stms (Lore m)) -> m (Result, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
m Result -> m (Result, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m Result -> m (Result, Stms (Lore m)))
-> m Result -> m (Result, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ (SubExp -> m SubExp) -> Result -> m Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IndexSubstitutions (LetAttr (Lore m)) -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IndexSubstitutions (LetAttr (Lore m)) -> SubExp -> m SubExp
substituteIndicesInSubExp IndexSubstitutions (LetAttr (Lore m))
substs') (Result -> m Result) -> Result -> m Result
forall a b. (a -> b) -> a -> b
$ Body (Lore m) -> Result
forall lore. BodyT lore -> Result
bodyResult Body (Lore m)
body
Stms (Lore m) -> Result -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> Result -> m (Body (Lore m))
mkBodyM (Stms (Lore m)
bnds'Stms (Lore m) -> Stms (Lore m) -> Stms (Lore m)
forall a. Semigroup a => a -> a -> a
<>Stms (Lore m)
ses_bnds) Result
ses
where bnds :: Stms (Lore m)
bnds = Body (Lore m) -> Stms (Lore m)
forall lore. BodyT lore -> Stms lore
bodyStms Body (Lore m)
body
update :: VName -> VName -> IndexSubstitution attr -> IndexSubstitutions attr
-> IndexSubstitutions attr
update :: VName
-> VName
-> IndexSubstitution attr
-> IndexSubstitutions attr
-> IndexSubstitutions attr
update VName
needle VName
name IndexSubstitution attr
subst ((VName
othername, IndexSubstitution attr
othersubst) : IndexSubstitutions attr
substs)
| VName
needle VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
othername = (VName
name, IndexSubstitution attr
subst) (VName, IndexSubstitution attr)
-> IndexSubstitutions attr -> IndexSubstitutions attr
forall a. a -> [a] -> [a]
: IndexSubstitutions attr
substs
| Bool
otherwise = (VName
othername, IndexSubstitution attr
othersubst) (VName, IndexSubstitution attr)
-> IndexSubstitutions attr -> IndexSubstitutions attr
forall a. a -> [a] -> [a]
: VName
-> VName
-> IndexSubstitution attr
-> IndexSubstitutions attr
-> IndexSubstitutions attr
forall attr.
VName
-> VName
-> IndexSubstitution attr
-> IndexSubstitutions attr
-> IndexSubstitutions attr
update VName
needle VName
name IndexSubstitution attr
subst IndexSubstitutions attr
substs
update VName
needle VName
_ IndexSubstitution attr
_ [] = String -> IndexSubstitutions attr
forall a. HasCallStack => String -> a
error (String -> IndexSubstitutions attr)
-> String -> IndexSubstitutions attr
forall a b. (a -> b) -> a -> b
$ String
"Cannot find substitution for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
needle