{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

-- | The bulk of the short-circuiting implementation.
module Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing
  ( mkCoalsTab,
    CoalsTab,
    mkCoalsTabGPU,
    mkCoalsTabMC,
  )
where

import Control.Exception.Base qualified as Exc
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.List qualified as L
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence (Seq (..))
import Data.Set qualified as S
import Futhark.Analysis.LastUse
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SeqMem
import Futhark.MonadFreshNames
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg
import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis
import Futhark.Util

-- | A helper type describing representations that can be short-circuited.
type Coalesceable rep inner =
  ( Mem rep inner,
    ASTRep rep,
    CanBeAliased inner,
    AliasableRep rep,
    Op rep ~ MemOp inner rep,
    HasMemBlock (Aliases rep),
    LetDec rep ~ LetDecMem,
    TopDownHelper (inner (Aliases rep))
  )

type ComputeScalarTable rep op =
  ScopeTab rep -> op -> ScalarTableM rep (M.Map VName (PrimExp VName))

-- Helper type for computing scalar tables on ops.
newtype ComputeScalarTableOnOp rep = ComputeScalarTableOnOp
  { forall rep.
ComputeScalarTableOnOp rep
-> ComputeScalarTable rep (Op (Aliases rep))
scalarTableOnOp :: ComputeScalarTable rep (Op (Aliases rep))
  }

type ScalarTableM rep a = Reader (ComputeScalarTableOnOp rep) a

data ShortCircuitReader rep = ShortCircuitReader
  { forall rep.
ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
onOp ::
      LUTabFun ->
      Pat (VarAliases, LetDecMem) ->
      Certs ->
      Op (Aliases rep) ->
      TopdownEnv rep ->
      BotUpEnv ->
      ShortCircuitM rep BotUpEnv,
    forall rep.
ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> Maybe [SSPointInfo]
ssPointFromOp ::
      LUTabFun ->
      TopdownEnv rep ->
      ScopeTab rep ->
      Pat (VarAliases, LetDecMem) ->
      Certs ->
      Op (Aliases rep) ->
      Maybe [SSPointInfo]
  }

newtype ShortCircuitM rep a = ShortCircuitM (ReaderT (ShortCircuitReader rep) (State VNameSource) a)
  deriving ((forall a b.
 (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b)
-> (forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a)
-> Functor (ShortCircuitM rep)
forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall a b. (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
forall rep a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall rep a b.
(a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep a b.
(a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
fmap :: forall a b. (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
$c<$ :: forall rep a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
<$ :: forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
Functor, Functor (ShortCircuitM rep)
Functor (ShortCircuitM rep)
-> (forall a. a -> ShortCircuitM rep a)
-> (forall a b.
    ShortCircuitM rep (a -> b)
    -> ShortCircuitM rep a -> ShortCircuitM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> ShortCircuitM rep a
    -> ShortCircuitM rep b
    -> ShortCircuitM rep c)
-> (forall a b.
    ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b)
-> (forall a b.
    ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a)
-> Applicative (ShortCircuitM rep)
forall rep. Functor (ShortCircuitM rep)
forall a. a -> ShortCircuitM rep a
forall rep a. a -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
forall a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
forall rep a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep a. a -> ShortCircuitM rep a
pure :: forall a. a -> ShortCircuitM rep a
$c<*> :: forall rep a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
<*> :: forall a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
liftA2 :: forall a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
$c*> :: forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
*> :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
$c<* :: forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
<* :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
Applicative, Applicative (ShortCircuitM rep)
Applicative (ShortCircuitM rep)
-> (forall a b.
    ShortCircuitM rep a
    -> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b)
-> (forall a b.
    ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b)
-> (forall a. a -> ShortCircuitM rep a)
-> Monad (ShortCircuitM rep)
forall rep. Applicative (ShortCircuitM rep)
forall a. a -> ShortCircuitM rep a
forall rep a. a -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
>>= :: forall a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
$c>> :: forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
>> :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
$creturn :: forall rep a. a -> ShortCircuitM rep a
return :: forall a. a -> ShortCircuitM rep a
Monad, MonadReader (ShortCircuitReader rep), MonadState VNameSource)

instance MonadFreshNames (ShortCircuitM rep) where
  putNameSource :: VNameSource -> ShortCircuitM rep ()
putNameSource = VNameSource -> ShortCircuitM rep ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
  getNameSource :: ShortCircuitM rep VNameSource
getNameSource = ShortCircuitM rep VNameSource
forall s (m :: * -> *). MonadState s m => m s
get

emptyTopdownEnv :: TopdownEnv rep
emptyTopdownEnv :: forall rep. TopdownEnv rep
emptyTopdownEnv =
  TopdownEnv
    { alloc :: AllocTab
alloc = AllocTab
forall a. Monoid a => a
mempty,
      scope :: ScopeTab rep
scope = ScopeTab rep
forall a. Monoid a => a
mempty,
      inhibited :: InhibitTab
inhibited = InhibitTab
forall a. Monoid a => a
mempty,
      v_alias :: VarAliasTab
v_alias = VarAliasTab
forall a. Monoid a => a
mempty,
      m_alias :: InhibitTab
m_alias = InhibitTab
forall a. Monoid a => a
mempty,
      nonNegatives :: Names
nonNegatives = Names
forall a. Monoid a => a
mempty,
      scalarTable :: Map VName (PrimExp VName)
scalarTable = Map VName (PrimExp VName)
forall a. Monoid a => a
mempty,
      knownLessThan :: [(VName, PrimExp VName)]
knownLessThan = [(VName, PrimExp VName)]
forall a. Monoid a => a
mempty,
      td_asserts :: [SubExp]
td_asserts = [SubExp]
forall a. Monoid a => a
mempty
    }

emptyBotUpEnv :: BotUpEnv
emptyBotUpEnv :: BotUpEnv
emptyBotUpEnv =
  BotUpEnv
    { scals :: Map VName (PrimExp VName)
scals = Map VName (PrimExp VName)
forall a. Monoid a => a
mempty,
      activeCoals :: CoalsTab
activeCoals = CoalsTab
forall a. Monoid a => a
mempty,
      successCoals :: CoalsTab
successCoals = CoalsTab
forall a. Monoid a => a
mempty,
      inhibit :: InhibitTab
inhibit = InhibitTab
forall a. Monoid a => a
mempty
    }

--------------------------------------------------------------------------------
--- Main Coalescing Transformation computes a successful coalescing table    ---
--------------------------------------------------------------------------------

-- | Given a 'Prog' in 'SegMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTab :: (MonadFreshNames m) => Prog (Aliases SeqMem) -> m (M.Map Name CoalsTab)
mkCoalsTab :: forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases SeqMem) -> m (Map Name CoalsTab)
mkCoalsTab Prog (Aliases SeqMem)
prog =
  LUTabProg
-> ShortCircuitReader SeqMem
-> ComputeScalarTableOnOp SeqMem
-> Prog (Aliases SeqMem)
-> m (Map Name CoalsTab)
forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg
    (Prog (Aliases SeqMem) -> LUTabProg
lastUseSeqMem Prog (Aliases SeqMem)
prog)
    ((InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases SeqMem)
 -> TopdownEnv SeqMem
 -> BotUpEnv
 -> ShortCircuitM SeqMem BotUpEnv)
-> (InhibitTab
    -> TopdownEnv SeqMem
    -> ScopeTab SeqMem
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases SeqMem)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader SeqMem
forall rep.
(InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> (InhibitTab
    -> TopdownEnv rep
    -> ScopeTab rep
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases rep)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> TopdownEnv SeqMem
-> BotUpEnv
-> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem InhibitTab
-> TopdownEnv SeqMem
-> ScopeTab SeqMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> Maybe [SSPointInfo]
genSSPointInfoSeqMem)
    (ComputeScalarTable SeqMem (Op (Aliases SeqMem))
-> ComputeScalarTableOnOp SeqMem
forall rep.
ComputeScalarTable rep (Op (Aliases rep))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp (ComputeScalarTable SeqMem (Op (Aliases SeqMem))
 -> ComputeScalarTableOnOp SeqMem)
-> ComputeScalarTable SeqMem (Op (Aliases SeqMem))
-> ComputeScalarTableOnOp SeqMem
forall a b. (a -> b) -> a -> b
$ (Op (Aliases SeqMem)
 -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
-> ComputeScalarTable SeqMem (Op (Aliases SeqMem))
forall a b. a -> b -> a
const ((Op (Aliases SeqMem)
  -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
 -> ComputeScalarTable SeqMem (Op (Aliases SeqMem)))
-> (Op (Aliases SeqMem)
    -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
-> ComputeScalarTable SeqMem (Op (Aliases SeqMem))
forall a b. (a -> b) -> a -> b
$ ScalarTableM SeqMem (Map VName (PrimExp VName))
-> Op (Aliases SeqMem)
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
forall a b. a -> b -> a
const (ScalarTableM SeqMem (Map VName (PrimExp VName))
 -> Op (Aliases SeqMem)
 -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
-> Op (Aliases SeqMem)
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ Map VName (PrimExp VName)
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp SeqMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty)
    Prog (Aliases SeqMem)
prog

-- | Given a 'Prog' in 'GPUMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTabGPU :: (MonadFreshNames m) => Prog (Aliases GPUMem) -> m (M.Map Name CoalsTab)
mkCoalsTabGPU :: forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases GPUMem) -> m (Map Name CoalsTab)
mkCoalsTabGPU Prog (Aliases GPUMem)
prog =
  LUTabProg
-> ShortCircuitReader GPUMem
-> ComputeScalarTableOnOp GPUMem
-> Prog (Aliases GPUMem)
-> m (Map Name CoalsTab)
forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg
    (Prog (Aliases GPUMem) -> LUTabProg
lastUseGPUMem Prog (Aliases GPUMem)
prog)
    ((InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases GPUMem)
 -> TopdownEnv GPUMem
 -> BotUpEnv
 -> ShortCircuitM GPUMem BotUpEnv)
-> (InhibitTab
    -> TopdownEnv GPUMem
    -> ScopeTab GPUMem
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases GPUMem)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader GPUMem
forall rep.
(InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> (InhibitTab
    -> TopdownEnv rep
    -> ScopeTab rep
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases rep)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem InhibitTab
-> TopdownEnv GPUMem
-> ScopeTab GPUMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> Maybe [SSPointInfo]
genSSPointInfoGPUMem)
    (ComputeScalarTable GPUMem (Op (Aliases GPUMem))
-> ComputeScalarTableOnOp GPUMem
forall rep.
ComputeScalarTable rep (Op (Aliases rep))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp (ComputeScalarTable GPUMem (HostOp NoOp (Aliases GPUMem))
-> ComputeScalarTable GPUMem (MemOp (HostOp NoOp) (Aliases GPUMem))
forall rep (inner :: * -> *).
ComputeScalarTable rep (inner (Aliases rep))
-> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp ComputeScalarTable GPUMem (HostOp NoOp (Aliases GPUMem))
computeScalarTableGPUMem))
    Prog (Aliases GPUMem)
prog

-- | Given a 'Prog' in 'MCMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTabMC :: (MonadFreshNames m) => Prog (Aliases MCMem) -> m (M.Map Name CoalsTab)
mkCoalsTabMC :: forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases MCMem) -> m (Map Name CoalsTab)
mkCoalsTabMC Prog (Aliases MCMem)
prog =
  LUTabProg
-> ShortCircuitReader MCMem
-> ComputeScalarTableOnOp MCMem
-> Prog (Aliases MCMem)
-> m (Map Name CoalsTab)
forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg
    (Prog (Aliases MCMem) -> LUTabProg
lastUseMCMem Prog (Aliases MCMem)
prog)
    ((InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases MCMem)
 -> TopdownEnv MCMem
 -> BotUpEnv
 -> ShortCircuitM MCMem BotUpEnv)
-> (InhibitTab
    -> TopdownEnv MCMem
    -> ScopeTab MCMem
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases MCMem)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader MCMem
forall rep.
(InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> (InhibitTab
    -> TopdownEnv rep
    -> ScopeTab rep
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases rep)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
shortCircuitMCMem InhibitTab
-> TopdownEnv MCMem
-> ScopeTab MCMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> Maybe [SSPointInfo]
genSSPointInfoMCMem)
    (ComputeScalarTable MCMem (Op (Aliases MCMem))
-> ComputeScalarTableOnOp MCMem
forall rep.
ComputeScalarTable rep (Op (Aliases rep))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp (ComputeScalarTable MCMem (MCOp NoOp (Aliases MCMem))
-> ComputeScalarTable MCMem (MemOp (MCOp NoOp) (Aliases MCMem))
forall rep (inner :: * -> *).
ComputeScalarTable rep (inner (Aliases rep))
-> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp ComputeScalarTable MCMem (MCOp NoOp (Aliases MCMem))
computeScalarTableMCMem))
    Prog (Aliases MCMem)
prog

-- | Given a function, compute the coalescing table
mkCoalsTabProg ::
  (MonadFreshNames m, Coalesceable rep inner) =>
  LUTabProg ->
  ShortCircuitReader rep ->
  ComputeScalarTableOnOp rep ->
  Prog (Aliases rep) ->
  m (M.Map Name CoalsTab)
mkCoalsTabProg :: forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg (InhibitTab
_, Map Name InhibitTab
lutab_prog) ShortCircuitReader rep
r ComputeScalarTableOnOp rep
computeScalarOnOp =
  ([(Name, CoalsTab)] -> Map Name CoalsTab)
-> m [(Name, CoalsTab)] -> m (Map Name CoalsTab)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Name, CoalsTab)] -> Map Name CoalsTab
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (m [(Name, CoalsTab)] -> m (Map Name CoalsTab))
-> (Prog (Aliases rep) -> m [(Name, CoalsTab)])
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FunDef (Aliases rep) -> m (Name, CoalsTab))
-> [FunDef (Aliases rep)] -> m [(Name, CoalsTab)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FunDef (Aliases rep) -> m (Name, CoalsTab)
onFun ([FunDef (Aliases rep)] -> m [(Name, CoalsTab)])
-> (Prog (Aliases rep) -> [FunDef (Aliases rep)])
-> Prog (Aliases rep)
-> m [(Name, CoalsTab)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog (Aliases rep) -> [FunDef (Aliases rep)]
forall rep. Prog rep -> [FunDef rep]
progFuns
  where
    onFun :: FunDef (Aliases rep) -> m (Name, CoalsTab)
onFun fun :: FunDef (Aliases rep)
fun@(FunDef Maybe EntryPoint
_ Attrs
_ Name
fname [(RetType (Aliases rep), RetAls)]
_ [FParam (Aliases rep)]
fpars Body (Aliases rep)
body) = do
      -- First compute last-use information
      let unique_mems :: AllocTab
unique_mems = [Param FParamMem] -> AllocTab
getUniqueMemFParam [FParam (Aliases rep)]
[Param FParamMem]
fpars
          lutab :: InhibitTab
lutab = Map Name InhibitTab
lutab_prog Map Name InhibitTab -> Name -> InhibitTab
forall k a. Ord k => Map k a -> k -> a
M.! Name
fname
          scalar_table :: Map VName (PrimExp VName)
scalar_table =
            Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
-> ComputeScalarTableOnOp rep -> Map VName (PrimExp VName)
forall r a. Reader r a -> r -> a
runReader
              ( (Stm (Aliases rep)
 -> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
                  (ScopeTab rep
-> Stm (Aliases rep)
-> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef (Aliases rep)
fun ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body))
                  (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
              )
              ComputeScalarTableOnOp rep
computeScalarOnOp
          topenv :: TopdownEnv rep
topenv =
            TopdownEnv Any
forall rep. TopdownEnv rep
emptyTopdownEnv
              { scope :: ScopeTab rep
scope = [Param FParamMem] -> ScopeTab rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [FParam (Aliases rep)]
[Param FParamMem]
fpars,
                alloc :: AllocTab
alloc = AllocTab
unique_mems,
                scalarTable :: Map VName (PrimExp VName)
scalarTable = Map VName (PrimExp VName)
scalar_table,
                nonNegatives :: Names
nonNegatives = (Param FParamMem -> Names) -> [Param FParamMem] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Param FParamMem -> Names
paramSizes [FParam (Aliases rep)]
[Param FParamMem]
fpars
              }
          ShortCircuitM ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
m = InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [FParam (Aliases rep)]
[Param FParamMem]
fpars Body (Aliases rep)
body TopdownEnv rep
topenv
      (Name
fname,) (CoalsTab -> (Name, CoalsTab)) -> m CoalsTab -> m (Name, CoalsTab)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VNameSource -> (CoalsTab, VNameSource)) -> m CoalsTab
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource (State VNameSource CoalsTab
-> VNameSource -> (CoalsTab, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
-> ShortCircuitReader rep -> State VNameSource CoalsTab
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
m ShortCircuitReader rep
r))

paramSizes :: Param FParamMem -> Names
paramSizes :: Param FParamMem -> Names
paramSizes (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
shp Uniqueness
_ MemBind
_)) = ShapeBase SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn ShapeBase SubExp
shp
paramSizes Param FParamMem
_ = Names
forall a. Monoid a => a
mempty

-- | Short-circuit handler for a 'SeqMem' 'Op'.
--
-- Because 'SeqMem' don't have any special operation, simply return the input
-- 'BotUpEnv'.
shortCircuitSeqMem :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases SeqMem) -> TopdownEnv SeqMem -> BotUpEnv -> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem :: InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> TopdownEnv SeqMem
-> BotUpEnv
-> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ Op (Aliases SeqMem)
_ TopdownEnv SeqMem
_ = BotUpEnv -> ShortCircuitM SeqMem BotUpEnv
forall a. a -> ShortCircuitM SeqMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | Short-circuit handler for SegOp.
shortCircuitSegOp ::
  (Coalesceable rep inner) =>
  (lvl -> Bool) ->
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  SegOp lvl (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
shortCircuitSegOp :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegMap lvl
lvl SegSpace
space [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env =
  -- No special handling necessary for 'SegMap'. Just call the helper-function.
  Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
0 lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
binops [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env =
  -- When handling 'SegRed', we we first invalidate all active coalesce-entries
  -- where any of the variables in 'vartab' are also free in the list of
  -- 'SegBinOp'. In other words, anything that is used as part of the reduction
  -- step should probably not be coalesced.
  let to_fail :: CoalsTab
to_fail = (CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Coalesced -> [VName]) -> Map VName Coalesced -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` (SegBinOp (Aliases rep) -> Names)
-> [SegBinOp (Aliases rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Lambda (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda (Aliases rep) -> Names)
-> (SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> SegBinOp (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp (Aliases rep)]
binops) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) =
        ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
active, inhibit :: InhibitTab
inhibit = InhibitTab
inh}
      num_reds :: Int
num_reds = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts
   in Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
num_reds lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env'
  where
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp (Aliases rep)
op <- [SegBinOp (Aliases rep)]
binops
      let shp :: ShapeBase SubExp
shp = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp (Aliases rep)
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shp) (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda (Aliases rep) -> [Type]) -> Lambda (Aliases rep) -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp (Aliases rep)
op)
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
binops [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env =
  -- Like in the handling of 'SegRed', we do not want to coalesce anything that
  -- is used in the 'SegBinOp'
  let to_fail :: CoalsTab
to_fail = (CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Coalesced -> [VName]) -> Map VName Coalesced -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` (SegBinOp (Aliases rep) -> Names)
-> [SegBinOp (Aliases rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Lambda (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda (Aliases rep) -> Names)
-> (SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> SegBinOp (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp (Aliases rep)]
binops) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
active, inhibit :: InhibitTab
inhibit = InhibitTab
inh}
   in Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
0 lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env'
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
histops [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- Need to take zipped patterns and histDest (flattened) and insert transitive coalesces
  let to_fail :: CoalsTab
to_fail = (CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Coalesced -> [VName]) -> Map VName Coalesced -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` (HistOp (Aliases rep) -> Names) -> [HistOp (Aliases rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Lambda (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda (Aliases rep) -> Names)
-> (HistOp (Aliases rep) -> Lambda (Aliases rep))
-> HistOp (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp (Aliases rep)]
histops) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
active, inhibit :: InhibitTab
inhibit = InhibitTab
inh}
  BotUpEnv
bu_env'' <- Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
0 lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env'
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$
    (BotUpEnv -> (PatElem (VarAliases, LParamMem), VName) -> BotUpEnv)
-> BotUpEnv
-> [(PatElem (VarAliases, LParamMem), VName)]
-> BotUpEnv
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl BotUpEnv -> (PatElem (VarAliases, LParamMem), VName) -> BotUpEnv
insertHistCoals BotUpEnv
bu_env'' ([(PatElem (VarAliases, LParamMem), VName)] -> BotUpEnv)
-> [(PatElem (VarAliases, LParamMem), VName)] -> BotUpEnv
forall a b. (a -> b) -> a -> b
$
      [PatElem (VarAliases, LParamMem)]
-> [VName] -> [(PatElem (VarAliases, LParamMem), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
pat) ([VName] -> [(PatElem (VarAliases, LParamMem), VName)])
-> [VName] -> [(PatElem (VarAliases, LParamMem), VName)]
forall a b. (a -> b) -> a -> b
$
        (HistOp (Aliases rep) -> [VName])
-> [HistOp (Aliases rep)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Aliases rep) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Aliases rep)]
histops
  where
    insertHistCoals :: BotUpEnv -> (PatElem (VarAliases, LParamMem), VName) -> BotUpEnv
insertHistCoals BotUpEnv
acc (PatElem VName
p (VarAliases, LParamMem)
_, VName
hist_dest) =
      case ( VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
p (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env,
             VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
hist_dest (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
           ) of
        (Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
p_mem IxFun
_), Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
dest_mem IxFun
_)) ->
          case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
p_mem (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
successCoals BotUpEnv
acc of
            Just CoalsEntry
entry ->
              -- Update this entry with an optdep for the memory block of hist_dest
              let entry' :: CoalsEntry
entry' = CoalsEntry
entry {optdeps :: Map VName VName
optdeps = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p VName
p_mem (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
entry}
               in BotUpEnv
acc
                    { successCoals :: CoalsTab
successCoals = VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
p_mem CoalsEntry
entry' (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
successCoals BotUpEnv
acc,
                      activeCoals :: CoalsTab
activeCoals = VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
dest_mem CoalsEntry
entry (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
acc
                    }
            Maybe CoalsEntry
Nothing -> BotUpEnv
acc
        (Maybe ArrayMemBound, Maybe ArrayMemBound)
_ -> BotUpEnv
acc

-- | Short-circuit handler for 'GPUMem' 'Op'.
--
-- When the 'Op' is a 'SegOp', we handle it accordingly, otherwise we do
-- nothing.
shortCircuitGPUMem ::
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  Op (Aliases GPUMem) ->
  TopdownEnv GPUMem ->
  BotUpEnv ->
  ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem :: InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Alloc SubExp
_ Space
_) TopdownEnv GPUMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM GPUMem BotUpEnv
forall a. a -> ShortCircuitM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (GPU.SegOp SegOp SegLevel (Aliases GPUMem)
op)) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  (SegLevel -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp SegLevel (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp SegLevel -> Bool
isSegThread InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp SegLevel (Aliases GPUMem)
op TopdownEnv GPUMem
td_env BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (GPU.GPUBody [Type]
_ Body (Aliases GPUMem)
body)) TopdownEnv GPUMem
td_env BotUpEnv
bu_env = do
  VName
fresh1 <- String -> ShortCircuitM GPUMem VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"gpubody"
  VName
fresh2 <- String -> ShortCircuitM GPUMem VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"gpubody"
  Int
-> (SegLevel -> Bool)
-> SegLevel
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper
    Int
0
    SegLevel -> Bool
isSegThread
    -- Construct a 'SegLevel' corresponding to a single thread
    ( SegVirt -> Maybe KernelGrid -> SegLevel
GPU.SegThread SegVirt
GPU.SegNoVirt (Maybe KernelGrid -> SegLevel) -> Maybe KernelGrid -> SegLevel
forall a b. (a -> b) -> a -> b
$
        KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just (KernelGrid -> Maybe KernelGrid) -> KernelGrid -> Maybe KernelGrid
forall a b. (a -> b) -> a -> b
$
          Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
GPU.KernelGrid
            (SubExp -> Count NumGroups SubExp
forall {k} (u :: k) e. e -> Count u e
GPU.Count (SubExp -> Count NumGroups SubExp)
-> SubExp -> Count NumGroups SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)
            (SubExp -> Count GroupSize SubExp
forall {k} (u :: k) e. e -> Count u e
GPU.Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)
    )
    InhibitTab
lutab
    Pat (VarAliases, LParamMem)
pat
    Certs
certs
    (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
fresh1 [(VName
fresh2, PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)])
    (Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody Body (Aliases GPUMem)
body)
    TopdownEnv GPUMem
td_env
    BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Inner (GPU.SizeOp SizeOp
_)) TopdownEnv GPUMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM GPUMem BotUpEnv
forall a. a -> ShortCircuitM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Inner (GPU.OtherOp NoOp (Aliases GPUMem)
NoOp)) TopdownEnv GPUMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM GPUMem BotUpEnv
forall a. a -> ShortCircuitM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env

shortCircuitMCMem ::
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  Op (Aliases MCMem) ->
  TopdownEnv MCMem ->
  BotUpEnv ->
  ShortCircuitM MCMem BotUpEnv
shortCircuitMCMem :: InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
shortCircuitMCMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Alloc SubExp
_ Space
_) TopdownEnv MCMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM MCMem BotUpEnv
forall a. a -> ShortCircuitM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitMCMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Inner (MC.OtherOp NoOp (Aliases MCMem)
NoOp)) TopdownEnv MCMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM MCMem BotUpEnv
forall a. a -> ShortCircuitM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitMCMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (MC.ParOp (Just SegOp () (Aliases MCMem)
par_op) SegOp () (Aliases MCMem)
op)) TopdownEnv MCMem
td_env BotUpEnv
bu_env =
  (() -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp () (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases MCMem)
par_op TopdownEnv MCMem
td_env BotUpEnv
bu_env
    ShortCircuitM MCMem BotUpEnv
-> (BotUpEnv -> ShortCircuitM MCMem BotUpEnv)
-> ShortCircuitM MCMem BotUpEnv
forall a b.
ShortCircuitM MCMem a
-> (a -> ShortCircuitM MCMem b) -> ShortCircuitM MCMem b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (() -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp () (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases MCMem)
op TopdownEnv MCMem
td_env
shortCircuitMCMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (MC.ParOp Maybe (SegOp () (Aliases MCMem))
Nothing SegOp () (Aliases MCMem)
op)) TopdownEnv MCMem
td_env BotUpEnv
bu_env =
  (() -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp () (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases MCMem)
op TopdownEnv MCMem
td_env BotUpEnv
bu_env

dropLastSegSpace :: SegSpace -> SegSpace
dropLastSegSpace :: SegSpace -> SegSpace
dropLastSegSpace SegSpace
space = SegSpace
space {unSegSpace :: [(VName, SubExp)]
unSegSpace = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. HasCallStack => [a] -> [a]
init ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space}

isSegThread :: GPU.SegLevel -> Bool
isSegThread :: SegLevel -> Bool
isSegThread GPU.SegThread {} = Bool
True
isSegThread SegLevel
_ = Bool
False

-- | Computes the slice written at the end of a thread in a 'SegOp'.
threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice SegSpace
space Returns {} =
  Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a. a -> Maybe a
Just (Slice (TPrimExp Int64 VName)
 -> Maybe (Slice (TPrimExp Int64 VName)))
-> Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
    [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
      ((VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> ((VName, SubExp) -> TPrimExp Int64 VName)
-> (VName, SubExp)
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> PrimExp VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimType -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) (VName -> PrimExp VName)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)])
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$
        SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
threadSlice SegSpace
space (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims VName
_) =
  Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a. a -> Maybe a
Just
    (Slice (TPrimExp Int64 VName)
 -> Maybe (Slice (TPrimExp Int64 VName)))
-> Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice
    ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp, SubExp)
 -> (VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(SubExp, SubExp, SubExp)]
-> [(VName, SubExp)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      ( \(SubExp
_, SubExp
block_tile_size0, SubExp
reg_tile_size0) (VName
x0, SubExp
_) ->
          let x :: TPrimExp Int64 VName
x = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x0
              block_tile_size :: TPrimExp Int64 VName
block_tile_size = SubExp -> TPrimExp Int64 VName
pe64 SubExp
block_tile_size0
              reg_tile_size :: TPrimExp Int64 VName
reg_tile_size = SubExp -> TPrimExp Int64 VName
pe64 SubExp
reg_tile_size0
           in TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice (TPrimExp Int64 VName
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
block_tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
reg_tile_size) (TPrimExp Int64 VName
block_tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
reg_tile_size) TPrimExp Int64 VName
1
      )
      [(SubExp, SubExp, SubExp)]
dims
    ([(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)])
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
threadSlice SegSpace
_ KernelResult
_ = Maybe (Slice (TPrimExp Int64 VName))
forall a. Maybe a
Nothing

bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody (Body BodyDec (Aliases GPUMem)
dec Stms (Aliases GPUMem)
stms Result
res) =
  BodyDec (Aliases GPUMem)
-> Stms (Aliases GPUMem)
-> [KernelResult]
-> KernelBody (Aliases GPUMem)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases GPUMem)
dec Stms (Aliases GPUMem)
stms ([KernelResult] -> KernelBody (Aliases GPUMem))
-> [KernelResult] -> KernelBody (Aliases GPUMem)
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExpRes Certs
cert SubExp
subexps) -> ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultNoSimplify Certs
cert SubExp
subexps) Result
res

-- | A helper for all the different kinds of 'SegOp'.
--
-- Consists of four parts:
--
-- 1. Create coalescing relations between the pattern elements and the kernel
-- body results using 'makeSegMapCoals'.
--
-- 2. Process the statements of the 'KernelBody'.
--
-- 3. Check the overlap between the different threads.
--
-- 4. Mark active coalescings as finished, since a 'SegOp' is an array creation
-- point.
shortCircuitSegOpHelper ::
  (Coalesceable rep inner) =>
  -- | The number of returns for which we should drop the last seg space
  Int ->
  -- | Whether we should look at a segop with this lvl.
  (lvl -> Bool) ->
  lvl ->
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  SegSpace ->
  KernelBody (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
num_reds lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab pat :: Pat (VarAliases, LParamMem)
pat@(Pat [PatElem (VarAliases, LParamMem)]
ps0) Certs
pat_certs SegSpace
space0 KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- We need to drop the last element of the 'SegSpace' for pattern elements
  -- that correspond to reductions.
  let ps_space_and_res :: [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res =
        [PatElem (VarAliases, LParamMem)]
-> [SegSpace]
-> [KernelResult]
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarAliases, LParamMem)]
ps0 (Int -> SegSpace -> [SegSpace]
forall a. Int -> a -> [a]
replicate Int
num_reds (SegSpace -> SegSpace
dropLastSegSpace SegSpace
space0) [SegSpace] -> [SegSpace] -> [SegSpace]
forall a. Semigroup a => a -> a -> a
<> SegSpace -> [SegSpace]
forall a. a -> [a]
repeat SegSpace
space0) ([KernelResult]
 -> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)])
-> [KernelResult]
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
forall a b. (a -> b) -> a -> b
$
          KernelBody (Aliases rep) -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases rep)
kernel_body
  -- Create coalescing relations between pattern elements and kernel body
  -- results
  let (CoalsTab
actv0, InhibitTab
inhibit0) =
        CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          (Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
pat)
      (CoalsTab
actv_return, InhibitTab
inhibit_return) =
        if Int
num_reds Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
          then (CoalsTab
actv0, InhibitTab
inhibit0)
          else ((CoalsTab, InhibitTab)
 -> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
 -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((lvl -> Bool)
-> lvl
-> TopdownEnv rep
-> KernelBody (Aliases rep)
-> Certs
-> (CoalsTab, InhibitTab)
-> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> lvl
-> TopdownEnv rep
-> KernelBody (Aliases rep)
-> Certs
-> (CoalsTab, InhibitTab)
-> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
makeSegMapCoals lvl -> Bool
lvlOK lvl
lvl TopdownEnv rep
td_env KernelBody (Aliases rep)
kernel_body Certs
pat_certs) (CoalsTab
actv0, InhibitTab
inhibit0) [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res

  -- Start from empty references, we'll update with aggregates later.
  let actv0' :: CoalsTab
actv0' = (CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (\CoalsEntry
etry -> CoalsEntry
etry {memrefs :: MemRefs
memrefs = MemRefs
forall a. Monoid a => a
mempty}) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ CoalsTab
actv0 CoalsTab -> CoalsTab -> CoalsTab
forall a. Semigroup a => a -> a -> a
<> CoalsTab
actv_return
  -- Process kernel body statements
  BotUpEnv
bu_env' <-
    InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$
      BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv0', inhibit :: InhibitTab
inhibit = InhibitTab
inhibit_return}

  let actv_coals_after :: CoalsTab
actv_coals_after =
        (VName -> CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey
          ( \VName
k CoalsEntry
etry ->
              CoalsEntry
etry
                { memrefs :: MemRefs
memrefs = CoalsEntry -> MemRefs
memrefs CoalsEntry
etry MemRefs -> MemRefs -> MemRefs
forall a. Semigroup a => a -> a -> a
<> MemRefs -> (CoalsEntry -> MemRefs) -> Maybe CoalsEntry -> MemRefs
forall b a. b -> (a -> b) -> Maybe a -> b
maybe MemRefs
forall a. Monoid a => a
mempty CoalsEntry -> MemRefs
memrefs (VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsTab
actv0 CoalsTab -> CoalsTab -> CoalsTab
forall a. Semigroup a => a -> a -> a
<> CoalsTab
actv_return)
                }
          )
          (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env'

  -- Check partial overlap.
  let checkPartialOverlap :: BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM rep BotUpEnv
checkPartialOverlap BotUpEnv
bu_env_f (VName
k, CoalsEntry
entry) = do
        let sliceThreadAccess :: (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> AccessSummary
sliceThreadAccess (PatElem (VarAliases, LParamMem)
p, SegSpace
space, KernelResult
res) =
              case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p) (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry of
                Just (Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixf) FreeVarSubsts
_) ->
                  AccessSummary
-> (Slice (TPrimExp Int64 VName) -> AccessSummary)
-> Maybe (Slice (TPrimExp Int64 VName))
-> AccessSummary
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                    AccessSummary
Undeterminable
                    ( IxFun -> AccessSummary
ixfunToAccessSummary
                        (IxFun -> AccessSummary)
-> (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName)
-> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixf
                        (Slice (TPrimExp Int64 VName) -> IxFun)
-> (Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> IxFun
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice (IxFun -> [TPrimExp Int64 VName]
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
ixf)
                    )
                    (Maybe (Slice (TPrimExp Int64 VName)) -> AccessSummary)
-> Maybe (Slice (TPrimExp Int64 VName)) -> AccessSummary
forall a b. (a -> b) -> a -> b
$ SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice SegSpace
space KernelResult
res
                Maybe Coalesced
Nothing -> AccessSummary
forall a. Monoid a => a
mempty
            thread_writes :: AccessSummary
thread_writes = ((PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
 -> AccessSummary)
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
-> AccessSummary
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> AccessSummary
sliceThreadAccess [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res
            source_writes :: AccessSummary
source_writes = MemRefs -> AccessSummary
srcwrts (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry) AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> AccessSummary
thread_writes
        AccessSummary
destination_uses <-
          case MemRefs -> AccessSummary
dstrefs (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry)
            AccessSummary -> AccessSummary -> AccessSummary
`accessSubtract` MemRefs -> AccessSummary
dstrefs (MemRefs -> (CoalsEntry -> MemRefs) -> Maybe CoalsEntry -> MemRefs
forall b a. b -> (a -> b) -> Maybe a -> b
maybe MemRefs
forall a. Monoid a => a
mempty CoalsEntry -> MemRefs
memrefs (Maybe CoalsEntry -> MemRefs) -> Maybe CoalsEntry -> MemRefs
forall a b. (a -> b) -> a -> b
$ VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env) of
            Set Set LmadRef
s ->
              (LmadRef -> ShortCircuitM rep AccessSummary)
-> [LmadRef] -> ShortCircuitM rep AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
                (Map VName (PrimExp VName)
-> [(VName, SubExp)] -> LmadRef -> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> LmadRef -> m AccessSummary
aggSummaryMapPartial (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) ([(VName, SubExp)] -> LmadRef -> ShortCircuitM rep AccessSummary)
-> [(VName, SubExp)] -> LmadRef -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0)
                (Set LmadRef -> [LmadRef]
forall a. Set a -> [a]
S.toList Set LmadRef
s)
            AccessSummary
Undeterminable -> AccessSummary -> ShortCircuitM rep AccessSummary
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
        -- Do not allow short-circuiting from a segop-local memory
        -- block (not in the topdown scope) to an outer memory block.
        if CoalsEntry -> VName
dstmem CoalsEntry
entry VName -> Map VName (NameInfo (Aliases rep)) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` TopdownEnv rep -> Map VName (NameInfo (Aliases rep))
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
          Bool -> Bool -> Bool
&& TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
destination_uses AccessSummary
source_writes
          then BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f
          else do
            let (CoalsTab
ac, InhibitTab
inh) = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env_f) VName
k
            BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env_f {activeCoals :: CoalsTab
activeCoals = CoalsTab
ac, inhibit :: InhibitTab
inhibit = InhibitTab
inh}

  BotUpEnv
bu_env'' <-
    (BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> [(VName, CoalsEntry)] -> ShortCircuitM rep BotUpEnv
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM rep BotUpEnv
checkPartialOverlap
      (BotUpEnv
bu_env' {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv_coals_after})
      ([(VName, CoalsEntry)] -> ShortCircuitM rep BotUpEnv)
-> [(VName, CoalsEntry)] -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [(VName, CoalsEntry)]
forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
actv_coals_after

  let updateMemRefs :: CoalsEntry -> ShortCircuitM rep CoalsEntry
updateMemRefs CoalsEntry
entry = do
        AccessSummary
wrts <- Map VName (PrimExp VName)
-> [(VName, SubExp)]
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0) (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary) -> MemRefs -> AccessSummary
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> MemRefs
memrefs CoalsEntry
entry
        AccessSummary
uses <- Map VName (PrimExp VName)
-> [(VName, SubExp)]
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0) (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary) -> MemRefs -> AccessSummary
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> MemRefs
memrefs CoalsEntry
entry

        -- Add destination uses from the pattern
        let uses' :: AccessSummary
uses' =
              (PatElem (VarAliases, LParamMem) -> AccessSummary)
-> [PatElem (VarAliases, LParamMem)] -> AccessSummary
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
                ( \case
                    PatElem VName
_ (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
p_mem IxFun
p_ixf))
                      | VName
p_mem VName -> Names -> Bool
`nameIn` CoalsEntry -> Names
alsmem CoalsEntry
entry ->
                          IxFun -> AccessSummary
ixfunToAccessSummary IxFun
p_ixf
                    PatElem (VarAliases, LParamMem)
_ -> AccessSummary
forall a. Monoid a => a
mempty
                )
                [PatElem (VarAliases, LParamMem)]
ps0

        CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsEntry
entry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs (AccessSummary
uses AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> AccessSummary
uses') AccessSummary
wrts}

  CoalsTab
actv <- (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map VName a -> m (Map VName b)
mapM CoalsEntry -> ShortCircuitM rep CoalsEntry
updateMemRefs (CoalsTab -> ShortCircuitM rep CoalsTab)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env''
  let bu_env''' :: BotUpEnv
bu_env''' = BotUpEnv
bu_env'' {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv}

  -- Process pattern and return values
  let mergee_writes :: [(PatElem (VarAliases, LParamMem), (VName, VName, IxFun))]
mergee_writes =
        ((PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
 -> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, IxFun)))
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
-> [(PatElem (VarAliases, LParamMem), (VName, VName, IxFun))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
          ( \(PatElem (VarAliases, LParamMem)
p, SegSpace
_, KernelResult
_) ->
              ((VName, VName, IxFun)
 -> (PatElem (VarAliases, LParamMem), (VName, VName, IxFun)))
-> Maybe (VName, VName, IxFun)
-> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, IxFun))
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElem (VarAliases, LParamMem)
p,) (Maybe (VName, VName, IxFun)
 -> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, IxFun)))
-> Maybe (VName, VName, IxFun)
-> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, IxFun))
forall a b. (a -> b) -> a -> b
$
                TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn' TopdownEnv rep
td_env (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env''') (VName -> Maybe (VName, VName, IxFun))
-> VName -> Maybe (VName, VName, IxFun)
forall a b. (a -> b) -> a -> b
$
                  PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p
          )
          [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res

  -- Now, for each mergee write, we need to check that it doesn't overlap with any previous uses of the destination.
  let checkMergeeOverlap :: BotUpEnv
-> (PatElem (VarAliases, LParamMem), (VName, VName, IxFun))
-> ShortCircuitM rep BotUpEnv
checkMergeeOverlap BotUpEnv
bu_env_f (PatElem (VarAliases, LParamMem)
p, (VName
m_b, VName
_, IxFun
ixf)) =
        let as :: AccessSummary
as = IxFun -> AccessSummary
ixfunToAccessSummary IxFun
ixf
         in -- Should be @bu_env@ here, because we need to check overlap
            -- against previous uses.
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env of
              Just CoalsEntry
coal_entry -> do
                let mrefs :: MemRefs
mrefs =
                      CoalsEntry -> MemRefs
memrefs CoalsEntry
coal_entry
                    res :: Bool
res = TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
as (AccessSummary -> Bool) -> AccessSummary -> Bool
forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
dstrefs MemRefs
mrefs
                    fail_res :: BotUpEnv
fail_res =
                      let (CoalsTab
ac, InhibitTab
inh) = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env_f) VName
m_b
                       in BotUpEnv
bu_env_f {activeCoals :: CoalsTab
activeCoals = CoalsTab
ac, inhibit :: InhibitTab
inhibit = InhibitTab
inh}

                if Bool
res
                  then case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p) (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_entry of
                    Maybe Coalesced
Nothing -> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f
                    Just (Coalesced CoalescedKind
knd mbd :: ArrayMemBound
mbd@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixfn) FreeVarSubsts
_) -> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$
                      case Map VName (NameInfo (Aliases rep))
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> Map VName (NameInfo (Aliases rep))
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) IxFun
ixfn of
                        Just FreeVarSubsts
fv_subst ->
                          let entry :: CoalsEntry
entry =
                                CoalsEntry
coal_entry
                                  { vartab :: Map VName Coalesced
vartab =
                                      VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                                        (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p)
                                        (CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd ArrayMemBound
mbd FreeVarSubsts
fv_subst)
                                        (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_entry)
                                  }
                              (CoalsTab
ac, CoalsTab
suc) =
                                (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env_f) VName
m_b CoalsEntry
entry
                           in BotUpEnv
bu_env_f {activeCoals :: CoalsTab
activeCoals = CoalsTab
ac, successCoals :: CoalsTab
successCoals = CoalsTab
suc}
                        Maybe FreeVarSubsts
Nothing ->
                          BotUpEnv
fail_res
                  else BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
fail_res
              Maybe CoalsEntry
_ -> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f

  (BotUpEnv
 -> (PatElem (VarAliases, LParamMem), (VName, VName, IxFun))
 -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv
-> [(PatElem (VarAliases, LParamMem), (VName, VName, IxFun))]
-> ShortCircuitM rep BotUpEnv
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM BotUpEnv
-> (PatElem (VarAliases, LParamMem), (VName, VName, IxFun))
-> ShortCircuitM rep BotUpEnv
checkMergeeOverlap BotUpEnv
bu_env''' [(PatElem (VarAliases, LParamMem), (VName, VName, IxFun))]
mergee_writes

-- | Given a pattern element and the corresponding kernel result, try to put the
-- kernel result directly in the memory block of pattern element
makeSegMapCoals ::
  (Coalesceable rep inner) =>
  (lvl -> Bool) ->
  lvl ->
  TopdownEnv rep ->
  KernelBody (Aliases rep) ->
  Certs ->
  (CoalsTab, InhibitTab) ->
  (PatElem (VarAliases, LetDecMem), SegSpace, KernelResult) ->
  (CoalsTab, InhibitTab)
makeSegMapCoals :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> lvl
-> TopdownEnv rep
-> KernelBody (Aliases rep)
-> Certs
-> (CoalsTab, InhibitTab)
-> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
makeSegMapCoals lvl -> Bool
lvlOK lvl
lvl TopdownEnv rep
td_env KernelBody (Aliases rep)
kernel_body Certs
pat_certs (CoalsTab
active, InhibitTab
inhb) (PatElem VName
pat_name (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
pat_mem IxFun
pat_ixf)), SegSpace
space, Returns ResultManifest
_ Certs
_ (Var VName
return_name))
  | Just (MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
return_mem IxFun
_) <-
      VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
return_name (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body),
    lvl -> Bool
lvlOK lvl
lvl,
    MemMem Space
pat_space <- Reader (Scope rep) LParamMem -> Scope rep -> LParamMem
forall r a. Reader r a -> r -> a
runReader (VName -> Reader (Scope rep) LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
pat_mem) (Scope rep -> LParamMem) -> Scope rep -> LParamMem
forall a b. (a -> b) -> a -> b
$ Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases (Scope (Aliases rep) -> Scope rep)
-> Scope (Aliases rep) -> Scope rep
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env,
    MemMem Space
return_space <-
      TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body) Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
        Scope (Aliases rep)
-> (Scope (Aliases rep) -> Scope rep) -> Scope rep
forall a b. a -> (a -> b) -> b
& Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases
        Scope rep -> (Scope rep -> LParamMem) -> LParamMem
forall a b. a -> (a -> b) -> b
& Reader (Scope rep) LParamMem -> Scope rep -> LParamMem
forall r a. Reader r a -> r -> a
runReader (VName -> Reader (Scope rep) LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
return_mem),
    Space
pat_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
return_space =
      case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_mem CoalsTab
active of
        Maybe CoalsEntry
Nothing ->
          -- We are not in a transitive case
          case ( Bool -> (Names -> Bool) -> Maybe Names -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (VName
pat_mem `nameIn`) (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
return_mem InhibitTab
inhb),
                 CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced
                   CoalescedKind
InPlaceCoal
                   (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
pat_mem (IxFun -> ArrayMemBound) -> IxFun -> ArrayMemBound
forall a b. (a -> b) -> a -> b
$ IxFun -> IxFun
resultSlice IxFun
pat_ixf)
                   FreeVarSubsts
forall a. Monoid a => a
mempty
                   Coalesced
-> (Coalesced -> Map VName Coalesced) -> Map VName Coalesced
forall a b. a -> (a -> b) -> b
& VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
return_name
                   Map VName Coalesced
-> (Map VName Coalesced -> Maybe (Map VName Coalesced))
-> Maybe (Map VName Coalesced)
forall a b. a -> (a -> b) -> b
& (Map VName Coalesced -> VName -> Maybe (Map VName Coalesced))
-> VName -> Map VName Coalesced -> Maybe (Map VName Coalesced)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env) VName
return_name
               ) of
            (Bool
False, Just Map VName Coalesced
vtab) ->
              ( CoalsTab
active
                  CoalsTab -> CoalsTab -> CoalsTab
forall a. Semigroup a => a -> a -> a
<> VName -> CoalsEntry -> CoalsTab
forall k a. k -> a -> Map k a
M.singleton
                    VName
return_mem
                    (VName
-> IxFun
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> Certs
-> CoalsEntry
CoalsEntry VName
pat_mem IxFun
pat_ixf (VName -> Names
oneName VName
pat_mem) Map VName Coalesced
vtab Map VName VName
forall a. Monoid a => a
mempty MemRefs
forall a. Monoid a => a
mempty Certs
pat_certs),
                InhibitTab
inhb
              )
            (Bool, Maybe (Map VName Coalesced))
_ -> (CoalsTab
active, InhibitTab
inhb)
        Just CoalsEntry
trans ->
          case ( Bool -> (Names -> Bool) -> Maybe Names -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (CoalsEntry -> VName
dstmem CoalsEntry
trans `nameIn`) (Maybe Names -> Bool) -> Maybe Names -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
return_mem InhibitTab
inhb,
                 let Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
trans_mem IxFun
trans_ixf) FreeVarSubsts
_ =
                       Coalesced -> Maybe Coalesced -> Coalesced
forall a. a -> Maybe a -> a
fromMaybe (String -> Coalesced
forall a. HasCallStack => String -> a
error String
"Impossible") (Maybe Coalesced -> Coalesced) -> Maybe Coalesced -> Coalesced
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_name (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
trans
                  in CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced
                       CoalescedKind
TransitiveCoal
                       (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
trans_mem (IxFun -> ArrayMemBound) -> IxFun -> ArrayMemBound
forall a b. (a -> b) -> a -> b
$ IxFun -> IxFun
resultSlice IxFun
trans_ixf)
                       FreeVarSubsts
forall a. Monoid a => a
mempty
                       Coalesced
-> (Coalesced -> Map VName Coalesced) -> Map VName Coalesced
forall a b. a -> (a -> b) -> b
& VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
return_name
                       Map VName Coalesced
-> (Map VName Coalesced -> Maybe (Map VName Coalesced))
-> Maybe (Map VName Coalesced)
forall a b. a -> (a -> b) -> b
& (Map VName Coalesced -> VName -> Maybe (Map VName Coalesced))
-> VName -> Map VName Coalesced -> Maybe (Map VName Coalesced)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env) VName
return_name
               ) of
            (Bool
False, Just Map VName Coalesced
vtab) ->
              let opts :: Map VName VName
opts =
                    if CoalsEntry -> VName
dstmem CoalsEntry
trans VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
pat_mem
                      then Map VName VName
forall a. Monoid a => a
mempty
                      else VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
pat_name VName
pat_mem (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
trans
               in ( VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                      VName
return_mem
                      ( VName
-> IxFun
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> Certs
-> CoalsEntry
CoalsEntry
                          (CoalsEntry -> VName
dstmem CoalsEntry
trans)
                          (CoalsEntry -> IxFun
dstind CoalsEntry
trans)
                          (VName -> Names
oneName VName
pat_mem Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> Names
alsmem CoalsEntry
trans)
                          Map VName Coalesced
vtab
                          Map VName VName
opts
                          MemRefs
forall a. Monoid a => a
mempty
                          (CoalsEntry -> Certs
certs CoalsEntry
trans Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
pat_certs)
                      )
                      CoalsTab
active,
                    InhibitTab
inhb
                  )
            (Bool, Maybe (Map VName Coalesced))
_ -> (CoalsTab
active, InhibitTab
inhb)
  where
    thread_slice :: Slice (TPrimExp Int64 VName)
thread_slice =
      SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        [(VName, SubExp)]
-> ([(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)])
-> [DimIndex (TPrimExp Int64 VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> ((VName, SubExp) -> TPrimExp Int64 VName)
-> (VName, SubExp)
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> PrimExp VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimType -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) (VName -> PrimExp VName)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst)
        [DimIndex (TPrimExp Int64 VName)]
-> ([DimIndex (TPrimExp Int64 VName)]
    -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
forall a b. a -> (a -> b) -> b
& [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice
    resultSlice :: IxFun -> IxFun
resultSlice IxFun
ixf = IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixf (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice (IxFun -> [TPrimExp Int64 VName]
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
ixf) Slice (TPrimExp Int64 VName)
thread_slice
makeSegMapCoals lvl -> Bool
_ lvl
_ TopdownEnv rep
td_env KernelBody (Aliases rep)
_ Certs
_ (CoalsTab, InhibitTab)
x (PatElem (VarAliases, LParamMem)
_, SegSpace
_, WriteReturns Certs
_ VName
return_name [(Slice SubExp, SubExp)]
_) =
  case VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
return_name (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env of
    Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
return_mem IxFun
_) -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
x VName
return_mem
    Maybe ArrayMemBound
Nothing -> String -> (CoalsTab, InhibitTab)
forall a. HasCallStack => String -> a
error String
"Should not happen?"
makeSegMapCoals lvl -> Bool
_ lvl
_ TopdownEnv rep
td_env KernelBody (Aliases rep)
_ Certs
_ (CoalsTab, InhibitTab)
x (PatElem (VarAliases, LParamMem)
_, SegSpace
_, KernelResult
result) =
  KernelResult -> Names
forall a. FreeIn a => a -> Names
freeIn KernelResult
result
    Names -> (Names -> [VName]) -> [VName]
forall a b. a -> (a -> b) -> b
& Names -> [VName]
namesToList
    [VName] -> ([VName] -> [ArrayMemBound]) -> [ArrayMemBound]
forall a b. a -> (a -> b) -> b
& (VName -> Maybe ArrayMemBound) -> [VName] -> [ArrayMemBound]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((VName -> Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> VName -> Maybe ArrayMemBound
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo (Scope (Aliases rep) -> VName -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> VName -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env)
    [ArrayMemBound]
-> ([ArrayMemBound] -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
forall a b. a -> (a -> b) -> b
& (ArrayMemBound -> (CoalsTab, InhibitTab) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [ArrayMemBound]
-> (CoalsTab, InhibitTab)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> VName -> (CoalsTab, InhibitTab) -> (CoalsTab, InhibitTab)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (VName -> (CoalsTab, InhibitTab) -> (CoalsTab, InhibitTab))
-> (ArrayMemBound -> VName)
-> ArrayMemBound
-> (CoalsTab, InhibitTab)
-> (CoalsTab, InhibitTab)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayMemBound -> VName
memName) (CoalsTab, InhibitTab)
x

fullSlice :: [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice :: [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice [TPrimExp Int64 VName]
shp (Slice [DimIndex (TPrimExp Int64 VName)]
slc) =
  [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)]
slc [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (\TPrimExp Int64 VName
d -> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
d TPrimExp Int64 VName
1) (Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([DimIndex (TPrimExp Int64 VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TPrimExp Int64 VName)]
slc) [TPrimExp Int64 VName]
shp)

fixPointCoalesce ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  [Param FParamMem] ->
  Body (Aliases rep) ->
  TopdownEnv rep ->
  ShortCircuitM rep CoalsTab
fixPointCoalesce :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [Param FParamMem]
fpar Body (Aliases rep)
bdy TopdownEnv rep
topenv = do
  BotUpEnv
buenv <- InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
bdy) TopdownEnv rep
topenv (BotUpEnv
emptyBotUpEnv {inhibit :: InhibitTab
inhibit = TopdownEnv rep -> InhibitTab
forall rep. TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
topenv})
  let succ_tab :: CoalsTab
succ_tab = BotUpEnv -> CoalsTab
successCoals BotUpEnv
buenv
      actv_tab :: CoalsTab
actv_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
buenv
      inhb_tab :: InhibitTab
inhb_tab = BotUpEnv -> InhibitTab
inhibit BotUpEnv
buenv
      -- Allow short-circuiting function parameters that are unique and have
      -- matching index functions, otherwise mark as failed
      handleFunctionParams :: (CoalsTab, InhibitTab, CoalsTab)
-> (a, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
handleFunctionParams (CoalsTab
a, InhibitTab
i, CoalsTab
s) (a
_, Uniqueness
u, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m IxFun
ixf) =
        case (Uniqueness
u, VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
a) of
          (Uniqueness
Unique, Just CoalsEntry
entry)
            | CoalsEntry -> IxFun
dstind CoalsEntry
entry IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
ixf,
              Set Set LmadRef
dst_uses <- MemRefs -> AccessSummary
dstrefs (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry),
              Set LmadRef
dst_uses Set LmadRef -> Set LmadRef -> Bool
forall a. Eq a => a -> a -> Bool
== Set LmadRef
forall a. Monoid a => a
mempty ->
                let (CoalsTab
a', CoalsTab
s') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
a, CoalsTab
s) VName
m CoalsEntry
entry
                 in (CoalsTab
a', InhibitTab
i, CoalsTab
s')
          (Uniqueness, Maybe CoalsEntry)
_ ->
            let (CoalsTab
a', InhibitTab
i') = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
a, InhibitTab
i) VName
m
             in (CoalsTab
a', InhibitTab
i', CoalsTab
s)
      (CoalsTab
actv_tab', InhibitTab
inhb_tab', CoalsTab
succ_tab') =
        ((CoalsTab, InhibitTab, CoalsTab)
 -> (VName, Uniqueness, ArrayMemBound)
 -> (CoalsTab, InhibitTab, CoalsTab))
-> (CoalsTab, InhibitTab, CoalsTab)
-> [(VName, Uniqueness, ArrayMemBound)]
-> (CoalsTab, InhibitTab, CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          (CoalsTab, InhibitTab, CoalsTab)
-> (VName, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
forall {a}.
(CoalsTab, InhibitTab, CoalsTab)
-> (a, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
handleFunctionParams
          (CoalsTab
actv_tab, InhibitTab
inhb_tab, CoalsTab
succ_tab)
          ([(VName, Uniqueness, ArrayMemBound)]
 -> (CoalsTab, InhibitTab, CoalsTab))
-> [(VName, Uniqueness, ArrayMemBound)]
-> (CoalsTab, InhibitTab, CoalsTab)
forall a b. (a -> b) -> a -> b
$ [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)]
getArrMemAssocFParam [Param FParamMem]
fpar

      (CoalsTab
succ_tab'', InhibitTab
failed_optdeps) = CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
succ_tab' InhibitTab
forall k a. Map k a
M.empty
      inhb_tab'' :: InhibitTab
inhb_tab'' = (Names -> Names -> Names) -> InhibitTab -> InhibitTab -> InhibitTab
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) InhibitTab
failed_optdeps InhibitTab
inhb_tab'
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ CoalsTab -> Bool
forall k a. Map k a -> Bool
M.null CoalsTab
actv_tab'
    then String -> ShortCircuitM rep CoalsTab
forall a. HasCallStack => String -> a
error (String
"COALESCING ROOT: BROKEN INV, active not empty: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [VName] -> String
forall a. Show a => a -> String
show (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
actv_tab'))
    else
      if InhibitTab -> Bool
forall k a. Map k a -> Bool
M.null (InhibitTab -> Bool) -> InhibitTab -> Bool
forall a b. (a -> b) -> a -> b
$ InhibitTab
inhb_tab'' InhibitTab -> InhibitTab -> InhibitTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` TopdownEnv rep -> InhibitTab
forall rep. TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
topenv
        then CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
succ_tab''
        else InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [Param FParamMem]
fpar Body (Aliases rep)
bdy (TopdownEnv rep
topenv {inhibited :: InhibitTab
inhibited = InhibitTab
inhb_tab''})
  where
    fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
    fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
coaltab InhibitTab
inhbtab =
      let (CoalsTab
coaltab', InhibitTab
inhbtab') = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
filterDeps (CoalsTab
coaltab, InhibitTab
inhbtab) (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab)
       in if [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab')
            then (CoalsTab
coaltab', InhibitTab
inhbtab')
            else CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
coaltab' InhibitTab
inhbtab'

    filterDeps :: (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
filterDeps (CoalsTab
coal, InhibitTab
inhb) VName
mb
      | Bool -> Bool
not (VName -> CoalsTab -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member VName
mb CoalsTab
coal) = (CoalsTab
coal, InhibitTab
inhb)
    filterDeps (CoalsTab
coal, InhibitTab
inhb) VName
mb
      | Just CoalsEntry
coal_etry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
coal =
          let failed :: Map VName VName
failed = (VName -> VName -> Bool) -> Map VName VName -> Map VName VName
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> VName -> Bool
forall {k}. Ord k => Map k CoalsEntry -> VName -> k -> Bool
failedOptDep CoalsTab
coal) (CoalsEntry -> Map VName VName
optdeps CoalsEntry
coal_etry)
           in if Map VName VName -> Bool
forall k a. Map k a -> Bool
M.null Map VName VName
failed
                then (CoalsTab
coal, InhibitTab
inhb) -- all ok
                else -- optimistic dependencies failed for the current
                -- memblock; extend inhibited mem-block mergings.
                  (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
coal, InhibitTab
inhb) VName
mb
    filterDeps (CoalsTab, InhibitTab)
_ VName
_ = String -> (CoalsTab, InhibitTab)
forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun filterDeps, impossible case reached!"
    failedOptDep :: Map k CoalsEntry -> VName -> k -> Bool
failedOptDep Map k CoalsEntry
coal VName
_ k
mr
      | Bool -> Bool
not (k
mr k -> Map k CoalsEntry -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map k CoalsEntry
coal) = Bool
True
    failedOptDep Map k CoalsEntry
coal VName
r k
mr
      | Just CoalsEntry
coal_etry <- k -> Map k CoalsEntry -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
mr Map k CoalsEntry
coal = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
r VName -> Map VName Coalesced -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry
    failedOptDep Map k CoalsEntry
_ VName
_ k
_ = String -> Bool
forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun failedOptDep, impossible case reached!"

-- | Perform short-circuiting on 'Stms'.
mkCoalsTabStms ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  Stms (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
mkCoalsTabStms :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab Stms (Aliases rep)
stms0 = Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
stms0
  where
    non_negs_in_pats :: Names
non_negs_in_pats = (Stm (Aliases rep) -> Names) -> Stms (Aliases rep) -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat (VarAliases, LParamMem) -> Names
forall rep. Typed rep => Pat rep -> Names
nonNegativesInPat (Pat (VarAliases, LParamMem) -> Names)
-> (Stm (Aliases rep) -> Pat (VarAliases, LParamMem))
-> Stm (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Aliases rep) -> Pat (VarAliases, LParamMem)
Stm (Aliases rep) -> Pat (LetDec (Aliases rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms (Aliases rep)
stms0
    traverseStms :: Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
Empty TopdownEnv rep
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
    traverseStms (Stm (Aliases rep)
stm :<| Stms (Aliases rep)
stms) TopdownEnv rep
td_env BotUpEnv
bu_env = do
      -- Compute @td_env@ top down
      let td_env' :: TopdownEnv rep
td_env' = TopdownEnv rep -> Stm (Aliases rep) -> TopdownEnv rep
forall rep (inner :: * -> *).
(ASTRep rep, Op rep ~ MemOp inner rep,
 TopDownHelper (inner (Aliases rep))) =>
TopdownEnv rep -> Stm (Aliases rep) -> TopdownEnv rep
updateTopdownEnv TopdownEnv rep
td_env Stm (Aliases rep)
stm
      -- Compute @bu_env@ bottom up
      BotUpEnv
bu_env' <- Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
stms TopdownEnv rep
td_env' BotUpEnv
bu_env
      InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStm InhibitTab
lutab Stm (Aliases rep)
stm (TopdownEnv rep
td_env' {nonNegatives :: Names
nonNegatives = TopdownEnv rep -> Names
forall rep. TopdownEnv rep -> Names
nonNegatives TopdownEnv rep
td_env' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
non_negs_in_pats}) BotUpEnv
bu_env'

-- | Array (register) coalescing can have one of three shapes:
--      a) @let y    = copy(b^{lu})@
--      b) @let y    = concat(a, b^{lu})@
--      c) @let y[i] = b^{lu}@
--   The intent is to use the memory block of the left-hand side
--     for the right-hand side variable, meaning to store @b@ in
--     @m_y@ (rather than @m_b@).
--   The following five safety conditions are necessary:
--      1. the right-hand side is lastly-used in the current statement
--      2. the allocation of @m_y@ dominates the creation of @b@
--         ^ relax it by hoisting the allocation of @m_y@
--      3. there is no use of the left-hand side memory block @m_y@
--           during the liveness of @b@, i.e., in between its last use
--           and its creation.
--         ^ relax it by pointwise/interval-based checking
--      4. @b@ is a newly created array, i.e., does not aliases anything
--         ^ relax it to support exitential memory blocks for if-then-else
--      5. the new index function of @b@ corresponding to memory block @m_y@
--           can be translated at the definition of @b@, and the
--           same for all variables aliasing @b@.
--   Observation: during the live range of @b@, @m_b@ can only be used by
--                variables aliased with @b@, because @b@ is newly created.
--                relax it: in case @m_b@ is existential due to an if-then-else
--                          then the checks should be extended to the actual
--                          array-creation points.
mkCoalsTabStm ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  Stm (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
mkCoalsTabStm :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStm InhibitTab
_ (Let (Pat [PatElem (LetDec (Aliases rep))
pe]) StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e) TopdownEnv rep
td_env BotUpEnv
bu_env
  | Just PrimExp VName
primexp <- (VName -> Maybe (PrimExp VName))
-> Exp (Aliases rep) -> Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
forall rep.
AliasableRep rep =>
ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
vnameToPrimExp (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)) Exp (Aliases rep)
e =
      BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {scals :: Map VName (PrimExp VName)
scals = VName
-> PrimExp VName
-> Map VName (PrimExp VName)
-> Map VName (PrimExp VName)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
pe) PrimExp VName
primexp (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)}
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
patt StmAux (ExpDec (Aliases rep))
_ (Match [SubExp]
_ [Case (Body (Aliases rep))]
cases Body (Aliases rep)
defbody MatchDec (BranchType (Aliases rep))
_)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  let pat_val_elms :: [PatElem (VarAliases, LParamMem)]
pat_val_elms = Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
patt
      -- ToDo: 1. we need to record existential memory blocks in alias table on the top-down pass.
      --       2. need to extend the scope table

      --  i) Filter @activeCoals@ by the 2ND AND 5th safety conditions:
      (CoalsTab
activeCoals0, InhibitTab
inhibit0) =
        CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          [PatElem (VarAliases, LParamMem)]
pat_val_elms

      -- ii) extend @activeCoals@ by transfering the pattern-elements bindings existent
      --     in @activeCoals@ to the body results of the then and else branches, but only
      --     if the current pattern element can be potentially coalesced and also
      --     if the current pattern element satisfies safety conditions 2 & 5.
      res_mem_def :: [MemBodyResult]
res_mem_def = CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals0 (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PatElem (VarAliases, LParamMem)]
pat_val_elms Body (Aliases rep)
defbody
      res_mem_cases :: [[MemBodyResult]]
res_mem_cases = (Case (Body (Aliases rep)) -> [MemBodyResult])
-> [Case (Body (Aliases rep))] -> [[MemBodyResult]]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals0 (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PatElem (VarAliases, LParamMem)]
pat_val_elms (Body (Aliases rep) -> [MemBodyResult])
-> (Case (Body (Aliases rep)) -> Body (Aliases rep))
-> Case (Body (Aliases rep))
-> [MemBodyResult]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody) [Case (Body (Aliases rep))]
cases

      subs_def :: FreeVarSubsts
subs_def = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
patt ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
defbody
      subs_cases :: [FreeVarSubsts]
subs_cases = (Case (Body (Aliases rep)) -> FreeVarSubsts)
-> [Case (Body (Aliases rep))] -> [FreeVarSubsts]
forall a b. (a -> b) -> [a] -> [b]
map (Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
patt ([SubExp] -> FreeVarSubsts)
-> (Case (Body (Aliases rep)) -> [SubExp])
-> Case (Body (Aliases rep))
-> FreeVarSubsts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp])
-> (Case (Body (Aliases rep)) -> Result)
-> Case (Body (Aliases rep))
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Aliases rep) -> Result)
-> (Case (Body (Aliases rep)) -> Body (Aliases rep))
-> Case (Body (Aliases rep))
-> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody) [Case (Body (Aliases rep))]
cases

      actv_def_i :: CoalsTab
actv_def_i = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_def) CoalsTab
activeCoals0 [MemBodyResult]
res_mem_def
      actv_cases_i :: [CoalsTab]
actv_cases_i = (FreeVarSubsts -> [MemBodyResult] -> CoalsTab)
-> [FreeVarSubsts] -> [[MemBodyResult]] -> [CoalsTab]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\FreeVarSubsts
subs [MemBodyResult]
res -> (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs) CoalsTab
activeCoals0 [MemBodyResult]
res) [FreeVarSubsts]
subs_cases [[MemBodyResult]]
res_mem_cases

      -- eliminate the original pattern binding of the if statement,
      -- @let x = if y[0,0] > 0 then map (+y[0,0]) a else map (+1) b@
      -- @let y[0] = x@
      -- should succeed because @m_y@ is used before @x@ is created.
      aux :: Map VName a -> MemBodyResult -> Map VName a
aux Map VName a
ac (MemBodyResult VName
m_b VName
_ VName
_ VName
m_r) = if VName
m_b VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_r then Map VName a
ac else VName -> Map VName a -> Map VName a
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
m_b Map VName a
ac
      actv_def :: CoalsTab
actv_def = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> MemBodyResult -> CoalsTab
forall {a}. Map VName a -> MemBodyResult -> Map VName a
aux CoalsTab
actv_def_i [MemBodyResult]
res_mem_def
      actv_cases :: [CoalsTab]
actv_cases = (CoalsTab -> [MemBodyResult] -> CoalsTab)
-> [CoalsTab] -> [[MemBodyResult]] -> [CoalsTab]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> MemBodyResult -> CoalsTab
forall {a}. Map VName a -> MemBodyResult -> Map VName a
aux) [CoalsTab]
actv_cases_i [[MemBodyResult]]
res_mem_cases

  -- iii) process the then and else bodies
  BotUpEnv
res_def <- InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
defbody) TopdownEnv rep
td_env (BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv_def})
  [BotUpEnv]
res_cases <- (Case (Body (Aliases rep))
 -> CoalsTab -> ShortCircuitM rep BotUpEnv)
-> [Case (Body (Aliases rep))]
-> [CoalsTab]
-> ShortCircuitM rep [BotUpEnv]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Case (Body (Aliases rep))
c CoalsTab
a -> InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms (Body (Aliases rep) -> Stms (Aliases rep))
-> Body (Aliases rep) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody Case (Body (Aliases rep))
c) TopdownEnv rep
td_env (BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
a})) [Case (Body (Aliases rep))]
cases [CoalsTab]
actv_cases
  let (CoalsTab
actv_def0, CoalsTab
succ_def0, InhibitTab
inhb_def0) = (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
res_def, BotUpEnv -> CoalsTab
successCoals BotUpEnv
res_def, BotUpEnv -> InhibitTab
inhibit BotUpEnv
res_def)

      -- iv) optimistically mark the pattern succesful:
      ((CoalsTab
activeCoals1, InhibitTab
inhibit1), CoalsTab
successCoals1) =
        (((CoalsTab, InhibitTab), CoalsTab)
 -> [MemBodyResult] -> ((CoalsTab, InhibitTab), CoalsTab))
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [[MemBodyResult]]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( [(CoalsTab, CoalsTab)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall {c}.
[(CoalsTab, Map VName c)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun
              ( (CoalsTab
actv_def0, CoalsTab
succ_def0)
                  (CoalsTab, CoalsTab)
-> [(CoalsTab, CoalsTab)] -> [(CoalsTab, CoalsTab)]
forall a. a -> [a] -> [a]
: [CoalsTab] -> [CoalsTab] -> [(CoalsTab, CoalsTab)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases) ((BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
successCoals [BotUpEnv]
res_cases)
              )
          )
          ((CoalsTab
activeCoals0, InhibitTab
inhibit0), BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env)
          ([[MemBodyResult]] -> [[MemBodyResult]]
forall a. [[a]] -> [[a]]
L.transpose ([[MemBodyResult]] -> [[MemBodyResult]])
-> [[MemBodyResult]] -> [[MemBodyResult]]
forall a b. (a -> b) -> a -> b
$ [MemBodyResult]
res_mem_def [MemBodyResult] -> [[MemBodyResult]] -> [[MemBodyResult]]
forall a. a -> [a] -> [a]
: [[MemBodyResult]]
res_mem_cases)

      --  v) unify coalescing results of all branches by taking the union
      --     of all entries in the current/then/else success tables.

      actv_res :: CoalsTab
actv_res = (CoalsTab -> CoalsTab -> CoalsTab)
-> CoalsTab -> [CoalsTab] -> CoalsTab
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsTab -> CoalsTab -> CoalsTab
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry) CoalsTab
activeCoals1 ([CoalsTab] -> CoalsTab) -> [CoalsTab] -> CoalsTab
forall a b. (a -> b) -> a -> b
$ CoalsTab
actv_def0 CoalsTab -> [CoalsTab] -> [CoalsTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases

      succ_res :: CoalsTab
succ_res = (CoalsTab -> CoalsTab -> CoalsTab)
-> CoalsTab -> [CoalsTab] -> CoalsTab
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsTab -> CoalsTab -> CoalsTab
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry) CoalsTab
successCoals1 ([CoalsTab] -> CoalsTab) -> [CoalsTab] -> CoalsTab
forall a b. (a -> b) -> a -> b
$ CoalsTab
succ_def0 CoalsTab -> [CoalsTab] -> [CoalsTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
successCoals [BotUpEnv]
res_cases

      -- vi) The step of filtering by 3rd safety condition is not
      --       necessary, because we perform index analysis of the
      --       source/destination uses, and they should have been
      --       filtered during the analysis of the then/else bodies.
      inhibit_res :: InhibitTab
inhibit_res =
        (Names -> Names -> Names) -> [InhibitTab] -> InhibitTab
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
M.unionsWith
          Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>)
          ( InhibitTab
inhibit1
              InhibitTab -> [InhibitTab] -> [InhibitTab]
forall a. a -> [a] -> [a]
: (CoalsTab -> InhibitTab -> InhibitTab)
-> [CoalsTab] -> [InhibitTab] -> [InhibitTab]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                ( \CoalsTab
actv InhibitTab
inhb ->
                    let failed :: CoalsTab
failed = CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
actv (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ (CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsTab -> CoalsTab -> CoalsTab
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsTab
actv CoalsTab
activeCoals0
                     in (CoalsTab, InhibitTab) -> InhibitTab
forall a b. (a, b) -> b
snd ((CoalsTab, InhibitTab) -> InhibitTab)
-> (CoalsTab, InhibitTab) -> InhibitTab
forall a b. (a -> b) -> a -> b
$ ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
failed, InhibitTab
inhb) (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
failed)
                )
                (CoalsTab
actv_def0 CoalsTab -> [CoalsTab] -> [CoalsTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases)
                (InhibitTab
inhb_def0 InhibitTab -> [InhibitTab] -> [InhibitTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> InhibitTab) -> [BotUpEnv] -> [InhibitTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> InhibitTab
inhibit [BotUpEnv]
res_cases)
          )
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    BotUpEnv
bu_env
      { activeCoals :: CoalsTab
activeCoals = CoalsTab
actv_res,
        successCoals :: CoalsTab
successCoals = CoalsTab
succ_res,
        inhibit :: InhibitTab
inhibit = InhibitTab
inhibit_res
      }
  where
    foldfun :: [(CoalsTab, Map VName c)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab, InhibitTab), CoalsTab)
_ [] =
      String -> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => String -> a
error String
"Imposible Case 1!!!"
    foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab
act, InhibitTab
_), CoalsTab
_) [MemBodyResult]
mem_body_results
      | Maybe CoalsEntry
Nothing <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (MemBodyResult -> VName
patMem (MemBodyResult -> VName) -> MemBodyResult -> VName
forall a b. (a -> b) -> a -> b
$ [MemBodyResult] -> MemBodyResult
forall a. HasCallStack => [a] -> a
head [MemBodyResult]
mem_body_results) CoalsTab
act =
          String -> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => String -> a
error String
"Imposible Case 2!!!"
    foldfun
      [(CoalsTab, Map VName c)]
acc
      ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      mem_body_results :: [MemBodyResult]
mem_body_results@(MemBodyResult VName
m_b VName
_ VName
_ VName
_ : [MemBodyResult]
_)
        | Just CoalsEntry
info <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
          Just [c]
_ <- (MemBodyResult -> Map VName c -> Maybe c)
-> [MemBodyResult] -> [Map VName c] -> Maybe [c]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (VName -> Map VName c -> Maybe c
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName -> Map VName c -> Maybe c)
-> (MemBodyResult -> VName)
-> MemBodyResult
-> Map VName c
-> Maybe c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results ([Map VName c] -> Maybe [c]) -> [Map VName c] -> Maybe [c]
forall a b. (a -> b) -> a -> b
$ ((CoalsTab, Map VName c) -> Map VName c)
-> [(CoalsTab, Map VName c)] -> [Map VName c]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab, Map VName c) -> Map VName c
forall a b. (a, b) -> b
snd [(CoalsTab, Map VName c)]
acc =
            -- Optimistically promote to successful coalescing and append!
            let info' :: CoalsEntry
info' =
                  CoalsEntry
info
                    { optdeps :: Map VName VName
optdeps =
                        (MemBodyResult -> Map VName VName -> Map VName VName)
-> Map VName VName -> [MemBodyResult] -> Map VName VName
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
                          (\MemBodyResult
mbr -> VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (MemBodyResult -> VName
bodyName MemBodyResult
mbr) (MemBodyResult -> VName
bodyMem MemBodyResult
mbr))
                          (CoalsEntry -> Map VName VName
optdeps CoalsEntry
info)
                          [MemBodyResult]
mem_body_results
                    }
                (CoalsTab
act', CoalsTab
succc') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
act, CoalsTab
succc) VName
m_b CoalsEntry
info'
             in ((CoalsTab
act', InhibitTab
inhb), CoalsTab
succc')
    foldfun
      [(CoalsTab, Map VName c)]
acc
      ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      mem_body_results :: [MemBodyResult]
mem_body_results@(MemBodyResult VName
m_b VName
_ VName
_ VName
_ : [MemBodyResult]
_)
        | Just CoalsEntry
info <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
          (MemBodyResult -> Bool) -> [MemBodyResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
(==) VName
m_b (VName -> Bool)
-> (MemBodyResult -> VName) -> MemBodyResult -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results,
          Just [CoalsEntry]
info' <- (MemBodyResult -> CoalsTab -> Maybe CoalsEntry)
-> [MemBodyResult] -> [CoalsTab] -> Maybe [CoalsEntry]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName -> CoalsTab -> Maybe CoalsEntry)
-> (MemBodyResult -> VName)
-> MemBodyResult
-> CoalsTab
-> Maybe CoalsEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results ([CoalsTab] -> Maybe [CoalsEntry])
-> [CoalsTab] -> Maybe [CoalsEntry]
forall a b. (a -> b) -> a -> b
$ ((CoalsTab, Map VName c) -> CoalsTab)
-> [(CoalsTab, Map VName c)] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab, Map VName c) -> CoalsTab
forall a b. (a, b) -> a
fst [(CoalsTab, Map VName c)]
acc =
            -- Treating special case resembling:
            -- @let x0 = map (+1) a                                  @
            -- @let x3 = if cond then let x1 = x0 with [0] <- 2 in x1@
            -- @                 else let x2 = x0 with [1] <- 3 in x2@
            -- @let z[1] = x3                                        @
            -- In this case the result active table should be the union
            -- of the @m_x@ entries of the then and else active tables.
            let info'' :: CoalsEntry
info'' =
                  (CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsEntry -> [CoalsEntry] -> CoalsEntry
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsEntry
info [CoalsEntry]
info'
                act' :: CoalsTab
act' = VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info'' CoalsTab
act
             in ((CoalsTab
act', InhibitTab
inhb), CoalsTab
succc)
    foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) (MemBodyResult
mbr : [MemBodyResult]
_) =
      -- one of the branches has failed coalescing,
      -- hence remove the coalescing of the result.

      ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) (MemBodyResult -> VName
patMem MemBodyResult
mbr), CoalsTab
succc)
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (Loop [(FParam (Aliases rep), SubExp)]
arginis LoopForm
lform Body (Aliases rep)
body)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  let pat_val_elms :: [PatElem (VarAliases, LParamMem)]
pat_val_elms = Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat

      --  i) Filter @activeCoals@ by the 2nd, 3rd AND 5th safety conditions. In
      --  other words, for each active coalescing target, the creation of the
      --  array we're trying to merge should happen before the allocation of the
      --  merge target and the index function should be translateable.
      (CoalsTab
actv0, InhibitTab
inhibit0) =
        CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          [PatElem (VarAliases, LParamMem)]
pat_val_elms
      -- ii) Extend @activeCoals@ by transfering the pattern-elements bindings
      --     existent in @activeCoals@ to the loop-body results, but only if:
      --       (a) the pattern element is a candidate for coalescing,        &&
      --       (b) the pattern element satisfies safety conditions 2 & 5,
      --           (conditions (a) and (b) have already been checked above), &&
      --       (c) the memory block of the corresponding body result is
      --           allocated outside the loop, i.e., non-existential,        &&
      --       (d) the init name is lastly-used in the initialization
      --           of the loop variant.
      --     Otherwise fail and remove from active-coalescing table!
      bdy_ress :: Result
bdy_ress = Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
body
      ([(VName, VName)]
patmems, [(VName, VName)]
argmems, [(VName, VName)]
inimems, [(VName, VName)]
resmems) =
        [((VName, VName), (VName, VName), (VName, VName), (VName, VName))]
-> ([(VName, VName)], [(VName, VName)], [(VName, VName)],
    [(VName, VName)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
L.unzip4 ([((VName, VName), (VName, VName), (VName, VName), (VName, VName))]
 -> ([(VName, VName)], [(VName, VName)], [(VName, VName)],
     [(VName, VName)]))
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
-> ([(VName, VName)], [(VName, VName)], [(VName, VName)],
    [(VName, VName)])
forall a b. (a -> b) -> a -> b
$
          ((PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
  SubExp)
 -> Maybe
      ((VName, VName), (VName, VName), (VName, VName), (VName, VName)))
-> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
     SubExp)]
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CoalsTab
-> (PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
    SubExp)
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
mapmbFun CoalsTab
actv0) ([PatElem (VarAliases, LParamMem)]
-> [(Param FParamMem, SubExp)]
-> [SubExp]
-> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
     SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarAliases, LParamMem)]
pat_val_elms [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis ([SubExp]
 -> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
      SubExp)])
-> [SubExp]
-> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
     SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bdy_ress) -- td_env'

      -- remove the other pattern elements from the active coalescing table:
      coal_pat_names :: Names
coal_pat_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((VName, VName) -> VName) -> [(VName, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName) -> VName
forall a b. (a, b) -> a
fst [(VName, VName)]
patmems
      (CoalsTab
actv1, InhibitTab
inhibit1) =
        ((CoalsTab, InhibitTab)
 -> (VName, ArrayMemBound) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [(VName, ArrayMemBound)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \(CoalsTab
act, InhibitTab
inhb) (VName
b, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_b IxFun
_) ->
              if VName
b VName -> Names -> Bool
`nameIn` Names
coal_pat_names
                then (CoalsTab
act, InhibitTab
inhb) -- ok
                else (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) VName
m_b -- remove from active
          )
          (CoalsTab
actv0, InhibitTab
inhibit0)
          (Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat)

      -- iii) Process the loop's body.
      --      If the memory blocks of the loop result and loop variant param differ
      --      then make the original memory block of the loop result conflict with
      --      the original memory block of the loop parameter. This is done in
      --      order to prevent the coalescing of @a1@, @a0@, @x@ and @db@ in the
      --      same memory block of @y@ in the example below:
      --      @loop(a1 = a0) = for i < n do @
      --      @    let x = map (stencil a1) (iota n)@
      --      @    let db = copy x          @
      --      @    in db                    @
      --      @let y[0] = a1                @
      --      Meaning the coalescing of @x@ in @let db = copy x@ should fail because
      --      @a1@ appears in the definition of @let x = map (stencil a1) (iota n)@.
      res_mem_bdy :: [MemBodyResult]
res_mem_bdy = ((VName, VName) -> (VName, VName) -> MemBodyResult)
-> [(VName, VName)] -> [(VName, VName)] -> [MemBodyResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
resmems
      res_mem_arg :: [MemBodyResult]
res_mem_arg = ((VName, VName) -> (VName, VName) -> MemBodyResult)
-> [(VName, VName)] -> [(VName, VName)] -> [MemBodyResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
argmems
      res_mem_ini :: [MemBodyResult]
res_mem_ini = ((VName, VName) -> (VName, VName) -> MemBodyResult)
-> [(VName, VName)] -> [(VName, VName)] -> [MemBodyResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
inimems

      actv2 :: CoalsTab
actv2 =
        let subs_res :: FreeVarSubsts
subs_res = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
body
            actv11 :: CoalsTab
actv11 = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_res) CoalsTab
actv1 [MemBodyResult]
res_mem_bdy
            subs_arg :: FreeVarSubsts
subs_arg = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis
            actv12 :: CoalsTab
actv12 = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_arg) CoalsTab
actv11 [MemBodyResult]
res_mem_arg
            subs_ini :: FreeVarSubsts
subs_ini = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis
         in (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_ini) CoalsTab
actv12 [MemBodyResult]
res_mem_ini

      -- The code below adds an aliasing relation to the loop-arg memory
      --   so that to prevent, e.g., the coalescing of an iterative stencil
      --   (you need a buffer for the result and a separate one for the stencil).
      -- @ let b =               @
      -- @    loop (a) for i<N do@
      -- @        stencil a      @
      -- @  ...                  @
      -- @  y[slc_y] = b         @
      -- This should fail coalescing because we are aliasing @m_a@ with
      --   the memory block of the result.
      insertMemAliases :: CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab
insertMemAliases CoalsTab
tab (MemBodyResult VName
_ VName
_ VName
_ VName
m_r, MemBodyResult VName
_ VName
_ VName
_ VName
m_a) =
        if VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_a
          then CoalsTab
tab
          else case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_r CoalsTab
tab of
            Maybe CoalsEntry
Nothing -> CoalsTab
tab
            Just CoalsEntry
etry ->
              VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r (CoalsEntry
etry {alsmem :: Names
alsmem = CoalsEntry -> Names
alsmem CoalsEntry
etry Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> VName -> Names
oneName VName
m_a}) CoalsTab
tab
      actv3 :: CoalsTab
actv3 = (CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab)
-> CoalsTab -> [(MemBodyResult, MemBodyResult)] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab
insertMemAliases CoalsTab
actv2 ([MemBodyResult]
-> [MemBodyResult] -> [(MemBodyResult, MemBodyResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MemBodyResult]
res_mem_bdy [MemBodyResult]
res_mem_arg)
      -- analysing the loop body starts from a null memory-reference set;
      --  the results of the loop body iteration are aggregated later
      actv4 :: CoalsTab
actv4 = (CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (\CoalsEntry
etry -> CoalsEntry
etry {memrefs :: MemRefs
memrefs = MemRefs
forall a. Monoid a => a
mempty}) CoalsTab
actv3
  BotUpEnv
res_env_body <-
    InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms
      InhibitTab
lutab
      (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
      TopdownEnv rep
td_env'
      ( BotUpEnv
bu_env
          { activeCoals :: CoalsTab
activeCoals = CoalsTab
actv4,
            inhibit :: InhibitTab
inhibit = InhibitTab
inhibit1
          }
      )
  let scals_loop :: Map VName (PrimExp VName)
scals_loop = BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
res_env_body
      (CoalsTab
res_actv0, CoalsTab
res_succ0, InhibitTab
res_inhb0) = (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
res_env_body, BotUpEnv -> CoalsTab
successCoals BotUpEnv
res_env_body, BotUpEnv -> InhibitTab
inhibit BotUpEnv
res_env_body)
      -- iv) Aggregate memory references across loop and filter unsound coalescing
      -- a) Filter the active-table by the FIRST SOUNDNESS condition, namely:
      --     W_i does not overlap with Union_{j=i+1..n} U_j,
      --     where W_i corresponds to the Write set of src mem-block m_b,
      --     and U_j correspond to the uses of the destination
      --     mem-block m_y, in which m_b is coalesced into.
      --     W_i and U_j correspond to the accesses within the loop body.
      mb_loop_idx :: Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx = LoopForm
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mbLoopIndexRange LoopForm
lform
  CoalsTab
res_actv1 <- (CoalsEntry -> ShortCircuitM rep Bool)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall k (m :: * -> *) v.
(Eq k, Monad m) =>
(v -> m Bool) -> Map k v -> m (Map k v)
filterMapM1 (Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep Bool
loopSoundness1Entry Map VName (PrimExp VName)
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx) CoalsTab
res_actv0

  -- b) Update the memory-reference summaries across loop:
  --   W = Union_{i=0..n-1} W_i Union W_{before-loop}
  --   U = Union_{i=0..n-1} U_i Union U_{before-loop}
  CoalsTab
res_actv2 <- (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map VName a -> m (Map VName b)
mapM (ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep CoalsEntry
aggAcrossLoopEntry (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env' ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)) Map VName (PrimExp VName)
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx) CoalsTab
res_actv1

  -- c) check soundness of the successful promotions for:
  --      - the entries that have been promoted to success during the loop-body pass
  --      - for all the entries of active table
  --    Filter the entries by the SECOND SOUNDNESS CONDITION, namely:
  --      Union_{i=1..n-1} W_i does not overlap the before-the-loop uses
  --        of the destination memory block.
  let res_actv3 :: CoalsTab
res_actv3 = (VName -> CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
actv3) CoalsTab
res_actv2

  let tmp_succ :: CoalsTab
tmp_succ =
        (VName -> CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
forall {k} {a} {p}. Ord k => Map k a -> k -> p -> Bool
okLookup CoalsTab
actv3) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$
          CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
res_succ0 (BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env)
      ver_succ :: CoalsTab
ver_succ = (VName -> CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
actv3) CoalsTab
tmp_succ
  let suc_fail :: CoalsTab
suc_fail = CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
tmp_succ CoalsTab
ver_succ
      (CoalsTab
res_succ, InhibitTab
res_inhb1) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
res_succ0, InhibitTab
res_inhb0) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
suc_fail
      --
      act_fail :: CoalsTab
act_fail = CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
res_actv0 CoalsTab
res_actv3
      (CoalsTab
_, InhibitTab
res_inhb) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
res_actv0, InhibitTab
res_inhb1) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
act_fail
      res_actv :: CoalsTab
res_actv =
        (VName -> CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey (CoalsTab -> VName -> CoalsEntry -> CoalsEntry
forall {k}.
Ord k =>
Map k CoalsEntry -> k -> CoalsEntry -> CoalsEntry
addBeforeLoop CoalsTab
actv3) CoalsTab
res_actv3

      -- v) optimistically mark the pattern succesful if there is any chance to succeed
      ((CoalsTab
fin_actv1, InhibitTab
fin_inhb1), CoalsTab
fin_succ1) =
        (((CoalsTab, InhibitTab), CoalsTab)
 -> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
 -> ((CoalsTab, InhibitTab), CoalsTab))
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((CoalsTab, InhibitTab), CoalsTab)
-> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> ((CoalsTab, InhibitTab), CoalsTab)
foldFunOptimPromotion ((CoalsTab
res_actv, InhibitTab
res_inhb), CoalsTab
res_succ) ([((VName, VName), (VName, VName), (VName, VName), (VName, VName))]
 -> ((CoalsTab, InhibitTab), CoalsTab))
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a b. (a -> b) -> a -> b
$
          [(VName, VName)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
L.zip4 [(VName, VName)]
patmems [(VName, VName)]
argmems [(VName, VName)]
resmems [(VName, VName)]
inimems
      (CoalsTab
fin_actv2, InhibitTab
fin_inhb2) =
        ((CoalsTab, InhibitTab)
 -> VName -> CoalsEntry -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> CoalsTab -> (CoalsTab, InhibitTab)
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey
          ( \(CoalsTab, InhibitTab)
acc VName
k CoalsEntry
_ ->
              if VName
k VName -> Names -> Bool
`nameIn` [VName] -> Names
namesFromList (((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis)
                then (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
acc VName
k
                else (CoalsTab, InhibitTab)
acc
          )
          (CoalsTab
fin_actv1, InhibitTab
fin_inhb1)
          CoalsTab
fin_actv1
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
fin_actv2, successCoals :: CoalsTab
successCoals = CoalsTab
fin_succ1, inhibit :: InhibitTab
inhibit = InhibitTab
fin_inhb2}
  where
    allocs_bdy :: AllocTab
allocs_bdy = (AllocTab -> Stm (Aliases rep) -> AllocTab)
-> AllocTab -> Stms (Aliases rep) -> AllocTab
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl AllocTab -> Stm (Aliases rep) -> AllocTab
forall {rep} {inner :: * -> *}.
(OpC rep ~ MemOp inner) =>
AllocTab -> Stm rep -> AllocTab
getAllocs (TopdownEnv rep -> AllocTab
forall rep. TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env') (Stms (Aliases rep) -> AllocTab) -> Stms (Aliases rep) -> AllocTab
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body
    td_env_allocs :: TopdownEnv rep
td_env_allocs = TopdownEnv rep
td_env' {alloc :: AllocTab
alloc = AllocTab
allocs_bdy, scope :: ScopeTab rep
scope = TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env' ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)}
    td_env' :: TopdownEnv rep
td_env' = TopdownEnv rep
-> [(FParam rep, SubExp)] -> LoopForm -> TopdownEnv rep
forall rep.
TopdownEnv rep
-> [(FParam rep, SubExp)] -> LoopForm -> TopdownEnv rep
updateTopdownEnvLoop TopdownEnv rep
td_env [(FParam rep, SubExp)]
[(FParam (Aliases rep), SubExp)]
arginis LoopForm
lform
    getAllocs :: AllocTab -> Stm rep -> AllocTab
getAllocs AllocTab
tab (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ Space
sp))) =
      VName -> Space -> AllocTab -> AllocTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Space
sp AllocTab
tab
    getAllocs AllocTab
tab Stm rep
_ = AllocTab
tab
    okLookup :: Map k a -> k -> p -> Bool
okLookup Map k a
tab k
m p
_
      | Just a
_ <- k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
m Map k a
tab = Bool
True
    okLookup Map k a
_ k
_ p
_ = Bool
False
    --
    mapmbFun :: CoalsTab
-> (PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
    SubExp)
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
mapmbFun CoalsTab
actv0 (PatElem (VarAliases, LParamMem)
patel, (Param FParamMem
arg, SubExp
ini), SubExp
bdyres)
      | VName
b <- PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
patel,
        (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_b IxFun
_)) <- PatElem (VarAliases, LParamMem) -> (VarAliases, LParamMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LParamMem)
patel,
        VName
a <- Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
arg,
        -- Not safe to short-circuit if the index function of this
        -- parameter is variant to the loop.
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> Bool)
-> [(Param FParamMem, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` FParamMem -> Names
forall a. FreeIn a => a -> Names
freeIn (Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
arg)) (VName -> Bool)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis,
        Var VName
a0 <- SubExp
ini,
        Var VName
r <- SubExp
bdyres,
        Just CoalsEntry
coal_etry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
actv0,
        Just Coalesced
_ <- VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_a IxFun
_) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
a (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_a0 IxFun
_) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
a0 (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_r IxFun
_) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
r (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just Names
nms <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
a InhibitTab
lutab,
        VName
a0 VName -> Names -> Bool
`nameIn` Names
nms,
        VName
m_r VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` AllocTab -> [VName]
forall k a. Map k a -> [k]
M.keys (TopdownEnv rep -> AllocTab
forall rep. TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env_allocs) =
          ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
forall a. a -> Maybe a
Just ((VName
b, VName
m_b), (VName
a, VName
m_a), (VName
a0, VName
m_a0), (VName
r, VName
m_r))
    mapmbFun CoalsTab
_ (PatElem (VarAliases, LParamMem)
_patel, (Param FParamMem
_arg, SubExp
_ini), SubExp
_bdyres) = Maybe
  ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
forall a. Maybe a
Nothing
    foldFunOptimPromotion ::
      ((CoalsTab, InhibitTab), CoalsTab) ->
      ((VName, VName), (VName, VName), (VName, VName), (VName, VName)) ->
      ((CoalsTab, InhibitTab), CoalsTab)
    foldFunOptimPromotion :: ((CoalsTab, InhibitTab), CoalsTab)
-> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> ((CoalsTab, InhibitTab), CoalsTab)
foldFunOptimPromotion ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) ((VName
b, VName
m_b), (VName
a, VName
m_a), (VName
_r, VName
m_r), (VName
b_i, VName
m_i))
      | VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_i,
        Just CoalsEntry
info <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_i CoalsTab
act,
        Just Map VName Coalesced
vtab_i <- TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) VName
b_i =
          Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => Bool -> a -> a
Exc.assert
            (VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b Bool -> Bool -> Bool
&& VName
m_a VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b)
            ((VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b (CoalsEntry
info {vartab :: Map VName Coalesced
vartab = Map VName Coalesced
vtab_i}) CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      | VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_i =
          Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => Bool -> a -> a
Exc.assert
            (VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b Bool -> Bool -> Bool
&& VName
m_a VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b)
            ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) VName
m_b, CoalsTab
succc)
      | Just CoalsEntry
info_b0 <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
        Just CoalsEntry
info_a0 <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_a CoalsTab
act,
        Just CoalsEntry
info_i <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_i CoalsTab
act,
        VName -> CoalsTab -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member VName
m_r CoalsTab
succc,
        Just Map VName Coalesced
vtab_i <- TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info_i) VName
b_i,
        [Just CoalsEntry
info_b, Just CoalsEntry
info_a] <- ((VName, CoalsEntry) -> Maybe CoalsEntry)
-> [(VName, CoalsEntry)] -> [Maybe CoalsEntry]
forall a b. (a -> b) -> [a] -> [b]
map (VName, CoalsEntry) -> Maybe CoalsEntry
translateIxFnInScope [(VName
b, CoalsEntry
info_b0), (VName
a, CoalsEntry
info_a0)] =
          let info_b' :: CoalsEntry
info_b' = CoalsEntry
info_b {optdeps :: Map VName VName
optdeps = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b_i VName
m_i (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
info_b}
              info_a' :: CoalsEntry
info_a' = CoalsEntry
info_a {optdeps :: Map VName VName
optdeps = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b_i VName
m_i (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
info_a}
              info_i' :: CoalsEntry
info_i' =
                CoalsEntry
info_i
                  { optdeps :: Map VName VName
optdeps = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b VName
m_b (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
info_i,
                    memrefs :: MemRefs
memrefs = MemRefs
forall a. Monoid a => a
mempty,
                    vartab :: Map VName Coalesced
vartab = Map VName Coalesced
vtab_i
                  }
              act' :: CoalsTab
act' = VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_i CoalsEntry
info_i' CoalsTab
act
              (CoalsTab
act1, CoalsTab
succc1) =
                ((CoalsTab, CoalsTab)
 -> (VName, CoalsEntry) -> (CoalsTab, CoalsTab))
-> (CoalsTab, CoalsTab)
-> [(VName, CoalsEntry)]
-> (CoalsTab, CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                  (\(CoalsTab, CoalsTab)
acc (VName
m, CoalsEntry
info) -> (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab, CoalsTab)
acc VName
m CoalsEntry
info)
                  (CoalsTab
act', CoalsTab
succc)
                  [(VName
m_b, CoalsEntry
info_b'), (VName
m_a, CoalsEntry
info_a')]
           in -- ToDo: make sure that ixfun translates and update substitutions (?)
              ((CoalsTab
act1, InhibitTab
inhb), CoalsTab
succc1)
    foldFunOptimPromotion ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) ((VName
_, VName
m_b), (VName
_a, VName
m_a), (VName
_r, VName
m_r), (VName
_b_i, VName
m_i)) =
      Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => Bool -> a -> a
Exc.assert
        (VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
m_i)
        (((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) [VName
m_b, VName
m_a, VName
m_r, VName
m_i], CoalsTab
succc)

    translateIxFnInScope :: (VName, CoalsEntry) -> Maybe CoalsEntry
translateIxFnInScope (VName
x, CoalsEntry
info)
      | Just (Coalesced CoalescedKind
knd mbd :: ArrayMemBound
mbd@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixfn) FreeVarSubsts
_subs0) <- VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info),
        TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) =
          let scope_tab :: ScopeTab rep
scope_tab =
                TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
                  ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem] -> ScopeTab rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis)
           in case ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope_tab (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
ixfn of
                Just FreeVarSubsts
fv_subst ->
                  CoalsEntry -> Maybe CoalsEntry
forall a. a -> Maybe a
Just (CoalsEntry -> Maybe CoalsEntry) -> CoalsEntry -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsEntry
info {vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x (CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd ArrayMemBound
mbd FreeVarSubsts
fv_subst) (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info)}
                Maybe FreeVarSubsts
Nothing -> Maybe CoalsEntry
forall a. Maybe a
Nothing
    translateIxFnInScope (VName, CoalsEntry)
_ = Maybe CoalsEntry
forall a. Maybe a
Nothing
    se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
    mbLoopIndexRange ::
      LoopForm ->
      Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
    mbLoopIndexRange :: LoopForm
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mbLoopIndexRange (WhileLoop VName
_) = Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
forall a. Maybe a
Nothing
    mbLoopIndexRange (ForLoop VName
inm IntType
_inttp SubExp
seN) = (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
forall a. a -> Maybe a
Just (VName
inm, (SubExp -> TPrimExp Int64 VName
pe64 SubExp
se0, SubExp -> TPrimExp Int64 VName
pe64 SubExp
seN))
    addBeforeLoop :: Map k CoalsEntry -> k -> CoalsEntry -> CoalsEntry
addBeforeLoop Map k CoalsEntry
actv_bef k
m_b CoalsEntry
etry =
      case k -> Map k CoalsEntry -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
m_b Map k CoalsEntry
actv_bef of
        Maybe CoalsEntry
Nothing -> CoalsEntry
etry
        Just CoalsEntry
etry0 ->
          CoalsEntry
etry {memrefs :: MemRefs
memrefs = CoalsEntry -> MemRefs
memrefs CoalsEntry
etry0 MemRefs -> MemRefs -> MemRefs
forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> MemRefs
memrefs CoalsEntry
etry}
    aggAcrossLoopEntry :: ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep CoalsEntry
aggAcrossLoopEntry ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx CoalsEntry
etry = do
      AccessSummary
wrts <-
        ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *) rep.
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$
          (MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      AccessSummary
uses <-
        ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *) rep.
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$
          (MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsEntry
etry {memrefs :: MemRefs
memrefs = AccessSummary -> AccessSummary -> MemRefs
MemRefs AccessSummary
uses AccessSummary
wrts}
    loopSoundness1Entry :: Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep Bool
loopSoundness1Entry Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx CoalsEntry
etry = do
      let wrt_i :: AccessSummary
wrt_i = (MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      AccessSummary
use_p <-
        Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopPartial (Map VName (PrimExp VName)
scal_tab Map VName (PrimExp VName)
-> Map VName (PrimExp VName) -> Map VName (PrimExp VName)
forall a. Semigroup a => a -> a -> a
<> TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$
          MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary) -> MemRefs -> AccessSummary
forall a b. (a -> b) -> a -> b
$
            CoalsEntry -> MemRefs
memrefs CoalsEntry
etry
      Bool -> ShortCircuitM rep Bool
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> ShortCircuitM rep Bool) -> Bool -> ShortCircuitM rep Bool
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env' AccessSummary
wrt_i AccessSummary
use_p
    loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool
    loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
old_actv VName
m_b CoalsEntry
etry =
      case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
old_actv of
        Maybe CoalsEntry
Nothing -> Bool
True
        Just CoalsEntry
etry0 ->
          let uses_before :: AccessSummary
uses_before = (MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry0
              write_loop :: AccessSummary
write_loop = (MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
           in TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
write_loop AccessSummary
uses_before

-- The case of in-place update:
--   @let x' = x with slice <- elm@
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let pat :: Pat (LetDec (Aliases rep))
pat@(Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Update Safety
safety VName
x Slice SubExp
_ SubExp
_elm))) TopdownEnv rep
td_env BotUpEnv
bu_env
  | [(VName
_, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_x IxFun
_)] <- Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat = do
      -- (a) filter by the 3rd safety for @elm@ and @x'@
      let (CoalsTab
actv, InhibitTab
inhbt) = TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
          -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well
          (CoalsTab
actv', InhibitTab
inhbt') =
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
actv of
              Maybe CoalsEntry
Nothing -> (CoalsTab
actv, InhibitTab
inhbt)
              Just CoalsEntry
info ->
                case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
x') (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) of
                  Maybe Coalesced
Nothing -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x
                  Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
x_indfun) FreeVarSubsts
_) ->
                    case ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
x_indfun of
                      Just FreeVarSubsts
fv_subs
                        | TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) ->
                            let coal_etry_x :: Coalesced
coal_etry_x = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subs
                                info' :: CoalsEntry
info' =
                                  CoalsEntry
info
                                    { vartab :: Map VName Coalesced
vartab =
                                        VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x Coalesced
coal_etry_x (Map VName Coalesced -> Map VName Coalesced)
-> Map VName Coalesced -> Map VName Coalesced
forall a b. (a -> b) -> a -> b
$
                                          VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
x') Coalesced
coal_etry_x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info)
                                    }
                             in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_x CoalsEntry
info' CoalsTab
actv, InhibitTab
inhbt)
                      Maybe FreeVarSubsts
_ ->
                        (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x

      -- (c) this stm is also a potential source for coalescing, so process it
      CoalsTab
actv'' <-
        if Safety
safety Safety -> Safety -> Bool
forall a. Eq a => a -> a -> Bool
== Safety
Unsafe
          then Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env {inhibited :: InhibitTab
inhibited = InhibitTab
inhbt'} BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv'}
          else CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
actv'
      BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv'', inhibit :: InhibitTab
inhibit = InhibitTab
inhbt'}

-- The case of flat in-place update:
--   @let x' = x with flat-slice <- elm@
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let pat :: Pat (LetDec (Aliases rep))
pat@(Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (FlatUpdate VName
x FlatSlice SubExp
_ VName
_elm))) TopdownEnv rep
td_env BotUpEnv
bu_env
  | [(VName
_, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_x IxFun
_)] <- Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat = do
      -- (a) filter by the 3rd safety for @elm@ and @x'@
      let (CoalsTab
actv, InhibitTab
inhbt) = TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
          -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well
          (CoalsTab
actv', InhibitTab
inhbt') =
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
actv of
              Maybe CoalsEntry
Nothing -> (CoalsTab
actv, InhibitTab
inhbt)
              Just CoalsEntry
info ->
                case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
x') (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) of
                  -- this case should not happen, but if it can that
                  -- just fail conservatively
                  Maybe Coalesced
Nothing -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x
                  Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
x_indfun) FreeVarSubsts
_) ->
                    case ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
x_indfun of
                      Just FreeVarSubsts
fv_subs
                        | TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) ->
                            let coal_etry_x :: Coalesced
coal_etry_x = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subs
                                info' :: CoalsEntry
info' =
                                  CoalsEntry
info
                                    { vartab :: Map VName Coalesced
vartab =
                                        VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x Coalesced
coal_etry_x (Map VName Coalesced -> Map VName Coalesced)
-> Map VName Coalesced -> Map VName Coalesced
forall a b. (a -> b) -> a -> b
$
                                          VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
x') Coalesced
coal_etry_x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info)
                                    }
                             in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_x CoalsEntry
info' CoalsTab
actv, InhibitTab
inhbt)
                      Maybe FreeVarSubsts
_ ->
                        (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x

      -- (c) this stm is also a potential source for coalescing, so process it
      CoalsTab
actv'' <- Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env {inhibited :: InhibitTab
inhibited = InhibitTab
inhbt'} BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv'}
      BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
actv'', inhibit :: InhibitTab
inhibit = InhibitTab
inhbt'}
--
mkCoalsTabStm InhibitTab
_ (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (BasicOp Update {})) TopdownEnv rep
_ BotUpEnv
_ =
  String -> ShortCircuitM rep BotUpEnv
forall a. HasCallStack => String -> a
error (String -> ShortCircuitM rep BotUpEnv)
-> String -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ String
"In ArrayCoalescing.hs, fun mkCoalsTabStm, illegal pattern for in-place update: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat (VarAliases, LParamMem) -> String
forall a. Show a => a -> String
show Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat
-- default handling
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (Op Op (Aliases rep)
op)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- Process body
  InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> MemOp inner (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
on_op <- (ShortCircuitReader rep
 -> InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> MemOp inner (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> ShortCircuitM
     rep
     (InhibitTab
      -> Pat (VarAliases, LParamMem)
      -> Certs
      -> MemOp inner (Aliases rep)
      -> TopdownEnv rep
      -> BotUpEnv
      -> ShortCircuitM rep BotUpEnv)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> MemOp inner (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep.
ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
onOp
  BotUpEnv
bu_env' <- InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> MemOp inner (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
on_op InhibitTab
lutab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux) Op (Aliases rep)
MemOp inner (Aliases rep)
op TopdownEnv rep
td_env BotUpEnv
bu_env
  CoalsTab
activeCoals' <- Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env BotUpEnv
bu_env'
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env' {activeCoals :: CoalsTab
activeCoals = CoalsTab
activeCoals'}
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  --   i) Filter @activeCoals@ by the 3rd safety condition:
  --      this is now relaxed by use of LMAD eqs:
  --      the memory referenced in stm are added to memrefs::dstrefs
  --      in corresponding coal-tab entries.
  let (CoalsTab
activeCoals', InhibitTab
inhibit') = TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm

      --  ii) promote any of the entries in @activeCoals@ to @successCoals@ as long as
      --        - this statement defined a variable consumed in a coalesced statement
      --        - and safety conditions 2, 4, and 5 are satisfied.
      --      AND extend @activeCoals@ table for any definition of a variable that
      --      aliases a coalesced variable.
      safe_4 :: Bool
safe_4 = Exp (Aliases rep) -> Bool
forall rep. Exp rep -> Bool
createsNewArrOK Exp (Aliases rep)
e
      ((CoalsTab
activeCoals'', InhibitTab
inhibit''), CoalsTab
successCoals') =
        (((CoalsTab, InhibitTab), CoalsTab)
 -> (VName, ArrayMemBound) -> ((CoalsTab, InhibitTab), CoalsTab))
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [(VName, ArrayMemBound)]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> (VName, ArrayMemBound)
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun Bool
safe_4) ((CoalsTab
activeCoals', InhibitTab
inhibit'), BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env) (Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat)

  -- iii) record a potentially coalesced statement in @activeCoals@
  CoalsTab
activeCoals''' <- Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env BotUpEnv
bu_env {successCoals :: CoalsTab
successCoals = CoalsTab
successCoals', activeCoals :: CoalsTab
activeCoals = CoalsTab
activeCoals''}
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env {activeCoals :: CoalsTab
activeCoals = CoalsTab
activeCoals''', inhibit :: InhibitTab
inhibit = InhibitTab
inhibit'', successCoals :: CoalsTab
successCoals = CoalsTab
successCoals'}
  where
    foldfun :: Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> (VName, ArrayMemBound)
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun Bool
safe_4 ((CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc) (VName
b, MemBlock PrimType
tp ShapeBase SubExp
shp VName
mb IxFun
_b_indfun) =
      case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
a_acc of
        Maybe CoalsEntry
Nothing -> ((CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
        Just info :: CoalsEntry
info@(CoalsEntry VName
x_mem IxFun
_ Names
_ Map VName Coalesced
vtab Map VName VName
_ MemRefs
_ Certs
certs) ->
          let failed :: (CoalsTab, InhibitTab)
failed = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
a_acc, InhibitTab
inhb) VName
mb
           in case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b Map VName Coalesced
vtab of
                Maybe Coalesced
Nothing ->
                  -- we hit the definition of some variable @b@ aliased with
                  --    the coalesced variable @x@, hence extend @activeCoals@, e.g.,
                  --       @let x = map f arr  @
                  --       @let b = alias x  @ <- current statement
                  --       @ ... use of b ...  @
                  --       @let c = alias b    @ <- currently fails
                  --       @let y[i] = x       @
                  -- where @alias@ can be @transpose@, @slice@, @reshape@.
                  -- We use getTransitiveAlias helper function to track the aliasing
                  --    through the td_env, and to find the updated ixfun of @b@:
                  case TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
a_acc VName
b of
                    Maybe (VName, VName, IxFun)
Nothing -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc)
                    Just (VName
_, VName
_, IxFun
b_indfun') ->
                      case ( ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
b_indfun',
                             ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) Certs
certs
                           ) of
                        (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst') ->
                          let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
TransitiveCoal (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
shp VName
x_mem IxFun
b_indfun') (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                              info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                           in ((VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info' CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
                        (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc)
                Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
new_indfun) FreeVarSubsts
_) ->
                  -- we are at the definition of the coalesced variable @b@
                  -- if 2,4,5 hold promote it to successful coalesced table,
                  -- or if e = transpose, etc. then postpone decision for later on
                  let safe_2 :: Bool
safe_2 = TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
x_mem
                   in case ( ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) IxFun
new_indfun,
                             ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) Certs
certs
                           ) of
                        (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst')
                          | Bool
safe_2 ->
                              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                                  info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                               in if Bool
safe_4
                                    then -- array creation point, successful coalescing verified!

                                      let (CoalsTab
a_acc', CoalsTab
s_acc') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
a_acc, CoalsTab
s_acc) VName
mb CoalsEntry
info'
                                       in ((CoalsTab
a_acc', InhibitTab
inhb), CoalsTab
s_acc')
                                    else -- this is an invertible alias case of the kind
                                    -- @ let b    = alias a @
                                    -- @ let x[i] = b @
                                    -- do not promote, but update the index function

                                      ((VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info' CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
                        (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc) -- fail!

ixfunToAccessSummary :: IxFun.IxFun (TPrimExp Int64 VName) -> AccessSummary
ixfunToAccessSummary :: IxFun -> AccessSummary
ixfunToAccessSummary (IxFun.IxFun LmadRef
lmad [TPrimExp Int64 VName]
_) = Set LmadRef -> AccessSummary
Set (Set LmadRef -> AccessSummary) -> Set LmadRef -> AccessSummary
forall a b. (a -> b) -> a -> b
$ LmadRef -> Set LmadRef
forall a. a -> Set a
S.singleton LmadRef
lmad

-- | Check safety conditions 2 and 5 and update new substitutions:
-- called on the pat-elements of loop and if-then-else expressions.
--
-- The safety conditions are: The allocation of merge target should dominate the
-- creation of the array we're trying to merge and the new index function of the
-- array can be translated at the definition site of b. The latter requires that
-- any variables used in the index function of the target array are available at
-- the definition site of b.
filterSafetyCond2and5 ::
  (HasMemBlock (Aliases rep)) =>
  CoalsTab ->
  InhibitTab ->
  ScalarTab ->
  TopdownEnv rep ->
  [PatElem (VarAliases, LetDecMem)] ->
  (CoalsTab, InhibitTab)
filterSafetyCond2and5 :: forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5 CoalsTab
act_coal InhibitTab
inhb_coal Map VName (PrimExp VName)
scals_env TopdownEnv rep
td_env [PatElem (VarAliases, LParamMem)]
pes =
  ((CoalsTab, InhibitTab)
 -> PatElem (VarAliases, LParamMem) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab)
-> PatElem (VarAliases, LParamMem) -> (CoalsTab, InhibitTab)
helper (CoalsTab
act_coal, InhibitTab
inhb_coal) [PatElem (VarAliases, LParamMem)]
pes
  where
    helper :: (CoalsTab, InhibitTab)
-> PatElem (VarAliases, LParamMem) -> (CoalsTab, InhibitTab)
helper (CoalsTab
acc, InhibitTab
inhb) PatElem (VarAliases, LParamMem)
patel = do
      -- For each pattern element in the input list
      case (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
patel, PatElem (VarAliases, LParamMem) -> (VarAliases, LParamMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LParamMem)
patel) of
        (VName
b, (VarAliases
_, MemArray PrimType
tp0 ShapeBase SubExp
shp0 NoUniqueness
_ (ArrayIn VName
m_b IxFun
_idxfn_b))) ->
          -- If it is an array in memory block m_b
          case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
acc of
            Maybe CoalsEntry
Nothing -> (CoalsTab
acc, InhibitTab
inhb)
            Just info :: CoalsEntry
info@(CoalsEntry VName
x_mem IxFun
_ Names
_ Map VName Coalesced
vtab Map VName VName
_ MemRefs
_ Certs
certs) ->
              -- And m_b we're trying to coalesce m_b
              let failed :: (CoalsTab, InhibitTab)
failed = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
acc, InhibitTab
inhb) VName
m_b
               in -- It is not safe to short circuit if some other pattern
                  -- element is aliased to this one, as this indicates the
                  -- two pattern elements reference the same physical
                  -- value somehow.
                  if (PatElem (VarAliases, LParamMem) -> Bool)
-> [PatElem (VarAliases, LParamMem)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` PatElem (VarAliases, LParamMem) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem (VarAliases, LParamMem)
patel) (VName -> Bool)
-> (PatElem (VarAliases, LParamMem) -> VName)
-> PatElem (VarAliases, LParamMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LParamMem)]
pes
                    then (CoalsTab, InhibitTab)
failed
                    else case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b Map VName Coalesced
vtab of
                      Maybe Coalesced
Nothing ->
                        case TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, IxFun)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
acc VName
b of
                          Maybe (VName, VName, IxFun)
Nothing -> (CoalsTab, InhibitTab)
failed
                          Just (VName
_, VName
_, IxFun
b_indfun') ->
                            -- And we have the index function of b
                            case ( ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env IxFun
b_indfun',
                                   ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env Certs
certs
                                 ) of
                              (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst') ->
                                let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
TransitiveCoal (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp0 ShapeBase SubExp
shp0 VName
x_mem IxFun
b_indfun') (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                                    info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                                 in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info' CoalsTab
acc, InhibitTab
inhb)
                              (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> (CoalsTab, InhibitTab)
failed
                      Just (Coalesced CoalescedKind
k (MemBlock PrimType
pt ShapeBase SubExp
shp VName
_ IxFun
new_indfun) FreeVarSubsts
_) ->
                        let safe_2 :: Bool
safe_2 = TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
x_mem
                         in case ( ScopeTab rep
-> Map VName (PrimExp VName) -> IxFun -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env IxFun
new_indfun,
                                   ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env Certs
certs
                                 ) of
                              (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst')
                                | Bool
safe_2 ->
                                    let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
pt ShapeBase SubExp
shp VName
x_mem IxFun
new_indfun) (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                                        info' :: CoalsEntry
info' = CoalsEntry
info {vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b Coalesced
mem_info Map VName Coalesced
vtab}
                                     in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info' CoalsTab
acc, InhibitTab
inhb)
                              (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> (CoalsTab, InhibitTab)
failed
        (VName, (VarAliases, LParamMem))
_ -> (CoalsTab
acc, InhibitTab
inhb)

-- |   Pattern matches a potentially coalesced statement and
--     records a new association in @activeCoals@
mkCoalsHelper3PatternMatch ::
  (Coalesceable rep inner) =>
  Stm (Aliases rep) ->
  LUTabFun ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env BotUpEnv
bu_env = do
  Maybe [SSPointInfo]
clst <- InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Stm (Aliases rep)
-> ShortCircuitM rep (Maybe [SSPointInfo])
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Stm (Aliases rep)
-> ShortCircuitM rep (Maybe [SSPointInfo])
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Stm (Aliases rep)
stm
  case Maybe [SSPointInfo]
clst of
    Maybe [SSPointInfo]
Nothing -> CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
activeCoals_tab
    Just [SSPointInfo]
clst' -> CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoalsTab -> ShortCircuitM rep CoalsTab)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall a b. (a -> b) -> a -> b
$ (CoalsTab -> SSPointInfo -> CoalsTab)
-> CoalsTab -> [SSPointInfo] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> SSPointInfo -> CoalsTab
processNewCoalesce CoalsTab
activeCoals_tab [SSPointInfo]
clst'
  where
    successCoals_tab :: CoalsTab
successCoals_tab = BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env
    activeCoals_tab :: CoalsTab
activeCoals_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
    processNewCoalesce :: CoalsTab -> SSPointInfo -> CoalsTab
processNewCoalesce CoalsTab
acc (CoalescedKind
knd, IxFun -> IxFun
alias_fn, VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
_, PrimType
tp_b, ShapeBase SubExp
shp_b, Certs
certs) =
      -- test whether we are in a transitive coalesced case, i.e.,
      --      @let b = scratch ...@
      --      @.....@
      --      @let x[j] = b@
      --      @let y[i] = x@
      -- and compose the index function of @x@ with that of @y@,
      -- and update aliasing of the @m_b@ entry to also contain @m_y@
      -- on top of @m_x@, i.e., transitively, any use of @m_y@ should
      -- be checked for the lifetime of @b@.
      let proper_coals_tab :: CoalsTab
proper_coals_tab = case CoalescedKind
knd of
            CoalescedKind
InPlaceCoal -> CoalsTab
activeCoals_tab
            CoalescedKind
_ -> CoalsTab
successCoals_tab
          (VName
m_yx, IxFun
ind_yx, Names
mem_yx_al, Map VName VName
x_deps, Certs
certs') =
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
proper_coals_tab of
              Maybe CoalsEntry
Nothing ->
                (VName
m_x, IxFun -> IxFun
alias_fn IxFun
ind_x, VName -> Names
oneName VName
m_x, Map VName VName
forall k a. Map k a
M.empty, Certs
forall a. Monoid a => a
mempty)
              Just (CoalsEntry VName
m_y IxFun
ind_y Names
y_al Map VName Coalesced
vtab Map VName VName
x_deps0 MemRefs
_ Certs
certs'') ->
                let ind :: IxFun
ind = case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName Coalesced
vtab of
                      Just (Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ IxFun
ixf) FreeVarSubsts
_) ->
                        IxFun
ixf
                      Maybe Coalesced
Nothing ->
                        IxFun
ind_y
                 in (VName
m_y, IxFun -> IxFun
alias_fn IxFun
ind, VName -> Names
oneName VName
m_x Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
y_al, Map VName VName
x_deps0, Certs
certs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
certs'')
          m_b_aliased_m_yx :: Bool
m_b_aliased_m_yx = TopdownEnv rep -> VName -> [VName] -> Bool
forall rep. TopdownEnv rep -> VName -> [VName] -> Bool
areAnyAliased TopdownEnv rep
td_env VName
m_b [VName
m_yx] -- m_b \= m_yx
       in if Bool -> Bool
not Bool
m_b_aliased_m_yx Bool -> Bool -> Bool
&& TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
m_yx -- nameIn m_yx (alloc td_env)
      -- Finally update the @activeCoals@ table with a fresh
      --   binding for @m_b@; if such one exists then overwrite.
      -- Also, add all variables from the alias chain of @b@ to
      --   @vartab@, for example, in the case of a sequence:
      --   @ b0 = if cond then ... else ... @
      --   @ b1 = alias0 b0 @
      --   @ b  = alias1 b1 @
      --   @ x[j] = b @
      -- Then @b1@ and @b0@ should also be added to @vartab@ if
      --   @alias1@ and @alias0@ are invertible, otherwise fail early!
            then
              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
tp_b ShapeBase SubExp
shp_b VName
m_yx IxFun
ind_yx) FreeVarSubsts
forall k a. Map k a
M.empty
                  opts' :: Map VName VName
opts' =
                    if VName
m_yx VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_x
                      then Map VName VName
forall k a. Map k a
M.empty
                      else VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x VName
m_x Map VName VName
x_deps
                  vtab :: Map VName Coalesced
vtab = VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
b Coalesced
mem_info
                  mvtab :: Maybe (Map VName Coalesced)
mvtab = TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env Map VName Coalesced
vtab VName
b

                  is_inhibited :: Bool
is_inhibited = case VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b (InhibitTab -> Maybe Names) -> InhibitTab -> Maybe Names
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> InhibitTab
forall rep. TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
td_env of
                    Just Names
nms -> VName
m_yx VName -> Names -> Bool
`nameIn` Names
nms
                    Maybe Names
Nothing -> Bool
False
               in case (Bool
is_inhibited, Maybe (Map VName Coalesced)
mvtab) of
                    (Bool
True, Maybe (Map VName Coalesced)
_) -> CoalsTab
acc -- fail due to inhibited
                    (Bool
_, Maybe (Map VName Coalesced)
Nothing) -> CoalsTab
acc -- fail early due to non-invertible aliasing
                    (Bool
_, Just Map VName Coalesced
vtab') ->
                      -- successfully adding a new coalesced entry
                      let coal_etry :: CoalsEntry
coal_etry =
                            VName
-> IxFun
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> Certs
-> CoalsEntry
CoalsEntry
                              VName
m_yx
                              IxFun
ind_yx
                              Names
mem_yx_al
                              Map VName Coalesced
vtab'
                              Map VName VName
opts'
                              MemRefs
forall a. Monoid a => a
mempty
                              (Certs
certs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
certs')
                       in VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
coal_etry CoalsTab
acc
            else CoalsTab
acc

-- | Information about a particular short-circuit point
type SSPointInfo =
  ( CoalescedKind,
    IxFun -> IxFun,
    VName,
    VName,
    IxFun,
    VName,
    VName,
    IxFun,
    PrimType,
    Shape,
    Certs
  )

-- | Given an op, return a list of potential short-circuit points
type GenSSPoint rep op =
  LUTabFun ->
  TopdownEnv rep ->
  ScopeTab rep ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  op ->
  Maybe [SSPointInfo]

genSSPointInfoSeqMem ::
  GenSSPoint SeqMem (Op (Aliases SeqMem))
genSSPointInfoSeqMem :: InhibitTab
-> TopdownEnv SeqMem
-> ScopeTab SeqMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> Maybe [SSPointInfo]
genSSPointInfoSeqMem InhibitTab
_ TopdownEnv SeqMem
_ ScopeTab SeqMem
_ Pat (VarAliases, LParamMem)
_ Certs
_ Op (Aliases SeqMem)
_ =
  Maybe [SSPointInfo]
forall a. Maybe a
Nothing

-- | For 'SegOp', we currently only handle 'SegMap', and only under the following
-- circumstances:
--
--  1. The 'SegMap' has only one return/pattern value, which is a 'Returns'.
--
--  2. The 'KernelBody' contains an 'Index' statement that is indexing an array using
--  only the values from the 'SegSpace'.
--
--  3. The array being indexed is last-used in that statement, is free in the
--  'SegMap', is unique or has been recently allocated (specifically, it should
--  not be a non-unique argument to the enclosing function), has elements with
--  the same bit-size as the pattern elements, and has the exact same 'IxFun' as
--  the pattern of the 'SegMap' statement.
--
-- There can be multiple candidate arrays, but the current implementation will
-- always just try the first one.
--
-- The first restriction could be relaxed by trying to match up arrays in the
-- 'KernelBody' with patterns of the 'SegMap', but the current implementation
-- should be enough to handle many common cases.
--
-- The result of the 'SegMap' is treated as the destination, while the candidate
-- array from inside the body is treated as the source.
genSSPointInfoSegOp ::
  (Coalesceable rep inner) => GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp
  InhibitTab
lutab
  TopdownEnv rep
td_env
  ScopeTab rep
scopetab
  (Pat [PatElem VName
dst (VarAliases
_, MemArray PrimType
dst_pt ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
dst_mem IxFun
dst_ixf))])
  Certs
certs
  (SegMap lvl
_ SegSpace
space [Type]
_ kernel_body :: KernelBody (Aliases rep)
kernel_body@KernelBody {kernelBodyResult :: forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult = [Returns {}]})
    | (VName
src, MemBlock PrimType
src_pt ShapeBase SubExp
shp VName
src_mem IxFun
src_ixf) : [(VName, ArrayMemBound)]
_ <-
        (Stm (Aliases rep) -> Maybe (VName, ArrayMemBound))
-> [Stm (Aliases rep)] -> [(VName, ArrayMemBound)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm (Aliases rep) -> Maybe (VName, ArrayMemBound)
getPotentialMapShortCircuit ([Stm (Aliases rep)] -> [(VName, ArrayMemBound)])
-> [Stm (Aliases rep)] -> [(VName, ArrayMemBound)]
forall a b. (a -> b) -> a -> b
$
          Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
            KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body =
        [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
MapCoal, IxFun -> IxFun
forall a. a -> a
id, VName
dst, VName
dst_mem, IxFun
dst_ixf, VName
src, VName
src_mem, IxFun
src_ixf, PrimType
src_pt, ShapeBase SubExp
shp, Certs
certs)]
    where
      iterators :: [VName]
iterators = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      frees :: Names
frees = KernelBody (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases rep)
kernel_body

      getPotentialMapShortCircuit :: Stm (Aliases rep) -> Maybe (VName, ArrayMemBound)
getPotentialMapShortCircuit (Let (Pat [PatElem VName
x LetDec (Aliases rep)
_]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Index VName
src Slice SubExp
slc)))
        | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slc,
          [SubExp] -> [SubExp]
forall a. Ord a => [a] -> [a]
L.sort [SubExp]
inds [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> [SubExp]
forall a. Ord a => [a] -> [a]
L.sort ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
iterators),
          Just Names
last_uses <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab,
          VName
src VName -> Names -> Bool
`nameIn` Names
last_uses,
          Just memblock :: ArrayMemBound
memblock@(MemBlock PrimType
src_pt ShapeBase SubExp
_ VName
src_mem IxFun
src_ixf) <-
            VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
src ScopeTab rep
scopetab,
          VName
src_mem VName -> Names -> Bool
`nameIn` Names
last_uses,
          -- The 'alloc' table contains allocated memory blocks, including
          -- unique memory blocks from the enclosing function. It does _not_
          -- include non-unique memory blocks from the enclosing function.
          VName
src_mem VName -> AllocTab -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` TopdownEnv rep -> AllocTab
forall rep. TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env,
          VName
src VName -> Names -> Bool
`nameIn` Names
frees,
          IxFun
src_ixf IxFun -> IxFun -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun
dst_ixf,
          PrimType -> Int
primBitSize PrimType
src_pt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> Int
primBitSize PrimType
dst_pt =
            (VName, ArrayMemBound) -> Maybe (VName, ArrayMemBound)
forall a. a -> Maybe a
Just (VName
src, ArrayMemBound
memblock)
      getPotentialMapShortCircuit Stm (Aliases rep)
_ = Maybe (VName, ArrayMemBound)
forall a. Maybe a
Nothing
genSSPointInfoSegOp InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ SegOp lvl (Aliases rep)
_ =
  Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genSSPointInfoMemOp ::
  GenSSPoint rep (inner (Aliases rep)) ->
  GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp :: forall rep (inner :: * -> *).
GenSSPoint rep (inner (Aliases rep))
-> GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp GenSSPoint rep (inner (Aliases rep))
onOp InhibitTab
lutab TopdownEnv rep
td_end ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner inner (Aliases rep)
op) =
  GenSSPoint rep (inner (Aliases rep))
onOp InhibitTab
lutab TopdownEnv rep
td_end ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs inner (Aliases rep)
op
genSSPointInfoMemOp GenSSPoint rep (inner (Aliases rep))
_ InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ MemOp inner (Aliases rep)
_ = Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genSSPointInfoGPUMem ::
  GenSSPoint GPUMem (Op (Aliases GPUMem))
genSSPointInfoGPUMem :: InhibitTab
-> TopdownEnv GPUMem
-> ScopeTab GPUMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> Maybe [SSPointInfo]
genSSPointInfoGPUMem = GenSSPoint GPUMem (HostOp NoOp (Aliases GPUMem))
-> GenSSPoint GPUMem (MemOp (HostOp NoOp) (Aliases GPUMem))
forall rep (inner :: * -> *).
GenSSPoint rep (inner (Aliases rep))
-> GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp GenSSPoint GPUMem (HostOp NoOp (Aliases GPUMem))
forall {rep} {inner :: * -> *} {op :: * -> *}.
(LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem,
 OpReturns (inner rep), RephraseOp inner, ASTRep rep,
 CanBeAliased inner, AliasedOp (inner (Aliases rep)),
 HasMemBlock (Aliases rep), TopDownHelper (inner (Aliases rep))) =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> HostOp op (Aliases rep)
-> Maybe [SSPointInfo]
f
  where
    f :: InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> HostOp op (Aliases rep)
-> Maybe [SSPointInfo]
f InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs (GPU.SegOp SegOp SegLevel (Aliases rep)
op) =
      GenSSPoint rep (SegOp SegLevel (Aliases rep))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp SegLevel (Aliases rep)
op
    f InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ HostOp op (Aliases rep)
_ = Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genSSPointInfoMCMem ::
  GenSSPoint MCMem (Op (Aliases MCMem))
genSSPointInfoMCMem :: InhibitTab
-> TopdownEnv MCMem
-> ScopeTab MCMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> Maybe [SSPointInfo]
genSSPointInfoMCMem = GenSSPoint MCMem (MCOp NoOp (Aliases MCMem))
-> GenSSPoint MCMem (MemOp (MCOp NoOp) (Aliases MCMem))
forall rep (inner :: * -> *).
GenSSPoint rep (inner (Aliases rep))
-> GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp GenSSPoint MCMem (MCOp NoOp (Aliases MCMem))
forall {rep} {inner :: * -> *} {op :: * -> *}.
(LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem,
 OpReturns (inner rep), RephraseOp inner, ASTRep rep,
 CanBeAliased inner, AliasedOp (inner (Aliases rep)),
 HasMemBlock (Aliases rep), TopDownHelper (inner (Aliases rep))) =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> MCOp op (Aliases rep)
-> Maybe [SSPointInfo]
f
  where
    f :: InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> MCOp op (Aliases rep)
-> Maybe [SSPointInfo]
f InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs (MC.ParOp Maybe (SegOp () (Aliases rep))
Nothing SegOp () (Aliases rep)
op) =
      GenSSPoint rep (SegOp () (Aliases rep))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases rep)
op
    f InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ MCOp op (Aliases rep)
_ = Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genCoalStmtInfo ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  TopdownEnv rep ->
  ScopeTab rep ->
  Stm (Aliases rep) ->
  ShortCircuitM rep (Maybe [SSPointInfo])
-- CASE a) @let x <- copy(b^{lu})@
genCoalStmtInfo :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Stm (Aliases rep)
-> ShortCircuitM rep (Maybe [SSPointInfo])
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
_ ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (Replicate (Shape []) (Var VName
b))))
  | Pat [PatElem VName
x (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (LetDec (Aliases rep))
pat =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ case (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab, VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab) of
        (Just Names
last_uses, Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b IxFun
ind_b)) ->
          if VName
b VName -> Names -> Bool
`notNameIn` Names
last_uses
            then Maybe [SSPointInfo]
forall a. Maybe a
Nothing
            else [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
CopyCoal, IxFun -> IxFun
forall a. a -> a
id, VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)]
        (Maybe Names, Maybe ArrayMemBound)
_ -> Maybe [SSPointInfo]
forall a. Maybe a
Nothing
-- CASE c) @let x[i] = b^{lu}@
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
_ ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (Update Safety
_ VName
x Slice SubExp
slice_x (Var VName
b))))
  | Pat [PatElem VName
x' (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (LetDec (Aliases rep))
pat =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ case (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x' InhibitTab
lutab, VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab) of
        (Just Names
last_uses, Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b IxFun
ind_b)) ->
          if VName
b VName -> Names -> Bool
`notNameIn` Names
last_uses
            then Maybe [SSPointInfo]
forall a. Maybe a
Nothing
            else [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
InPlaceCoal, (IxFun -> Slice SubExp -> IxFun
`updateIndFunSlice` Slice SubExp
slice_x), VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)]
        (Maybe Names, Maybe ArrayMemBound)
_ -> Maybe [SSPointInfo]
forall a. Maybe a
Nothing
  where
    updateIndFunSlice :: IxFun -> Slice SubExp -> IxFun
    updateIndFunSlice :: IxFun -> Slice SubExp -> IxFun
updateIndFunSlice IxFun
ind_fun Slice SubExp
slc_x =
      let slc_x' :: [DimIndex (TPrimExp Int64 VName)]
slc_x' = (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> DimIndex a -> DimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) ([DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)])
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc_x
       in IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ind_fun (Slice (TPrimExp Int64 VName) -> IxFun)
-> Slice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice [DimIndex (TPrimExp Int64 VName)]
slc_x'
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
_ ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (FlatUpdate VName
x FlatSlice SubExp
slice_x VName
b)))
  | Pat [PatElem VName
x' (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (LetDec (Aliases rep))
pat =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ case (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x' InhibitTab
lutab, VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab) of
        (Just Names
last_uses, Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b IxFun
ind_b)) ->
          if VName
b VName -> Names -> Bool
`notNameIn` Names
last_uses
            then Maybe [SSPointInfo]
forall a. Maybe a
Nothing
            else [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
InPlaceCoal, (IxFun -> FlatSlice SubExp -> IxFun
`updateIndFunSlice` FlatSlice SubExp
slice_x), VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)]
        (Maybe Names, Maybe ArrayMemBound)
_ -> Maybe [SSPointInfo]
forall a. Maybe a
Nothing
  where
    updateIndFunSlice :: IxFun -> FlatSlice SubExp -> IxFun
    updateIndFunSlice :: IxFun -> FlatSlice SubExp -> IxFun
updateIndFunSlice IxFun
ind_fun (FlatSlice SubExp
offset [FlatDimIndex SubExp]
dims) =
      IxFun -> FlatSlice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun
ind_fun (FlatSlice (TPrimExp Int64 VName) -> IxFun)
-> FlatSlice (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset) ([FlatDimIndex (TPrimExp Int64 VName)]
 -> FlatSlice (TPrimExp Int64 VName))
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName))
-> [FlatDimIndex SubExp] -> [FlatDimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> FlatDimIndex a -> FlatDimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
dims

-- CASE b) @let x = concat(a, b^{lu})@
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
_ ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (Concat Int
concat_dim (VName
b0 :| [VName]
bs) SubExp
_)))
  | Pat [PatElem VName
x (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x IxFun
ind_x))] <- Pat (LetDec (Aliases rep))
pat =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ case VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab of
        Maybe Names
Nothing -> Maybe [SSPointInfo]
forall a. Maybe a
Nothing
        Just Names
last_uses ->
          let zero :: TPrimExp Int64 VName
zero = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
              markConcatParts :: ([SSPointInfo], TPrimExp Int64 VName, Bool)
-> VName -> ([SSPointInfo], TPrimExp Int64 VName, Bool)
markConcatParts ([SSPointInfo]
acc, TPrimExp Int64 VName
offs, Bool
succ0) VName
b =
                if Bool -> Bool
not Bool
succ0
                  then ([SSPointInfo]
acc, TPrimExp Int64 VName
offs, Bool
succ0)
                  else case VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab of
                    Just (MemBlock PrimType
tpb shpb :: ShapeBase SubExp
shpb@(Shape dims :: [SubExp]
dims@(SubExp
_ : [SubExp]
_)) VName
m_b IxFun
ind_b)
                      | Just SubExp
d <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
concat_dim [SubExp]
dims ->
                          let offs' :: TPrimExp Int64 VName
offs' = TPrimExp Int64 VName
offs TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
d
                           in if VName
b VName -> Names -> Bool
`nameIn` Names
last_uses
                                then
                                  let slc :: Slice (TPrimExp Int64 VName)
slc =
                                        [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                                          (SubExp -> DimIndex (TPrimExp Int64 VName))
-> [SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
zero (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
concat_dim [SubExp]
dims)
                                            [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Semigroup a => a -> a -> a
<> [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
offs (SubExp -> TPrimExp Int64 VName
pe64 SubExp
d)]
                                            [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Semigroup a => a -> a -> a
<> (SubExp -> DimIndex (TPrimExp Int64 VName))
-> [SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
zero (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop (Int
concat_dim Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [SubExp]
dims)
                                   in ( [SSPointInfo]
acc [SSPointInfo] -> [SSPointInfo] -> [SSPointInfo]
forall a. [a] -> [a] -> [a]
++ [(CoalescedKind
ConcatCoal, (IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
`IxFun.slice` Slice (TPrimExp Int64 VName)
slc), VName
x, VName
m_x, IxFun
ind_x, VName
b, VName
m_b, IxFun
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)],
                                        TPrimExp Int64 VName
offs',
                                        Bool
True
                                      )
                                else ([SSPointInfo]
acc, TPrimExp Int64 VName
offs', Bool
True)
                    Maybe ArrayMemBound
_ -> ([SSPointInfo]
acc, TPrimExp Int64 VName
offs, Bool
False)
              ([SSPointInfo]
res, TPrimExp Int64 VName
_, Bool
_) = (([SSPointInfo], TPrimExp Int64 VName, Bool)
 -> VName -> ([SSPointInfo], TPrimExp Int64 VName, Bool))
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
-> [VName]
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([SSPointInfo], TPrimExp Int64 VName, Bool)
-> VName -> ([SSPointInfo], TPrimExp Int64 VName, Bool)
markConcatParts ([], TPrimExp Int64 VName
zero, Bool
True) (VName
b0 VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
bs)
           in if [SSPointInfo] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SSPointInfo]
res then Maybe [SSPointInfo]
forall a. Maybe a
Nothing else [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [SSPointInfo]
res
-- case d) short-circuit points from ops. For instance, the result of a segmap
-- can be considered a short-circuit point.
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (Op Op (Aliases rep)
op)) = do
  InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> OpC rep (Aliases rep)
-> Maybe [SSPointInfo]
ss_op <- (ShortCircuitReader rep
 -> InhibitTab
 -> TopdownEnv rep
 -> ScopeTab rep
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> OpC rep (Aliases rep)
 -> Maybe [SSPointInfo])
-> ShortCircuitM
     rep
     (InhibitTab
      -> TopdownEnv rep
      -> ScopeTab rep
      -> Pat (VarAliases, LParamMem)
      -> Certs
      -> OpC rep (Aliases rep)
      -> Maybe [SSPointInfo])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> OpC rep (Aliases rep)
-> Maybe [SSPointInfo]
ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> Maybe [SSPointInfo]
forall rep.
ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> Maybe [SSPointInfo]
ssPointFromOp
  Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> OpC rep (Aliases rep)
-> Maybe [SSPointInfo]
ss_op InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux) OpC rep (Aliases rep)
Op (Aliases rep)
op
-- CASE other than a), b), c), or d) not supported
genCoalStmtInfo InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Stm (Aliases rep)
_ = Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [SSPointInfo]
forall a. Maybe a
Nothing

data MemBodyResult = MemBodyResult
  { MemBodyResult -> VName
patMem :: VName,
    MemBodyResult -> VName
_patName :: VName,
    MemBodyResult -> VName
bodyName :: VName,
    MemBodyResult -> VName
bodyMem :: VName
  }

-- | Results in pairs of pattern-blockresult pairs of (var name, mem block)
--   for those if-patterns that are candidates for coalescing.
findMemBodyResult ::
  (HasMemBlock (Aliases rep)) =>
  CoalsTab ->
  ScopeTab rep ->
  [PatElem (VarAliases, LetDecMem)] ->
  Body (Aliases rep) ->
  [MemBodyResult]
findMemBodyResult :: forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals_tab ScopeTab rep
scope_env [PatElem (VarAliases, LParamMem)]
patelms Body (Aliases rep)
bdy =
  ((PatElem (VarAliases, LParamMem), SubExp) -> Maybe MemBodyResult)
-> [(PatElem (VarAliases, LParamMem), SubExp)] -> [MemBodyResult]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
    (PatElem (VarAliases, LParamMem), SubExp) -> Maybe MemBodyResult
findMemBodyResult'
    ([PatElem (VarAliases, LParamMem)]
-> [SubExp] -> [(PatElem (VarAliases, LParamMem), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (VarAliases, LParamMem)]
patelms ([SubExp] -> [(PatElem (VarAliases, LParamMem), SubExp)])
-> [SubExp] -> [(PatElem (VarAliases, LParamMem), SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
bdy)
  where
    scope_env' :: ScopeTab rep
scope_env' = ScopeTab rep
scope_env ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
bdy)
    findMemBodyResult' :: (PatElem (VarAliases, LParamMem), SubExp) -> Maybe MemBodyResult
findMemBodyResult' (PatElem (VarAliases, LParamMem)
patel, SubExp
se_r) =
      case (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
patel, PatElem (VarAliases, LParamMem) -> (VarAliases, LParamMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LParamMem)
patel, SubExp
se_r) of
        (VName
b, (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_b IxFun
_)), Var VName
r) ->
          case VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
r ScopeTab rep
scope_env' of
            Maybe ArrayMemBound
Nothing -> Maybe MemBodyResult
forall a. Maybe a
Nothing
            Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_r IxFun
_) ->
              case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
activeCoals_tab of
                Maybe CoalsEntry
Nothing -> Maybe MemBodyResult
forall a. Maybe a
Nothing
                Just CoalsEntry
coal_etry ->
                  case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry) of
                    Maybe Coalesced
Nothing -> Maybe MemBodyResult
forall a. Maybe a
Nothing
                    Just Coalesced
_ -> MemBodyResult -> Maybe MemBodyResult
forall a. a -> Maybe a
Just (MemBodyResult -> Maybe MemBodyResult)
-> MemBodyResult -> Maybe MemBodyResult
forall a b. (a -> b) -> a -> b
$ VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r
        (VName, (VarAliases, LParamMem), SubExp)
_ -> Maybe MemBodyResult
forall a. Maybe a
Nothing

-- | transfers coalescing from if-pattern to then|else body result
--   in the active coalesced table. The transfer involves, among
--   others, inserting @(r,m_r)@ in the optimistically-dependency
--   set of @m_b@'s entry and inserting @(b,m_b)@ in the opt-deps
--   set of @m_r@'s entry. Meaning, ultimately, @m_b@ can be merged
--   if @m_r@ can be merged (and vice-versa). This is checked by a
--   fix point iteration at the function-definition level.
transferCoalsToBody ::
  M.Map VName (TPrimExp Int64 VName) -> -- (PrimExp VName)
  CoalsTab ->
  MemBodyResult ->
  CoalsTab
transferCoalsToBody :: FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
exist_subs CoalsTab
activeCoals_tab (MemBodyResult VName
m_b VName
b VName
r VName
m_r)
  | -- the @Nothing@ pattern for the two lookups cannot happen
    -- because they were already cheked in @findMemBodyResult@
    Just CoalsEntry
etry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
activeCoals_tab,
    Just (Coalesced CoalescedKind
knd (MemBlock PrimType
btp ShapeBase SubExp
shp VName
_ IxFun
ind_b) FreeVarSubsts
subst_b) <- VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry =
      -- by definition of if-stmt, r and b have the same basic type, shape and
      -- index function, hence, for example, do not need to rebase
      -- We will check whether it is translatable at the definition point of r.
      let ind_r :: IxFun
ind_r = FreeVarSubsts -> IxFun -> IxFun
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
IxFun.substituteInIxFun FreeVarSubsts
exist_subs IxFun
ind_b
          subst_r :: FreeVarSubsts
subst_r = FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union FreeVarSubsts
exist_subs FreeVarSubsts
subst_b
          mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd (PrimType -> ShapeBase SubExp -> VName -> IxFun -> ArrayMemBound
MemBlock PrimType
btp ShapeBase SubExp
shp (CoalsEntry -> VName
dstmem CoalsEntry
etry) IxFun
ind_r) FreeVarSubsts
subst_r
       in if VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b -- already unified, just add binding for @r@
            then
              let etry' :: CoalsEntry
etry' =
                    CoalsEntry
etry
                      { optdeps :: Map VName VName
optdeps = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b VName
m_b (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry),
                        vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced -> Map VName Coalesced
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
r Coalesced
mem_info (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry)
                      }
               in VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r CoalsEntry
etry' CoalsTab
activeCoals_tab
            else -- make them both optimistically depend on each other

              let opts_x_new :: Map VName VName
opts_x_new = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
r VName
m_r (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry)
                  -- Here we should translate the @ind_b@ field of @mem_info@
                  -- across the existential introduced by the if-then-else
                  coal_etry :: CoalsEntry
coal_etry =
                    CoalsEntry
etry
                      { vartab :: Map VName Coalesced
vartab = VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
r Coalesced
mem_info,
                        optdeps :: Map VName VName
optdeps = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
b VName
m_b (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry)
                      }
               in VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b (CoalsEntry
etry {optdeps :: Map VName VName
optdeps = Map VName VName
opts_x_new}) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$
                    VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r CoalsEntry
coal_etry CoalsTab
activeCoals_tab
  | Bool
otherwise = String -> CoalsTab
forall a. HasCallStack => String -> a
error String
"Impossible"

mkSubsTab ::
  Pat (aliases, LetDecMem) ->
  [SubExp] ->
  M.Map VName (TPrimExp Int64 VName)
mkSubsTab :: forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (aliases, LParamMem)
pat [SubExp]
res =
  let pat_elms :: [PatElem (aliases, LParamMem)]
pat_elms = Pat (aliases, LParamMem) -> [PatElem (aliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (aliases, LParamMem)
pat
   in [(VName, TPrimExp Int64 VName)] -> FreeVarSubsts
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, TPrimExp Int64 VName)] -> FreeVarSubsts)
-> [(VName, TPrimExp Int64 VName)] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ ((PatElem (aliases, LParamMem), SubExp)
 -> Maybe (VName, TPrimExp Int64 VName))
-> [(PatElem (aliases, LParamMem), SubExp)]
-> [(VName, TPrimExp Int64 VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (PatElem (aliases, LParamMem), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
forall {a} {d} {u} {ret}.
(PatElem (a, MemInfo d u ret), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
mki64subst ([(PatElem (aliases, LParamMem), SubExp)]
 -> [(VName, TPrimExp Int64 VName)])
-> [(PatElem (aliases, LParamMem), SubExp)]
-> [(VName, TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ [PatElem (aliases, LParamMem)]
-> [SubExp] -> [(PatElem (aliases, LParamMem), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (aliases, LParamMem)]
pat_elms [SubExp]
res
  where
    mki64subst :: (PatElem (a, MemInfo d u ret), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
mki64subst (PatElem (a, MemInfo d u ret)
a, Var VName
v)
      | (a
_, MemPrim (IntType IntType
Int64)) <- PatElem (a, MemInfo d u ret) -> (a, MemInfo d u ret)
forall dec. PatElem dec -> dec
patElemDec PatElem (a, MemInfo d u ret)
a = (VName, TPrimExp Int64 VName)
-> Maybe (VName, TPrimExp Int64 VName)
forall a. a -> Maybe a
Just (PatElem (a, MemInfo d u ret) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (a, MemInfo d u ret)
a, VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
v)
    mki64subst (PatElem (a, MemInfo d u ret)
a, se :: SubExp
se@(Constant (IntValue (Int64Value Int64
_)))) = (VName, TPrimExp Int64 VName)
-> Maybe (VName, TPrimExp Int64 VName)
forall a. a -> Maybe a
Just (PatElem (a, MemInfo d u ret) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (a, MemInfo d u ret)
a, SubExp -> TPrimExp Int64 VName
pe64 SubExp
se)
    mki64subst (PatElem (a, MemInfo d u ret), SubExp)
_ = Maybe (VName, TPrimExp Int64 VName)
forall a. Maybe a
Nothing

computeScalarTable ::
  (Coalesceable rep inner) =>
  ScopeTab rep ->
  Stm (Aliases rep) ->
  ScalarTableM rep (M.Map VName (PrimExp VName))
computeScalarTable :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable ScopeTab rep
scope_table (Let (Pat [PatElem (LetDec (Aliases rep))
pe]) StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e)
  | Just PrimExp VName
primexp <- (VName -> Maybe (PrimExp VName))
-> Exp (Aliases rep) -> Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
forall rep.
AliasableRep rep =>
ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
vnameToPrimExp ScopeTab rep
scope_table Map VName (PrimExp VName)
forall a. Monoid a => a
mempty) Exp (Aliases rep)
e =
      Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName (PrimExp VName)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Map VName (PrimExp VName)
forall k a. k -> a -> Map k a
M.singleton (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
pe) PrimExp VName
primexp
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Loop [(FParam (Aliases rep), SubExp)]
loop_inits LoopForm
loop_form Body (Aliases rep)
body)) =
  (Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$
        ScopeTab rep
scope_table
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem] -> ScopeTab rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
loop_inits)
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> LoopForm -> ScopeTab rep
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
loop_form
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
    )
    (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Match [SubExp]
_ [Case (Body (Aliases rep))]
cases Body (Aliases rep)
body MatchDec (BranchType (Aliases rep))
_)) = do
  Map VName (PrimExp VName)
body_tab <- (Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM (ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ ScopeTab rep
scope_table ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)) (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
  Map VName (PrimExp VName)
cases_tab <-
    (Case (Body (Aliases rep))
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Case (Body (Aliases rep))]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
      ( \(Case [Maybe PrimValue]
_ Body (Aliases rep)
b) ->
          (Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
            (ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ ScopeTab rep
scope_table ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
b))
            ( Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
                Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body
            )
      )
      [Case (Body (Aliases rep))]
cases
  Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName (PrimExp VName)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ Map VName (PrimExp VName)
body_tab Map VName (PrimExp VName)
-> Map VName (PrimExp VName) -> Map VName (PrimExp VName)
forall a. Semigroup a => a -> a -> a
<> Map VName (PrimExp VName)
cases_tab
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Op Op (Aliases rep)
op)) = do
  ScopeTab rep
-> MemOp inner (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
on_op <- (ComputeScalarTableOnOp rep
 -> ScopeTab rep
 -> MemOp inner (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ReaderT
     (ComputeScalarTableOnOp rep)
     Identity
     (ScopeTab rep
      -> MemOp inner (Aliases rep)
      -> ReaderT
           (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ComputeScalarTableOnOp rep
-> ComputeScalarTable rep (Op (Aliases rep))
ComputeScalarTableOnOp rep
-> ScopeTab rep
-> MemOp inner (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep.
ComputeScalarTableOnOp rep
-> ComputeScalarTable rep (Op (Aliases rep))
scalarTableOnOp
  ScopeTab rep
-> MemOp inner (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
on_op ScopeTab rep
scope_table Op (Aliases rep)
MemOp inner (Aliases rep)
op
computeScalarTable ScopeTab rep
_ Stm (Aliases rep)
_ = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty

computeScalarTableMemOp ::
  ComputeScalarTable rep (inner (Aliases rep)) -> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp :: forall rep (inner :: * -> *).
ComputeScalarTable rep (inner (Aliases rep))
-> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp ComputeScalarTable rep (inner (Aliases rep))
_ ScopeTab rep
_ (Alloc SubExp
_ Space
_) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableMemOp ComputeScalarTable rep (inner (Aliases rep))
onInner ScopeTab rep
scope_table (Inner inner (Aliases rep)
op) = ComputeScalarTable rep (inner (Aliases rep))
onInner ScopeTab rep
scope_table inner (Aliases rep)
op

computeScalarTableSegOp ::
  (Coalesceable rep inner) =>
  ComputeScalarTable rep (GPU.SegOp lvl (Aliases rep))
computeScalarTableSegOp :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab rep
scope_table SegOp lvl (Aliases rep)
segop = do
  (Stm (Aliases rep) -> ScalarTableM rep (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ScalarTableM rep (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ScalarTableM rep (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$
        ScopeTab rep
scope_table
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms (KernelBody (Aliases rep) -> Stms (Aliases rep))
-> KernelBody (Aliases rep) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$ SegOp lvl (Aliases rep) -> KernelBody (Aliases rep)
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl (Aliases rep)
segop)
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> SegSpace -> ScopeTab rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl (Aliases rep) -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl (Aliases rep)
segop)
    )
    (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms (KernelBody (Aliases rep) -> Stms (Aliases rep))
-> KernelBody (Aliases rep) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$ SegOp lvl (Aliases rep) -> KernelBody (Aliases rep)
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl (Aliases rep)
segop)

computeScalarTableGPUMem ::
  ComputeScalarTable GPUMem (GPU.HostOp NoOp (Aliases GPUMem))
computeScalarTableGPUMem :: ComputeScalarTable GPUMem (HostOp NoOp (Aliases GPUMem))
computeScalarTableGPUMem ScopeTab GPUMem
scope_table (GPU.SegOp SegOp SegLevel (Aliases GPUMem)
segop) =
  ComputeScalarTable GPUMem (SegOp SegLevel (Aliases GPUMem))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab GPUMem
scope_table SegOp SegLevel (Aliases GPUMem)
segop
computeScalarTableGPUMem ScopeTab GPUMem
_ (GPU.SizeOp SizeOp
_) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp GPUMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
_ (GPU.OtherOp NoOp (Aliases GPUMem)
NoOp) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp GPUMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
scope_table (GPU.GPUBody [Type]
_ Body (Aliases GPUMem)
body) =
  (Stm (Aliases GPUMem)
 -> ReaderT
      (ComputeScalarTableOnOp GPUMem)
      Identity
      (Map VName (PrimExp VName)))
-> [Stm (Aliases GPUMem)]
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    (ScopeTab GPUMem
-> Stm (Aliases GPUMem)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab GPUMem
 -> Stm (Aliases GPUMem)
 -> ReaderT
      (ComputeScalarTableOnOp GPUMem)
      Identity
      (Map VName (PrimExp VName)))
-> ScopeTab GPUMem
-> Stm (Aliases GPUMem)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ ScopeTab GPUMem
scope_table ScopeTab GPUMem -> ScopeTab GPUMem -> ScopeTab GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases GPUMem) -> ScopeTab GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases GPUMem) -> Stms (Aliases GPUMem)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases GPUMem)
body))
    (Stms (Aliases GPUMem) -> [Stm (Aliases GPUMem)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases GPUMem) -> [Stm (Aliases GPUMem)])
-> Stms (Aliases GPUMem) -> [Stm (Aliases GPUMem)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases GPUMem) -> Stms (Aliases GPUMem)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases GPUMem)
body)

computeScalarTableMCMem ::
  ComputeScalarTable MCMem (MC.MCOp NoOp (Aliases MCMem))
computeScalarTableMCMem :: ComputeScalarTable MCMem (MCOp NoOp (Aliases MCMem))
computeScalarTableMCMem ScopeTab MCMem
_ (MC.OtherOp NoOp (Aliases MCMem)
NoOp) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp MCMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableMCMem ScopeTab MCMem
scope_table (MC.ParOp Maybe (SegOp () (Aliases MCMem))
par_op SegOp () (Aliases MCMem)
segop) =
  Map VName (PrimExp VName)
-> Map VName (PrimExp VName) -> Map VName (PrimExp VName)
forall a. Semigroup a => a -> a -> a
(<>)
    (Map VName (PrimExp VName)
 -> Map VName (PrimExp VName) -> Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem)
     Identity
     (Map VName (PrimExp VName) -> Map VName (PrimExp VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT
  (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
-> (SegOp () (Aliases MCMem)
    -> ReaderT
         (ComputeScalarTableOnOp MCMem)
         Identity
         (Map VName (PrimExp VName)))
-> Maybe (SegOp () (Aliases MCMem))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp MCMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty) (ComputeScalarTable MCMem (SegOp () (Aliases MCMem))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab MCMem
scope_table) Maybe (SegOp () (Aliases MCMem))
par_op
    ReaderT
  (ComputeScalarTableOnOp MCMem)
  Identity
  (Map VName (PrimExp VName) -> Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall a b.
ReaderT (ComputeScalarTableOnOp MCMem) Identity (a -> b)
-> ReaderT (ComputeScalarTableOnOp MCMem) Identity a
-> ReaderT (ComputeScalarTableOnOp MCMem) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ComputeScalarTable MCMem (SegOp () (Aliases MCMem))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab MCMem
scope_table SegOp () (Aliases MCMem)
segop

filterMapM1 :: (Eq k, Monad m) => (v -> m Bool) -> M.Map k v -> m (M.Map k v)
filterMapM1 :: forall k (m :: * -> *) v.
(Eq k, Monad m) =>
(v -> m Bool) -> Map k v -> m (Map k v)
filterMapM1 v -> m Bool
f Map k v
m = ([(k, v)] -> Map k v) -> m [(k, v)] -> m (Map k v)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(k, v)] -> Map k v
forall k a. Eq k => [(k, a)] -> Map k a
M.fromAscList (m [(k, v)] -> m (Map k v)) -> m [(k, v)] -> m (Map k v)
forall a b. (a -> b) -> a -> b
$ ((k, v) -> m Bool) -> [(k, v)] -> m [(k, v)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (v -> m Bool
f (v -> m Bool) -> ((k, v) -> v) -> (k, v) -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (k, v) -> v
forall a b. (a, b) -> b
snd) ([(k, v)] -> m [(k, v)]) -> [(k, v)] -> m [(k, v)]
forall a b. (a -> b) -> a -> b
$ Map k v -> [(k, v)]
forall k a. Map k a -> [(k, a)]
M.toAscList Map k v
m