{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg
  ( recordMemRefUses,
    freeVarSubstitutions,
    translateAccessSummary,
    aggSummaryLoopTotal,
    aggSummaryLoopPartial,
    aggSummaryMapPartial,
    aggSummaryMapTotal,
    noMemOverlap,
  )
where

import Control.Monad
import Data.Function ((&))
import Data.List (intersect, partition, uncons)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Analysis.AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis
import Futhark.Util

-----------------------------------------------------
-- Some translations of Accesses and Ixfuns        --
-----------------------------------------------------

-- | Checks whether the index function can be translated at the current program
-- point and also returns the substitutions.  It comes down to answering the
-- question: "can one perform enough substitutions (from the bottom-up scalar
-- table) until all vars appearing in the index function are defined in the
-- current scope?"
freeVarSubstitutions ::
  (FreeIn a) =>
  ScopeTab rep ->
  ScalarTab ->
  a ->
  Maybe FreeVarSubsts
freeVarSubstitutions :: forall a rep.
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope0 ScalarTab
scals0 a
indfun =
  FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' FreeVarSubsts
forall a. Monoid a => a
mempty ([VName] -> Maybe FreeVarSubsts) -> [VName] -> Maybe FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
indfun
  where
    freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
    freeVarSubstitutions' :: FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' FreeVarSubsts
subs [] = FreeVarSubsts -> Maybe FreeVarSubsts
forall a. a -> Maybe a
Just FreeVarSubsts
subs
    freeVarSubstitutions' FreeVarSubsts
subs0 [VName]
fvs =
      let fvs_not_in_scope :: [VName]
fvs_not_in_scope = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> ScopeTab rep -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` ScopeTab rep
scope0) [VName]
fvs
       in case (VName -> Maybe (FreeVarSubsts, [VName]))
-> [VName] -> Maybe ([FreeVarSubsts], [[VName]])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM VName -> Maybe (FreeVarSubsts, [VName])
getSubstitution [VName]
fvs_not_in_scope of
            -- We require that all free variables can be substituted
            Just ([FreeVarSubsts]
subs, [[VName]]
new_fvs) ->
              FreeVarSubsts -> [VName] -> Maybe FreeVarSubsts
freeVarSubstitutions' (FreeVarSubsts
subs0 FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> [FreeVarSubsts] -> FreeVarSubsts
forall a. Monoid a => [a] -> a
mconcat [FreeVarSubsts]
subs) ([VName] -> Maybe FreeVarSubsts) -> [VName] -> Maybe FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
new_fvs
            Maybe ([FreeVarSubsts], [[VName]])
Nothing -> Maybe FreeVarSubsts
forall a. Maybe a
Nothing
    getSubstitution :: VName -> Maybe (FreeVarSubsts, [VName])
getSubstitution VName
v
      | Just PrimExp VName
pe <- VName -> ScalarTab -> Maybe (PrimExp VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScalarTab
scals0,
        IntType IntType
_ <- PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp VName
pe =
          (FreeVarSubsts, [VName]) -> Maybe (FreeVarSubsts, [VName])
forall a. a -> Maybe a
Just (VName -> TPrimExp Int64 VName -> FreeVarSubsts
forall k a. k -> a -> Map k a
M.singleton VName
v (TPrimExp Int64 VName -> FreeVarSubsts)
-> TPrimExp Int64 VName -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp PrimExp VName
pe, Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> Names
forall a. FreeIn a => a -> Names
freeIn PrimExp VName
pe)
    getSubstitution VName
_v = Maybe (FreeVarSubsts, [VName])
forall a. Maybe a
Nothing

-- | Translates free variables in an access summary
translateAccessSummary :: ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary :: forall rep.
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary ScopeTab rep
_ ScalarTab
_ AccessSummary
Undeterminable = AccessSummary
Undeterminable
translateAccessSummary ScopeTab rep
scope0 ScalarTab
scals0 (Set Set LmadRef
slmads)
  | Just FreeVarSubsts
subs <- ScopeTab rep -> ScalarTab -> Set LmadRef -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope0 ScalarTab
scals0 Set LmadRef
slmads =
      Set LmadRef
slmads
        Set LmadRef -> (Set LmadRef -> Set LmadRef) -> Set LmadRef
forall a b. a -> (a -> b) -> b
& (LmadRef -> LmadRef) -> Set LmadRef -> Set LmadRef
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD FreeVarSubsts
subs)
        Set LmadRef -> (Set LmadRef -> AccessSummary) -> AccessSummary
forall a b. a -> (a -> b) -> b
& Set LmadRef -> AccessSummary
Set
translateAccessSummary ScopeTab rep
_ ScalarTab
_ AccessSummary
_ = AccessSummary
Undeterminable

-- | This function computes the written and read memory references for the current statement
getUseSumFromStm ::
  (Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
  TopdownEnv rep ->
  CoalsTab ->
  Stm (Aliases rep) ->
  -- | A pair of written and written+read memory locations, along with their
  -- associated array and the index function used
  Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm :: forall rep (inner :: * -> *).
(Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (BasicOp (Index VName
arr (Slice [DimIndex SubExp]
slc))))
  | Just (MemBlock PrimType
_ Shape
shp VName
_ IxFun
_) <- VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
arr (TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env),
    [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slc Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp) Bool -> Bool -> Bool
&& (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimIndex SubExp -> Bool
forall {d}. DimIndex d -> Bool
isFix [DimIndex SubExp]
slc = do
      (VName
mem_b, VName
mem_arr, IxFun
ixfn_arr) <- TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
arr
      let new_ixfn :: IxFun
new_ixfn = IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfn_arr (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> DimIndex a -> DimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [DimIndex SubExp]
slc
      ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [(VName
mem_b, VName
mem_arr, IxFun
new_ixfn)])
  where
    isFix :: DimIndex d -> Bool
isFix DimFix {} = Bool
True
    isFix DimIndex d
_ = Bool
False
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp Index {})) = ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([], []) -- incomplete slices
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp FlatIndex {})) = ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([], []) -- incomplete slices
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
pes) StmAux (ExpDec (Aliases rep))
_ (BasicOp (ArrayLit [SubExp]
ses Type
_))) =
  let rds :: [(VName, VName, IxFun)]
rds = (VName -> Maybe (VName, VName, IxFun))
-> [VName] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) ([VName] -> [(VName, VName, IxFun)])
-> [VName] -> [(VName, VName, IxFun)]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
seName [SubExp]
ses
      wrts :: [(VName, VName, IxFun)]
wrts = (PatElem (VarAliases, LetDec rep) -> Maybe (VName, VName, IxFun))
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (VName -> Maybe (VName, VName, IxFun))
-> (PatElem (VarAliases, LetDec rep) -> VName)
-> PatElem (VarAliases, LetDec rep)
-> Maybe (VName, VName, IxFun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LetDec rep)]
[PatElem (LetDec (Aliases rep))]
pes
   in ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
wrts, [(VName, VName, IxFun)]
wrts [(VName, VName, IxFun)]
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rds)
  where
    seName :: SubExp -> Maybe VName
seName (Var VName
a) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
a
    seName (Constant PrimValue
_) = Maybe VName
forall a. Maybe a
Nothing
-- In place update @x[slc] <- a@. In the "in-place update" case,
--   summaries should be added after the old variable @x@ has
--   been added in the active coalesced table.
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Update Safety
_ VName
_x (Slice [DimIndex SubExp]
slc) SubExp
a_se))) = do
  (VName
m_b, VName
m_x, IxFun
x_ixfn) <- TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LetDec rep)
PatElem (LetDec (Aliases rep))
x')
  let x_ixfn_slc :: IxFun
x_ixfn_slc = IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
x_ixfn (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> DimIndex a -> DimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [DimIndex SubExp]
slc
      r1 :: (VName, VName, IxFun)
r1 = (VName
m_b, VName
m_x, IxFun
x_ixfn_slc)
  case SubExp
a_se of
    Constant PrimValue
_ -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
    Var VName
a -> case TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
a of
      Maybe (VName, VName, IxFun)
Nothing -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
      Just (VName, VName, IxFun)
r2 -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1, (VName, VName, IxFun)
r2])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
y]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate (Shape []) (Var VName
x)))) = do
  -- y = copy x
  (VName, VName, IxFun)
wrt <- TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (VName -> Maybe (VName, VName, IxFun))
-> VName -> Maybe (VName, VName, IxFun)
forall a b. (a -> b) -> a -> b
$ PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LetDec rep)
PatElem (LetDec (Aliases rep))
y
  (VName, VName, IxFun)
rd <- TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
x
  ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName, VName, IxFun)
wrt], [(VName, VName, IxFun)
wrt, (VName, VName, IxFun)
rd])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate (Shape []) SubExp
_))) =
  [Char] -> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. HasCallStack => [Char] -> a
error [Char]
"Impossible"
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Concat Int
_i (VName
a :| [VName]
bs) SubExp
_ses))) =
  -- concat
  let ws :: [(VName, VName, IxFun)]
ws = (PatElem (VarAliases, LetDec rep) -> Maybe (VName, VName, IxFun))
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (VName -> Maybe (VName, VName, IxFun))
-> (PatElem (VarAliases, LetDec rep) -> VName)
-> PatElem (VarAliases, LetDec rep)
-> Maybe (VName, VName, IxFun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LetDec rep)]
[PatElem (LetDec (Aliases rep))]
ys
      rs :: [(VName, VName, IxFun)]
rs = (VName -> Maybe (VName, VName, IxFun))
-> [VName] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) (VName
a VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
bs)
   in ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws [(VName, VName, IxFun)]
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rs)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Manifest [Int]
_perm VName
x))) =
  let ws :: [(VName, VName, IxFun)]
ws = (PatElem (VarAliases, LetDec rep) -> Maybe (VName, VName, IxFun))
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (VName -> Maybe (VName, VName, IxFun))
-> (PatElem (VarAliases, LetDec rep) -> VName)
-> PatElem (VarAliases, LetDec rep)
-> Maybe (VName, VName, IxFun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LetDec rep)]
[PatElem (LetDec (Aliases rep))]
ys
      rs :: [(VName, VName, IxFun)]
rs = (VName -> Maybe (VName, VName, IxFun))
-> [VName] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) [VName
x]
   in ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws [(VName, VName, IxFun)]
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. [a] -> [a] -> [a]
++ [(VName, VName, IxFun)]
rs)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Replicate Shape
_shp SubExp
se))) =
  let ws :: [(VName, VName, IxFun)]
ws = (PatElem (VarAliases, LetDec rep) -> Maybe (VName, VName, IxFun))
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (VName -> Maybe (VName, VName, IxFun))
-> (PatElem (VarAliases, LetDec rep) -> VName)
-> PatElem (VarAliases, LetDec rep)
-> Maybe (VName, VName, IxFun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LetDec rep)]
[PatElem (LetDec (Aliases rep))]
ys
   in case SubExp
se of
        Constant PrimValue
_ -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws)
        Var VName
x -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)]
ws, [(VName, VName, IxFun)]
ws [(VName, VName, IxFun)]
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. [a] -> [a] -> [a]
++ (VName -> Maybe (VName, VName, IxFun))
-> [VName] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab) [VName
x])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))
x]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (FlatUpdate VName
_ (FlatSlice SubExp
offset [FlatDimIndex SubExp]
slc) VName
v)))
  | Just (VName
m_b, VName
m_x, IxFun
x_ixfn) <- TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LetDec rep)
PatElem (LetDec (Aliases rep))
x) = do
      let x_ixfn_slc :: IxFun
x_ixfn_slc =
            IxFun -> FlatSlice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun
x_ixfn (FlatSlice (TPrimExp Int64 VName) -> IxFun)
-> FlatSlice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset) ([FlatDimIndex (TPrimExp Int64 VName)]
 -> FlatSlice (TPrimExp Int64 VName))
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName))
-> [FlatDimIndex SubExp] -> [FlatDimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> FlatDimIndex a -> FlatDimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
slc
      let r1 :: (VName, VName, IxFun)
r1 = (VName
m_b, VName
m_x, IxFun
x_ixfn_slc)
      case TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab VName
v of
        Maybe (VName, VName, IxFun)
Nothing -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1])
        Just (VName, VName, IxFun)
r2 -> ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([(VName, VName, IxFun)
r1], [(VName, VName, IxFun)
r1, (VName, VName, IxFun)
r2])
-- getUseSumFromStm td_env coal_tab (Let (Pat ys) _ (BasicOp bop)) =
--   let wrt = mapMaybe (getDirAliasedIxfn td_env coal_tab . patElemName) ys
--    in trace ("getUseBla: " <> show bop) $ pure (wrt, wrt)
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
coal_tab (Let (Pat [PatElem (LetDec (Aliases rep))]
ys) StmAux (ExpDec (Aliases rep))
_ (BasicOp Iota {})) =
  let wrt :: [(VName, VName, IxFun)]
wrt = (PatElem (VarAliases, LetDec rep) -> Maybe (VName, VName, IxFun))
-> [PatElem (VarAliases, LetDec rep)] -> [(VName, VName, IxFun)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
coal_tab (VName -> Maybe (VName, VName, IxFun))
-> (PatElem (VarAliases, LetDec rep) -> VName)
-> PatElem (VarAliases, LetDec rep)
-> Maybe (VName, VName, IxFun)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LetDec rep)]
[PatElem (LetDec (Aliases rep))]
ys
   in ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName, VName, IxFun)]
wrt, [(VName, VName, IxFun)]
wrt)
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ BasicOp {}) = ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ (Let Pat {} StmAux (ExpDec (Aliases rep))
_ (Op (Alloc SubExp
_ Space
_))) = ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. a -> Maybe a
Just ([], [])
getUseSumFromStm TopdownEnv rep
_ CoalsTab
_ Stm (Aliases rep)
_ =
  -- if-then-else, loops are supposed to be treated separately,
  -- calls are not supported, and Ops are not yet supported
  Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall a. Maybe a
Nothing

-- | This function:
--     1. computes the written and read memory references for the current statement
--          (by calling @getUseSumFromStm@)
--     2. fails the entries in active coalesced table for which the write set
--          overlaps the uses of the destination (to that point)
recordMemRefUses ::
  (AliasableRep rep, Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
  TopdownEnv rep ->
  BotUpEnv ->
  Stm (Aliases rep) ->
  (CoalsTab, InhibitTab)
recordMemRefUses :: forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm =
  let active_tab :: CoalsTab
active_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      inhibit_tab :: InhibitTab
inhibit_tab = BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env
      active_etries :: [(VName, CoalsEntry)]
active_etries = CoalsTab -> [(VName, CoalsEntry)]
forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
active_tab
   in case TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
forall rep (inner :: * -> *).
(Op rep ~ MemOp inner rep, HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> CoalsTab
-> Stm (Aliases rep)
-> Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
getUseSumFromStm TopdownEnv rep
td_env CoalsTab
active_tab Stm (Aliases rep)
stm of
        Maybe ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
Nothing ->
          CoalsTab -> [(VName, CoalsEntry)]
forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
active_tab
            [(VName, CoalsEntry)]
-> ([(VName, CoalsEntry)] -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
forall a b. a -> (a -> b) -> b
& ((CoalsTab, InhibitTab)
 -> (VName, CoalsEntry) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [(VName, CoalsEntry)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \(CoalsTab, InhibitTab)
state (VName
m_b, CoalsEntry
entry) ->
                  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Pat (VarAliases, LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm (Aliases rep) -> Pat (LetDec (Aliases rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm (Aliases rep)
stm) [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry)
                    then (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
state VName
m_b
                    else (CoalsTab, InhibitTab)
state
              )
              (CoalsTab
active_tab, InhibitTab
inhibit_tab)
        Just ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
use_sums ->
          let ([Maybe AccessSummary]
mb_wrts, [AccessSummary]
prev_uses, [AccessSummary]
mb_lmads) =
                ((VName, CoalsEntry)
 -> (Maybe AccessSummary, AccessSummary, AccessSummary))
-> [(VName, CoalsEntry)]
-> [(Maybe AccessSummary, AccessSummary, AccessSummary)]
forall a b. (a -> b) -> [a] -> [b]
map (([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> CoalsTab
-> (VName, CoalsEntry)
-> (Maybe AccessSummary, AccessSummary, AccessSummary)
checkOverlapAndExpand ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
use_sums CoalsTab
active_tab) [(VName, CoalsEntry)]
active_etries
                  [(Maybe AccessSummary, AccessSummary, AccessSummary)]
-> ([(Maybe AccessSummary, AccessSummary, AccessSummary)]
    -> ([Maybe AccessSummary], [AccessSummary], [AccessSummary]))
-> ([Maybe AccessSummary], [AccessSummary], [AccessSummary])
forall a b. a -> (a -> b) -> b
& [(Maybe AccessSummary, AccessSummary, AccessSummary)]
-> ([Maybe AccessSummary], [AccessSummary], [AccessSummary])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3

              -- keep only the entries that do not overlap with the memory
              -- blocks defined in @pat@ or @inner_free_vars@.
              -- the others must be recorded in @inhibit_tab@ because
              -- they violate the 3rd safety condition.
              active_tab1 :: CoalsTab
active_tab1 =
                [(VName, CoalsEntry)] -> CoalsTab
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                  ([(VName, CoalsEntry)] -> CoalsTab)
-> [(VName, CoalsEntry)] -> CoalsTab
forall a b. (a -> b) -> a -> b
$ ((AccessSummary,
  (AccessSummary, AccessSummary, (VName, CoalsEntry)))
 -> (VName, CoalsEntry))
-> [(AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
-> [(VName, CoalsEntry)]
forall a b. (a -> b) -> [a] -> [b]
map
                    ( \(AccessSummary
wrts, (AccessSummary
uses, AccessSummary
prev_use, (VName
k, CoalsEntry
etry))) ->
                        let mrefs' :: MemRefs
mrefs' = (CoalsEntry -> MemRefs
memrefs CoalsEntry
etry) {dstrefs :: AccessSummary
dstrefs = AccessSummary
prev_use}
                            etry' :: CoalsEntry
etry' = CoalsEntry
etry {memrefs :: MemRefs
memrefs = MemRefs
mrefs'}
                         in (VName
k, AccessSummary -> AccessSummary -> CoalsEntry -> CoalsEntry
addLmads AccessSummary
wrts AccessSummary
uses CoalsEntry
etry')
                    )
                  ([(AccessSummary,
   (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
 -> [(VName, CoalsEntry)])
-> [(AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
-> [(VName, CoalsEntry)]
forall a b. (a -> b) -> a -> b
$ ((Maybe AccessSummary,
  (AccessSummary, AccessSummary, (VName, CoalsEntry)))
 -> Maybe
      (AccessSummary,
       (AccessSummary, AccessSummary, (VName, CoalsEntry))))
-> [(Maybe AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
-> [(AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\(Maybe AccessSummary
x, (AccessSummary, AccessSummary, (VName, CoalsEntry))
y) -> (,(AccessSummary, AccessSummary, (VName, CoalsEntry))
y) (AccessSummary
 -> (AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry))))
-> Maybe AccessSummary
-> Maybe
     (AccessSummary,
      (AccessSummary, AccessSummary, (VName, CoalsEntry)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AccessSummary
x) -- only keep successful coals
                  ([(Maybe AccessSummary,
   (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
 -> [(AccessSummary,
      (AccessSummary, AccessSummary, (VName, CoalsEntry)))])
-> [(Maybe AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
-> [(AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
forall a b. (a -> b) -> a -> b
$ [Maybe AccessSummary]
-> [(AccessSummary, AccessSummary, (VName, CoalsEntry))]
-> [(Maybe AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe AccessSummary]
mb_wrts
                  ([(AccessSummary, AccessSummary, (VName, CoalsEntry))]
 -> [(Maybe AccessSummary,
      (AccessSummary, AccessSummary, (VName, CoalsEntry)))])
-> [(AccessSummary, AccessSummary, (VName, CoalsEntry))]
-> [(Maybe AccessSummary,
     (AccessSummary, AccessSummary, (VName, CoalsEntry)))]
forall a b. (a -> b) -> a -> b
$ [AccessSummary]
-> [AccessSummary]
-> [(VName, CoalsEntry)]
-> [(AccessSummary, AccessSummary, (VName, CoalsEntry))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [AccessSummary]
mb_lmads [AccessSummary]
prev_uses [(VName, CoalsEntry)]
active_etries
              failed_tab :: CoalsTab
failed_tab =
                [(VName, CoalsEntry)] -> CoalsTab
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, CoalsEntry)] -> CoalsTab)
-> [(VName, CoalsEntry)] -> CoalsTab
forall a b. (a -> b) -> a -> b
$
                  ((Maybe AccessSummary, (VName, CoalsEntry)) -> (VName, CoalsEntry))
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
-> [(VName, CoalsEntry)]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe AccessSummary, (VName, CoalsEntry)) -> (VName, CoalsEntry)
forall a b. (a, b) -> b
snd ([(Maybe AccessSummary, (VName, CoalsEntry))]
 -> [(VName, CoalsEntry)])
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
-> [(VName, CoalsEntry)]
forall a b. (a -> b) -> a -> b
$
                    ((Maybe AccessSummary, (VName, CoalsEntry)) -> Bool)
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe AccessSummary -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe AccessSummary -> Bool)
-> ((Maybe AccessSummary, (VName, CoalsEntry))
    -> Maybe AccessSummary)
-> (Maybe AccessSummary, (VName, CoalsEntry))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe AccessSummary, (VName, CoalsEntry)) -> Maybe AccessSummary
forall a b. (a, b) -> a
fst) ([(Maybe AccessSummary, (VName, CoalsEntry))]
 -> [(Maybe AccessSummary, (VName, CoalsEntry))])
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
forall a b. (a -> b) -> a -> b
$
                      [Maybe AccessSummary]
-> [(VName, CoalsEntry)]
-> [(Maybe AccessSummary, (VName, CoalsEntry))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Maybe AccessSummary]
mb_wrts [(VName, CoalsEntry)]
active_etries
              (CoalsTab
_, InhibitTab
inhibit_tab1) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
failed_tab, InhibitTab
inhibit_tab) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
failed_tab
           in (CoalsTab
active_tab1, InhibitTab
inhibit_tab1)
  where
    checkOverlapAndExpand :: ([(VName, VName, IxFun)], [(VName, VName, IxFun)])
-> CoalsTab
-> (VName, CoalsEntry)
-> (Maybe AccessSummary, AccessSummary, AccessSummary)
checkOverlapAndExpand ([(VName, VName, IxFun)]
stm_wrts, [(VName, VName, IxFun)]
stm_uses) CoalsTab
active_tab (VName
m_b, CoalsEntry
etry) =
      let alias_m_b :: Names
alias_m_b = Names -> VName -> Names
getAliases Names
forall a. Monoid a => a
mempty VName
m_b
          stm_uses' :: [(VName, VName, IxFun)]
stm_uses' = ((VName, VName, IxFun) -> Bool)
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`notNameIn` Names
alias_m_b) (VName -> Bool)
-> ((VName, VName, IxFun) -> VName)
-> (VName, VName, IxFun)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName, IxFun) -> VName
forall {a} {b} {c}. (a, b, c) -> a
tupFst) [(VName, VName, IxFun)]
stm_uses
          all_aliases :: Names
all_aliases = (Names -> VName -> Names) -> Names -> [VName] -> Names
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Names -> VName -> Names
getAliases Names
forall a. Monoid a => a
mempty ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Names
alsmem CoalsEntry
etry
          ixfns :: [IxFun]
ixfns = ((VName, VName, IxFun) -> IxFun)
-> [(VName, VName, IxFun)] -> [IxFun]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName, IxFun) -> IxFun
forall {a} {b} {c}. (a, b, c) -> c
tupThd ([(VName, VName, IxFun)] -> [IxFun])
-> [(VName, VName, IxFun)] -> [IxFun]
forall a b. (a -> b) -> a -> b
$ ((VName, VName, IxFun) -> Bool)
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
all_aliases) (VName -> Bool)
-> ((VName, VName, IxFun) -> VName)
-> (VName, VName, IxFun)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName, IxFun) -> VName
forall {a} {b} {c}. (a, b, c) -> b
tupSnd) [(VName, VName, IxFun)]
stm_uses'
          lmads' :: [LmadRef]
lmads' = (IxFun -> Maybe LmadRef) -> [IxFun] -> [LmadRef]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe IxFun -> Maybe LmadRef
mbLmad [IxFun]
ixfns
          lmads'' :: AccessSummary
lmads'' =
            if [LmadRef] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LmadRef]
lmads' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [IxFun] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IxFun]
ixfns
              then Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ [LmadRef] -> Set LmadRef
forall a. Ord a => [a] -> Set a
S.fromList [LmadRef]
lmads'
              else AccessSummary
Undeterminable
          wrt_ixfns :: [IxFun]
wrt_ixfns = ((VName, VName, IxFun) -> IxFun)
-> [(VName, VName, IxFun)] -> [IxFun]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName, IxFun) -> IxFun
forall {a} {b} {c}. (a, b, c) -> c
tupThd ([(VName, VName, IxFun)] -> [IxFun])
-> [(VName, VName, IxFun)] -> [IxFun]
forall a b. (a -> b) -> a -> b
$ ((VName, VName, IxFun) -> Bool)
-> [(VName, VName, IxFun)] -> [(VName, VName, IxFun)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
alias_m_b) (VName -> Bool)
-> ((VName, VName, IxFun) -> VName)
-> (VName, VName, IxFun)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, VName, IxFun) -> VName
forall {a} {b} {c}. (a, b, c) -> a
tupFst) [(VName, VName, IxFun)]
stm_wrts
          wrt_tmps :: [LmadRef]
wrt_tmps = (IxFun -> Maybe LmadRef) -> [IxFun] -> [LmadRef]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe IxFun -> Maybe LmadRef
mbLmad [IxFun]
wrt_ixfns
          prev_use :: AccessSummary
prev_use =
            ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
forall rep.
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (TopdownEnv rep -> ScalarTab
forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) (AccessSummary -> AccessSummary) -> AccessSummary -> AccessSummary
forall a b. (a -> b) -> a -> b
$
              (MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
          wrt_lmads' :: AccessSummary
wrt_lmads' =
            if [LmadRef] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LmadRef]
wrt_tmps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [IxFun] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [IxFun]
wrt_ixfns
              then Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ [LmadRef] -> Set LmadRef
forall a. Ord a => [a] -> Set a
S.fromList [LmadRef]
wrt_tmps
              else AccessSummary
Undeterminable
          original_mem_aliases :: Names
original_mem_aliases =
            ((VName, VName, IxFun) -> VName)
-> [(VName, VName, IxFun)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (VName, VName, IxFun) -> VName
forall {a} {b} {c}. (a, b, c) -> a
tupFst [(VName, VName, IxFun)]
stm_uses
              [VName]
-> ([VName] -> Maybe (VName, [VName])) -> Maybe (VName, [VName])
forall a b. a -> (a -> b) -> b
& [VName] -> Maybe (VName, [VName])
forall a. [a] -> Maybe (a, [a])
uncons
              Maybe (VName, [VName])
-> (Maybe (VName, [VName]) -> Maybe VName) -> Maybe VName
forall a b. a -> (a -> b) -> b
& ((VName, [VName]) -> VName)
-> Maybe (VName, [VName]) -> Maybe VName
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (VName, [VName]) -> VName
forall a b. (a, b) -> a
fst
              Maybe VName
-> (Maybe VName -> Maybe CoalsEntry) -> Maybe CoalsEntry
forall a b. a -> (a -> b) -> b
& (VName -> Maybe CoalsEntry) -> Maybe VName -> Maybe CoalsEntry
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
(=<<) (VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` CoalsTab
active_tab)
              Maybe CoalsEntry -> (Maybe CoalsEntry -> Names) -> Names
forall a b. a -> (a -> b) -> b
& Names -> (CoalsEntry -> Names) -> Maybe CoalsEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty CoalsEntry -> Names
alsmem
          (AccessSummary
wrt_lmads'', AccessSummary
lmads) =
            if VName
m_b VName -> Names -> Bool
`nameIn` Names
original_mem_aliases
              then (AccessSummary
wrt_lmads' AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> AccessSummary
lmads'', Set LmadRef -> AccessSummary
Set Set LmadRef
forall a. Monoid a => a
mempty)
              else (AccessSummary
wrt_lmads', AccessSummary
lmads'')
          no_overlap :: Bool
no_overlap = TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env (AccessSummary
lmads AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> AccessSummary
prev_use) AccessSummary
wrt_lmads''
          wrt_lmads :: Maybe AccessSummary
wrt_lmads =
            if Bool
no_overlap
              then AccessSummary -> Maybe AccessSummary
forall a. a -> Maybe a
Just AccessSummary
wrt_lmads''
              else Maybe AccessSummary
forall a. Maybe a
Nothing
       in (Maybe AccessSummary
wrt_lmads, AccessSummary
prev_use, AccessSummary
lmads)

    tupFst :: (a, b, c) -> a
tupFst (a
a, b
_, c
_) = a
a
    tupSnd :: (a, b, c) -> b
tupSnd (a
_, b
b, c
_) = b
b
    tupThd :: (a, b, c) -> c
tupThd (a
_, b
_, c
c) = c
c
    getAliases :: Names -> VName -> Names
getAliases Names
acc VName
m =
      VName -> Names
oneName VName
m
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
acc
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> Maybe Names -> Names
forall a. a -> Maybe a -> a
fromMaybe Names
forall a. Monoid a => a
mempty (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m (TopdownEnv rep -> InhibitTab
forall rep. TopdownEnv rep -> InhibitTab
m_alias TopdownEnv rep
td_env))
    mbLmad :: IxFun -> Maybe LmadRef
mbLmad IxFun
indfun
      | Just FreeVarSubsts
subs <- ScopeTab rep -> ScalarTab -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep -> ScalarTab -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> ScalarTab
scals BotUpEnv
bu_env) IxFun
indfun,
        (IxFun.IxFun LmadRef
lmad Shape (TPrimExp Int64 VName)
_) <- FreeVarSubsts -> IxFun -> IxFun
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun FreeVarSubsts
subs IxFun
indfun =
          LmadRef -> Maybe LmadRef
forall a. a -> Maybe a
Just LmadRef
lmad
    mbLmad IxFun
_ = Maybe LmadRef
forall a. Maybe a
Nothing
    addLmads :: AccessSummary -> AccessSummary -> CoalsEntry -> CoalsEntry
addLmads AccessSummary
wrts AccessSummary
uses CoalsEntry
etry =
      CoalsEntry
etry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs AccessSummary
uses AccessSummary
wrts MemRefs -> MemRefs -> MemRefs
forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> MemRefs
memrefs CoalsEntry
etry}

-- | Check for memory overlap of two access summaries.
--
-- This check is conservative, so unless we can guarantee that there is no
-- overlap, we return 'False'.
noMemOverlap :: (AliasableRep rep) => TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap :: forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
_ AccessSummary
_ (Set Set LmadRef
mr)
  | Set LmadRef
mr Set LmadRef -> Set LmadRef -> Bool
forall a. Eq a => a -> a -> Bool
== Set LmadRef
forall a. Monoid a => a
mempty = Bool
True
noMemOverlap TopdownEnv rep
_ (Set Set LmadRef
mr) AccessSummary
_
  | Set LmadRef
mr Set LmadRef -> Set LmadRef -> Bool
forall a. Eq a => a -> a -> Bool
== Set LmadRef
forall a. Monoid a => a
mempty = Bool
True
noMemOverlap TopdownEnv rep
td_env (Set Set LmadRef
is0) (Set Set LmadRef
js0)
  | Just [PrimExp VName]
non_negs <- (VName -> Maybe (PrimExp VName))
-> [VName] -> Maybe [PrimExp VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((VName -> Maybe (PrimExp VName)) -> SubExp -> Maybe (PrimExp VName)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM (ScopeTab rep -> ScalarTab -> VName -> Maybe (PrimExp VName)
forall rep.
AliasableRep rep =>
ScopeTab rep -> ScalarTab -> VName -> Maybe (PrimExp VName)
vnameToPrimExp (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (TopdownEnv rep -> ScalarTab
forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) (SubExp -> Maybe (PrimExp VName))
-> (VName -> SubExp) -> VName -> Maybe (PrimExp VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> Maybe [PrimExp VName])
-> [VName] -> Maybe [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Names
forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env =
      let ([LmadRef]
_, [LmadRef]
not_disjoints) =
            (LmadRef -> Bool) -> [LmadRef] -> ([LmadRef], [LmadRef])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition
              ( \LmadRef
i ->
                  (LmadRef -> Bool) -> [LmadRef] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                    ( \LmadRef
j ->
                        [(VName, PrimExp VName)] -> Names -> LmadRef -> LmadRef -> Bool
IxFun.disjoint [(VName, PrimExp VName)]
less_thans (TopdownEnv rep -> Names
forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env) LmadRef
i LmadRef
j
                          Bool -> Bool -> Bool
|| ()
-> ()
-> [(VName, PrimExp VName)]
-> Names
-> LmadRef
-> LmadRef
-> Bool
forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LmadRef
-> LmadRef
-> Bool
IxFun.disjoint2 () () [(VName, PrimExp VName)]
less_thans (TopdownEnv rep -> Names
forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env) LmadRef
i LmadRef
j
                          Bool -> Bool -> Bool
|| Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LmadRef
-> LmadRef
-> Bool
IxFun.disjoint3 (NameInfo (Aliases rep) -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo (Aliases rep) -> Type) -> ScopeTab rep -> Map VName Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negs LmadRef
i LmadRef
j
                    )
                    [LmadRef]
js
              )
              [LmadRef]
is
       in [LmadRef] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LmadRef]
not_disjoints
  where
    less_thans :: [(VName, PrimExp VName)]
less_thans = ((VName, PrimExp VName) -> (VName, PrimExp VName))
-> [(VName, PrimExp VName)] -> [(VName, PrimExp VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((PrimExp VName -> PrimExp VName)
-> (VName, PrimExp VName) -> (VName, PrimExp VName)
forall a b. (a -> b) -> (VName, a) -> (VName, b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((PrimExp VName -> PrimExp VName)
 -> (VName, PrimExp VName) -> (VName, PrimExp VName))
-> (PrimExp VName -> PrimExp VName)
-> (VName, PrimExp VName)
-> (VName, PrimExp VName)
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a. Eq a => (a -> a) -> a -> a
fixPoint ((PrimExp VName -> PrimExp VName)
 -> PrimExp VName -> PrimExp VName)
-> (PrimExp VName -> PrimExp VName)
-> PrimExp VName
-> PrimExp VName
forall a b. (a -> b) -> a -> b
$ ScalarTab -> PrimExp VName -> PrimExp VName
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp (ScalarTab -> PrimExp VName -> PrimExp VName)
-> ScalarTab -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> ScalarTab
forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) ([(VName, PrimExp VName)] -> [(VName, PrimExp VName)])
-> [(VName, PrimExp VName)] -> [(VName, PrimExp VName)]
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> [(VName, PrimExp VName)]
forall rep. TopdownEnv rep -> [(VName, PrimExp VName)]
knownLessThan TopdownEnv rep
td_env
    asserts :: [PrimExp VName]
asserts = (SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map ((PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a. Eq a => (a -> a) -> a -> a
fixPoint (ScalarTab -> PrimExp VName -> PrimExp VName
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp (ScalarTab -> PrimExp VName -> PrimExp VName)
-> ScalarTab -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> ScalarTab
forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env) (PrimExp VName -> PrimExp VName)
-> (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
Bool) ([SubExp] -> [PrimExp VName]) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> [SubExp]
forall rep. TopdownEnv rep -> [SubExp]
td_asserts TopdownEnv rep
td_env
    is :: [LmadRef]
is = (LmadRef -> LmadRef) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> [a] -> [b]
map ((LmadRef -> LmadRef) -> LmadRef -> LmadRef
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD (FreeVarSubsts -> LmadRef -> LmadRef)
-> FreeVarSubsts -> LmadRef -> LmadRef
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ScalarTab -> FreeVarSubsts
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TopdownEnv rep -> ScalarTab
forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) ([LmadRef] -> [LmadRef]) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
is0
    js :: [LmadRef]
js = (LmadRef -> LmadRef) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> [a] -> [b]
map ((LmadRef -> LmadRef) -> LmadRef -> LmadRef
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD (FreeVarSubsts -> LmadRef -> LmadRef)
-> FreeVarSubsts -> LmadRef -> LmadRef
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ScalarTab -> FreeVarSubsts
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TopdownEnv rep -> ScalarTab
forall rep. TopdownEnv rep -> ScalarTab
scalarTable TopdownEnv rep
td_env)) ([LmadRef] -> [LmadRef]) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
js0
noMemOverlap TopdownEnv rep
_ AccessSummary
_ AccessSummary
_ = Bool
False

-- | Computes the total aggregated access summary for a loop by expanding the
-- access summary given according to the iterator variable and bounds of the
-- loop.
--
-- Corresponds to:
--
-- \[
--   \bigcup_{j=0}^{j<n} Access_j
-- \]
aggSummaryLoopTotal ::
  (MonadFreshNames m) =>
  ScopeTab rep ->
  ScopeTab rep ->
  ScalarTab ->
  Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) ->
  AccessSummary ->
  m AccessSummary
aggSummaryLoopTotal :: forall (m :: * -> *) rep.
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> ScalarTab
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
Undeterminable = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ (Set Set LmadRef
l)
  | Set LmadRef
l Set LmadRef -> Set LmadRef -> Bool
forall a. Eq a => a -> a -> Bool
== Set LmadRef
forall a. Monoid a => a
mempty = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AccessSummary -> m AccessSummary)
-> AccessSummary -> m AccessSummary
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set Set LmadRef
forall a. Monoid a => a
mempty
aggSummaryLoopTotal ScopeTab rep
scope_bef ScopeTab rep
scope_loop ScalarTab
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
access
  | Set Set LmadRef
ls <- ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
forall rep.
ScopeTab rep -> ScalarTab -> AccessSummary -> AccessSummary
translateAccessSummary ScopeTab rep
scope_loop ScalarTab
scals_loop AccessSummary
access,
    Names
nms <- (Names -> Names -> Names) -> Names -> [Names] -> Names
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) Names
forall a. Monoid a => a
mempty ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (LmadRef -> Names) -> [LmadRef] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map LmadRef -> Names
forall a. FreeIn a => a -> Names
freeIn ([LmadRef] -> [Names]) -> [LmadRef] -> [Names]
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
ls,
    (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
inBeforeScope ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
nms = do
      AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AccessSummary -> m AccessSummary)
-> AccessSummary -> m AccessSummary
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set Set LmadRef
ls
  where
    inBeforeScope :: VName -> Bool
inBeforeScope VName
v =
      case VName -> ScopeTab rep -> Maybe (NameInfo (Aliases rep))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v ScopeTab rep
scope_bef of
        Maybe (NameInfo (Aliases rep))
Nothing -> Bool
False
        Just NameInfo (Aliases rep)
_ -> Bool
True
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
scalars_loop (Just (VName
iterator_var, (TPrimExp Int64 VName
lower_bound, TPrimExp Int64 VName
upper_bound))) (Set Set LmadRef
lmads) =
  (LmadRef -> m AccessSummary) -> [LmadRef] -> m AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
iterator_var TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
upper_bound
        (LmadRef -> m AccessSummary)
-> (LmadRef -> LmadRef) -> LmadRef -> m AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LmadRef -> LmadRef) -> LmadRef -> LmadRef
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD (FreeVarSubsts -> LmadRef -> LmadRef)
-> FreeVarSubsts -> LmadRef -> LmadRef
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> ScalarTab -> FreeVarSubsts
forall a b. (a -> b) -> Map VName a -> Map VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars_loop)
    )
    (Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
lmads)
aggSummaryLoopTotal ScopeTab rep
_ ScopeTab rep
_ ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
_ = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable

-- | For a given iteration of the loop $i$, computes the aggregated loop access
-- summary of all later iterations.
--
-- Corresponds to:
--
-- \[
--   \bigcup_{j=i+1}^{j<n} Access_j
-- \]
aggSummaryLoopPartial ::
  (MonadFreshNames m) =>
  ScalarTab ->
  Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName)) ->
  AccessSummary ->
  m AccessSummary
aggSummaryLoopPartial :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopPartial ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
_ AccessSummary
Undeterminable = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ScalarTab
_ Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
Nothing AccessSummary
_ = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryLoopPartial ScalarTab
scalars_loop (Just (VName
iterator_var, (TPrimExp Int64 VName
_, TPrimExp Int64 VName
upper_bound))) (Set Set LmadRef
lmads) = do
  -- map over each index function in the access summary
  --   Substitube a fresh variable k for the loop iterator
  --   if k is in stride or span of ixfun: fall back to total
  --   new_stride = old_offset - old_offset (where k+1 is substituted for k)
  --   new_offset = old_offset where k = lower bound of iteration
  --   new_span = upper bound of iteration
  (LmadRef -> m AccessSummary) -> [LmadRef] -> m AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne
        VName
iterator_var
        (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
iterator_var (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
        (TPrimExp Int64 VName
upper_bound TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
typedLeafExp VName
iterator_var TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
        (LmadRef -> m AccessSummary)
-> (LmadRef -> LmadRef) -> LmadRef -> m AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LmadRef -> LmadRef) -> LmadRef -> LmadRef
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD (FreeVarSubsts -> LmadRef -> LmadRef)
-> FreeVarSubsts -> LmadRef -> LmadRef
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> ScalarTab -> FreeVarSubsts
forall a b. (a -> b) -> Map VName a -> Map VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars_loop)
    )
    (Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
lmads)

-- | For a given map with $k$ dimensions and an index $i$ for each dimension,
-- compute the aggregated access summary of all other threads.
--
-- For the innermost dimension, this corresponds to
--
-- \[
--   \bigcup_{j=0}^{j<i} Access_j \cup \bigcup_{j=i+1}^{j<n} Access_j
-- \]
--
-- where $Access_j$ describes the point accesses in the map. As we move up in
-- dimensionality, the previous access summaries are kept, in addition to the
-- total aggregation of the inner dimensions. For outer dimensions, the equation
-- is the same, the point accesses in $Access_j$ are replaced with the total
-- aggregation of the inner dimensions.
aggSummaryMapPartial :: (MonadFreshNames m) => ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial ScalarTab
_ [] = m AccessSummary -> LmadRef -> m AccessSummary
forall a b. a -> b -> a
const (m AccessSummary -> LmadRef -> m AccessSummary)
-> m AccessSummary -> LmadRef -> m AccessSummary
forall a b. (a -> b) -> a -> b
$ AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
forall a. Monoid a => a
mempty
aggSummaryMapPartial ScalarTab
scalars [(VName, SubExp)]
dims =
  AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper AccessSummary
forall a. Monoid a => a
mempty ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
dims) (AccessSummary -> m AccessSummary)
-> (LmadRef -> AccessSummary) -> LmadRef -> m AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary)
-> (LmadRef -> Set LmadRef) -> LmadRef -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LmadRef -> Set LmadRef
forall a. a -> Set a
S.singleton -- Reverse dims so we work from the inside out
  where
    helper :: AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper AccessSummary
acc [] AccessSummary
_ = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
acc
    helper AccessSummary
Undeterminable [(VName, SubExp)]
_ AccessSummary
_ = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
    helper AccessSummary
_ [(VName, SubExp)]
_ AccessSummary
Undeterminable = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
    helper (Set Set LmadRef
acc) ((VName
gtid, SubExp
size) : [(VName, SubExp)]
rest) (Set Set LmadRef
as) = do
      AccessSummary
partial_as <- ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne ScalarTab
scalars (VName
gtid, SubExp
size) (Set LmadRef -> AccessSummary
Set Set LmadRef
as)
      AccessSummary
total_as <-
        (LmadRef -> m AccessSummary) -> [LmadRef] -> m AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
          (VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid TPrimExp Int64 VName
0 (PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size))
          (Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
as)
      AccessSummary
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
helper (Set LmadRef -> AccessSummary
Set Set LmadRef
acc AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> AccessSummary
partial_as) [(VName, SubExp)]
rest AccessSummary
total_as

-- | Given an access summary $a$, a thread id $i$ and the size $n$ of the
-- dimension, compute the partial map summary.
--
-- Corresponds to
--
-- \[
--   \bigcup_{j=0}^{j<i} a_j \cup \bigcup_{j=i+1}^{j<n} a_j
-- \]
aggSummaryMapPartialOne :: (MonadFreshNames m) => ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> (VName, SubExp) -> AccessSummary -> m AccessSummary
aggSummaryMapPartialOne ScalarTab
_ (VName, SubExp)
_ AccessSummary
Undeterminable = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryMapPartialOne ScalarTab
_ (VName
_, Constant PrimValue
n) (Set Set LmadRef
_) | PrimValue -> Bool
oneIsh PrimValue
n = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
forall a. Monoid a => a
mempty
aggSummaryMapPartialOne ScalarTab
scalars (VName
gtid, SubExp
size) (Set Set LmadRef
lmads0) =
  ((TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary)
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
-> m AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    (TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary
helper
    [ (TPrimExp Int64 VName
0, PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64)),
      ( PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1,
        PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
gtid (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
      )
    ]
  where
    lmads :: [LmadRef]
lmads = (LmadRef -> LmadRef) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> [a] -> [b]
map ((LmadRef -> LmadRef) -> LmadRef -> LmadRef
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD (FreeVarSubsts -> LmadRef -> LmadRef)
-> FreeVarSubsts -> LmadRef -> LmadRef
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> ScalarTab -> FreeVarSubsts
forall a b. (a -> b) -> Map VName a -> Map VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars)) ([LmadRef] -> [LmadRef]) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
lmads0
    helper :: (TPrimExp Int64 VName, TPrimExp Int64 VName) -> m AccessSummary
helper (TPrimExp Int64 VName
x, TPrimExp Int64 VName
y) = (LmadRef -> m AccessSummary) -> [LmadRef] -> m AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM (VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid TPrimExp Int64 VName
x TPrimExp Int64 VName
y) [LmadRef]
lmads

-- | Computes to total access summary over a multi-dimensional map.
aggSummaryMapTotal :: (MonadFreshNames m) => ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal :: forall (m :: * -> *).
MonadFreshNames m =>
ScalarTab -> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal ScalarTab
_ [] AccessSummary
_ = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
forall a. Monoid a => a
mempty
aggSummaryMapTotal ScalarTab
_ [(VName, SubExp)]
_ (Set Set LmadRef
lmads)
  | Set LmadRef
lmads Set LmadRef -> Set LmadRef -> Bool
forall a. Eq a => a -> a -> Bool
== Set LmadRef
forall a. Monoid a => a
mempty = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
forall a. Monoid a => a
mempty
aggSummaryMapTotal ScalarTab
_ [(VName, SubExp)]
_ AccessSummary
Undeterminable = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
aggSummaryMapTotal ScalarTab
scalars [(VName, SubExp)]
segspace (Set Set LmadRef
lmads0) =
  (AccessSummary -> (VName, SubExp) -> m AccessSummary)
-> AccessSummary -> [(VName, SubExp)] -> m AccessSummary
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
    ( \AccessSummary
as' (VName
gtid', SubExp
size') -> case AccessSummary
as' of
        Set Set LmadRef
lmads' ->
          (LmadRef -> m AccessSummary) -> [LmadRef] -> m AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
            ( VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
gtid' TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> LmadRef -> m AccessSummary)
-> TPrimExp Int64 VName -> LmadRef -> m AccessSummary
forall a b. (a -> b) -> a -> b
$
                PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
                  PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
size'
            )
            (Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
lmads')
        AccessSummary
Undeterminable -> AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
    )
    (Set LmadRef -> AccessSummary
Set Set LmadRef
lmads)
    ([(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse [(VName, SubExp)]
segspace)
  where
    lmads :: Set LmadRef
lmads =
      [LmadRef] -> Set LmadRef
forall a. Ord a => [a] -> Set a
S.fromList ([LmadRef] -> Set LmadRef) -> [LmadRef] -> Set LmadRef
forall a b. (a -> b) -> a -> b
$
        (LmadRef -> LmadRef) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> [a] -> [b]
map ((LmadRef -> LmadRef) -> LmadRef -> LmadRef
forall a. Eq a => (a -> a) -> a -> a
fixPoint (FreeVarSubsts -> LmadRef -> LmadRef
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
IxFun.substituteInLMAD (FreeVarSubsts -> LmadRef -> LmadRef)
-> FreeVarSubsts -> LmadRef -> LmadRef
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> TPrimExp Int64 VName)
-> ScalarTab -> FreeVarSubsts
forall a b. (a -> b) -> Map VName a -> Map VName b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ScalarTab
scalars)) ([LmadRef] -> [LmadRef]) -> [LmadRef] -> [LmadRef]
forall a b. (a -> b) -> a -> b
$
          Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
lmads0

-- | Helper function that aggregates the accesses of single LMAD according to a
-- given iterator value, a lower bound and a span.
--
-- If successful, the result is an index function with an extra outer
-- dimension. The stride of the outer dimension is computed by taking the
-- difference between two points in the index function.
--
-- The function returns 'Underterminable' if the iterator is free in the output
-- LMAD or the dimensions of the input LMAD .
aggSummaryOne :: (MonadFreshNames m) => VName -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> LmadRef -> m AccessSummary
aggSummaryOne :: forall (m :: * -> *).
MonadFreshNames m =>
VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> LmadRef
-> m AccessSummary
aggSummaryOne VName
iterator_var TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
spn lmad :: LmadRef
lmad@(IxFun.LMAD TPrimExp Int64 VName
offset0 [LMADDim (TPrimExp Int64 VName)]
dims0)
  | VName
iterator_var VName -> Names -> Bool
`nameIn` [LMADDim (TPrimExp Int64 VName)] -> Names
forall a. FreeIn a => a -> Names
freeIn [LMADDim (TPrimExp Int64 VName)]
dims0 = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
  | VName
iterator_var VName -> Names -> Bool
`notNameIn` TPrimExp Int64 VName -> Names
forall a. FreeIn a => a -> Names
freeIn TPrimExp Int64 VName
offset0 = AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AccessSummary -> m AccessSummary)
-> AccessSummary -> m AccessSummary
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ LmadRef -> Set LmadRef
forall a. a -> Set a
S.singleton LmadRef
lmad
  | Bool
otherwise = do
      VName
new_var <- [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"k"
      let offset :: TPrimExp Int64 VName
offset = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith (VName -> TPrimExp Int64 VName
typedLeafExp VName
new_var) TPrimExp Int64 VName
offset0
          offsetp1 :: TPrimExp Int64 VName
offsetp1 = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith (VName -> TPrimExp Int64 VName
typedLeafExp VName
new_var TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) TPrimExp Int64 VName
offset0
          new_stride :: TPrimExp Int64 VName
new_stride = PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v
constFoldPrimExp (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> PrimExp VName
simplify (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
offsetp1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset
          new_offset :: TPrimExp Int64 VName
new_offset = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith TPrimExp Int64 VName
lower_bound TPrimExp Int64 VName
offset0
          new_lmad :: LmadRef
new_lmad =
            TPrimExp Int64 VName -> [LMADDim (TPrimExp Int64 VName)] -> LmadRef
forall num. num -> [LMADDim num] -> LMAD num
IxFun.LMAD TPrimExp Int64 VName
new_offset ([LMADDim (TPrimExp Int64 VName)] -> LmadRef)
-> [LMADDim (TPrimExp Int64 VName)] -> LmadRef
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> LMADDim (TPrimExp Int64 VName)
forall num. num -> num -> LMADDim num
IxFun.LMADDim TPrimExp Int64 VName
new_stride TPrimExp Int64 VName
spn LMADDim (TPrimExp Int64 VName)
-> [LMADDim (TPrimExp Int64 VName)]
-> [LMADDim (TPrimExp Int64 VName)]
forall a. a -> [a] -> [a]
: [LMADDim (TPrimExp Int64 VName)]
dims0
      if VName
new_var VName -> Names -> Bool
`nameIn` LmadRef -> Names
forall a. FreeIn a => a -> Names
freeIn LmadRef
new_lmad
        then AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
        else AccessSummary -> m AccessSummary
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AccessSummary -> m AccessSummary)
-> AccessSummary -> m AccessSummary
forall a b. (a -> b) -> a -> b
$ Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ LmadRef -> Set LmadRef
forall a. a -> Set a
S.singleton LmadRef
new_lmad
  where
    replaceIteratorWith :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
replaceIteratorWith TPrimExp Int64 VName
se = PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarTab -> PrimExp VName -> PrimExp VName
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp (VName -> PrimExp VName -> ScalarTab
forall k a. k -> a -> Map k a
M.singleton VName
iterator_var (PrimExp VName -> ScalarTab) -> PrimExp VName -> ScalarTab
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
se) (PrimExp VName -> PrimExp VName)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Takes a 'VName' and converts it into a 'TPrimExp' with type 'Int64'.
typedLeafExp :: VName -> TPrimExp Int64 VName
typedLeafExp :: VName -> TPrimExp Int64 VName
typedLeafExp VName
vname = PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
vname (IntType -> PrimType
IntType IntType
Int64)