{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg
( recordMemRefUses,
freeVarSubstitutions,
translateAccessSummary,
aggSummaryLoopTotal,
aggSummaryLoopPartial,
aggSummaryMapPartial,
aggSummaryMapTotal,
noMemOverlap,
)
where
import Control.Monad
import Data.Function ((&))
import Data.List (intersect, partition, uncons)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis
import Futhark.Util
freeVarSubstitutions ::
(FreeIn a) =>
ScopeTab rep ->
ScalarTab ->
a ->
Maybe FreeVarSubsts
freeVarSubstitutions :: forall a rep.
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope0 ScalarTab
scals0 a
indfun =
FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn a
indfun
where
freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' FreeVarSubsts
subs [] = forall a. a -> Maybe a
Just FreeVarSubsts
subs
freeVarSubstitutions' FreeVarSubsts
subs0 [VName]
fvs =
let fvs_not_in_scope :: [VName]
fvs_not_in_scope = forall a. (a -> Bool) -> [a] -> [a]
filter (forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` ScopeTab rep
scope0) [VName]
fvs
in case forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM VName -> Maybe (FreeVarSubsts, [VName])
getSubstitution [VName]
fvs_not_in_scope of
Just ([FreeVarSubsts]
subs, [[VName]]
new_fvs) ->
FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' (FreeVarSubsts
subs0 forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [FreeVarSubsts]
subs) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
new_fvs
Maybe ([FreeVarSubsts], [[VName]])
Nothing -> forall a. Maybe a
Nothing
getSubstitution :: VName -> Maybe (FreeVarSubsts, [VName])
getSubstitution VName
v
| Just PrimExp VName
pe <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScalarTab
scals0,
IntType IntType
_ <- forall v. PrimExp v -> PrimType
primExpType PrimExp VName
pe =
forall a. a -> Maybe a
Just (forall k a. k -> a -> Map k a
M.singleton VName
v forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp PrimExp VName
pe, Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn PrimExp VName
pe)
getSubstitution VName
_v = forall a. Maybe a
Nothing
translateAccessSummary :: ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary :: forall rep.
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary ScopeTab rep
_ ScalarTab
_ AccessSummary
Undeterminable = AccessSummary
Undeterminable
translateAccessSummary ScopeTab rep
scope0 ScalarTab
scals0 (Set Set LmadRef
slmads)
| Just FreeVarSubsts
subs <- forall a rep.
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope0 ScalarTab
scals0 Set LmadRef
slmads =
Set LmadRef
slmads
forall a b. a -> (a -> b) -> b
& forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD FreeVarSubsts
subs)
forall a b. a -> (a -> b) -> b
& Set LmadRef -> AccessSummary
Set
translateAccessSummary ScopeTab rep
_ ScalarTab
_ AccessSummary
_ = AccessSummary
Undeterminable
getUseSumFromStm ::
(Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
TopdownEnv rep ->
CoalsTab ->
Stm (Aliases rep) ->
Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm :: forall rep (inner :: * -> *).
(Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (BasicOp (Index VName
arr (Slice [DimIndex SubExp]
slc))))
| Just (MemBlock PrimType
_ Shape
shp VName
_ IxFun
_) <- forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
arr (forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env),
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slc forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims Shape
shp) Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {d}. DimIndex d -> Bool
isFix [DimIndex SubExp]
slc = do
(VName
mem_b, VName
mem_arr, IxFun
ixfn_arr) <- forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
arr
let new_ixfn :: IxFun
new_ixfn = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfn_arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [DimIndex SubExp]
slc
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [(VName
mem_b, VName
mem_arr, IxFun
new_ixfn)])
where
isFix :: DimIndex d -> Bool
isFix DimFix {} = Bool
True
isFix DimIndex d
_ = Bool
False
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp Index {})) = forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp FlatIndex {})) = forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
pes) StmAux (ExpDec (Aliases rep))
_ (BasicOp (ArrayLit [SubExp]
ses Type
_))) =
let rds :: [(VName, VName, IxFun)]
rds = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
seName [SubExp]
ses
wrts :: [(VName, VName, IxFun)]
wrts = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
pes
in forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
wrts, [(VName, VName, IxFun)]
wrts forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rds)
where
seName :: SubExp -> Maybe VName
seName (Var VName
a) = forall a. a -> Maybe a
Just VName
a
seName (Constant PrimValue
_) = forall a. Maybe a
Nothing
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Update Safety
_ VName
_x (Slice [DimIndex SubExp]
slc) SubExp
a_se))) = do
(VName
m_b, VName
m_x, IxFun
x_ixfn) <- forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x')
let x_ixfn_slc :: IxFun
x_ixfn_slc = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
x_ixfn forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [DimIndex SubExp]
slc
r1 :: (VName, VName, IxFun)
r1 = (VName
m_b, VName
m_x, IxFun
x_ixfn_slc)
case SubExp
a_se of
Constant PrimValue
_ -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
Var VName
a -> case forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
a of
Maybe (VName, VName, IxFun)
Nothing -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
Just (VName, VName, IxFun)
r2 -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1, (VName, VName, IxFun)
r2])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
y]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate (Shape []) (Var VName
x)))) = do
(VName, VName, IxFun)
wrt <- forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
y
(VName, VName, IxFun)
rd <- forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
x
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName, VName, IxFun)
wrt], [(VName, VName, IxFun)
wrt, (VName, VName, IxFun)
rd])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate (Shape []) SubExp
_))) =
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible"
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Concat Int
_i (VName
a :| [VName]
bs) SubExp
_ses))) =
let ws :: [(VName, VName, IxFun)]
ws = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
rs :: [(VName, VName, IxFun)]
rs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) (VName
a forall a. a -> [a] -> [a]
: [VName]
bs)
in forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rs)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Manifest [Int]
_perm VName
x))) =
let ws :: [(VName, VName, IxFun)]
ws = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
rs :: [(VName, VName, IxFun)]
rs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) [VName
x]
in forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rs)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate Shape
_shp SubExp
se))) =
let ws :: [(VName, VName, IxFun)]
ws = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
in case SubExp
se of
Constant PrimValue
_ -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws)
Var VName
x -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws forall a. [a] -> [a] -> [a]
++ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) [VName
x])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
x]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (FlatUpdate VName
_ (FlatSlice SubExp
offset [FlatDimIndex SubExp]
slc) VName
v)))
| Just (VName
m_b, VName
m_x, IxFun
x_ixfn) <- forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec (Aliases rep))
x) = do
let x_ixfn_slc :: IxFun
x_ixfn_slc =
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun
x_ixfn forall a b. (a -> b) -> a -> b
$ forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
slc
let r1 :: (VName, VName, IxFun)
r1 = (VName
m_b, VName
m_x, IxFun
x_ixfn_slc)
case forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
v of
Maybe (VName, VName, IxFun)
Nothing -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
Just (VName, VName, IxFun)
r2 -> forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1, (VName, VName, IxFun)
r2])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp Iota {})) =
let wrt :: [(VName, VName, IxFun)]
wrt = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem (LetDec (Aliases rep))]
ys
in forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName, VName, IxFun)]
wrt, [(VName, VName, IxFun)]
wrt)
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ BasicOp {}) = forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (Op (Alloc SubExp
_ Space
_))) = forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ Stm (Aliases rep)
_ =
forall a. Maybe a
Nothing
recordMemRefUses ::
(AliasableRep rep, Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
TopdownEnv rep ->
BotUpEnv ->
Stm (Aliases rep) ->
(CoalsTab, InhibitTab)
recordMemRefUses :: forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm =
let active_tab :: CoalsTab
active_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
inhibit_tab :: InhibitTab
inhibit_tab = BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env
active_etries :: [(VName, CoalsEntry)]
active_etries = forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
active_tab
in case forall rep (inner :: * -> *).
(Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
active_tab Stm (Aliases rep)
stm of
Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
Nothing ->
forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
active_tab
forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \(CoalsTab, InhibitTab)
state (VName
m_b, CoalsEntry
entry) ->
if Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames (forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm (Aliases rep)
stm) forall a. Eq a => [a] -> [a] -> [a]
`intersect` forall k a. Map k a -> [k]
M.keys (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry)
then (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
state VName
m_b
else (CoalsTab, InhibitTab)
state
)
(CoalsTab
active_tab, InhibitTab
inhibit_tab)
Just ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
use_sums ->
let ([Maybe AccessSummary]
mb_wrts, [AccessSummary]
prev_uses, [AccessSummary]
mb_lmads) =
forall a b. (a -> b) -> [a] -> [b]
map (([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> CoalsTab
-> (VName, CoalsEntry)
-> (Maybe AccessSummary, AccessSummary, AccessSummary)
checkOverlapAndExpand ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
use_sums CoalsTab
active_tab) [(VName, CoalsEntry)]
active_etries
forall a b. a -> (a -> b) -> b
& forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3
active_tab1 :: CoalsTab
active_tab1 =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map
( \(AccessSummary
wrts, (AccessSummary
uses, AccessSummary
prev_use, (VName
k, CoalsEntry
etry))) ->
let mrefs' :: MemRefs
mrefs' = (CoalsEntry -> MemRefs
memrefs CoalsEntry
etry) {dstrefs :: AccessSummary
dstrefs = AccessSummary
prev_use}
etry' :: CoalsEntry
etry' = CoalsEntry
etry {memrefs :: MemRefs
memrefs = MemRefs
mrefs'}
in (VName
k, AccessSummary -> AccessSummary -> CoalsEntry -> CoalsEntry
addLmads AccessSummary
wrts AccessSummary
uses CoalsEntry
etry')
)
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Maybe AccessSummary
x, (AccessSummary, AccessSummary, (VName, CoalsEntry))
y) -> (,(AccessSummary, AccessSummary, (VName, CoalsEntry))
y) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AccessSummary
x)
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe AccessSummary]
mb_wrts
forall a b. (a -> b) -> a -> b
$ forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [AccessSummary]
mb_lmads [AccessSummary]
prev_uses [(VName, CoalsEntry)]
active_etries
failed_tab :: CoalsTab
failed_tab =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isNothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe AccessSummary]
mb_wrts [(VName, CoalsEntry)]
active_etries
(CoalsTab
_, InhibitTab
inhibit_tab1) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
failed_tab, InhibitTab
inhibit_tab) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys CoalsTab
failed_tab
in (CoalsTab
active_tab1, InhibitTab
inhibit_tab1)
where
checkOverlapAndExpand :: ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> CoalsTab
-> (VName, CoalsEntry)
-> (Maybe AccessSummary, AccessSummary, AccessSummary)
checkOverlapAndExpand ([(VName, VName, IxFun)]
stm_wrts, [(VName, VName, IxFun)]
stm_uses) CoalsTab
active_tab (VName
m_b, CoalsEntry
etry) =
let alias_m_b :: Names
alias_m_b = Names -> VName -> Names
getAliases forall a. Monoid a => a
mempty VName
m_b
stm_uses' :: [(VName, VName, IxFun)]
stm_uses' = forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`notNameIn` Names
alias_m_b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> a
tupFst) [(VName, VName, IxFun)]
stm_uses
all_aliases :: Names
all_aliases = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Names -> VName -> Names
getAliases forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Names
alsmem CoalsEntry
etry
ixfns :: [IxFun]
ixfns = forall a b. (a -> b) -> [a] -> [b]
map forall {a} {b} {c}. (a, b, c) -> c
tupThd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
all_aliases) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> b
tupSnd) [(VName, VName, IxFun)]
stm_uses'
lmads' :: [LmadRef]
lmads' = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe IxFun -> Maybe LmadRef
mbLmad [IxFun]
ixfns
lmads'' :: AccessSummary
lmads'' =
if forall (t :: * -> *) a. Foldable t => t a -> Int
length [LmadRef]
lmads' forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [IxFun]
ixfns
then Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
S.fromList [LmadRef]
lmads'
else AccessSummary
Undeterminable
wrt_ixfns :: [IxFun]
wrt_ixfns = forall a b. (a -> b) -> [a] -> [b]
map forall {a} {b} {c}. (a, b, c) -> c
tupThd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
alias_m_b) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a} {b} {c}. (a, b, c) -> a
tupFst) [(VName, VName, IxFun)]
stm_wrts
wrt_tmps :: [LmadRef]
wrt_tmps = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe IxFun -> Maybe LmadRef
mbLmad [IxFun]
wrt_ixfns
prev_use :: AccessSummary
prev_use =
forall rep.
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary (forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) forall a b. (a -> b) -> a -> b
$
(MemRefs -> AccessSummary
dstrefs forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
wrt_lmads' :: AccessSummary
wrt_lmads' =
if forall (t :: * -> *) a. Foldable t => t a -> Int
length [LmadRef]
wrt_tmps forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [IxFun]
wrt_ixfns
then Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
S.fromList [LmadRef]
wrt_tmps
else AccessSummary
Undeterminable
original_mem_aliases :: Names
original_mem_aliases =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {a} {b} {c}. (a, b, c) -> a
tupFst [(VName, VName, IxFun)]
stm_uses
forall a b. a -> (a -> b) -> b
& forall a. [a] -> Maybe (a, [a])
uncons
forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst
forall a b. a -> (a -> b) -> b
& forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
(=<<) (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` CoalsTab
active_tab)
forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty CoalsEntry -> Names
alsmem
(AccessSummary
wrt_lmads'', AccessSummary
lmads) =
if VName
m_b VName -> Names -> Bool
`nameIn` Names
original_mem_aliases
then (AccessSummary
wrt_lmads' forall a. Semigroup a => a -> a -> a
<> AccessSummary
lmads'', Set LmadRef -> AccessSummary
Set forall a. Monoid a => a
mempty)
else (AccessSummary
wrt_lmads', AccessSummary
lmads'')
no_overlap :: Bool
no_overlap = forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env (AccessSummary
lmads forall a. Semigroup a => a -> a -> a
<> AccessSummary
prev_use) AccessSummary
wrt_lmads''
wrt_lmads :: Maybe AccessSummary
wrt_lmads =
if Bool
no_overlap
then forall a. a -> Maybe a
Just AccessSummary
wrt_lmads''
else forall a. Maybe a
Nothing
in (Maybe AccessSummary
wrt_lmads, AccessSummary
prev_use, AccessSummary
lmads)
tupFst :: (a, b, c) -> a
tupFst (a
a, b
_, c
_) = a
a
tupSnd :: (a, b, c) -> b
tupSnd (a
_, b
b, c
_) = b
b
tupThd :: (a, b, c) -> c
tupThd (a
_, b
_, c
c) = c
c
getAliases :: Names -> VName -> Names
getAliases Names
acc VName
m =
VName -> Names
oneName VName
m
forall a. Semigroup a => a -> a -> a
<> Names
acc
forall a. Semigroup a => a -> a -> a
<> forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m (forall rep. TopdownEnv rep -> InhibitTab
m_alias TopdownEnv rep
td_env))
mbLmad :: IxFun -> Maybe LmadRef
mbLmad IxFun
indfun
| Just FreeVarSubsts
subs <- forall a rep.
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> ScalarTab
scals BotUpEnv
bu_env) IxFun
indfun,
(IxFun.IxFun LmadRef
lmad Shape (TPrimExp Int64 VName)
_) <- forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun FreeVarSubsts
subs IxFun
indfun =
forall a. a -> Maybe a
Just LmadRef
lmad
mbLmad IxFun
_ = forall a. Maybe a
Nothing
addLmads :: AccessSummary -> AccessSummary -> CoalsEntry -> CoalsEntry
addLmads AccessSummary
wrts AccessSummary
uses CoalsEntry
etry =
CoalsEntry
etry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs AccessSummary
uses AccessSummary
wrts forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> MemRefs
memrefs CoalsEntry
etry}
noMemOverlap :: (AliasableRep rep) => TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap :: forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
_ AccessSummary
_ (Set Set LmadRef
mr)
| Set LmadRef
mr forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = Bool
True
noMemOverlap TopdownEnv rep
_ (Set Set LmadRef
mr) AccessSummary
_
| Set LmadRef
mr forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = Bool
True
noMemOverlap TopdownEnv rep
td_env (Set Set LmadRef
is0) (Set Set LmadRef
js0)
| Just [PrimExp VName]
non_negs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (forall rep.
AliasableRep rep =>
ScopeTab rep -> ScalarTab -> VName -> Maybe (PrimExp VName)
vnameToPrimExp (forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env =
let ([LmadRef]
_, [LmadRef]
not_disjoints) =
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition
( \LmadRef
i ->
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
( \LmadRef
j ->
[(VName, PrimExp VName)] -> Names -> LmadRef -> LmadRef -> Bool
IxFun.disjoint [(VName, PrimExp VName)]
less_thans (forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env) LmadRef
i LmadRef
j
Bool -> Bool -> Bool
|| forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LmadRef
-> LmadRef
-> Bool
IxFun.disjoint2 () () [(VName, PrimExp VName)]
less_thans (forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env) LmadRef
i LmadRef
j
Bool -> Bool -> Bool
|| Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LmadRef
-> LmadRef
-> Bool
IxFun.disjoint3 (forall t. Typed t => t -> Type
typeOf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negs LmadRef
i LmadRef
j
)
[LmadRef]
js
)
[LmadRef]
is
in forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LmadRef]
not_disjoints
where
less_thans :: [(VName, PrimExp VName)]
less_thans = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Eq a => (a -> a) -> a -> a
fixPoint forall a b. (a -> b) -> a -> b
$ forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp forall a b. (a -> b) -> a -> b
$ forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) forall a b. (a -> b) -> a -> b
$ forall rep. TopdownEnv rep -> [(VName, PrimExp VName)]
knownLessThan TopdownEnv rep
td_env
asserts :: [PrimExp VName]
asserts = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp forall a b. (a -> b) -> a -> b
$ forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
Bool) forall a b. (a -> b) -> a -> b
$ forall rep. TopdownEnv rep -> [SubExp]
td_asserts TopdownEnv rep
td_env
is :: [LmadRef]
is = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
is0
js :: [LmadRef]
js = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
js0
noMemOverlap TopdownEnv rep
_ AccessSummary
_ AccessSummary
_ = Bool
False
aggSummaryLoopTotal ::
(MonadFreshNames m) =>
ScopeTab rep ->
ScopeTab rep ->
ScalarTab ->
Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) ->
AccessSummary ->
m AccessSummary
aggSummaryLoopTotal :: forall (m :: * -> *) rep.
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> ScalarTab
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ (Set Set LmadRef
l)
| Set LmadRef
l forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set forall a. Monoid a => a
mempty
aggSummaryLoopTotal ScopeTab rep
scope_bef ScopeTab rep
scope_loop ScalarTab
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
access
| Set Set LmadRef
ls <- forall rep.
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary ScopeTab rep
scope_loop ScalarTab
scals_loop AccessSummary
access,
Names
nms <- forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Semigroup a => a -> a -> a
(<>) forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
ls,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inBeforeScope forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
nms = do
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set Set LmadRef
ls
where
inBeforeScope :: VName -> Bool
inBeforeScope VName
v =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScopeTab rep
scope_bef of
Maybe (NameInfo (Aliases rep))
Nothing -> Bool
False
Just NameInfo (Aliases rep)
_ -> Bool
True
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
scalars_loop (Just (VName
iterator_var, (TPrimExp Int64 VName
lower_bound, TPrimExp Int64 VName
upper_bound))) (Set Set LmadRef
lmads) =
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
( forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
iterator_var TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
upper_bound
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars_loop)
)
(forall a. Set a -> [a]
S.toList Set LmadRef
lmads)
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ::
(MonadFreshNames m) =>
ScalarTab ->
Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) ->
AccessSummary ->
m AccessSummary
aggSummaryLoopPartial :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopPartial ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
Nothing AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ScalarTab
scalars_loop (Just (VName
iterator_var, (TPrimExp Int64 VName
_, TPrimExp Int64 VName
upper_bound))) (Set Set LmadRef
lmads) = do
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
( forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne
VName
iterator_var
(forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
iterator_var forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
(TPrimExp Int64 VName
upper_bound forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
typedLeafExp VName
iterator_var forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars_loop)
)
(forall a. Set a -> [a]
S.toList Set LmadRef
lmads)
aggSummaryMapPartial :: (MonadFreshNames m) => ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial ScalarTab
_ [] = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapPartial ScalarTab
scalars [(VName, SubExp)]
dims =
AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper forall a. Monoid a => a
mempty (forall a. [a] -> [a]
reverse [(VName, SubExp)]
dims) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set LmadRef -> AccessSummary
Set forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Set a
S.singleton
where
helper :: AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper AccessSummary
acc [] AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
acc
helper AccessSummary
Undeterminable [(VName, SubExp)]
_ AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
helper AccessSummary
_ [(VName, SubExp)]
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
helper (Set Set LmadRef
acc) ((VName
gtid, SubExp
size) : [(VName, SubExp)]
rest) (Set Set LmadRef
as) = do
AccessSummary
partial_as <- forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne ScalarTab
scalars (VName
gtid, SubExp
size) (Set LmadRef -> AccessSummary
Set Set LmadRef
as)
AccessSummary
total_as <-
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
(forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid TPrimExp Int64 VName
0 (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size))
(forall a. Set a -> [a]
S.toList Set LmadRef
as)
AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper (Set LmadRef -> AccessSummary
Set Set LmadRef
acc forall a. Semigroup a => a -> a -> a
<> AccessSummary
partial_as) [(VName, SubExp)]
rest AccessSummary
total_as
aggSummaryMapPartialOne :: (MonadFreshNames m) => ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne ScalarTab
_ (VName, SubExp)
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryMapPartialOne ScalarTab
_ (VName
_, Constant PrimValue
n) (Set Set LmadRef
_) | PrimValue -> Bool
oneIsh PrimValue
n = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapPartialOne ScalarTab
scalars (VName
gtid, SubExp
size) (Set Set LmadRef
lmads0) =
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
(TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary
helper
[ (TPrimExp Int64 VName
0, forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64)),
( forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1,
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size) forall a. Num a => a -> a -> a
- forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
)
]
where
lmads :: [LmadRef]
lmads = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars)) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList Set LmadRef
lmads0
helper :: (TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary
helper (TPrimExp Int64 VName
x, TPrimExp Int64 VName
y) = forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM (forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid TPrimExp Int64 VName
x TPrimExp Int64 VName
y) [LmadRef]
lmads
aggSummaryMapTotal :: (MonadFreshNames m) => ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal ScalarTab
_ [] AccessSummary
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapTotal ScalarTab
_ [(VName, SubExp)]
_ (Set Set LmadRef
lmads)
| Set LmadRef
lmads forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
aggSummaryMapTotal ScalarTab
_ [(VName, SubExp)]
_ AccessSummary
Undeterminable = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryMapTotal ScalarTab
scalars [(VName, SubExp)]
segspace (Set Set LmadRef
lmads0) =
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \AccessSummary
as' (VName
gtid', SubExp
size') -> case AccessSummary
as' of
Set Set LmadRef
lmads' ->
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
( forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid' TPrimExp Int64 VName
0 forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size'
)
(forall a. Set a -> [a]
S.toList Set LmadRef
lmads')
AccessSummary
Undeterminable -> forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
)
(Set LmadRef -> AccessSummary
Set Set LmadRef
lmads)
(forall a. [a] -> [a]
reverse [(VName, SubExp)]
segspace)
where
lmads :: Set LmadRef
lmads =
forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars)) forall a b. (a -> b) -> a -> b
$
forall a. Set a -> [a]
S.toList Set LmadRef
lmads0
aggSummaryOne :: (MonadFreshNames m) => VName -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> LmadRef -> m AccessSummary
aggSummaryOne :: forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
iterator_var TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
spn lmad :: LmadRef
lmad@(IxFun.LMAD TPrimExp Int64 VName
offset0 [LMADDim (TPrimExp Int64 VName)]
dims0)
| VName
iterator_var VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn [LMADDim (TPrimExp Int64 VName)]
dims0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
| VName
iterator_var VName -> Names -> Bool
`notNameIn` forall a. FreeIn a => a -> Names
freeIn TPrimExp Int64 VName
offset0 = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. a -> Set a
S.singleton LmadRef
lmad
| Bool
otherwise = do
VName
new_var <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"k"
let offset :: TPrimExp Int64 VName
offset = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith (VName -> TPrimExp Int64 VName
typedLeafExp VName
new_var) TPrimExp Int64 VName
offset0
offsetp1 :: TPrimExp Int64 VName
offsetp1 = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith (VName -> TPrimExp Int64 VName
typedLeafExp VName
new_var forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
offset0
new_stride :: TPrimExp Int64 VName
new_stride = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimExp v
constFoldPrimExp forall a b. (a -> b) -> a -> b
$ PrimExp VName -> PrimExp VName
simplify forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
offsetp1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset
new_offset :: TPrimExp Int64 VName
new_offset = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
offset0
new_lmad :: LmadRef
new_lmad =
forall num. num -> [LMADDim num] -> LMAD num
IxFun.LMAD TPrimExp Int64 VName
new_offset forall a b. (a -> b) -> a -> b
$ forall num. num -> num -> LMADDim num
IxFun.LMADDim TPrimExp Int64 VName
new_stride TPrimExp Int64 VName
spn forall a. a -> [a] -> [a]
: [LMADDim (TPrimExp Int64 VName)]
dims0
if VName
new_var VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn LmadRef
new_lmad
then forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set forall a b. (a -> b) -> a -> b
$ forall a. a -> Set a
S.singleton LmadRef
new_lmad
where
replaceIteratorWith :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith TPrimExp Int64 VName
se = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp (forall k a. k -> a -> Map k a
M.singleton VName
iterator_var forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
se) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
typedLeafExp :: VName -> TPrimExp Int64 VName
typedLeafExp :: VName -> TPrimExp Int64 VName
typedLeafExp VName
vname = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
vname (IntType -> PrimType
IntType IntType
Int64)