{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.ArrayShortCircuiting
( optimiseSeqMem,
optimiseGPUMem,
optimiseMCMem,
)
where
import Control.Monad
import Control.Monad.Reader
import Data.Function ((&))
import Data.List qualified as L
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Futhark.Analysis.Alias qualified as AnlAls
import Futhark.IR.Aliases
import Futhark.IR.GPUMem
import Futhark.IR.MCMem
import Futhark.IR.Mem.IxFun (substituteInIxFun)
import Futhark.IR.SeqMem
import Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Pass (Pass (..))
import Futhark.Pass qualified as Pass
import Futhark.Util
data Env inner = Env
{ forall inner. Env inner -> CoalsTab
envCoalesceTab :: CoalsTab,
forall inner. Env inner -> inner -> UpdateM inner inner
onInner :: inner -> UpdateM inner inner,
forall inner. Env inner -> Names
memAllocsToRemove :: Names
}
type UpdateM inner a = Reader (Env inner) a
optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit" String
"Array Short-Circuiting" forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases SeqMem) -> m (Map Name CoalsTab)
mkCoalsTab forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams
optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-gpu" String
"Array Short-Circuiting (GPU)" forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases GPUMem) -> m (Map Name CoalsTab)
mkCoalsTabGPU HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams
optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-mc" String
"Array Short-Circuiting (MC)" forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases MCMem) -> m (Map Name CoalsTab)
mkCoalsTabMC MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams
replaceInParams :: CoalsTab -> [Param FParamMem] -> (Names, [Param FParamMem])
replaceInParams :: CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams CoalsTab
coalstab [Param (MemInfo SubExp Uniqueness MemBind)]
fparams =
let (Names
mem_allocs_to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
fparams') =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty) [Param (MemInfo SubExp Uniqueness MemBind)]
fparams
in (Names
mem_allocs_to_remove, forall a. [a] -> [a]
reverse [Param (MemInfo SubExp Uniqueness MemBind)]
fparams')
where
replaceInParam :: (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (Names
to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
acc) (Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec) =
case MemInfo SubExp Uniqueness MemBind
dec of
MemMem Space
_
| Just CoalsEntry
entry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name CoalsTab
coalstab ->
(VName -> Names
oneName (CoalsEntry -> VName
dstmem CoalsEntry
entry) forall a. Semigroup a => a -> a -> a
<> Names
to_remove, forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs (CoalsEntry -> VName
dstmem CoalsEntry
entry) MemInfo SubExp Uniqueness MemBind
dec forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u (ArrayIn VName
m IxFun
ixf)
| Just CoalsEntry
entry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
coalstab ->
(Names
to_remove, forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name (forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn (CoalsEntry -> VName
dstmem CoalsEntry
entry) IxFun
ixf) forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
MemInfo SubExp Uniqueness MemBind
_ -> (Names
to_remove, forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
removeAllocsInStms :: Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms :: forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms = do
Names
to_remove <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> Names
memAllocsToRemove
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
forall a b. a -> (a -> b) -> b
& forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Names -> Bool
nameIn Names
to_remove forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat)
forall a b. a -> (a -> b) -> b
& forall rep. [Stm rep] -> Stms rep
stmsFromList
forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
pass ::
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String ->
String ->
(Prog (Aliases rep) -> Pass.PassM (M.Map Name CoalsTab)) ->
(inner rep -> UpdateM (inner rep) (inner rep)) ->
(CoalsTab -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])) ->
Pass rep rep
pass :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
flag String
desc Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk inner rep -> UpdateM (inner rep) (inner rep)
on_inner CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams =
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
flag String
desc forall a b. (a -> b) -> a -> b
$ \Prog rep
prog -> do
Map Name CoalsTab
coaltabs <- Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk forall a b. (a -> b) -> a -> b
$ forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
AnlAls.aliasAnalysis Prog rep
prog
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
Pass.intraproceduralTransformationWithConsts forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs) Prog rep
prog
where
onFun :: Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs Stms rep
_ FunDef rep
f = do
let coaltab :: CoalsTab
coaltab = Map Name CoalsTab
coaltabs forall k a. Ord k => Map k a -> k -> a
M.! forall rep. FunDef rep -> Name
funDefName FunDef rep
f
let (Names
mem_allocs_to_remove, [FParam (Aliases rep)]
new_fparams) = CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams CoalsTab
coaltab forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
FunDef rep
f
{ funDefBody :: Body rep
funDefBody = CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
f,
funDefParams :: [FParam rep]
funDefParams = [FParam (Aliases rep)]
new_fparams
}
onBody :: CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove Body rep
body =
Body rep
body
{ bodyStms :: Stms rep
bodyStms =
forall r a. Reader r a -> r -> a
runReader
(forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body)
(forall inner.
CoalsTab -> (inner -> UpdateM inner inner) -> Names -> Env inner
Env CoalsTab
coaltab inner rep -> UpdateM (inner rep) (inner rep)
on_inner Names
mem_allocs_to_remove),
bodyResult :: Result
bodyResult = forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body rep
body
}
replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab SubExpRes
res =
case forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup CoalsTab
coaltab forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExpRes -> Maybe VName
subExpResVName SubExpRes
res of
Just CoalsEntry
entry -> SubExpRes
res {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ CoalsEntry -> VName
dstmem CoalsEntry
entry}
Maybe CoalsEntry
Nothing -> SubExpRes
res
updateStms ::
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep ->
UpdateM (inner rep) (Stms rep)
updateStms :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms = do
Stms rep
stms' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm Stms rep
stms
forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms'
replaceInStm ::
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep ->
UpdateM (inner rep) (Stm rep)
replaceInStm :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm (Let (Pat [PatElem (LetDec rep)]
elems) (StmAux Certs
c Attrs
a ExpDec rep
d) Exp rep
e) = do
[PatElem LetDecMem]
elems' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem [PatElem (LetDec rep)]
elems
Exp rep
e' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
elems' Exp rep
e
[CoalsEntry]
entries <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (forall k a. Map k a -> [a]
M.elems forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall inner. Env inner -> CoalsTab
envCoalesceTab)
let c' :: Certs
c' = case forall a. (a -> Bool) -> [a] -> [a]
filter (\CoalsEntry
entry -> (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem (LetDec rep)]
elems forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` forall k a. Map k a -> [k]
M.keys (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry)) forall a. Eq a => a -> a -> Bool
/= []) [CoalsEntry]
entries of
[] -> Certs
c
[CoalsEntry]
entries' -> Certs
c forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CoalsEntry -> Certs
certs [CoalsEntry]
entries'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LetDecMem]
elems') (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
c' Attrs
a ExpDec rep
d) Exp rep
e'
where
replaceInPatElem :: PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem :: forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem p :: PatElem LetDecMem
p@(PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
u MemBind
_)) =
forall a. a -> Maybe a -> a
fromMaybe PatElem LetDecMem
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname forall dec. VName -> dec -> PatElem dec
PatElem NoUniqueness
u
replaceInPatElem PatElem LetDecMem
p = forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem LetDecMem
p
replaceInExp ::
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] ->
Exp rep ->
UpdateM (inner rep) (Exp rep)
replaceInExp :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@(BasicOp BasicOp
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
pat_elems (Match [SubExp]
cond_ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
dec) = do
Body rep
defbody' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
defbody
[Case (Body rep)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Case [Maybe PrimValue]
p Body rep
b) -> forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
b) [Case (Body rep)]
cases
[BranchTypeMem]
case_rets <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun [PatElem LetDecMem]
pat_elems) [PatElem LetDecMem]
pat_elems forall a b. (a -> b) -> a -> b
$ forall rt. MatchDec rt -> [rt]
matchReturns MatchDec (BranchType rep)
dec
let dec' :: MatchDec BranchTypeMem
dec' = MatchDec (BranchType rep)
dec {matchReturns :: [BranchTypeMem]
matchReturns = [BranchTypeMem]
case_rets}
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond_ses [Case (Body rep)]
cases' Body rep
defbody' MatchDec BranchTypeMem
dec'
replaceInExp [PatElem LetDecMem]
_ (Loop [(FParam rep, SubExp)]
loop_inits LoopForm
loop_form (Body BodyDec rep
dec Stms rep
stms Result
res)) = do
[Param (MemInfo SubExp Uniqueness MemBind)]
loop_inits' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
loop_inits
Stms rep
stms' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms
CoalsTab
coalstab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
let res' :: Result
res' = forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coalstab) Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
loop_inits' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
loop_inits) LoopForm
loop_form forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms' Result
res'
replaceInExp [PatElem LetDecMem]
_ (Op Op rep
op) =
case Op rep
op of
Inner inner rep
i -> do
inner rep -> UpdateM (inner rep) (inner rep)
on_op <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> inner -> UpdateM inner inner
onInner
forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> inner rep -> UpdateM (inner rep) (inner rep)
on_op inner rep
i
Op rep
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op Op rep
op
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@WithAcc {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@Apply {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInSegOp ::
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep ->
UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp :: forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body) = do
Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegHist lvl
lvl SegSpace
sp [HistOp rep]
hist_ops [Type]
tps KernelBody rep
body) = do
Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp rep]
hist_ops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInHostOp :: HostOp NoOp GPUMem -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp :: HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp (SegOp SegOp SegLevel GPUMem
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp SegLevel GPUMem
op
replaceInHostOp HostOp NoOp GPUMem
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp NoOp GPUMem
op
replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op) =
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp Maybe (SegOp () MCMem)
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp () MCMem
op
replaceInMCOp MCOp NoOp MCMem
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp NoOp MCMem
op
generalizeIxfun :: [PatElem dec] -> PatElem LetDecMem -> BodyReturns -> UpdateM inner BodyReturns
generalizeIxfun :: forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun
[PatElem dec]
pat_elems
(PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
ixf)))
m :: BranchTypeMem
m@(MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u MemReturn
_) = do
CoalsTab
coaltab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall k a. Ord k => k -> Map k a -> Bool
M.member VName
vname forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> Map VName Coalesced
vartab) CoalsTab
coaltab
then
[VName] -> IxFun -> ExtIxFun
existentialiseIxFun (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem dec]
pat_elems) IxFun
ixf
forall a b. a -> (a -> b) -> b
& VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem
forall a b. a -> (a -> b) -> b
& forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u
forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
else forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m
generalizeIxfun [PatElem dec]
_ PatElem LetDecMem
_ BranchTypeMem
m = forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m
replaceInIfBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody b :: Body rep
b@(Body BodyDec rep
_ Stms rep
stms Result
res) = do
CoalsTab
coaltab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
Stms rep
stms' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body rep
b {bodyStms :: Stms rep
bodyStms = Stms rep
stms', bodyResult :: Result
bodyResult = forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab) Result
res}
replaceInFParam :: Param FParamMem -> UpdateM inner (Param FParamMem)
replaceInFParam :: forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam p :: Param (MemInfo SubExp Uniqueness MemBind)
p@(Param Attrs
_ VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
u MemBind
_)) = do
forall a. a -> Maybe a -> a
fromMaybe Param (MemInfo SubExp Uniqueness MemBind)
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname (forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty) Uniqueness
u
replaceInFParam Param (MemInfo SubExp Uniqueness MemBind)
p = forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemInfo SubExp Uniqueness MemBind)
p
lookupAndReplace ::
VName ->
(VName -> MemBound u -> a) ->
u ->
UpdateM inner (Maybe a)
lookupAndReplace :: forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname VName -> MemBound u -> a
f u
u = do
CoalsTab
coaltab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CoalsEntry -> Map VName Coalesced
vartab CoalsTab
coaltab of
Just (Coalesced CoalescedKind
_ (MemBlock PrimType
pt ShapeBase SubExp
shp VName
mem IxFun
ixf) FreeVarSubsts
subs) ->
IxFun
ixf
forall a b. a -> (a -> b) -> b
& forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun FreeVarSubsts
subs)
forall a b. a -> (a -> b) -> b
& VName -> IxFun -> MemBind
ArrayIn VName
mem
forall a b. a -> (a -> b) -> b
& forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp u
u
forall a b. a -> (a -> b) -> b
& VName -> MemBound u -> a
f VName
vname
forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a
Just
forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
Maybe Coalesced
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing