{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

-- | The bulk of the short-circuiting implementation.
module Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing
  ( mkCoalsTab,
    CoalsTab,
    mkCoalsTabGPU,
    mkCoalsTabMC,
  )
where

import Control.Exception.Base qualified as Exc
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.List qualified as L
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence (Seq (..))
import Data.Set qualified as S
import Futhark.Analysis.LastUse
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Aliases
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.SeqMem
import Futhark.MonadFreshNames
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Optimise.ArrayShortCircuiting.MemRefAggreg
import Futhark.Optimise.ArrayShortCircuiting.TopdownAnalysis
import Futhark.Util

-- | A helper type describing representations that can be short-circuited.
type Coalesceable rep inner =
  ( Mem rep inner,
    ASTRep rep,
    CanBeAliased inner,
    AliasableRep rep,
    Op rep ~ MemOp inner rep,
    HasMemBlock (Aliases rep),
    LetDec rep ~ LetDecMem,
    TopDownHelper (inner (Aliases rep))
  )

type ComputeScalarTable rep op =
  ScopeTab rep -> op -> ScalarTableM rep (M.Map VName (PrimExp VName))

-- Helper type for computing scalar tables on ops.
newtype ComputeScalarTableOnOp rep = ComputeScalarTableOnOp
  { forall rep.
ComputeScalarTableOnOp rep
-> ComputeScalarTable rep (Op (Aliases rep))
scalarTableOnOp :: ComputeScalarTable rep (Op (Aliases rep))
  }

type ScalarTableM rep a = Reader (ComputeScalarTableOnOp rep) a

data ShortCircuitReader rep = ShortCircuitReader
  { forall rep.
ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
onOp ::
      LUTabFun ->
      Pat (VarAliases, LetDecMem) ->
      Certs ->
      Op (Aliases rep) ->
      TopdownEnv rep ->
      BotUpEnv ->
      ShortCircuitM rep BotUpEnv,
    forall rep.
ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> Maybe [SSPointInfo]
ssPointFromOp ::
      LUTabFun ->
      TopdownEnv rep ->
      ScopeTab rep ->
      Pat (VarAliases, LetDecMem) ->
      Certs ->
      Op (Aliases rep) ->
      Maybe [SSPointInfo]
  }

newtype ShortCircuitM rep a = ShortCircuitM (ReaderT (ShortCircuitReader rep) (State VNameSource) a)
  deriving ((forall a b.
 (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b)
-> (forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a)
-> Functor (ShortCircuitM rep)
forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall a b. (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
forall rep a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall rep a b.
(a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep a b.
(a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
fmap :: forall a b. (a -> b) -> ShortCircuitM rep a -> ShortCircuitM rep b
$c<$ :: forall rep a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
<$ :: forall a b. a -> ShortCircuitM rep b -> ShortCircuitM rep a
Functor, Functor (ShortCircuitM rep)
Functor (ShortCircuitM rep) =>
(forall a. a -> ShortCircuitM rep a)
-> (forall a b.
    ShortCircuitM rep (a -> b)
    -> ShortCircuitM rep a -> ShortCircuitM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> ShortCircuitM rep a
    -> ShortCircuitM rep b
    -> ShortCircuitM rep c)
-> (forall a b.
    ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b)
-> (forall a b.
    ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a)
-> Applicative (ShortCircuitM rep)
forall rep. Functor (ShortCircuitM rep)
forall a. a -> ShortCircuitM rep a
forall rep a. a -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
forall a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
forall rep a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep a. a -> ShortCircuitM rep a
pure :: forall a. a -> ShortCircuitM rep a
$c<*> :: forall rep a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
<*> :: forall a b.
ShortCircuitM rep (a -> b)
-> ShortCircuitM rep a -> ShortCircuitM rep b
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
liftA2 :: forall a b c.
(a -> b -> c)
-> ShortCircuitM rep a
-> ShortCircuitM rep b
-> ShortCircuitM rep c
$c*> :: forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
*> :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
$c<* :: forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
<* :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep a
Applicative, Applicative (ShortCircuitM rep)
Applicative (ShortCircuitM rep) =>
(forall a b.
 ShortCircuitM rep a
 -> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b)
-> (forall a b.
    ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b)
-> (forall a. a -> ShortCircuitM rep a)
-> Monad (ShortCircuitM rep)
forall rep. Applicative (ShortCircuitM rep)
forall a. a -> ShortCircuitM rep a
forall rep a. a -> ShortCircuitM rep a
forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
forall rep a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
>>= :: forall a b.
ShortCircuitM rep a
-> (a -> ShortCircuitM rep b) -> ShortCircuitM rep b
$c>> :: forall rep a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
>> :: forall a b.
ShortCircuitM rep a -> ShortCircuitM rep b -> ShortCircuitM rep b
$creturn :: forall rep a. a -> ShortCircuitM rep a
return :: forall a. a -> ShortCircuitM rep a
Monad, MonadReader (ShortCircuitReader rep), MonadState VNameSource)

instance MonadFreshNames (ShortCircuitM rep) where
  putNameSource :: VNameSource -> ShortCircuitM rep ()
putNameSource = VNameSource -> ShortCircuitM rep ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
  getNameSource :: ShortCircuitM rep VNameSource
getNameSource = ShortCircuitM rep VNameSource
forall s (m :: * -> *). MonadState s m => m s
get

emptyTopdownEnv :: TopdownEnv rep
emptyTopdownEnv :: forall rep. TopdownEnv rep
emptyTopdownEnv =
  TopdownEnv
    { alloc :: AllocTab
alloc = AllocTab
forall a. Monoid a => a
mempty,
      scope :: ScopeTab rep
scope = ScopeTab rep
forall a. Monoid a => a
mempty,
      inhibited :: InhibitTab
inhibited = InhibitTab
forall a. Monoid a => a
mempty,
      v_alias :: VarAliasTab
v_alias = VarAliasTab
forall a. Monoid a => a
mempty,
      m_alias :: InhibitTab
m_alias = InhibitTab
forall a. Monoid a => a
mempty,
      nonNegatives :: Names
nonNegatives = Names
forall a. Monoid a => a
mempty,
      scalarTable :: Map VName (PrimExp VName)
scalarTable = Map VName (PrimExp VName)
forall a. Monoid a => a
mempty,
      knownLessThan :: [(VName, PrimExp VName)]
knownLessThan = [(VName, PrimExp VName)]
forall a. Monoid a => a
mempty,
      td_asserts :: [SubExp]
td_asserts = [SubExp]
forall a. Monoid a => a
mempty
    }

emptyBotUpEnv :: BotUpEnv
emptyBotUpEnv :: BotUpEnv
emptyBotUpEnv =
  BotUpEnv
    { scals :: Map VName (PrimExp VName)
scals = Map VName (PrimExp VName)
forall a. Monoid a => a
mempty,
      activeCoals :: CoalsTab
activeCoals = CoalsTab
forall a. Monoid a => a
mempty,
      successCoals :: CoalsTab
successCoals = CoalsTab
forall a. Monoid a => a
mempty,
      inhibit :: InhibitTab
inhibit = InhibitTab
forall a. Monoid a => a
mempty
    }

--------------------------------------------------------------------------------
--- Main Coalescing Transformation computes a successful coalescing table    ---
--------------------------------------------------------------------------------

-- | Given a 'Prog' in 'SegMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTab :: (MonadFreshNames m) => Prog (Aliases SeqMem) -> m (M.Map Name CoalsTab)
mkCoalsTab :: forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases SeqMem) -> m (Map Name CoalsTab)
mkCoalsTab Prog (Aliases SeqMem)
prog =
  LUTabProg
-> ShortCircuitReader SeqMem
-> ComputeScalarTableOnOp SeqMem
-> Prog (Aliases SeqMem)
-> m (Map Name CoalsTab)
forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg
    (Prog (Aliases SeqMem) -> LUTabProg
lastUseSeqMem Prog (Aliases SeqMem)
prog)
    ((InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases SeqMem)
 -> TopdownEnv SeqMem
 -> BotUpEnv
 -> ShortCircuitM SeqMem BotUpEnv)
-> (InhibitTab
    -> TopdownEnv SeqMem
    -> ScopeTab SeqMem
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases SeqMem)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader SeqMem
forall rep.
(InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> (InhibitTab
    -> TopdownEnv rep
    -> ScopeTab rep
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases rep)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> TopdownEnv SeqMem
-> BotUpEnv
-> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem InhibitTab
-> TopdownEnv SeqMem
-> ScopeTab SeqMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> Maybe [SSPointInfo]
genSSPointInfoSeqMem)
    (ComputeScalarTable SeqMem (Op (Aliases SeqMem))
-> ComputeScalarTableOnOp SeqMem
forall rep.
ComputeScalarTable rep (Op (Aliases rep))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp (ComputeScalarTable SeqMem (Op (Aliases SeqMem))
 -> ComputeScalarTableOnOp SeqMem)
-> ComputeScalarTable SeqMem (Op (Aliases SeqMem))
-> ComputeScalarTableOnOp SeqMem
forall a b. (a -> b) -> a -> b
$ (Op (Aliases SeqMem)
 -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
-> ComputeScalarTable SeqMem (Op (Aliases SeqMem))
forall a b. a -> b -> a
const ((Op (Aliases SeqMem)
  -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
 -> ComputeScalarTable SeqMem (Op (Aliases SeqMem)))
-> (Op (Aliases SeqMem)
    -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
-> ComputeScalarTable SeqMem (Op (Aliases SeqMem))
forall a b. (a -> b) -> a -> b
$ ScalarTableM SeqMem (Map VName (PrimExp VName))
-> Op (Aliases SeqMem)
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
forall a b. a -> b -> a
const (ScalarTableM SeqMem (Map VName (PrimExp VName))
 -> Op (Aliases SeqMem)
 -> ScalarTableM SeqMem (Map VName (PrimExp VName)))
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
-> Op (Aliases SeqMem)
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ Map VName (PrimExp VName)
-> ScalarTableM SeqMem (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp SeqMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty)
    Prog (Aliases SeqMem)
prog

-- | Given a 'Prog' in 'GPUMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTabGPU :: (MonadFreshNames m) => Prog (Aliases GPUMem) -> m (M.Map Name CoalsTab)
mkCoalsTabGPU :: forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases GPUMem) -> m (Map Name CoalsTab)
mkCoalsTabGPU Prog (Aliases GPUMem)
prog =
  LUTabProg
-> ShortCircuitReader GPUMem
-> ComputeScalarTableOnOp GPUMem
-> Prog (Aliases GPUMem)
-> m (Map Name CoalsTab)
forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg
    (Prog (Aliases GPUMem) -> LUTabProg
lastUseGPUMem Prog (Aliases GPUMem)
prog)
    ((InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases GPUMem)
 -> TopdownEnv GPUMem
 -> BotUpEnv
 -> ShortCircuitM GPUMem BotUpEnv)
-> (InhibitTab
    -> TopdownEnv GPUMem
    -> ScopeTab GPUMem
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases GPUMem)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader GPUMem
forall rep.
(InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> (InhibitTab
    -> TopdownEnv rep
    -> ScopeTab rep
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases rep)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem InhibitTab
-> TopdownEnv GPUMem
-> ScopeTab GPUMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> Maybe [SSPointInfo]
genSSPointInfoGPUMem)
    (ComputeScalarTable GPUMem (Op (Aliases GPUMem))
-> ComputeScalarTableOnOp GPUMem
forall rep.
ComputeScalarTable rep (Op (Aliases rep))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp (ComputeScalarTable GPUMem (HostOp NoOp (Aliases GPUMem))
-> ComputeScalarTable GPUMem (MemOp (HostOp NoOp) (Aliases GPUMem))
forall rep (inner :: * -> *).
ComputeScalarTable rep (inner (Aliases rep))
-> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp ComputeScalarTable GPUMem (HostOp NoOp (Aliases GPUMem))
computeScalarTableGPUMem))
    Prog (Aliases GPUMem)
prog

-- | Given a 'Prog' in 'MCMem' representation, compute the coalescing table
-- by folding over each function.
mkCoalsTabMC :: (MonadFreshNames m) => Prog (Aliases MCMem) -> m (M.Map Name CoalsTab)
mkCoalsTabMC :: forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases MCMem) -> m (Map Name CoalsTab)
mkCoalsTabMC Prog (Aliases MCMem)
prog =
  LUTabProg
-> ShortCircuitReader MCMem
-> ComputeScalarTableOnOp MCMem
-> Prog (Aliases MCMem)
-> m (Map Name CoalsTab)
forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg
    (Prog (Aliases MCMem) -> LUTabProg
lastUseMCMem Prog (Aliases MCMem)
prog)
    ((InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases MCMem)
 -> TopdownEnv MCMem
 -> BotUpEnv
 -> ShortCircuitM MCMem BotUpEnv)
-> (InhibitTab
    -> TopdownEnv MCMem
    -> ScopeTab MCMem
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases MCMem)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader MCMem
forall rep.
(InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> Op (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> (InhibitTab
    -> TopdownEnv rep
    -> ScopeTab rep
    -> Pat (VarAliases, LParamMem)
    -> Certs
    -> Op (Aliases rep)
    -> Maybe [SSPointInfo])
-> ShortCircuitReader rep
ShortCircuitReader InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
shortCircuitMCMem InhibitTab
-> TopdownEnv MCMem
-> ScopeTab MCMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> Maybe [SSPointInfo]
genSSPointInfoMCMem)
    (ComputeScalarTable MCMem (Op (Aliases MCMem))
-> ComputeScalarTableOnOp MCMem
forall rep.
ComputeScalarTable rep (Op (Aliases rep))
-> ComputeScalarTableOnOp rep
ComputeScalarTableOnOp (ComputeScalarTable MCMem (MCOp NoOp (Aliases MCMem))
-> ComputeScalarTable MCMem (MemOp (MCOp NoOp) (Aliases MCMem))
forall rep (inner :: * -> *).
ComputeScalarTable rep (inner (Aliases rep))
-> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp ComputeScalarTable MCMem (MCOp NoOp (Aliases MCMem))
computeScalarTableMCMem))
    Prog (Aliases MCMem)
prog

-- | Given a function, compute the coalescing table
mkCoalsTabProg ::
  (MonadFreshNames m, Coalesceable rep inner) =>
  LUTabProg ->
  ShortCircuitReader rep ->
  ComputeScalarTableOnOp rep ->
  Prog (Aliases rep) ->
  m (M.Map Name CoalsTab)
mkCoalsTabProg :: forall (m :: * -> *) rep (inner :: * -> *).
(MonadFreshNames m, Coalesceable rep inner) =>
LUTabProg
-> ShortCircuitReader rep
-> ComputeScalarTableOnOp rep
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
mkCoalsTabProg (InhibitTab
_, Map Name InhibitTab
lutab_prog) ShortCircuitReader rep
r ComputeScalarTableOnOp rep
computeScalarOnOp =
  ([(Name, CoalsTab)] -> Map Name CoalsTab)
-> m [(Name, CoalsTab)] -> m (Map Name CoalsTab)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Name, CoalsTab)] -> Map Name CoalsTab
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (m [(Name, CoalsTab)] -> m (Map Name CoalsTab))
-> (Prog (Aliases rep) -> m [(Name, CoalsTab)])
-> Prog (Aliases rep)
-> m (Map Name CoalsTab)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FunDef (Aliases rep) -> m (Name, CoalsTab))
-> [FunDef (Aliases rep)] -> m [(Name, CoalsTab)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FunDef (Aliases rep) -> m (Name, CoalsTab)
onFun ([FunDef (Aliases rep)] -> m [(Name, CoalsTab)])
-> (Prog (Aliases rep) -> [FunDef (Aliases rep)])
-> Prog (Aliases rep)
-> m [(Name, CoalsTab)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog (Aliases rep) -> [FunDef (Aliases rep)]
forall rep. Prog rep -> [FunDef rep]
progFuns
  where
    onFun :: FunDef (Aliases rep) -> m (Name, CoalsTab)
onFun fun :: FunDef (Aliases rep)
fun@(FunDef Maybe EntryPoint
_ Attrs
_ Name
fname [(RetType (Aliases rep), RetAls)]
_ [FParam (Aliases rep)]
fpars Body (Aliases rep)
body) = do
      -- First compute last-use information
      let unique_mems :: AllocTab
unique_mems = [Param FParamMem] -> AllocTab
getUniqueMemFParam [FParam (Aliases rep)]
[Param FParamMem]
fpars
          lutab :: InhibitTab
lutab = Map Name InhibitTab
lutab_prog Map Name InhibitTab -> Name -> InhibitTab
forall k a. Ord k => Map k a -> k -> a
M.! Name
fname
          scalar_table :: Map VName (PrimExp VName)
scalar_table =
            Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
-> ComputeScalarTableOnOp rep -> Map VName (PrimExp VName)
forall r a. Reader r a -> r -> a
runReader
              ( (Stm (Aliases rep)
 -> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
                  (ScopeTab rep
-> Stm (Aliases rep)
-> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> Reader (ComputeScalarTableOnOp rep) (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef (Aliases rep)
fun ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body))
                  (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
              )
              ComputeScalarTableOnOp rep
computeScalarOnOp
          topenv :: TopdownEnv rep
topenv =
            TopdownEnv Any
forall rep. TopdownEnv rep
emptyTopdownEnv
              { scope = scopeOfFParams fpars,
                alloc = unique_mems,
                scalarTable = scalar_table,
                nonNegatives = foldMap paramSizes fpars
              }
          ShortCircuitM ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
m = InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [FParam (Aliases rep)]
[Param FParamMem]
fpars Body (Aliases rep)
body TopdownEnv rep
topenv
      (Name
fname,) (CoalsTab -> (Name, CoalsTab)) -> m CoalsTab -> m (Name, CoalsTab)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VNameSource -> (CoalsTab, VNameSource)) -> m CoalsTab
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource (State VNameSource CoalsTab
-> VNameSource -> (CoalsTab, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
-> ShortCircuitReader rep -> State VNameSource CoalsTab
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (ShortCircuitReader rep) (State VNameSource) CoalsTab
m ShortCircuitReader rep
r))

paramSizes :: Param FParamMem -> Names
paramSizes :: Param FParamMem -> Names
paramSizes (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
shp Uniqueness
_ MemBind
_)) = ShapeBase SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn ShapeBase SubExp
shp
paramSizes Param FParamMem
_ = Names
forall a. Monoid a => a
mempty

-- | Short-circuit handler for a 'SeqMem' 'Op'.
--
-- Because 'SeqMem' don't have any special operation, simply return the input
-- 'BotUpEnv'.
shortCircuitSeqMem :: LUTabFun -> Pat (VarAliases, LetDecMem) -> Certs -> Op (Aliases SeqMem) -> TopdownEnv SeqMem -> BotUpEnv -> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem :: InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> TopdownEnv SeqMem
-> BotUpEnv
-> ShortCircuitM SeqMem BotUpEnv
shortCircuitSeqMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ Op (Aliases SeqMem)
_ TopdownEnv SeqMem
_ = BotUpEnv -> ShortCircuitM SeqMem BotUpEnv
forall a. a -> ShortCircuitM SeqMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | Short-circuit handler for SegOp.
shortCircuitSegOp ::
  (Coalesceable rep inner) =>
  (lvl -> Bool) ->
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  SegOp lvl (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
shortCircuitSegOp :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegMap lvl
lvl SegSpace
space [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env =
  -- No special handling necessary for 'SegMap'. Just call the helper-function.
  Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
0 lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegRed lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
binops [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env =
  -- When handling 'SegRed', we we first invalidate all active coalesce-entries
  -- where any of the variables in 'vartab' are also free in the list of
  -- 'SegBinOp'. In other words, anything that is used as part of the reduction
  -- step should probably not be coalesced.
  let to_fail :: CoalsTab
to_fail = (CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Coalesced -> [VName]) -> Map VName Coalesced -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` (SegBinOp (Aliases rep) -> Names)
-> [SegBinOp (Aliases rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Lambda (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda (Aliases rep) -> Names)
-> (SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> SegBinOp (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp (Aliases rep)]
binops) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) =
        ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals = active, inhibit = inh}
      num_reds :: Int
num_reds = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_ts
   in Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
num_reds lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env'
  where
    segment_dims :: [SubExp]
segment_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [SubExp]
segSpaceDims SegSpace
space
    red_ts :: [Type]
red_ts = do
      SegBinOp (Aliases rep)
op <- [SegBinOp (Aliases rep)]
binops
      let shp :: ShapeBase SubExp
shp = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
segment_dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> SegBinOp (Aliases rep) -> ShapeBase SubExp
forall rep. SegBinOp rep -> ShapeBase SubExp
segBinOpShape SegBinOp (Aliases rep)
op
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> ShapeBase SubExp -> Type
`arrayOfShape` ShapeBase SubExp
shp) (Lambda (Aliases rep) -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda (Aliases rep) -> [Type]) -> Lambda (Aliases rep) -> [Type]
forall a b. (a -> b) -> a -> b
$ SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp (Aliases rep)
op)
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegScan lvl
lvl SegSpace
space [SegBinOp (Aliases rep)]
binops [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env =
  -- Like in the handling of 'SegRed', we do not want to coalesce anything that
  -- is used in the 'SegBinOp'
  let to_fail :: CoalsTab
to_fail = (CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Coalesced -> [VName]) -> Map VName Coalesced -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` (SegBinOp (Aliases rep) -> Names)
-> [SegBinOp (Aliases rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Lambda (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda (Aliases rep) -> Names)
-> (SegBinOp (Aliases rep) -> Lambda (Aliases rep))
-> SegBinOp (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp (Aliases rep)]
binops) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals = active, inhibit = inh}
   in Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
0 lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env'
shortCircuitSegOp lvl -> Bool
lvlOK InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs (SegHist lvl
lvl SegSpace
space [HistOp (Aliases rep)]
histops [Type]
_ KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- Need to take zipped patterns and histDest (flattened) and insert transitive coalesces
  let to_fail :: CoalsTab
to_fail = (CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\CoalsEntry
entry -> [VName] -> Names
namesFromList (Map VName Coalesced -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Coalesced -> [VName]) -> Map VName Coalesced -> [VName]
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry) Names -> Names -> Bool
`namesIntersect` (HistOp (Aliases rep) -> Names) -> [HistOp (Aliases rep)] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Lambda (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda (Aliases rep) -> Names)
-> (HistOp (Aliases rep) -> Lambda (Aliases rep))
-> HistOp (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Aliases rep) -> Lambda (Aliases rep)
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp (Aliases rep)]
histops) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
      (CoalsTab
active, InhibitTab
inh) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
to_fail
      bu_env' :: BotUpEnv
bu_env' = BotUpEnv
bu_env {activeCoals = active, inhibit = inh}
  BotUpEnv
bu_env'' <- Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
0 lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
pat_certs SegSpace
space KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env'
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$
    (BotUpEnv -> (PatElem (VarAliases, LParamMem), VName) -> BotUpEnv)
-> BotUpEnv
-> [(PatElem (VarAliases, LParamMem), VName)]
-> BotUpEnv
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl BotUpEnv -> (PatElem (VarAliases, LParamMem), VName) -> BotUpEnv
insertHistCoals BotUpEnv
bu_env'' ([(PatElem (VarAliases, LParamMem), VName)] -> BotUpEnv)
-> [(PatElem (VarAliases, LParamMem), VName)] -> BotUpEnv
forall a b. (a -> b) -> a -> b
$
      [PatElem (VarAliases, LParamMem)]
-> [VName] -> [(PatElem (VarAliases, LParamMem), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
pat) ([VName] -> [(PatElem (VarAliases, LParamMem), VName)])
-> [VName] -> [(PatElem (VarAliases, LParamMem), VName)]
forall a b. (a -> b) -> a -> b
$
        (HistOp (Aliases rep) -> [VName])
-> [HistOp (Aliases rep)] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp (Aliases rep) -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp (Aliases rep)]
histops
  where
    insertHistCoals :: BotUpEnv -> (PatElem (VarAliases, LParamMem), VName) -> BotUpEnv
insertHistCoals BotUpEnv
acc (PatElem VName
p (VarAliases, LParamMem)
_, VName
hist_dest) =
      case ( VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
p (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env,
             VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
hist_dest (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
           ) of
        (Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
p_mem LMAD
_), Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
dest_mem LMAD
_)) ->
          case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
p_mem (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
successCoals BotUpEnv
acc of
            Just CoalsEntry
entry ->
              -- Update this entry with an optdep for the memory block of hist_dest
              let entry' :: CoalsEntry
entry' = CoalsEntry
entry {optdeps = M.insert p p_mem $ optdeps entry}
               in BotUpEnv
acc
                    { successCoals = M.insert p_mem entry' $ successCoals acc,
                      activeCoals = M.insert dest_mem entry $ activeCoals acc
                    }
            Maybe CoalsEntry
Nothing -> BotUpEnv
acc
        (Maybe ArrayMemBound, Maybe ArrayMemBound)
_ -> BotUpEnv
acc

-- | Short-circuit handler for 'GPUMem' 'Op'.
--
-- When the 'Op' is a 'SegOp', we handle it accordingly, otherwise we do
-- nothing.
shortCircuitGPUMem ::
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  Op (Aliases GPUMem) ->
  TopdownEnv GPUMem ->
  BotUpEnv ->
  ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem :: InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
shortCircuitGPUMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Alloc SubExp
_ Space
_) TopdownEnv GPUMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM GPUMem BotUpEnv
forall a. a -> ShortCircuitM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (GPU.SegOp SegOp SegLevel (Aliases GPUMem)
op)) TopdownEnv GPUMem
td_env BotUpEnv
bu_env =
  (SegLevel -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp SegLevel (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp SegLevel -> Bool
isSegThread InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp SegLevel (Aliases GPUMem)
op TopdownEnv GPUMem
td_env BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (GPU.GPUBody [Type]
_ Body (Aliases GPUMem)
body)) TopdownEnv GPUMem
td_env BotUpEnv
bu_env = do
  VName
fresh1 <- String -> ShortCircuitM GPUMem VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"gpubody"
  VName
fresh2 <- String -> ShortCircuitM GPUMem VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"gpubody"
  Int
-> (SegLevel -> Bool)
-> SegLevel
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases GPUMem)
-> TopdownEnv GPUMem
-> BotUpEnv
-> ShortCircuitM GPUMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper
    Int
0
    SegLevel -> Bool
isSegThread
    -- Construct a 'SegLevel' corresponding to a single thread
    ( SegVirt -> Maybe KernelGrid -> SegLevel
GPU.SegThread SegVirt
GPU.SegNoVirt (Maybe KernelGrid -> SegLevel) -> Maybe KernelGrid -> SegLevel
forall a b. (a -> b) -> a -> b
$
        KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just (KernelGrid -> Maybe KernelGrid) -> KernelGrid -> Maybe KernelGrid
forall a b. (a -> b) -> a -> b
$
          Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
GPU.KernelGrid
            (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
GPU.Count (SubExp -> Count NumBlocks SubExp)
-> SubExp -> Count NumBlocks SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)
            (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
GPU.Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)
    )
    InhibitTab
lutab
    Pat (VarAliases, LParamMem)
pat
    Certs
certs
    (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
fresh1 [(VName
fresh2, PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1)])
    (Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody Body (Aliases GPUMem)
body)
    TopdownEnv GPUMem
td_env
    BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Inner (GPU.SizeOp SizeOp
_)) TopdownEnv GPUMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM GPUMem BotUpEnv
forall a. a -> ShortCircuitM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitGPUMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Inner (GPU.OtherOp NoOp (Aliases GPUMem)
NoOp)) TopdownEnv GPUMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM GPUMem BotUpEnv
forall a. a -> ShortCircuitM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env

shortCircuitMCMem ::
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  Op (Aliases MCMem) ->
  TopdownEnv MCMem ->
  BotUpEnv ->
  ShortCircuitM MCMem BotUpEnv
shortCircuitMCMem :: InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
shortCircuitMCMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Alloc SubExp
_ Space
_) TopdownEnv MCMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM MCMem BotUpEnv
forall a. a -> ShortCircuitM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitMCMem InhibitTab
_ Pat (VarAliases, LParamMem)
_ Certs
_ (Inner (MC.OtherOp NoOp (Aliases MCMem)
NoOp)) TopdownEnv MCMem
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM MCMem BotUpEnv
forall a. a -> ShortCircuitM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
shortCircuitMCMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (MC.ParOp (Just SegOp () (Aliases MCMem)
par_op) SegOp () (Aliases MCMem)
op)) TopdownEnv MCMem
td_env BotUpEnv
bu_env =
  (() -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp () (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases MCMem)
par_op TopdownEnv MCMem
td_env BotUpEnv
bu_env
    ShortCircuitM MCMem BotUpEnv
-> (BotUpEnv -> ShortCircuitM MCMem BotUpEnv)
-> ShortCircuitM MCMem BotUpEnv
forall a b.
ShortCircuitM MCMem a
-> (a -> ShortCircuitM MCMem b) -> ShortCircuitM MCMem b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (() -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp () (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases MCMem)
op TopdownEnv MCMem
td_env
shortCircuitMCMem InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner (MC.ParOp Maybe (SegOp () (Aliases MCMem))
Nothing SegOp () (Aliases MCMem)
op)) TopdownEnv MCMem
td_env BotUpEnv
bu_env =
  (() -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp () (Aliases MCMem)
-> TopdownEnv MCMem
-> BotUpEnv
-> ShortCircuitM MCMem BotUpEnv
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegOp lvl (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOp (Bool -> () -> Bool
forall a b. a -> b -> a
const Bool
True) InhibitTab
lutab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases MCMem)
op TopdownEnv MCMem
td_env BotUpEnv
bu_env

dropLastSegSpace :: SegSpace -> SegSpace
dropLastSegSpace :: SegSpace -> SegSpace
dropLastSegSpace SegSpace
space = SegSpace
space {unSegSpace = init $ unSegSpace space}

isSegThread :: GPU.SegLevel -> Bool
isSegThread :: SegLevel -> Bool
isSegThread GPU.SegThread {} = Bool
True
isSegThread SegLevel
_ = Bool
False

-- | Computes the slice written at the end of a thread in a 'SegOp'.
threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice :: SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice SegSpace
space Returns {} =
  Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a. a -> Maybe a
Just (Slice (TPrimExp Int64 VName)
 -> Maybe (Slice (TPrimExp Int64 VName)))
-> Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$
    [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
      ((VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> ((VName, SubExp) -> TPrimExp Int64 VName)
-> (VName, SubExp)
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> PrimExp VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimType -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) (VName -> PrimExp VName)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)])
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$
        SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
threadSlice SegSpace
space (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims VName
_) =
  Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a. a -> Maybe a
Just
    (Slice (TPrimExp Int64 VName)
 -> Maybe (Slice (TPrimExp Int64 VName)))
-> Slice (TPrimExp Int64 VName)
-> Maybe (Slice (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice
    ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp, SubExp)
 -> (VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(SubExp, SubExp, SubExp)]
-> [(VName, SubExp)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      ( \(SubExp
_, SubExp
block_tile_size0, SubExp
reg_tile_size0) (VName
x0, SubExp
_) ->
          let x :: TPrimExp Int64 VName
x = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x0
              block_tile_size :: TPrimExp Int64 VName
block_tile_size = SubExp -> TPrimExp Int64 VName
pe64 SubExp
block_tile_size0
              reg_tile_size :: TPrimExp Int64 VName
reg_tile_size = SubExp -> TPrimExp Int64 VName
pe64 SubExp
reg_tile_size0
           in TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice (TPrimExp Int64 VName
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
block_tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
reg_tile_size) (TPrimExp Int64 VName
block_tile_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
reg_tile_size) TPrimExp Int64 VName
1
      )
      [(SubExp, SubExp, SubExp)]
dims
    ([(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)])
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
threadSlice SegSpace
_ KernelResult
_ = Maybe (Slice (TPrimExp Int64 VName))
forall a. Maybe a
Nothing

bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody :: Body (Aliases GPUMem) -> KernelBody (Aliases GPUMem)
bodyToKernelBody (Body BodyDec (Aliases GPUMem)
dec Stms (Aliases GPUMem)
stms Result
res) =
  BodyDec (Aliases GPUMem)
-> Stms (Aliases GPUMem)
-> [KernelResult]
-> KernelBody (Aliases GPUMem)
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec (Aliases GPUMem)
dec Stms (Aliases GPUMem)
stms ([KernelResult] -> KernelBody (Aliases GPUMem))
-> [KernelResult] -> KernelBody (Aliases GPUMem)
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExpRes Certs
cert SubExp
subexps) -> ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultNoSimplify Certs
cert SubExp
subexps) Result
res

-- | A helper for all the different kinds of 'SegOp'.
--
-- Consists of four parts:
--
-- 1. Create coalescing relations between the pattern elements and the kernel
-- body results using 'makeSegMapCoals'.
--
-- 2. Process the statements of the 'KernelBody'.
--
-- 3. Check the overlap between the different threads.
--
-- 4. Mark active coalescings as finished, since a 'SegOp' is an array creation
-- point.
shortCircuitSegOpHelper ::
  (Coalesceable rep inner) =>
  -- | The number of returns for which we should drop the last seg space
  Int ->
  -- | Whether we should look at a segop with this lvl.
  (lvl -> Bool) ->
  lvl ->
  LUTabFun ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  SegSpace ->
  KernelBody (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
Int
-> (lvl -> Bool)
-> lvl
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> SegSpace
-> KernelBody (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
shortCircuitSegOpHelper Int
num_reds lvl -> Bool
lvlOK lvl
lvl InhibitTab
lutab pat :: Pat (VarAliases, LParamMem)
pat@(Pat [PatElem (VarAliases, LParamMem)]
ps0) Certs
pat_certs SegSpace
space0 KernelBody (Aliases rep)
kernel_body TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- We need to drop the last element of the 'SegSpace' for pattern elements
  -- that correspond to reductions.
  let ps_space_and_res :: [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res =
        [PatElem (VarAliases, LParamMem)]
-> [SegSpace]
-> [KernelResult]
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarAliases, LParamMem)]
ps0 (Int -> SegSpace -> [SegSpace]
forall a. Int -> a -> [a]
replicate Int
num_reds (SegSpace -> SegSpace
dropLastSegSpace SegSpace
space0) [SegSpace] -> [SegSpace] -> [SegSpace]
forall a. Semigroup a => a -> a -> a
<> SegSpace -> [SegSpace]
forall a. a -> [a]
repeat SegSpace
space0) ([KernelResult]
 -> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)])
-> [KernelResult]
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
forall a b. (a -> b) -> a -> b
$
          KernelBody (Aliases rep) -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody (Aliases rep)
kernel_body
  -- Create coalescing relations between pattern elements and kernel body
  -- results
  let (CoalsTab
actv0, InhibitTab
inhibit0) =
        CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          (Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
pat)
      (CoalsTab
actv_return, InhibitTab
inhibit_return) =
        if Int
num_reds Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
          then (CoalsTab
actv0, InhibitTab
inhibit0)
          else ((CoalsTab, InhibitTab)
 -> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
 -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((lvl -> Bool)
-> lvl
-> TopdownEnv rep
-> KernelBody (Aliases rep)
-> Certs
-> (CoalsTab, InhibitTab)
-> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> lvl
-> TopdownEnv rep
-> KernelBody (Aliases rep)
-> Certs
-> (CoalsTab, InhibitTab)
-> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
makeSegMapCoals lvl -> Bool
lvlOK lvl
lvl TopdownEnv rep
td_env KernelBody (Aliases rep)
kernel_body Certs
pat_certs) (CoalsTab
actv0, InhibitTab
inhibit0) [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res

  -- Start from empty references, we'll update with aggregates later.
  let actv0' :: CoalsTab
actv0' = (CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (\CoalsEntry
etry -> CoalsEntry
etry {memrefs = mempty}) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ CoalsTab
actv0 CoalsTab -> CoalsTab -> CoalsTab
forall a. Semigroup a => a -> a -> a
<> CoalsTab
actv_return
  -- Process kernel body statements
  BotUpEnv
bu_env' <-
    InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body) TopdownEnv rep
td_env (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$
      BotUpEnv
bu_env {activeCoals = actv0', inhibit = inhibit_return}

  let actv_coals_after :: CoalsTab
actv_coals_after =
        (VName -> CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey
          ( \VName
k CoalsEntry
etry ->
              CoalsEntry
etry
                { memrefs = memrefs etry <> maybe mempty memrefs (M.lookup k $ actv0 <> actv_return)
                }
          )
          (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env'

  -- Check partial overlap.
  let checkPartialOverlap :: BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM rep BotUpEnv
checkPartialOverlap BotUpEnv
bu_env_f (VName
k, CoalsEntry
entry) = do
        let sliceThreadAccess :: (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> AccessSummary
sliceThreadAccess (PatElem (VarAliases, LParamMem)
p, SegSpace
space, KernelResult
res) =
              case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p) (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
entry of
                Just (Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
ixf) FreeVarSubsts
_) ->
                  AccessSummary
-> (Slice (TPrimExp Int64 VName) -> AccessSummary)
-> Maybe (Slice (TPrimExp Int64 VName))
-> AccessSummary
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
                    AccessSummary
Undeterminable
                    ( LMAD -> AccessSummary
ixfunToAccessSummary
                        (LMAD -> AccessSummary)
-> (Slice (TPrimExp Int64 VName) -> LMAD)
-> Slice (TPrimExp Int64 VName)
-> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD -> Slice (TPrimExp Int64 VName) -> LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD
ixf
                        (Slice (TPrimExp Int64 VName) -> LMAD)
-> (Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
-> LMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice (LMAD -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD
ixf)
                    )
                    (Maybe (Slice (TPrimExp Int64 VName)) -> AccessSummary)
-> Maybe (Slice (TPrimExp Int64 VName)) -> AccessSummary
forall a b. (a -> b) -> a -> b
$ SegSpace -> KernelResult -> Maybe (Slice (TPrimExp Int64 VName))
threadSlice SegSpace
space KernelResult
res
                Maybe Coalesced
Nothing -> AccessSummary
forall a. Monoid a => a
mempty
            thread_writes :: AccessSummary
thread_writes = ((PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
 -> AccessSummary)
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
-> AccessSummary
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> AccessSummary
sliceThreadAccess [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res
            source_writes :: AccessSummary
source_writes = MemRefs -> AccessSummary
srcwrts (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry) AccessSummary -> AccessSummary -> AccessSummary
forall a. Semigroup a => a -> a -> a
<> AccessSummary
thread_writes
        AccessSummary
destination_uses <-
          case MemRefs -> AccessSummary
dstrefs (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry)
            AccessSummary -> AccessSummary -> AccessSummary
`accessSubtract` MemRefs -> AccessSummary
dstrefs (MemRefs -> (CoalsEntry -> MemRefs) -> Maybe CoalsEntry -> MemRefs
forall b a. b -> (a -> b) -> Maybe a -> b
maybe MemRefs
forall a. Monoid a => a
mempty CoalsEntry -> MemRefs
memrefs (Maybe CoalsEntry -> MemRefs) -> Maybe CoalsEntry -> MemRefs
forall a b. (a -> b) -> a -> b
$ VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env) of
            Set Set LMAD
s ->
              (LMAD -> ShortCircuitM rep AccessSummary)
-> [LMAD] -> ShortCircuitM rep AccessSummary
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
                (Map VName (PrimExp VName)
-> [(VName, SubExp)] -> LMAD -> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> LMAD -> m AccessSummary
aggSummaryMapPartial (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) ([(VName, SubExp)] -> LMAD -> ShortCircuitM rep AccessSummary)
-> [(VName, SubExp)] -> LMAD -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0)
                (Set LMAD -> [LMAD]
forall a. Set a -> [a]
S.toList Set LMAD
s)
            AccessSummary
Undeterminable -> AccessSummary -> ShortCircuitM rep AccessSummary
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AccessSummary
Undeterminable
        -- Do not allow short-circuiting from a segop-shared memory
        -- block (not in the topdown scope) to an outer memory block.
        if CoalsEntry -> VName
dstmem CoalsEntry
entry VName -> Map VName (NameInfo (Aliases rep)) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` TopdownEnv rep -> Map VName (NameInfo (Aliases rep))
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
          Bool -> Bool -> Bool
&& TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
destination_uses AccessSummary
source_writes
          then BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f
          else do
            let (CoalsTab
ac, InhibitTab
inh) = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env_f) VName
k
            BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env_f {activeCoals = ac, inhibit = inh}

  BotUpEnv
bu_env'' <-
    (BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> [(VName, CoalsEntry)] -> ShortCircuitM rep BotUpEnv
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      BotUpEnv -> (VName, CoalsEntry) -> ShortCircuitM rep BotUpEnv
checkPartialOverlap
      (BotUpEnv
bu_env' {activeCoals = actv_coals_after})
      ([(VName, CoalsEntry)] -> ShortCircuitM rep BotUpEnv)
-> [(VName, CoalsEntry)] -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [(VName, CoalsEntry)]
forall k a. Map k a -> [(k, a)]
M.toList CoalsTab
actv_coals_after

  let updateMemRefs :: CoalsEntry -> ShortCircuitM rep CoalsEntry
updateMemRefs CoalsEntry
entry = do
        AccessSummary
wrts <- Map VName (PrimExp VName)
-> [(VName, SubExp)]
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0) (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary) -> MemRefs -> AccessSummary
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> MemRefs
memrefs CoalsEntry
entry
        AccessSummary
uses <- Map VName (PrimExp VName)
-> [(VName, SubExp)]
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> [(VName, SubExp)] -> AccessSummary -> m AccessSummary
aggSummaryMapTotal (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space0) (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary) -> MemRefs -> AccessSummary
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> MemRefs
memrefs CoalsEntry
entry

        -- Add destination uses from the pattern
        let uses' :: AccessSummary
uses' =
              (PatElem (VarAliases, LParamMem) -> AccessSummary)
-> [PatElem (VarAliases, LParamMem)] -> AccessSummary
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap
                ( \case
                    PatElem VName
_ (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
p_mem LMAD
p_ixf))
                      | VName
p_mem VName -> Names -> Bool
`nameIn` CoalsEntry -> Names
alsmem CoalsEntry
entry ->
                          LMAD -> AccessSummary
ixfunToAccessSummary LMAD
p_ixf
                    PatElem (VarAliases, LParamMem)
_ -> AccessSummary
forall a. Monoid a => a
mempty
                )
                [PatElem (VarAliases, LParamMem)]
ps0

        CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsEntry
entry {memrefs = MemRefs (uses <> uses') wrts}

  CoalsTab
actv <- (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map VName a -> m (Map VName b)
mapM CoalsEntry -> ShortCircuitM rep CoalsEntry
updateMemRefs (CoalsTab -> ShortCircuitM rep CoalsTab)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env''
  let bu_env''' :: BotUpEnv
bu_env''' = BotUpEnv
bu_env'' {activeCoals = actv}

  -- Process pattern and return values
  let mergee_writes :: [(PatElem (VarAliases, LParamMem), (VName, VName, LMAD))]
mergee_writes =
        ((PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
 -> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, LMAD)))
-> [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
-> [(PatElem (VarAliases, LParamMem), (VName, VName, LMAD))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
          ( \(PatElem (VarAliases, LParamMem)
p, SegSpace
_, KernelResult
_) ->
              ((VName, VName, LMAD)
 -> (PatElem (VarAliases, LParamMem), (VName, VName, LMAD)))
-> Maybe (VName, VName, LMAD)
-> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, LMAD))
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElem (VarAliases, LParamMem)
p,) (Maybe (VName, VName, LMAD)
 -> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, LMAD)))
-> Maybe (VName, VName, LMAD)
-> Maybe (PatElem (VarAliases, LParamMem), (VName, VName, LMAD))
forall a b. (a -> b) -> a -> b
$
                TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD)
getDirAliasedIxfn' TopdownEnv rep
td_env (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env''') (VName -> Maybe (VName, VName, LMAD))
-> VName -> Maybe (VName, VName, LMAD)
forall a b. (a -> b) -> a -> b
$
                  PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p
          )
          [(PatElem (VarAliases, LParamMem), SegSpace, KernelResult)]
ps_space_and_res

  -- Now, for each mergee write, we need to check that it doesn't overlap with any previous uses of the destination.
  let checkMergeeOverlap :: BotUpEnv
-> (PatElem (VarAliases, LParamMem), (VName, VName, LMAD))
-> ShortCircuitM rep BotUpEnv
checkMergeeOverlap BotUpEnv
bu_env_f (PatElem (VarAliases, LParamMem)
p, (VName
m_b, VName
_, LMAD
ixf)) =
        let as :: AccessSummary
as = LMAD -> AccessSummary
ixfunToAccessSummary LMAD
ixf
         in -- Should be @bu_env@ here, because we need to check overlap
            -- against previous uses.
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b (CoalsTab -> Maybe CoalsEntry) -> CoalsTab -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env of
              Just CoalsEntry
coal_entry -> do
                let mrefs :: MemRefs
mrefs =
                      CoalsEntry -> MemRefs
memrefs CoalsEntry
coal_entry
                    res :: Bool
res = TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
as (AccessSummary -> Bool) -> AccessSummary -> Bool
forall a b. (a -> b) -> a -> b
$ MemRefs -> AccessSummary
dstrefs MemRefs
mrefs
                    fail_res :: BotUpEnv
fail_res =
                      let (CoalsTab
ac, InhibitTab
inh) = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env_f) VName
m_b
                       in BotUpEnv
bu_env_f {activeCoals = ac, inhibit = inh}

                if Bool
res
                  then case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
p) (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_entry of
                    Maybe Coalesced
Nothing -> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f
                    Just (Coalesced CoalescedKind
knd mbd :: ArrayMemBound
mbd@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
ixfn) FreeVarSubsts
_) -> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$
                      case Map VName (NameInfo (Aliases rep))
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> Map VName (NameInfo (Aliases rep))
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) LMAD
ixfn of
                        Just FreeVarSubsts
fv_subst ->
                          let entry :: CoalsEntry
entry =
                                CoalsEntry
coal_entry
                                  { vartab =
                                      M.insert
                                        (patElemName p)
                                        (Coalesced knd mbd fv_subst)
                                        (vartab coal_entry)
                                  }
                              (CoalsTab
ac, CoalsTab
suc) =
                                (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env_f, BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env_f) VName
m_b CoalsEntry
entry
                           in BotUpEnv
bu_env_f {activeCoals = ac, successCoals = suc}
                        Maybe FreeVarSubsts
Nothing ->
                          BotUpEnv
fail_res
                  else BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
fail_res
              Maybe CoalsEntry
_ -> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env_f

  (BotUpEnv
 -> (PatElem (VarAliases, LParamMem), (VName, VName, LMAD))
 -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv
-> [(PatElem (VarAliases, LParamMem), (VName, VName, LMAD))]
-> ShortCircuitM rep BotUpEnv
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM BotUpEnv
-> (PatElem (VarAliases, LParamMem), (VName, VName, LMAD))
-> ShortCircuitM rep BotUpEnv
checkMergeeOverlap BotUpEnv
bu_env''' [(PatElem (VarAliases, LParamMem), (VName, VName, LMAD))]
mergee_writes

-- | Given a pattern element and the corresponding kernel result, try to put the
-- kernel result directly in the memory block of pattern element
makeSegMapCoals ::
  (Coalesceable rep inner) =>
  (lvl -> Bool) ->
  lvl ->
  TopdownEnv rep ->
  KernelBody (Aliases rep) ->
  Certs ->
  (CoalsTab, InhibitTab) ->
  (PatElem (VarAliases, LetDecMem), SegSpace, KernelResult) ->
  (CoalsTab, InhibitTab)
makeSegMapCoals :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
(lvl -> Bool)
-> lvl
-> TopdownEnv rep
-> KernelBody (Aliases rep)
-> Certs
-> (CoalsTab, InhibitTab)
-> (PatElem (VarAliases, LParamMem), SegSpace, KernelResult)
-> (CoalsTab, InhibitTab)
makeSegMapCoals lvl -> Bool
lvlOK lvl
lvl TopdownEnv rep
td_env KernelBody (Aliases rep)
kernel_body Certs
pat_certs (CoalsTab
active, InhibitTab
inhb) (PatElem VName
pat_name (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
pat_mem LMAD
pat_ixf)), SegSpace
space, Returns ResultManifest
_ Certs
_ (Var VName
return_name))
  | Just (MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
return_mem LMAD
_) <-
      VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
return_name (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body),
    lvl -> Bool
lvlOK lvl
lvl,
    MemMem Space
pat_space <- Reader (Scope rep) LParamMem -> Scope rep -> LParamMem
forall r a. Reader r a -> r -> a
runReader (VName -> Reader (Scope rep) LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
pat_mem) (Scope rep -> LParamMem) -> Scope rep -> LParamMem
forall a b. (a -> b) -> a -> b
$ Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases (Scope (Aliases rep) -> Scope rep)
-> Scope (Aliases rep) -> Scope rep
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env,
    MemMem Space
return_space <-
      TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> Scope (Aliases rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body) Scope (Aliases rep) -> Scope (Aliases rep) -> Scope (Aliases rep)
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope (Aliases rep)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space
        Scope (Aliases rep)
-> (Scope (Aliases rep) -> Scope rep) -> Scope rep
forall a b. a -> (a -> b) -> b
& Scope (Aliases rep) -> Scope rep
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases
        Scope rep -> (Scope rep -> LParamMem) -> LParamMem
forall a b. a -> (a -> b) -> b
& Reader (Scope rep) LParamMem -> Scope rep -> LParamMem
forall r a. Reader r a -> r -> a
runReader (VName -> Reader (Scope rep) LParamMem
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
VName -> m LParamMem
lookupMemInfo VName
return_mem),
    Space
pat_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
return_space =
      case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_mem CoalsTab
active of
        Maybe CoalsEntry
Nothing ->
          -- We are not in a transitive case
          case ( Bool -> (Names -> Bool) -> Maybe Names -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (VName
pat_mem `nameIn`) (VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
return_mem InhibitTab
inhb),
                 CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced
                   CoalescedKind
InPlaceCoal
                   (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
pat_mem (LMAD -> ArrayMemBound) -> LMAD -> ArrayMemBound
forall a b. (a -> b) -> a -> b
$ LMAD -> LMAD
resultSlice LMAD
pat_ixf)
                   FreeVarSubsts
forall a. Monoid a => a
mempty
                   Coalesced
-> (Coalesced -> Map VName Coalesced) -> Map VName Coalesced
forall a b. a -> (a -> b) -> b
& VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
return_name
                   Map VName Coalesced
-> (Map VName Coalesced -> Maybe (Map VName Coalesced))
-> Maybe (Map VName Coalesced)
forall a b. a -> (a -> b) -> b
& (Map VName Coalesced -> VName -> Maybe (Map VName Coalesced))
-> VName -> Map VName Coalesced -> Maybe (Map VName Coalesced)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env) VName
return_name
               ) of
            (Bool
False, Just Map VName Coalesced
vtab) ->
              ( CoalsTab
active
                  CoalsTab -> CoalsTab -> CoalsTab
forall a. Semigroup a => a -> a -> a
<> VName -> CoalsEntry -> CoalsTab
forall k a. k -> a -> Map k a
M.singleton
                    VName
return_mem
                    (VName
-> LMAD
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> Certs
-> CoalsEntry
CoalsEntry VName
pat_mem LMAD
pat_ixf (VName -> Names
oneName VName
pat_mem) Map VName Coalesced
vtab Map VName VName
forall a. Monoid a => a
mempty MemRefs
forall a. Monoid a => a
mempty Certs
pat_certs),
                InhibitTab
inhb
              )
            (Bool, Maybe (Map VName Coalesced))
_ -> (CoalsTab
active, InhibitTab
inhb)
        Just CoalsEntry
trans ->
          case ( Bool -> (Names -> Bool) -> Maybe Names -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (CoalsEntry -> VName
dstmem CoalsEntry
trans `nameIn`) (Maybe Names -> Bool) -> Maybe Names -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
return_mem InhibitTab
inhb,
                 let Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
trans_mem LMAD
trans_ixf) FreeVarSubsts
_ =
                       Coalesced -> Maybe Coalesced -> Coalesced
forall a. a -> Maybe a -> a
fromMaybe (String -> Coalesced
forall a. HasCallStack => String -> a
error String
"Impossible") (Maybe Coalesced -> Coalesced) -> Maybe Coalesced -> Coalesced
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_name (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
trans
                  in CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced
                       CoalescedKind
TransitiveCoal
                       (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
return_shp VName
trans_mem (LMAD -> ArrayMemBound) -> LMAD -> ArrayMemBound
forall a b. (a -> b) -> a -> b
$ LMAD -> LMAD
resultSlice LMAD
trans_ixf)
                       FreeVarSubsts
forall a. Monoid a => a
mempty
                       Coalesced
-> (Coalesced -> Map VName Coalesced) -> Map VName Coalesced
forall a b. a -> (a -> b) -> b
& VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
return_name
                       Map VName Coalesced
-> (Map VName Coalesced -> Maybe (Map VName Coalesced))
-> Maybe (Map VName Coalesced)
forall a b. a -> (a -> b) -> b
& (Map VName Coalesced -> VName -> Maybe (Map VName Coalesced))
-> VName -> Map VName Coalesced -> Maybe (Map VName Coalesced)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env) VName
return_name
               ) of
            (Bool
False, Just Map VName Coalesced
vtab) ->
              let opts :: Map VName VName
opts =
                    if CoalsEntry -> VName
dstmem CoalsEntry
trans VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
pat_mem
                      then Map VName VName
forall a. Monoid a => a
mempty
                      else VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
pat_name VName
pat_mem (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName VName
optdeps CoalsEntry
trans
               in ( VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
                      VName
return_mem
                      ( VName
-> LMAD
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> Certs
-> CoalsEntry
CoalsEntry
                          (CoalsEntry -> VName
dstmem CoalsEntry
trans)
                          (CoalsEntry -> LMAD
dstind CoalsEntry
trans)
                          (VName -> Names
oneName VName
pat_mem Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> CoalsEntry -> Names
alsmem CoalsEntry
trans)
                          Map VName Coalesced
vtab
                          Map VName VName
opts
                          MemRefs
forall a. Monoid a => a
mempty
                          (CoalsEntry -> Certs
certs CoalsEntry
trans Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
pat_certs)
                      )
                      CoalsTab
active,
                    InhibitTab
inhb
                  )
            (Bool, Maybe (Map VName Coalesced))
_ -> (CoalsTab
active, InhibitTab
inhb)
  where
    thread_slice :: Slice (TPrimExp Int64 VName)
thread_slice =
      SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
        [(VName, SubExp)]
-> ([(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)])
-> [DimIndex (TPrimExp Int64 VName)]
forall a b. a -> (a -> b) -> b
& ((VName, SubExp) -> DimIndex (TPrimExp Int64 VName))
-> [(VName, SubExp)] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> ((VName, SubExp) -> TPrimExp Int64 VName)
-> (VName, SubExp)
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> PrimExp VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimType -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (IntType -> PrimType
IntType IntType
Int64) (VName -> PrimExp VName)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst)
        [DimIndex (TPrimExp Int64 VName)]
-> ([DimIndex (TPrimExp Int64 VName)]
    -> Slice (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName)
forall a b. a -> (a -> b) -> b
& [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice
    resultSlice :: LMAD -> LMAD
resultSlice LMAD
ixf = LMAD -> Slice (TPrimExp Int64 VName) -> LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD
ixf (Slice (TPrimExp Int64 VName) -> LMAD)
-> Slice (TPrimExp Int64 VName) -> LMAD
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice (LMAD -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD
ixf) Slice (TPrimExp Int64 VName)
thread_slice
makeSegMapCoals lvl -> Bool
_ lvl
_ TopdownEnv rep
td_env KernelBody (Aliases rep)
_ Certs
_ (CoalsTab, InhibitTab)
x (PatElem (VarAliases, LParamMem)
_, SegSpace
_, WriteReturns Certs
_ VName
return_name [(Slice SubExp, SubExp)]
_) =
  case VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
return_name (Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env of
    Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
return_mem LMAD
_) -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
x VName
return_mem
    Maybe ArrayMemBound
Nothing -> String -> (CoalsTab, InhibitTab)
forall a. HasCallStack => String -> a
error String
"Should not happen?"
makeSegMapCoals lvl -> Bool
_ lvl
_ TopdownEnv rep
td_env KernelBody (Aliases rep)
_ Certs
_ (CoalsTab, InhibitTab)
x (PatElem (VarAliases, LParamMem)
_, SegSpace
_, KernelResult
result) =
  KernelResult -> Names
forall a. FreeIn a => a -> Names
freeIn KernelResult
result
    Names -> (Names -> [VName]) -> [VName]
forall a b. a -> (a -> b) -> b
& Names -> [VName]
namesToList
    [VName] -> ([VName] -> [ArrayMemBound]) -> [ArrayMemBound]
forall a b. a -> (a -> b) -> b
& (VName -> Maybe ArrayMemBound) -> [VName] -> [ArrayMemBound]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((VName -> Scope (Aliases rep) -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> VName -> Maybe ArrayMemBound
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Scope (Aliases rep) -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo (Scope (Aliases rep) -> VName -> Maybe ArrayMemBound)
-> Scope (Aliases rep) -> VName -> Maybe ArrayMemBound
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env)
    [ArrayMemBound]
-> ([ArrayMemBound] -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
forall a b. a -> (a -> b) -> b
& (ArrayMemBound -> (CoalsTab, InhibitTab) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [ArrayMemBound]
-> (CoalsTab, InhibitTab)
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> VName -> (CoalsTab, InhibitTab) -> (CoalsTab, InhibitTab)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (VName -> (CoalsTab, InhibitTab) -> (CoalsTab, InhibitTab))
-> (ArrayMemBound -> VName)
-> ArrayMemBound
-> (CoalsTab, InhibitTab)
-> (CoalsTab, InhibitTab)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayMemBound -> VName
memName) (CoalsTab, InhibitTab)
x

fullSlice :: [TPrimExp Int64 VName] -> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice :: [TPrimExp Int64 VName]
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
fullSlice [TPrimExp Int64 VName]
shp (Slice [DimIndex (TPrimExp Int64 VName)]
slc) =
  [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)]
slc [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (\TPrimExp Int64 VName
d -> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
d TPrimExp Int64 VName
1) (Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([DimIndex (TPrimExp Int64 VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TPrimExp Int64 VName)]
slc) [TPrimExp Int64 VName]
shp)

fixPointCoalesce ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  [Param FParamMem] ->
  Body (Aliases rep) ->
  TopdownEnv rep ->
  ShortCircuitM rep CoalsTab
fixPointCoalesce :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [Param FParamMem]
fpar Body (Aliases rep)
bdy TopdownEnv rep
topenv = do
  BotUpEnv
buenv <- InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
bdy) TopdownEnv rep
topenv (BotUpEnv
emptyBotUpEnv {inhibit = inhibited topenv})
  let succ_tab :: CoalsTab
succ_tab = BotUpEnv -> CoalsTab
successCoals BotUpEnv
buenv
      actv_tab :: CoalsTab
actv_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
buenv
      inhb_tab :: InhibitTab
inhb_tab = BotUpEnv -> InhibitTab
inhibit BotUpEnv
buenv
      -- Allow short-circuiting function parameters that are unique and have
      -- matching index functions, otherwise mark as failed
      handleFunctionParams :: (CoalsTab, InhibitTab, CoalsTab)
-> (a, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
handleFunctionParams (CoalsTab
a, InhibitTab
i, CoalsTab
s) (a
_, Uniqueness
u, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m LMAD
ixf) =
        case (Uniqueness
u, VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
a) of
          (Uniqueness
Unique, Just CoalsEntry
entry)
            | CoalsEntry -> LMAD
dstind CoalsEntry
entry LMAD -> LMAD -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD
ixf,
              Set Set LMAD
dst_uses <- MemRefs -> AccessSummary
dstrefs (CoalsEntry -> MemRefs
memrefs CoalsEntry
entry),
              Set LMAD
dst_uses Set LMAD -> Set LMAD -> Bool
forall a. Eq a => a -> a -> Bool
== Set LMAD
forall a. Monoid a => a
mempty ->
                let (CoalsTab
a', CoalsTab
s') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
a, CoalsTab
s) VName
m CoalsEntry
entry
                 in (CoalsTab
a', InhibitTab
i, CoalsTab
s')
          (Uniqueness, Maybe CoalsEntry)
_ ->
            let (CoalsTab
a', InhibitTab
i') = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
a, InhibitTab
i) VName
m
             in (CoalsTab
a', InhibitTab
i', CoalsTab
s)
      (CoalsTab
actv_tab', InhibitTab
inhb_tab', CoalsTab
succ_tab') =
        ((CoalsTab, InhibitTab, CoalsTab)
 -> (VName, Uniqueness, ArrayMemBound)
 -> (CoalsTab, InhibitTab, CoalsTab))
-> (CoalsTab, InhibitTab, CoalsTab)
-> [(VName, Uniqueness, ArrayMemBound)]
-> (CoalsTab, InhibitTab, CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          (CoalsTab, InhibitTab, CoalsTab)
-> (VName, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
forall {a}.
(CoalsTab, InhibitTab, CoalsTab)
-> (a, Uniqueness, ArrayMemBound)
-> (CoalsTab, InhibitTab, CoalsTab)
handleFunctionParams
          (CoalsTab
actv_tab, InhibitTab
inhb_tab, CoalsTab
succ_tab)
          ([(VName, Uniqueness, ArrayMemBound)]
 -> (CoalsTab, InhibitTab, CoalsTab))
-> [(VName, Uniqueness, ArrayMemBound)]
-> (CoalsTab, InhibitTab, CoalsTab)
forall a b. (a -> b) -> a -> b
$ [Param FParamMem] -> [(VName, Uniqueness, ArrayMemBound)]
getArrMemAssocFParam [Param FParamMem]
fpar

      (CoalsTab
succ_tab'', InhibitTab
failed_optdeps) = CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
succ_tab' InhibitTab
forall k a. Map k a
M.empty
      inhb_tab'' :: InhibitTab
inhb_tab'' = (Names -> Names -> Names) -> InhibitTab -> InhibitTab -> InhibitTab
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>) InhibitTab
failed_optdeps InhibitTab
inhb_tab'
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ CoalsTab -> Bool
forall k a. Map k a -> Bool
M.null CoalsTab
actv_tab'
    then String -> ShortCircuitM rep CoalsTab
forall a. HasCallStack => String -> a
error (String
"COALESCING ROOT: BROKEN INV, active not empty: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [VName] -> String
forall a. Show a => a -> String
show (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
actv_tab'))
    else
      if InhibitTab -> Bool
forall k a. Map k a -> Bool
M.null (InhibitTab -> Bool) -> InhibitTab -> Bool
forall a b. (a -> b) -> a -> b
$ InhibitTab
inhb_tab'' InhibitTab -> InhibitTab -> InhibitTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` TopdownEnv rep -> InhibitTab
forall rep. TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
topenv
        then CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
succ_tab''
        else InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> [Param FParamMem]
-> Body (Aliases rep)
-> TopdownEnv rep
-> ShortCircuitM rep CoalsTab
fixPointCoalesce InhibitTab
lutab [Param FParamMem]
fpar Body (Aliases rep)
bdy (TopdownEnv rep
topenv {inhibited = inhb_tab''})
  where
    fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
    fixPointFilterDeps :: CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
coaltab InhibitTab
inhbtab =
      let (CoalsTab
coaltab', InhibitTab
inhbtab') = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
filterDeps (CoalsTab
coaltab, InhibitTab
inhbtab) (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab)
       in if [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
coaltab')
            then (CoalsTab
coaltab', InhibitTab
inhbtab')
            else CoalsTab -> InhibitTab -> (CoalsTab, InhibitTab)
fixPointFilterDeps CoalsTab
coaltab' InhibitTab
inhbtab'

    filterDeps :: (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
filterDeps (CoalsTab
coal, InhibitTab
inhb) VName
mb
      | Bool -> Bool
not (VName -> CoalsTab -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member VName
mb CoalsTab
coal) = (CoalsTab
coal, InhibitTab
inhb)
    filterDeps (CoalsTab
coal, InhibitTab
inhb) VName
mb
      | Just CoalsEntry
coal_etry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
coal =
          let failed :: Map VName VName
failed = (VName -> VName -> Bool) -> Map VName VName -> Map VName VName
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> VName -> Bool
forall {k}. Ord k => Map k CoalsEntry -> VName -> k -> Bool
failedOptDep CoalsTab
coal) (CoalsEntry -> Map VName VName
optdeps CoalsEntry
coal_etry)
           in if Map VName VName -> Bool
forall k a. Map k a -> Bool
M.null Map VName VName
failed
                then (CoalsTab
coal, InhibitTab
inhb) -- all ok
                else -- optimistic dependencies failed for the current
                -- memblock; extend inhibited mem-block mergings.
                  (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
coal, InhibitTab
inhb) VName
mb
    filterDeps (CoalsTab, InhibitTab)
_ VName
_ = String -> (CoalsTab, InhibitTab)
forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun filterDeps, impossible case reached!"
    failedOptDep :: Map k CoalsEntry -> VName -> k -> Bool
failedOptDep Map k CoalsEntry
coal VName
_ k
mr
      | Bool -> Bool
not (k
mr k -> Map k CoalsEntry -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map k CoalsEntry
coal) = Bool
True
    failedOptDep Map k CoalsEntry
coal VName
r k
mr
      | Just CoalsEntry
coal_etry <- k -> Map k CoalsEntry -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
mr Map k CoalsEntry
coal = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
r VName -> Map VName Coalesced -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry
    failedOptDep Map k CoalsEntry
_ VName
_ k
_ = String -> Bool
forall a. HasCallStack => String -> a
error String
"In ArrayCoalescing.hs, fun failedOptDep, impossible case reached!"

-- | Perform short-circuiting on 'Stms'.
mkCoalsTabStms ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  Stms (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
mkCoalsTabStms :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab Stms (Aliases rep)
stms0 = Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
stms0
  where
    non_negs_in_pats :: Names
non_negs_in_pats = (Stm (Aliases rep) -> Names) -> Stms (Aliases rep) -> Names
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat (VarAliases, LParamMem) -> Names
forall rep. Typed rep => Pat rep -> Names
nonNegativesInPat (Pat (VarAliases, LParamMem) -> Names)
-> (Stm (Aliases rep) -> Pat (VarAliases, LParamMem))
-> Stm (Aliases rep)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm (Aliases rep) -> Pat (VarAliases, LParamMem)
Stm (Aliases rep) -> Pat (LetDec (Aliases rep))
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms (Aliases rep)
stms0
    traverseStms :: Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
Empty TopdownEnv rep
_ BotUpEnv
bu_env = BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env
    traverseStms (Stm (Aliases rep)
stm :<| Stms (Aliases rep)
stms) TopdownEnv rep
td_env BotUpEnv
bu_env = do
      -- Compute @td_env@ top down
      let td_env' :: TopdownEnv rep
td_env' = TopdownEnv rep -> Stm (Aliases rep) -> TopdownEnv rep
forall rep (inner :: * -> *).
(ASTRep rep, Op rep ~ MemOp inner rep,
 TopDownHelper (inner (Aliases rep))) =>
TopdownEnv rep -> Stm (Aliases rep) -> TopdownEnv rep
updateTopdownEnv TopdownEnv rep
td_env Stm (Aliases rep)
stm
      -- Compute @bu_env@ bottom up
      BotUpEnv
bu_env' <- Stms (Aliases rep)
-> TopdownEnv rep -> BotUpEnv -> ShortCircuitM rep BotUpEnv
traverseStms Stms (Aliases rep)
stms TopdownEnv rep
td_env' BotUpEnv
bu_env
      InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStm InhibitTab
lutab Stm (Aliases rep)
stm (TopdownEnv rep
td_env' {nonNegatives = nonNegatives td_env' <> non_negs_in_pats}) BotUpEnv
bu_env'

-- | Array (register) coalescing can have one of three shapes:
--      a) @let y    = copy(b^{lu})@
--      b) @let y    = concat(a, b^{lu})@
--      c) @let y[i] = b^{lu}@
--   The intent is to use the memory block of the left-hand side
--     for the right-hand side variable, meaning to store @b@ in
--     @m_y@ (rather than @m_b@).
--   The following five safety conditions are necessary:
--      1. the right-hand side is lastly-used in the current statement
--      2. the allocation of @m_y@ dominates the creation of @b@
--         ^ relax it by hoisting the allocation of @m_y@
--      3. there is no use of the left-hand side memory block @m_y@
--           during the liveness of @b@, i.e., in between its last use
--           and its creation.
--         ^ relax it by pointwise/interval-based checking
--      4. @b@ is a newly created array, i.e., does not aliases anything
--         ^ relax it to support exitential memory blocks for if-then-else
--      5. the new index function of @b@ corresponding to memory block @m_y@
--           can be translated at the definition of @b@, and the
--           same for all variables aliasing @b@.
--   Observation: during the live range of @b@, @m_b@ can only be used by
--                variables aliased with @b@, because @b@ is newly created.
--                relax it: in case @m_b@ is existential due to an if-then-else
--                          then the checks should be extended to the actual
--                          array-creation points.
mkCoalsTabStm ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  Stm (Aliases rep) ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep BotUpEnv
mkCoalsTabStm :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stm (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStm InhibitTab
_ (Let (Pat [PatElem (LetDec (Aliases rep))
pe]) StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e) TopdownEnv rep
td_env BotUpEnv
bu_env
  | Just PrimExp VName
primexp <- (VName -> Maybe (PrimExp VName))
-> Exp (Aliases rep) -> Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
forall rep.
AliasableRep rep =>
ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
vnameToPrimExp (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)) Exp (Aliases rep)
e =
      BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {scals = M.insert (patElemName pe) primexp (scals bu_env)}
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
patt StmAux (ExpDec (Aliases rep))
_ (Match [SubExp]
_ [Case (Body (Aliases rep))]
cases Body (Aliases rep)
defbody MatchDec (BranchType (Aliases rep))
_)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  let pat_val_elms :: [PatElem (VarAliases, LParamMem)]
pat_val_elms = Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
patt
      -- ToDo: 1. we need to record existential memory blocks in alias table on the top-down pass.
      --       2. need to extend the scope table

      --  i) Filter @activeCoals@ by the 2ND AND 5th safety conditions:
      (CoalsTab
activeCoals0, InhibitTab
inhibit0) =
        CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          [PatElem (VarAliases, LParamMem)]
pat_val_elms

      -- ii) extend @activeCoals@ by transfering the pattern-elements bindings existent
      --     in @activeCoals@ to the body results of the then and else branches, but only
      --     if the current pattern element can be potentially coalesced and also
      --     if the current pattern element satisfies safety conditions 2 & 5.
      res_mem_def :: [MemBodyResult]
res_mem_def = CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals0 (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PatElem (VarAliases, LParamMem)]
pat_val_elms Body (Aliases rep)
defbody
      res_mem_cases :: [[MemBodyResult]]
res_mem_cases = (Case (Body (Aliases rep)) -> [MemBodyResult])
-> [Case (Body (Aliases rep))] -> [[MemBodyResult]]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals0 (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) [PatElem (VarAliases, LParamMem)]
pat_val_elms (Body (Aliases rep) -> [MemBodyResult])
-> (Case (Body (Aliases rep)) -> Body (Aliases rep))
-> Case (Body (Aliases rep))
-> [MemBodyResult]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody) [Case (Body (Aliases rep))]
cases

      subs_def :: FreeVarSubsts
subs_def = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
patt ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
defbody
      subs_cases :: [FreeVarSubsts]
subs_cases = (Case (Body (Aliases rep)) -> FreeVarSubsts)
-> [Case (Body (Aliases rep))] -> [FreeVarSubsts]
forall a b. (a -> b) -> [a] -> [b]
map (Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
patt ([SubExp] -> FreeVarSubsts)
-> (Case (Body (Aliases rep)) -> [SubExp])
-> Case (Body (Aliases rep))
-> FreeVarSubsts
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp])
-> (Case (Body (Aliases rep)) -> Result)
-> Case (Body (Aliases rep))
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult (Body (Aliases rep) -> Result)
-> (Case (Body (Aliases rep)) -> Body (Aliases rep))
-> Case (Body (Aliases rep))
-> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody) [Case (Body (Aliases rep))]
cases

      actv_def_i :: CoalsTab
actv_def_i = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_def) CoalsTab
activeCoals0 [MemBodyResult]
res_mem_def
      actv_cases_i :: [CoalsTab]
actv_cases_i = (FreeVarSubsts -> [MemBodyResult] -> CoalsTab)
-> [FreeVarSubsts] -> [[MemBodyResult]] -> [CoalsTab]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\FreeVarSubsts
subs [MemBodyResult]
res -> (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs) CoalsTab
activeCoals0 [MemBodyResult]
res) [FreeVarSubsts]
subs_cases [[MemBodyResult]]
res_mem_cases

      -- eliminate the original pattern binding of the if statement,
      -- @let x = if y[0,0] > 0 then map (+y[0,0]) a else map (+1) b@
      -- @let y[0] = x@
      -- should succeed because @m_y@ is used before @x@ is created.
      aux :: Map VName a -> MemBodyResult -> Map VName a
aux Map VName a
ac (MemBodyResult VName
m_b VName
_ VName
_ VName
m_r) = if VName
m_b VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_r then Map VName a
ac else VName -> Map VName a -> Map VName a
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
m_b Map VName a
ac
      actv_def :: CoalsTab
actv_def = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> MemBodyResult -> CoalsTab
forall {a}. Map VName a -> MemBodyResult -> Map VName a
aux CoalsTab
actv_def_i [MemBodyResult]
res_mem_def
      actv_cases :: [CoalsTab]
actv_cases = (CoalsTab -> [MemBodyResult] -> CoalsTab)
-> [CoalsTab] -> [[MemBodyResult]] -> [CoalsTab]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> MemBodyResult -> CoalsTab
forall {a}. Map VName a -> MemBodyResult -> Map VName a
aux) [CoalsTab]
actv_cases_i [[MemBodyResult]]
res_mem_cases

  -- iii) process the then and else bodies
  BotUpEnv
res_def <- InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
defbody) TopdownEnv rep
td_env (BotUpEnv
bu_env {activeCoals = actv_def})
  [BotUpEnv]
res_cases <- (Case (Body (Aliases rep))
 -> CoalsTab -> ShortCircuitM rep BotUpEnv)
-> [Case (Body (Aliases rep))]
-> [CoalsTab]
-> ShortCircuitM rep [BotUpEnv]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Case (Body (Aliases rep))
c CoalsTab
a -> InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms InhibitTab
lutab (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms (Body (Aliases rep) -> Stms (Aliases rep))
-> Body (Aliases rep) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Case (Body (Aliases rep)) -> Body (Aliases rep)
forall body. Case body -> body
caseBody Case (Body (Aliases rep))
c) TopdownEnv rep
td_env (BotUpEnv
bu_env {activeCoals = a})) [Case (Body (Aliases rep))]
cases [CoalsTab]
actv_cases
  let (CoalsTab
actv_def0, CoalsTab
succ_def0, InhibitTab
inhb_def0) = (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
res_def, BotUpEnv -> CoalsTab
successCoals BotUpEnv
res_def, BotUpEnv -> InhibitTab
inhibit BotUpEnv
res_def)

      -- iv) optimistically mark the pattern succesful:
      ((CoalsTab
activeCoals1, InhibitTab
inhibit1), CoalsTab
successCoals1) =
        (((CoalsTab, InhibitTab), CoalsTab)
 -> [MemBodyResult] -> ((CoalsTab, InhibitTab), CoalsTab))
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [[MemBodyResult]]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( [(CoalsTab, CoalsTab)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall {c}.
[(CoalsTab, Map VName c)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun
              ( (CoalsTab
actv_def0, CoalsTab
succ_def0)
                  (CoalsTab, CoalsTab)
-> [(CoalsTab, CoalsTab)] -> [(CoalsTab, CoalsTab)]
forall a. a -> [a] -> [a]
: [CoalsTab] -> [CoalsTab] -> [(CoalsTab, CoalsTab)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases) ((BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
successCoals [BotUpEnv]
res_cases)
              )
          )
          ((CoalsTab
activeCoals0, InhibitTab
inhibit0), BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env)
          ([[MemBodyResult]] -> [[MemBodyResult]]
forall a. [[a]] -> [[a]]
L.transpose ([[MemBodyResult]] -> [[MemBodyResult]])
-> [[MemBodyResult]] -> [[MemBodyResult]]
forall a b. (a -> b) -> a -> b
$ [MemBodyResult]
res_mem_def [MemBodyResult] -> [[MemBodyResult]] -> [[MemBodyResult]]
forall a. a -> [a] -> [a]
: [[MemBodyResult]]
res_mem_cases)

      --  v) unify coalescing results of all branches by taking the union
      --     of all entries in the current/then/else success tables.

      actv_res :: CoalsTab
actv_res = (CoalsTab -> CoalsTab -> CoalsTab)
-> CoalsTab -> [CoalsTab] -> CoalsTab
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsTab -> CoalsTab -> CoalsTab
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry) CoalsTab
activeCoals1 ([CoalsTab] -> CoalsTab) -> [CoalsTab] -> CoalsTab
forall a b. (a -> b) -> a -> b
$ CoalsTab
actv_def0 CoalsTab -> [CoalsTab] -> [CoalsTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases

      succ_res :: CoalsTab
succ_res = (CoalsTab -> CoalsTab -> CoalsTab)
-> CoalsTab -> [CoalsTab] -> CoalsTab
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsTab -> CoalsTab -> CoalsTab
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry) CoalsTab
successCoals1 ([CoalsTab] -> CoalsTab) -> [CoalsTab] -> CoalsTab
forall a b. (a -> b) -> a -> b
$ CoalsTab
succ_def0 CoalsTab -> [CoalsTab] -> [CoalsTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
successCoals [BotUpEnv]
res_cases

      -- vi) The step of filtering by 3rd safety condition is not
      --       necessary, because we perform index analysis of the
      --       source/destination uses, and they should have been
      --       filtered during the analysis of the then/else bodies.
      inhibit_res :: InhibitTab
inhibit_res =
        (Names -> Names -> Names) -> [InhibitTab] -> InhibitTab
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
M.unionsWith
          Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
(<>)
          ( InhibitTab
inhibit1
              InhibitTab -> [InhibitTab] -> [InhibitTab]
forall a. a -> [a] -> [a]
: (CoalsTab -> InhibitTab -> InhibitTab)
-> [CoalsTab] -> [InhibitTab] -> [InhibitTab]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                ( \CoalsTab
actv InhibitTab
inhb ->
                    let failed :: CoalsTab
failed = CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
actv (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$ (CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsTab -> CoalsTab -> CoalsTab
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsTab
actv CoalsTab
activeCoals0
                     in (CoalsTab, InhibitTab) -> InhibitTab
forall a b. (a, b) -> b
snd ((CoalsTab, InhibitTab) -> InhibitTab)
-> (CoalsTab, InhibitTab) -> InhibitTab
forall a b. (a -> b) -> a -> b
$ ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
failed, InhibitTab
inhb) (CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
failed)
                )
                (CoalsTab
actv_def0 CoalsTab -> [CoalsTab] -> [CoalsTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> CoalsTab) -> [BotUpEnv] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> CoalsTab
activeCoals [BotUpEnv]
res_cases)
                (InhibitTab
inhb_def0 InhibitTab -> [InhibitTab] -> [InhibitTab]
forall a. a -> [a] -> [a]
: (BotUpEnv -> InhibitTab) -> [BotUpEnv] -> [InhibitTab]
forall a b. (a -> b) -> [a] -> [b]
map BotUpEnv -> InhibitTab
inhibit [BotUpEnv]
res_cases)
          )
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    BotUpEnv
bu_env
      { activeCoals = actv_res,
        successCoals = succ_res,
        inhibit = inhibit_res
      }
  where
    foldfun :: [(CoalsTab, Map VName c)]
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [MemBodyResult]
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab, InhibitTab), CoalsTab)
_ [] =
      String -> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => String -> a
error String
"Imposible Case 1!!!"
    foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab
act, InhibitTab
_), CoalsTab
_) [MemBodyResult]
mem_body_results
      | Maybe CoalsEntry
Nothing <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (MemBodyResult -> VName
patMem (MemBodyResult -> VName) -> MemBodyResult -> VName
forall a b. (a -> b) -> a -> b
$ [MemBodyResult] -> MemBodyResult
forall a. HasCallStack => [a] -> a
head [MemBodyResult]
mem_body_results) CoalsTab
act =
          String -> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => String -> a
error String
"Imposible Case 2!!!"
    foldfun
      [(CoalsTab, Map VName c)]
acc
      ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      mem_body_results :: [MemBodyResult]
mem_body_results@(MemBodyResult VName
m_b VName
_ VName
_ VName
_ : [MemBodyResult]
_)
        | Just CoalsEntry
info <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
          Just [c]
_ <- (MemBodyResult -> Map VName c -> Maybe c)
-> [MemBodyResult] -> [Map VName c] -> Maybe [c]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (VName -> Map VName c -> Maybe c
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName -> Map VName c -> Maybe c)
-> (MemBodyResult -> VName)
-> MemBodyResult
-> Map VName c
-> Maybe c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results ([Map VName c] -> Maybe [c]) -> [Map VName c] -> Maybe [c]
forall a b. (a -> b) -> a -> b
$ ((CoalsTab, Map VName c) -> Map VName c)
-> [(CoalsTab, Map VName c)] -> [Map VName c]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab, Map VName c) -> Map VName c
forall a b. (a, b) -> b
snd [(CoalsTab, Map VName c)]
acc =
            -- Optimistically promote to successful coalescing and append!
            let info' :: CoalsEntry
info' =
                  CoalsEntry
info
                    { optdeps =
                        foldr
                          (\MemBodyResult
mbr -> VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (MemBodyResult -> VName
bodyName MemBodyResult
mbr) (MemBodyResult -> VName
bodyMem MemBodyResult
mbr))
                          (optdeps info)
                          mem_body_results
                    }
                (CoalsTab
act', CoalsTab
succc') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
act, CoalsTab
succc) VName
m_b CoalsEntry
info'
             in ((CoalsTab
act', InhibitTab
inhb), CoalsTab
succc')
    foldfun
      [(CoalsTab, Map VName c)]
acc
      ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      mem_body_results :: [MemBodyResult]
mem_body_results@(MemBodyResult VName
m_b VName
_ VName
_ VName
_ : [MemBodyResult]
_)
        | Just CoalsEntry
info <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
          (MemBodyResult -> Bool) -> [MemBodyResult] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
(==) VName
m_b (VName -> Bool)
-> (MemBodyResult -> VName) -> MemBodyResult -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results,
          Just [CoalsEntry]
info' <- (MemBodyResult -> CoalsTab -> Maybe CoalsEntry)
-> [MemBodyResult] -> [CoalsTab] -> Maybe [CoalsEntry]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName -> CoalsTab -> Maybe CoalsEntry)
-> (MemBodyResult -> VName)
-> MemBodyResult
-> CoalsTab
-> Maybe CoalsEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBodyResult -> VName
bodyMem) [MemBodyResult]
mem_body_results ([CoalsTab] -> Maybe [CoalsEntry])
-> [CoalsTab] -> Maybe [CoalsEntry]
forall a b. (a -> b) -> a -> b
$ ((CoalsTab, Map VName c) -> CoalsTab)
-> [(CoalsTab, Map VName c)] -> [CoalsTab]
forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab, Map VName c) -> CoalsTab
forall a b. (a, b) -> a
fst [(CoalsTab, Map VName c)]
acc =
            -- Treating special case resembling:
            -- @let x0 = map (+1) a                                  @
            -- @let x3 = if cond then let x1 = x0 with [0] <- 2 in x1@
            -- @                 else let x2 = x0 with [1] <- 3 in x2@
            -- @let z[1] = x3                                        @
            -- In this case the result active table should be the union
            -- of the @m_x@ entries of the then and else active tables.
            let info'' :: CoalsEntry
info'' =
                  (CoalsEntry -> CoalsEntry -> CoalsEntry)
-> CoalsEntry -> [CoalsEntry] -> CoalsEntry
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsEntry -> CoalsEntry -> CoalsEntry
unionCoalsEntry CoalsEntry
info [CoalsEntry]
info'
                act' :: CoalsTab
act' = VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info'' CoalsTab
act
             in ((CoalsTab
act', InhibitTab
inhb), CoalsTab
succc)
    foldfun [(CoalsTab, Map VName c)]
_ ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) (MemBodyResult
mbr : [MemBodyResult]
_) =
      -- one of the branches has failed coalescing,
      -- hence remove the coalescing of the result.

      ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) (MemBodyResult -> VName
patMem MemBodyResult
mbr), CoalsTab
succc)
mkCoalsTabStm InhibitTab
lutab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (Loop [(FParam (Aliases rep), SubExp)]
arginis LoopForm
lform Body (Aliases rep)
body)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  let pat_val_elms :: [PatElem (VarAliases, LParamMem)]
pat_val_elms = Pat (VarAliases, LParamMem) -> [PatElem (VarAliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat

      --  i) Filter @activeCoals@ by the 2nd, 3rd AND 5th safety conditions. In
      --  other words, for each active coalescing target, the creation of the
      --  array we're trying to merge should happen before the allocation of the
      --  merge target and the index function should be translateable.
      (CoalsTab
actv0, InhibitTab
inhibit0) =
        CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5
          (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env)
          (BotUpEnv -> InhibitTab
inhibit BotUpEnv
bu_env)
          (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env)
          TopdownEnv rep
td_env
          [PatElem (VarAliases, LParamMem)]
pat_val_elms
      -- ii) Extend @activeCoals@ by transfering the pattern-elements bindings
      --     existent in @activeCoals@ to the loop-body results, but only if:
      --       (a) the pattern element is a candidate for coalescing,        &&
      --       (b) the pattern element satisfies safety conditions 2 & 5,
      --           (conditions (a) and (b) have already been checked above), &&
      --       (c) the memory block of the corresponding body result is
      --           allocated outside the loop, i.e., non-existential,        &&
      --       (d) the init name is lastly-used in the initialization
      --           of the loop variant.
      --     Otherwise fail and remove from active-coalescing table!
      bdy_ress :: Result
bdy_ress = Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
body
      ([(VName, VName)]
patmems, [(VName, VName)]
argmems, [(VName, VName)]
inimems, [(VName, VName)]
resmems) =
        [((VName, VName), (VName, VName), (VName, VName), (VName, VName))]
-> ([(VName, VName)], [(VName, VName)], [(VName, VName)],
    [(VName, VName)])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
L.unzip4 ([((VName, VName), (VName, VName), (VName, VName), (VName, VName))]
 -> ([(VName, VName)], [(VName, VName)], [(VName, VName)],
     [(VName, VName)]))
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
-> ([(VName, VName)], [(VName, VName)], [(VName, VName)],
    [(VName, VName)])
forall a b. (a -> b) -> a -> b
$
          ((PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
  SubExp)
 -> Maybe
      ((VName, VName), (VName, VName), (VName, VName), (VName, VName)))
-> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
     SubExp)]
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (CoalsTab
-> (PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
    SubExp)
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
mapmbFun CoalsTab
actv0) ([PatElem (VarAliases, LParamMem)]
-> [(Param FParamMem, SubExp)]
-> [SubExp]
-> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
     SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (VarAliases, LParamMem)]
pat_val_elms [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis ([SubExp]
 -> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
      SubExp)])
-> [SubExp]
-> [(PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
     SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
bdy_ress) -- td_env'

      -- remove the other pattern elements from the active coalescing table:
      coal_pat_names :: Names
coal_pat_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((VName, VName) -> VName) -> [(VName, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName) -> VName
forall a b. (a, b) -> a
fst [(VName, VName)]
patmems
      (CoalsTab
actv1, InhibitTab
inhibit1) =
        ((CoalsTab, InhibitTab)
 -> (VName, ArrayMemBound) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [(VName, ArrayMemBound)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \(CoalsTab
act, InhibitTab
inhb) (VName
b, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_b LMAD
_) ->
              if VName
b VName -> Names -> Bool
`nameIn` Names
coal_pat_names
                then (CoalsTab
act, InhibitTab
inhb) -- ok
                else (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) VName
m_b -- remove from active
          )
          (CoalsTab
actv0, InhibitTab
inhibit0)
          (Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat)

      -- iii) Process the loop's body.
      --      If the memory blocks of the loop result and loop variant param differ
      --      then make the original memory block of the loop result conflict with
      --      the original memory block of the loop parameter. This is done in
      --      order to prevent the coalescing of @a1@, @a0@, @x@ and @db@ in the
      --      same memory block of @y@ in the example below:
      --      @loop(a1 = a0) = for i < n do @
      --      @    let x = map (stencil a1) (iota n)@
      --      @    let db = copy x          @
      --      @    in db                    @
      --      @let y[0] = a1                @
      --      Meaning the coalescing of @x@ in @let db = copy x@ should fail because
      --      @a1@ appears in the definition of @let x = map (stencil a1) (iota n)@.
      res_mem_bdy :: [MemBodyResult]
res_mem_bdy = ((VName, VName) -> (VName, VName) -> MemBodyResult)
-> [(VName, VName)] -> [(VName, VName)] -> [MemBodyResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
resmems
      res_mem_arg :: [MemBodyResult]
res_mem_arg = ((VName, VName) -> (VName, VName) -> MemBodyResult)
-> [(VName, VName)] -> [(VName, VName)] -> [MemBodyResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
argmems
      res_mem_ini :: [MemBodyResult]
res_mem_ini = ((VName, VName) -> (VName, VName) -> MemBodyResult)
-> [(VName, VName)] -> [(VName, VName)] -> [MemBodyResult]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VName
b, VName
m_b) (VName
r, VName
m_r) -> VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r) [(VName, VName)]
patmems [(VName, VName)]
inimems

      actv2 :: CoalsTab
actv2 =
        let subs_res :: FreeVarSubsts
subs_res = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
body
            actv11 :: CoalsTab
actv11 = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_res) CoalsTab
actv1 [MemBodyResult]
res_mem_bdy
            subs_arg :: FreeVarSubsts
subs_arg = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis
            actv12 :: CoalsTab
actv12 = (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_arg) CoalsTab
actv11 [MemBodyResult]
res_mem_arg
            subs_ini :: FreeVarSubsts
subs_ini = Pat (VarAliases, LParamMem) -> [SubExp] -> FreeVarSubsts
forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat ([SubExp] -> FreeVarSubsts) -> [SubExp] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis
         in (CoalsTab -> MemBodyResult -> CoalsTab)
-> CoalsTab -> [MemBodyResult] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
subs_ini) CoalsTab
actv12 [MemBodyResult]
res_mem_ini

      -- The code below adds an aliasing relation to the loop-arg memory
      --   so that to prevent, e.g., the coalescing of an iterative stencil
      --   (you need a buffer for the result and a separate one for the stencil).
      -- @ let b =               @
      -- @    loop (a) for i<N do@
      -- @        stencil a      @
      -- @  ...                  @
      -- @  y[slc_y] = b         @
      -- This should fail coalescing because we are aliasing @m_a@ with
      --   the memory block of the result.
      insertMemAliases :: CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab
insertMemAliases CoalsTab
tab (MemBodyResult VName
_ VName
_ VName
_ VName
m_r, MemBodyResult VName
_ VName
_ VName
_ VName
m_a) =
        if VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_a
          then CoalsTab
tab
          else case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_r CoalsTab
tab of
            Maybe CoalsEntry
Nothing -> CoalsTab
tab
            Just CoalsEntry
etry ->
              VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r (CoalsEntry
etry {alsmem = alsmem etry <> oneName m_a}) CoalsTab
tab
      actv3 :: CoalsTab
actv3 = (CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab)
-> CoalsTab -> [(MemBodyResult, MemBodyResult)] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> (MemBodyResult, MemBodyResult) -> CoalsTab
insertMemAliases CoalsTab
actv2 ([MemBodyResult]
-> [MemBodyResult] -> [(MemBodyResult, MemBodyResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [MemBodyResult]
res_mem_bdy [MemBodyResult]
res_mem_arg)
      -- analysing the loop body starts from a null memory-reference set;
      --  the results of the loop body iteration are aggregated later
      actv4 :: CoalsTab
actv4 = (CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (\CoalsEntry
etry -> CoalsEntry
etry {memrefs = mempty}) CoalsTab
actv3
  BotUpEnv
res_env_body <-
    InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> Stms (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
mkCoalsTabStms
      InhibitTab
lutab
      (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
      TopdownEnv rep
td_env'
      ( BotUpEnv
bu_env
          { activeCoals = actv4,
            inhibit = inhibit1
          }
      )
  let scals_loop :: Map VName (PrimExp VName)
scals_loop = BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
res_env_body
      (CoalsTab
res_actv0, CoalsTab
res_succ0, InhibitTab
res_inhb0) = (BotUpEnv -> CoalsTab
activeCoals BotUpEnv
res_env_body, BotUpEnv -> CoalsTab
successCoals BotUpEnv
res_env_body, BotUpEnv -> InhibitTab
inhibit BotUpEnv
res_env_body)
      -- iv) Aggregate memory references across loop and filter unsound coalescing
      -- a) Filter the active-table by the FIRST SOUNDNESS condition, namely:
      --     W_i does not overlap with Union_{j=i+1..n} U_j,
      --     where W_i corresponds to the Write set of src mem-block m_b,
      --     and U_j correspond to the uses of the destination
      --     mem-block m_y, in which m_b is coalesced into.
      --     W_i and U_j correspond to the accesses within the loop body.
      mb_loop_idx :: Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx = LoopForm
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mbLoopIndexRange LoopForm
lform
  CoalsTab
res_actv1 <- (CoalsEntry -> ShortCircuitM rep Bool)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall k (m :: * -> *) v.
(Eq k, Monad m) =>
(v -> m Bool) -> Map k v -> m (Map k v)
filterMapM1 (Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep Bool
loopSoundness1Entry Map VName (PrimExp VName)
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx) CoalsTab
res_actv0

  -- b) Update the memory-reference summaries across loop:
  --   W = Union_{i=0..n-1} W_i Union W_{before-loop}
  --   U = Union_{i=0..n-1} U_i Union U_{before-loop}
  CoalsTab
res_actv2 <- (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map VName a -> m (Map VName b)
mapM (ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep CoalsEntry
aggAcrossLoopEntry (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env' ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)) Map VName (PrimExp VName)
scals_loop Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mb_loop_idx) CoalsTab
res_actv1

  -- c) check soundness of the successful promotions for:
  --      - the entries that have been promoted to success during the loop-body pass
  --      - for all the entries of active table
  --    Filter the entries by the SECOND SOUNDNESS CONDITION, namely:
  --      Union_{i=1..n-1} W_i does not overlap the before-the-loop uses
  --        of the destination memory block.
  let res_actv3 :: CoalsTab
res_actv3 = (VName -> CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
actv3) CoalsTab
res_actv2

  let tmp_succ :: CoalsTab
tmp_succ =
        (VName -> CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
forall {k} {a} {p}. Ord k => Map k a -> k -> p -> Bool
okLookup CoalsTab
actv3) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$
          CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
res_succ0 (BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env)
      ver_succ :: CoalsTab
ver_succ = (VName -> CoalsEntry -> Bool) -> CoalsTab -> CoalsTab
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
actv3) CoalsTab
tmp_succ
  let suc_fail :: CoalsTab
suc_fail = CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
tmp_succ CoalsTab
ver_succ
      (CoalsTab
res_succ, InhibitTab
res_inhb1) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
res_succ0, InhibitTab
res_inhb0) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
suc_fail
      --
      act_fail :: CoalsTab
act_fail = CoalsTab -> CoalsTab -> CoalsTab
forall k a b. Ord k => Map k a -> Map k b -> Map k a
M.difference CoalsTab
res_actv0 CoalsTab
res_actv3
      (CoalsTab
_, InhibitTab
res_inhb) = ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
res_actv0, InhibitTab
res_inhb1) ([VName] -> (CoalsTab, InhibitTab))
-> [VName] -> (CoalsTab, InhibitTab)
forall a b. (a -> b) -> a -> b
$ CoalsTab -> [VName]
forall k a. Map k a -> [k]
M.keys CoalsTab
act_fail
      res_actv :: CoalsTab
res_actv =
        (VName -> CoalsEntry -> CoalsEntry) -> CoalsTab -> CoalsTab
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey (CoalsTab -> VName -> CoalsEntry -> CoalsEntry
forall {k}.
Ord k =>
Map k CoalsEntry -> k -> CoalsEntry -> CoalsEntry
addBeforeLoop CoalsTab
actv3) CoalsTab
res_actv3

      -- v) optimistically mark the pattern succesful if there is any chance to succeed
      ((CoalsTab
fin_actv1, InhibitTab
fin_inhb1), CoalsTab
fin_succ1) =
        (((CoalsTab, InhibitTab), CoalsTab)
 -> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
 -> ((CoalsTab, InhibitTab), CoalsTab))
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((CoalsTab, InhibitTab), CoalsTab)
-> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> ((CoalsTab, InhibitTab), CoalsTab)
foldFunOptimPromotion ((CoalsTab
res_actv, InhibitTab
res_inhb), CoalsTab
res_succ) ([((VName, VName), (VName, VName), (VName, VName), (VName, VName))]
 -> ((CoalsTab, InhibitTab), CoalsTab))
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a b. (a -> b) -> a -> b
$
          [(VName, VName)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [(VName, VName)]
-> [((VName, VName), (VName, VName), (VName, VName),
     (VName, VName))]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
L.zip4 [(VName, VName)]
patmems [(VName, VName)]
argmems [(VName, VName)]
resmems [(VName, VName)]
inimems
      (CoalsTab
fin_actv2, InhibitTab
fin_inhb2) =
        ((CoalsTab, InhibitTab)
 -> VName -> CoalsEntry -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> CoalsTab -> (CoalsTab, InhibitTab)
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey
          ( \(CoalsTab, InhibitTab)
acc VName
k CoalsEntry
_ ->
              if VName
k VName -> Names -> Bool
`nameIn` [VName] -> Names
namesFromList (((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis)
                then (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab, InhibitTab)
acc VName
k
                else (CoalsTab, InhibitTab)
acc
          )
          (CoalsTab
fin_actv1, InhibitTab
fin_inhb1)
          CoalsTab
fin_actv1
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env {activeCoals = fin_actv2, successCoals = fin_succ1, inhibit = fin_inhb2}
  where
    allocs_bdy :: AllocTab
allocs_bdy = (AllocTab -> Stm (Aliases rep) -> AllocTab)
-> AllocTab -> Stms (Aliases rep) -> AllocTab
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl AllocTab -> Stm (Aliases rep) -> AllocTab
forall {rep} {inner :: * -> *}.
(OpC rep ~ MemOp inner) =>
AllocTab -> Stm rep -> AllocTab
getAllocs (TopdownEnv rep -> AllocTab
forall rep. TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env') (Stms (Aliases rep) -> AllocTab) -> Stms (Aliases rep) -> AllocTab
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body
    td_env_allocs :: TopdownEnv rep
td_env_allocs = TopdownEnv rep
td_env' {alloc = allocs_bdy, scope = scope td_env' <> scopeOf (bodyStms body)}
    td_env' :: TopdownEnv rep
td_env' = TopdownEnv rep
-> [(FParam rep, SubExp)] -> LoopForm -> TopdownEnv rep
forall rep.
TopdownEnv rep
-> [(FParam rep, SubExp)] -> LoopForm -> TopdownEnv rep
updateTopdownEnvLoop TopdownEnv rep
td_env [(FParam rep, SubExp)]
[(FParam (Aliases rep), SubExp)]
arginis LoopForm
lform
    getAllocs :: AllocTab -> Stm rep -> AllocTab
getAllocs AllocTab
tab (Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ Space
sp))) =
      VName -> Space -> AllocTab -> AllocTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Space
sp AllocTab
tab
    getAllocs AllocTab
tab Stm rep
_ = AllocTab
tab
    okLookup :: Map k a -> k -> p -> Bool
okLookup Map k a
tab k
m p
_
      | Just a
_ <- k -> Map k a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
m Map k a
tab = Bool
True
    okLookup Map k a
_ k
_ p
_ = Bool
False
    --
    mapmbFun :: CoalsTab
-> (PatElem (VarAliases, LParamMem), (Param FParamMem, SubExp),
    SubExp)
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
mapmbFun CoalsTab
actv0 (PatElem (VarAliases, LParamMem)
patel, (Param FParamMem
arg, SubExp
ini), SubExp
bdyres)
      | VName
b <- PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
patel,
        (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_b LMAD
_)) <- PatElem (VarAliases, LParamMem) -> (VarAliases, LParamMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LParamMem)
patel,
        VName
a <- Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
arg,
        -- Not safe to short-circuit if the index function of this
        -- parameter is variant to the loop.
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> Bool)
-> [(Param FParamMem, SubExp)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` FParamMem -> Names
forall a. FreeIn a => a -> Names
freeIn (Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
arg)) (VName -> Bool)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis,
        Var VName
a0 <- SubExp
ini,
        Var VName
r <- SubExp
bdyres,
        Just CoalsEntry
coal_etry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
actv0,
        Just Coalesced
_ <- VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_a LMAD
_) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
a (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_a0 LMAD
_) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
a0 (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_r LMAD
_) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
r (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env_allocs),
        Just Names
nms <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
a InhibitTab
lutab,
        VName
a0 VName -> Names -> Bool
`nameIn` Names
nms,
        VName
m_r VName -> [VName] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` AllocTab -> [VName]
forall k a. Map k a -> [k]
M.keys (TopdownEnv rep -> AllocTab
forall rep. TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env_allocs) =
          ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> Maybe
     ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
forall a. a -> Maybe a
Just ((VName
b, VName
m_b), (VName
a, VName
m_a), (VName
a0, VName
m_a0), (VName
r, VName
m_r))
    mapmbFun CoalsTab
_ (PatElem (VarAliases, LParamMem)
_patel, (Param FParamMem
_arg, SubExp
_ini), SubExp
_bdyres) = Maybe
  ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
forall a. Maybe a
Nothing
    foldFunOptimPromotion ::
      ((CoalsTab, InhibitTab), CoalsTab) ->
      ((VName, VName), (VName, VName), (VName, VName), (VName, VName)) ->
      ((CoalsTab, InhibitTab), CoalsTab)
    foldFunOptimPromotion :: ((CoalsTab, InhibitTab), CoalsTab)
-> ((VName, VName), (VName, VName), (VName, VName), (VName, VName))
-> ((CoalsTab, InhibitTab), CoalsTab)
foldFunOptimPromotion ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) ((VName
b, VName
m_b), (VName
a, VName
m_a), (VName
_r, VName
m_r), (VName
b_i, VName
m_i))
      | VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_i,
        Just CoalsEntry
info <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_i CoalsTab
act,
        Just Map VName Coalesced
vtab_i <- TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) VName
b_i =
          Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => Bool -> a -> a
Exc.assert
            (VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b Bool -> Bool -> Bool
&& VName
m_a VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b)
            ((VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b (CoalsEntry
info {vartab = vtab_i}) CoalsTab
act, InhibitTab
inhb), CoalsTab
succc)
      | VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_i =
          Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => Bool -> a -> a
Exc.assert
            (VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b Bool -> Bool -> Bool
&& VName
m_a VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b)
            ((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) VName
m_b, CoalsTab
succc)
      | Just CoalsEntry
info_b0 <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
act,
        Just CoalsEntry
info_a0 <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_a CoalsTab
act,
        Just CoalsEntry
info_i <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_i CoalsTab
act,
        VName -> CoalsTab -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member VName
m_r CoalsTab
succc,
        Just Map VName Coalesced
vtab_i <- TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info_i) VName
b_i,
        [Just CoalsEntry
info_b, Just CoalsEntry
info_a] <- ((VName, CoalsEntry) -> Maybe CoalsEntry)
-> [(VName, CoalsEntry)] -> [Maybe CoalsEntry]
forall a b. (a -> b) -> [a] -> [b]
map (VName, CoalsEntry) -> Maybe CoalsEntry
translateIxFnInScope [(VName
b, CoalsEntry
info_b0), (VName
a, CoalsEntry
info_a0)] =
          let info_b' :: CoalsEntry
info_b' = CoalsEntry
info_b {optdeps = M.insert b_i m_i $ optdeps info_b}
              info_a' :: CoalsEntry
info_a' = CoalsEntry
info_a {optdeps = M.insert b_i m_i $ optdeps info_a}
              info_i' :: CoalsEntry
info_i' =
                CoalsEntry
info_i
                  { optdeps = M.insert b m_b $ optdeps info_i,
                    memrefs = mempty,
                    vartab = vtab_i
                  }
              act' :: CoalsTab
act' = VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_i CoalsEntry
info_i' CoalsTab
act
              (CoalsTab
act1, CoalsTab
succc1) =
                ((CoalsTab, CoalsTab)
 -> (VName, CoalsEntry) -> (CoalsTab, CoalsTab))
-> (CoalsTab, CoalsTab)
-> [(VName, CoalsEntry)]
-> (CoalsTab, CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                  (\(CoalsTab, CoalsTab)
acc (VName
m, CoalsEntry
info) -> (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab, CoalsTab)
acc VName
m CoalsEntry
info)
                  (CoalsTab
act', CoalsTab
succc)
                  [(VName
m_b, CoalsEntry
info_b'), (VName
m_a, CoalsEntry
info_a')]
           in -- ToDo: make sure that ixfun translates and update substitutions (?)
              ((CoalsTab
act1, InhibitTab
inhb), CoalsTab
succc1)
    foldFunOptimPromotion ((CoalsTab
act, InhibitTab
inhb), CoalsTab
succc) ((VName
_, VName
m_b), (VName
_a, VName
m_a), (VName
_r, VName
m_r), (VName
_b_i, VName
m_i)) =
      Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> ((CoalsTab, InhibitTab), CoalsTab)
forall a. HasCallStack => Bool -> a -> a
Exc.assert
        (VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
m_i)
        (((CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab) -> [VName] -> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
act, InhibitTab
inhb) [VName
m_b, VName
m_a, VName
m_r, VName
m_i], CoalsTab
succc)

    translateIxFnInScope :: (VName, CoalsEntry) -> Maybe CoalsEntry
translateIxFnInScope (VName
x, CoalsEntry
info)
      | Just (Coalesced CoalescedKind
knd mbd :: ArrayMemBound
mbd@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
ixfn) FreeVarSubsts
_subs0) <- VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info),
        TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) =
          let scope_tab :: ScopeTab rep
scope_tab =
                TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env
                  ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem] -> ScopeTab rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
arginis)
           in case ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions ScopeTab rep
scope_tab (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) LMAD
ixfn of
                Just FreeVarSubsts
fv_subst ->
                  CoalsEntry -> Maybe CoalsEntry
forall a. a -> Maybe a
Just (CoalsEntry -> Maybe CoalsEntry) -> CoalsEntry -> Maybe CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsEntry
info {vartab = M.insert x (Coalesced knd mbd fv_subst) (vartab info)}
                Maybe FreeVarSubsts
Nothing -> Maybe CoalsEntry
forall a. Maybe a
Nothing
    translateIxFnInScope (VName, CoalsEntry)
_ = Maybe CoalsEntry
forall a. Maybe a
Nothing
    se0 :: SubExp
se0 = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
    mbLoopIndexRange ::
      LoopForm ->
      Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
    mbLoopIndexRange :: LoopForm
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
mbLoopIndexRange (WhileLoop VName
_) = Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
forall a. Maybe a
Nothing
    mbLoopIndexRange (ForLoop VName
inm IntType
_inttp SubExp
seN) = (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
forall a. a -> Maybe a
Just (VName
inm, (SubExp -> TPrimExp Int64 VName
pe64 SubExp
se0, SubExp -> TPrimExp Int64 VName
pe64 SubExp
seN))
    addBeforeLoop :: Map k CoalsEntry -> k -> CoalsEntry -> CoalsEntry
addBeforeLoop Map k CoalsEntry
actv_bef k
m_b CoalsEntry
etry =
      case k -> Map k CoalsEntry -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup k
m_b Map k CoalsEntry
actv_bef of
        Maybe CoalsEntry
Nothing -> CoalsEntry
etry
        Just CoalsEntry
etry0 ->
          CoalsEntry
etry {memrefs = memrefs etry0 <> memrefs etry}
    aggAcrossLoopEntry :: ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep CoalsEntry
aggAcrossLoopEntry ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx CoalsEntry
etry = do
      AccessSummary
wrts <-
        ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *) rep.
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$
          (MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      AccessSummary
uses <-
        ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *) rep.
MonadFreshNames m =>
ScopeTab rep
-> ScopeTab rep
-> Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopTotal (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) ScopeTab rep
scope_loop Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$
          (MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoalsEntry -> ShortCircuitM rep CoalsEntry)
-> CoalsEntry -> ShortCircuitM rep CoalsEntry
forall a b. (a -> b) -> a -> b
$ CoalsEntry
etry {memrefs = MemRefs uses wrts}
    loopSoundness1Entry :: Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> CoalsEntry
-> ShortCircuitM rep Bool
loopSoundness1Entry Map VName (PrimExp VName)
scal_tab Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx CoalsEntry
etry = do
      let wrt_i :: AccessSummary
wrt_i = (MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
      AccessSummary
use_p <-
        Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> ShortCircuitM rep AccessSummary
forall (m :: * -> *).
MonadFreshNames m =>
Map VName (PrimExp VName)
-> Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> AccessSummary
-> m AccessSummary
aggSummaryLoopPartial (Map VName (PrimExp VName)
scal_tab Map VName (PrimExp VName)
-> Map VName (PrimExp VName) -> Map VName (PrimExp VName)
forall a. Semigroup a => a -> a -> a
<> TopdownEnv rep -> Map VName (PrimExp VName)
forall rep. TopdownEnv rep -> Map VName (PrimExp VName)
scalarTable TopdownEnv rep
td_env) Maybe (VName, (TPrimExp Int64 VName, TPrimExp Int64 VName))
idx (AccessSummary -> ShortCircuitM rep AccessSummary)
-> AccessSummary -> ShortCircuitM rep AccessSummary
forall a b. (a -> b) -> a -> b
$
          MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary) -> MemRefs -> AccessSummary
forall a b. (a -> b) -> a -> b
$
            CoalsEntry -> MemRefs
memrefs CoalsEntry
etry
      Bool -> ShortCircuitM rep Bool
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> ShortCircuitM rep Bool) -> Bool -> ShortCircuitM rep Bool
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env' AccessSummary
wrt_i AccessSummary
use_p
    loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool
    loopSoundness2Entry :: CoalsTab -> VName -> CoalsEntry -> Bool
loopSoundness2Entry CoalsTab
old_actv VName
m_b CoalsEntry
etry =
      case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
old_actv of
        Maybe CoalsEntry
Nothing -> Bool
True
        Just CoalsEntry
etry0 ->
          let uses_before :: AccessSummary
uses_before = (MemRefs -> AccessSummary
dstrefs (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry0
              write_loop :: AccessSummary
write_loop = (MemRefs -> AccessSummary
srcwrts (MemRefs -> AccessSummary)
-> (CoalsEntry -> MemRefs) -> CoalsEntry -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> MemRefs
memrefs) CoalsEntry
etry
           in TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
forall rep.
AliasableRep rep =>
TopdownEnv rep -> AccessSummary -> AccessSummary -> Bool
noMemOverlap TopdownEnv rep
td_env AccessSummary
write_loop AccessSummary
uses_before

-- The case of in-place update:
--   @let x' = x with slice <- elm@
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let pat :: Pat (LetDec (Aliases rep))
pat@(Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Update Safety
safety VName
x Slice SubExp
_ SubExp
_elm))) TopdownEnv rep
td_env BotUpEnv
bu_env
  | [(VName
_, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_x LMAD
_)] <- Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat = do
      -- (a) filter by the 3rd safety for @elm@ and @x'@
      let (CoalsTab
actv, InhibitTab
inhbt) = TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
          -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well
          (CoalsTab
actv', InhibitTab
inhbt') =
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
actv of
              Maybe CoalsEntry
Nothing -> (CoalsTab
actv, InhibitTab
inhbt)
              Just CoalsEntry
info ->
                case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
x') (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) of
                  Maybe Coalesced
Nothing -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x
                  Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
x_indfun) FreeVarSubsts
_) ->
                    case ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) LMAD
x_indfun of
                      Just FreeVarSubsts
fv_subs
                        | TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) ->
                            let coal_etry_x :: Coalesced
coal_etry_x = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subs
                                info' :: CoalsEntry
info' =
                                  CoalsEntry
info
                                    { vartab =
                                        M.insert x coal_etry_x $
                                          M.insert (patElemName x') coal_etry_x (vartab info)
                                    }
                             in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_x CoalsEntry
info' CoalsTab
actv, InhibitTab
inhbt)
                      Maybe FreeVarSubsts
_ ->
                        (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x

      -- (c) this stm is also a potential source for coalescing, so process it
      CoalsTab
actv'' <-
        if Safety
safety Safety -> Safety -> Bool
forall a. Eq a => a -> a -> Bool
== Safety
Unsafe
          then Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env {inhibited = inhbt'} BotUpEnv
bu_env {activeCoals = actv'}
          else CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
actv'
      BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {activeCoals = actv'', inhibit = inhbt'}

-- The case of flat in-place update:
--   @let x' = x with flat-slice <- elm@
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let pat :: Pat (LetDec (Aliases rep))
pat@(Pat [PatElem (LetDec (Aliases rep))
x']) StmAux (ExpDec (Aliases rep))
_ (BasicOp (FlatUpdate VName
x FlatSlice SubExp
_ VName
_elm))) TopdownEnv rep
td_env BotUpEnv
bu_env
  | [(VName
_, MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_x LMAD
_)] <- Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat = do
      -- (a) filter by the 3rd safety for @elm@ and @x'@
      let (CoalsTab
actv, InhibitTab
inhbt) = TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm
          -- (b) if @x'@ is in active coalesced table, then add an entry for @x@ as well
          (CoalsTab
actv', InhibitTab
inhbt') =
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
actv of
              Maybe CoalsEntry
Nothing -> (CoalsTab
actv, InhibitTab
inhbt)
              Just CoalsEntry
info ->
                case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
x') (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
info) of
                  -- this case should not happen, but if it can that
                  -- just fail conservatively
                  Maybe Coalesced
Nothing -> (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x
                  Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
x_indfun) FreeVarSubsts
_) ->
                    case ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) LMAD
x_indfun of
                      Just FreeVarSubsts
fv_subs
                        | TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env (CoalsEntry -> VName
dstmem CoalsEntry
info) ->
                            let coal_etry_x :: Coalesced
coal_etry_x = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk FreeVarSubsts
fv_subs
                                info' :: CoalsEntry
info' =
                                  CoalsEntry
info
                                    { vartab =
                                        M.insert x coal_etry_x $
                                          M.insert (patElemName x') coal_etry_x (vartab info)
                                    }
                             in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_x CoalsEntry
info' CoalsTab
actv, InhibitTab
inhbt)
                      Maybe FreeVarSubsts
_ ->
                        (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
actv, InhibitTab
inhbt) VName
m_x

      -- (c) this stm is also a potential source for coalescing, so process it
      CoalsTab
actv'' <- Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env {inhibited = inhbt'} BotUpEnv
bu_env {activeCoals = actv'}
      BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env {activeCoals = actv'', inhibit = inhbt'}
--
mkCoalsTabStm InhibitTab
_ (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ (BasicOp Update {})) TopdownEnv rep
_ BotUpEnv
_ =
  String -> ShortCircuitM rep BotUpEnv
forall a. HasCallStack => String -> a
error (String -> ShortCircuitM rep BotUpEnv)
-> String -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ String
"In ArrayCoalescing.hs, fun mkCoalsTabStm, illegal pattern for in-place update: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat (VarAliases, LParamMem) -> String
forall a. Show a => a -> String
show Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat
-- default handling
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (Op Op (Aliases rep)
op)) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  -- Process body
  InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> MemOp inner (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
on_op <- (ShortCircuitReader rep
 -> InhibitTab
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> MemOp inner (Aliases rep)
 -> TopdownEnv rep
 -> BotUpEnv
 -> ShortCircuitM rep BotUpEnv)
-> ShortCircuitM
     rep
     (InhibitTab
      -> Pat (VarAliases, LParamMem)
      -> Certs
      -> MemOp inner (Aliases rep)
      -> TopdownEnv rep
      -> BotUpEnv
      -> ShortCircuitM rep BotUpEnv)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> MemOp inner (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
forall rep.
ShortCircuitReader rep
-> InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
onOp
  BotUpEnv
bu_env' <- InhibitTab
-> Pat (VarAliases, LParamMem)
-> Certs
-> MemOp inner (Aliases rep)
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep BotUpEnv
on_op InhibitTab
lutab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux) Op (Aliases rep)
MemOp inner (Aliases rep)
op TopdownEnv rep
td_env BotUpEnv
bu_env
  CoalsTab
activeCoals' <- Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env BotUpEnv
bu_env'
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BotUpEnv -> ShortCircuitM rep BotUpEnv)
-> BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a b. (a -> b) -> a -> b
$ BotUpEnv
bu_env' {activeCoals = activeCoals'}
mkCoalsTabStm InhibitTab
lutab stm :: Stm (Aliases rep)
stm@(Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e) TopdownEnv rep
td_env BotUpEnv
bu_env = do
  --   i) Filter @activeCoals@ by the 3rd safety condition:
  --      this is now relaxed by use of LMAD eqs:
  --      the memory referenced in stm are added to memrefs::dstrefs
  --      in corresponding coal-tab entries.
  let (CoalsTab
activeCoals', InhibitTab
inhibit') = TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
forall rep (inner :: * -> *).
(AliasableRep rep, Op rep ~ MemOp inner rep,
 HasMemBlock (Aliases rep)) =>
TopdownEnv rep
-> BotUpEnv -> Stm (Aliases rep) -> (CoalsTab, InhibitTab)
recordMemRefUses TopdownEnv rep
td_env BotUpEnv
bu_env Stm (Aliases rep)
stm

      --  ii) promote any of the entries in @activeCoals@ to @successCoals@ as long as
      --        - this statement defined a variable consumed in a coalesced statement
      --        - and safety conditions 2, 4, and 5 are satisfied.
      --      AND extend @activeCoals@ table for any definition of a variable that
      --      aliases a coalesced variable.
      safe_4 :: Bool
safe_4 = Exp (Aliases rep) -> Bool
forall rep. Exp rep -> Bool
createsNewArrOK Exp (Aliases rep)
e
      ((CoalsTab
activeCoals'', InhibitTab
inhibit''), CoalsTab
successCoals') =
        (((CoalsTab, InhibitTab), CoalsTab)
 -> (VName, ArrayMemBound) -> ((CoalsTab, InhibitTab), CoalsTab))
-> ((CoalsTab, InhibitTab), CoalsTab)
-> [(VName, ArrayMemBound)]
-> ((CoalsTab, InhibitTab), CoalsTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> (VName, ArrayMemBound)
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun Bool
safe_4) ((CoalsTab
activeCoals', InhibitTab
inhibit'), BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env) (Pat (VarAliases, LParamMem) -> [(VName, ArrayMemBound)]
forall aliases.
Pat (aliases, LParamMem) -> [(VName, ArrayMemBound)]
getArrMemAssoc Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat)

  -- iii) record a potentially coalesced statement in @activeCoals@
  CoalsTab
activeCoals''' <- Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env BotUpEnv
bu_env {successCoals = successCoals', activeCoals = activeCoals''}
  BotUpEnv -> ShortCircuitM rep BotUpEnv
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BotUpEnv
bu_env {activeCoals = activeCoals''', inhibit = inhibit'', successCoals = successCoals'}
  where
    foldfun :: Bool
-> ((CoalsTab, InhibitTab), CoalsTab)
-> (VName, ArrayMemBound)
-> ((CoalsTab, InhibitTab), CoalsTab)
foldfun Bool
safe_4 ((CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc) (VName
b, MemBlock PrimType
tp ShapeBase SubExp
shp VName
mb LMAD
_b_indfun) =
      case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mb CoalsTab
a_acc of
        Maybe CoalsEntry
Nothing -> ((CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
        Just info :: CoalsEntry
info@(CoalsEntry VName
x_mem LMAD
_ Names
_ Map VName Coalesced
vtab Map VName VName
_ MemRefs
_ Certs
certs) ->
          let failed :: (CoalsTab, InhibitTab)
failed = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
a_acc, InhibitTab
inhb) VName
mb
           in case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b Map VName Coalesced
vtab of
                Maybe Coalesced
Nothing ->
                  -- we hit the definition of some variable @b@ aliased with
                  --    the coalesced variable @x@, hence extend @activeCoals@, e.g.,
                  --       @let x = map f arr  @
                  --       @let b = alias x  @ <- current statement
                  --       @ ... use of b ...  @
                  --       @let c = alias b    @ <- currently fails
                  --       @let y[i] = x       @
                  -- where @alias@ can be @transpose@, @slice@, @reshape@.
                  -- We use getTransitiveAlias helper function to track the aliasing
                  --    through the td_env, and to find the updated ixfun of @b@:
                  case TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
a_acc VName
b of
                    Maybe (VName, VName, LMAD)
Nothing -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc)
                    Just (VName
_, VName
_, LMAD
b_indfun') ->
                      case ( ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) LMAD
b_indfun',
                             ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) Certs
certs
                           ) of
                        (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst') ->
                          let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
TransitiveCoal (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
tp ShapeBase SubExp
shp VName
x_mem LMAD
b_indfun') (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                              info' :: CoalsEntry
info' = CoalsEntry
info {vartab = M.insert b mem_info vtab}
                           in ((VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info' CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
                        (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc)
                Just (Coalesced CoalescedKind
k mblk :: ArrayMemBound
mblk@(MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
new_indfun) FreeVarSubsts
_) ->
                  -- we are at the definition of the coalesced variable @b@
                  -- if 2,4,5 hold promote it to successful coalesced table,
                  -- or if e = transpose, etc. then postpone decision for later on
                  let safe_2 :: Bool
safe_2 = TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
x_mem
                   in case ( ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) LMAD
new_indfun,
                             ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) (BotUpEnv -> Map VName (PrimExp VName)
scals BotUpEnv
bu_env) Certs
certs
                           ) of
                        (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst')
                          | Bool
safe_2 ->
                              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k ArrayMemBound
mblk (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                                  info' :: CoalsEntry
info' = CoalsEntry
info {vartab = M.insert b mem_info vtab}
                               in if Bool
safe_4
                                    then -- array creation point, successful coalescing verified!

                                      let (CoalsTab
a_acc', CoalsTab
s_acc') = (CoalsTab, CoalsTab) -> VName -> CoalsEntry -> (CoalsTab, CoalsTab)
markSuccessCoal (CoalsTab
a_acc, CoalsTab
s_acc) VName
mb CoalsEntry
info'
                                       in ((CoalsTab
a_acc', InhibitTab
inhb), CoalsTab
s_acc')
                                    else -- this is an invertible alias case of the kind
                                    -- @ let b    = alias a @
                                    -- @ let x[i] = b @
                                    -- do not promote, but update the index function

                                      ((VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
mb CoalsEntry
info' CoalsTab
a_acc, InhibitTab
inhb), CoalsTab
s_acc)
                        (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> ((CoalsTab, InhibitTab)
failed, CoalsTab
s_acc) -- fail!

ixfunToAccessSummary :: LMAD.LMAD (TPrimExp Int64 VName) -> AccessSummary
ixfunToAccessSummary :: LMAD -> AccessSummary
ixfunToAccessSummary = Set LMAD -> AccessSummary
Set (Set LMAD -> AccessSummary)
-> (LMAD -> Set LMAD) -> LMAD -> AccessSummary
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD -> Set LMAD
forall a. a -> Set a
S.singleton

-- | Check safety conditions 2 and 5 and update new substitutions:
-- called on the pat-elements of loop and if-then-else expressions.
--
-- The safety conditions are: The allocation of merge target should dominate the
-- creation of the array we're trying to merge and the new index function of the
-- array can be translated at the definition site of b. The latter requires that
-- any variables used in the index function of the target array are available at
-- the definition site of b.
filterSafetyCond2and5 ::
  (HasMemBlock (Aliases rep)) =>
  CoalsTab ->
  InhibitTab ->
  ScalarTab ->
  TopdownEnv rep ->
  [PatElem (VarAliases, LetDecMem)] ->
  (CoalsTab, InhibitTab)
filterSafetyCond2and5 :: forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> InhibitTab
-> Map VName (PrimExp VName)
-> TopdownEnv rep
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
filterSafetyCond2and5 CoalsTab
act_coal InhibitTab
inhb_coal Map VName (PrimExp VName)
scals_env TopdownEnv rep
td_env [PatElem (VarAliases, LParamMem)]
pes =
  ((CoalsTab, InhibitTab)
 -> PatElem (VarAliases, LParamMem) -> (CoalsTab, InhibitTab))
-> (CoalsTab, InhibitTab)
-> [PatElem (VarAliases, LParamMem)]
-> (CoalsTab, InhibitTab)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (CoalsTab, InhibitTab)
-> PatElem (VarAliases, LParamMem) -> (CoalsTab, InhibitTab)
helper (CoalsTab
act_coal, InhibitTab
inhb_coal) [PatElem (VarAliases, LParamMem)]
pes
  where
    helper :: (CoalsTab, InhibitTab)
-> PatElem (VarAliases, LParamMem) -> (CoalsTab, InhibitTab)
helper (CoalsTab
acc, InhibitTab
inhb) PatElem (VarAliases, LParamMem)
patel = do
      -- For each pattern element in the input list
      case (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
patel, PatElem (VarAliases, LParamMem) -> (VarAliases, LParamMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LParamMem)
patel) of
        (VName
b, (VarAliases
_, MemArray PrimType
tp0 ShapeBase SubExp
shp0 NoUniqueness
_ (ArrayIn VName
m_b LMAD
_idxfn_b))) ->
          -- If it is an array in memory block m_b
          case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
acc of
            Maybe CoalsEntry
Nothing -> (CoalsTab
acc, InhibitTab
inhb)
            Just info :: CoalsEntry
info@(CoalsEntry VName
x_mem LMAD
_ Names
_ Map VName Coalesced
vtab Map VName VName
_ MemRefs
_ Certs
certs) ->
              -- And m_b we're trying to coalesce m_b
              let failed :: (CoalsTab, InhibitTab)
failed = (CoalsTab, InhibitTab) -> VName -> (CoalsTab, InhibitTab)
markFailedCoal (CoalsTab
acc, InhibitTab
inhb) VName
m_b
               in -- It is not safe to short circuit if some other pattern
                  -- element is aliased to this one, as this indicates the
                  -- two pattern elements reference the same physical
                  -- value somehow.
                  if (PatElem (VarAliases, LParamMem) -> Bool)
-> [PatElem (VarAliases, LParamMem)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((VName -> Names -> Bool
`nameIn` PatElem (VarAliases, LParamMem) -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem (VarAliases, LParamMem)
patel) (VName -> Bool)
-> (PatElem (VarAliases, LParamMem) -> VName)
-> PatElem (VarAliases, LParamMem)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (VarAliases, LParamMem)]
pes
                    then (CoalsTab, InhibitTab)
failed
                    else case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b Map VName Coalesced
vtab of
                      Maybe Coalesced
Nothing ->
                        case TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep -> CoalsTab -> VName -> Maybe (VName, VName, LMAD)
getDirAliasedIxfn TopdownEnv rep
td_env CoalsTab
acc VName
b of
                          Maybe (VName, VName, LMAD)
Nothing -> (CoalsTab, InhibitTab)
failed
                          Just (VName
_, VName
_, LMAD
b_indfun') ->
                            -- And we have the index function of b
                            case ( ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env LMAD
b_indfun',
                                   ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env Certs
certs
                                 ) of
                              (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst') ->
                                let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
TransitiveCoal (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
tp0 ShapeBase SubExp
shp0 VName
x_mem LMAD
b_indfun') (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                                    info' :: CoalsEntry
info' = CoalsEntry
info {vartab = M.insert b mem_info vtab}
                                 in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info' CoalsTab
acc, InhibitTab
inhb)
                              (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> (CoalsTab, InhibitTab)
failed
                      Just (Coalesced CoalescedKind
k (MemBlock PrimType
pt ShapeBase SubExp
shp VName
_ LMAD
new_indfun) FreeVarSubsts
_) ->
                        let safe_2 :: Bool
safe_2 = TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
x_mem
                         in case ( ScopeTab rep
-> Map VName (PrimExp VName) -> LMAD -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env LMAD
new_indfun,
                                   ScopeTab rep
-> Map VName (PrimExp VName) -> Certs -> Maybe FreeVarSubsts
forall a rep.
FreeIn a =>
ScopeTab rep
-> Map VName (PrimExp VName) -> a -> Maybe FreeVarSubsts
freeVarSubstitutions (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Map VName (PrimExp VName)
scals_env Certs
certs
                                 ) of
                              (Just FreeVarSubsts
fv_subst, Just FreeVarSubsts
fv_subst')
                                | Bool
safe_2 ->
                                    let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
k (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
pt ShapeBase SubExp
shp VName
x_mem LMAD
new_indfun) (FreeVarSubsts
fv_subst FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall a. Semigroup a => a -> a -> a
<> FreeVarSubsts
fv_subst')
                                        info' :: CoalsEntry
info' = CoalsEntry
info {vartab = M.insert b mem_info vtab}
                                     in (VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
info' CoalsTab
acc, InhibitTab
inhb)
                              (Maybe FreeVarSubsts, Maybe FreeVarSubsts)
_ -> (CoalsTab, InhibitTab)
failed
        (VName, (VarAliases, LParamMem))
_ -> (CoalsTab
acc, InhibitTab
inhb)

-- |   Pattern matches a potentially coalesced statement and
--     records a new association in @activeCoals@
mkCoalsHelper3PatternMatch ::
  (Coalesceable rep inner) =>
  Stm (Aliases rep) ->
  LUTabFun ->
  TopdownEnv rep ->
  BotUpEnv ->
  ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
Stm (Aliases rep)
-> InhibitTab
-> TopdownEnv rep
-> BotUpEnv
-> ShortCircuitM rep CoalsTab
mkCoalsHelper3PatternMatch Stm (Aliases rep)
stm InhibitTab
lutab TopdownEnv rep
td_env BotUpEnv
bu_env = do
  Maybe [SSPointInfo]
clst <- InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Stm (Aliases rep)
-> ShortCircuitM rep (Maybe [SSPointInfo])
forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Stm (Aliases rep)
-> ShortCircuitM rep (Maybe [SSPointInfo])
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env (TopdownEnv rep -> ScopeTab rep
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env) Stm (Aliases rep)
stm
  case Maybe [SSPointInfo]
clst of
    Maybe [SSPointInfo]
Nothing -> CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
activeCoals_tab
    Just [SSPointInfo]
clst' -> CoalsTab -> ShortCircuitM rep CoalsTab
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CoalsTab -> ShortCircuitM rep CoalsTab)
-> CoalsTab -> ShortCircuitM rep CoalsTab
forall a b. (a -> b) -> a -> b
$ (CoalsTab -> SSPointInfo -> CoalsTab)
-> CoalsTab -> [SSPointInfo] -> CoalsTab
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl CoalsTab -> SSPointInfo -> CoalsTab
processNewCoalesce CoalsTab
activeCoals_tab [SSPointInfo]
clst'
  where
    successCoals_tab :: CoalsTab
successCoals_tab = BotUpEnv -> CoalsTab
successCoals BotUpEnv
bu_env
    activeCoals_tab :: CoalsTab
activeCoals_tab = BotUpEnv -> CoalsTab
activeCoals BotUpEnv
bu_env
    processNewCoalesce :: CoalsTab -> SSPointInfo -> CoalsTab
processNewCoalesce CoalsTab
acc (CoalescedKind
knd, LMAD -> LMAD
alias_fn, VName
x, VName
m_x, LMAD
ind_x, VName
b, VName
m_b, LMAD
_, PrimType
tp_b, ShapeBase SubExp
shp_b, Certs
certs) =
      -- test whether we are in a transitive coalesced case, i.e.,
      --      @let b = scratch ...@
      --      @.....@
      --      @let x[j] = b@
      --      @let y[i] = x@
      -- and compose the index function of @x@ with that of @y@,
      -- and update aliasing of the @m_b@ entry to also contain @m_y@
      -- on top of @m_x@, i.e., transitively, any use of @m_y@ should
      -- be checked for the lifetime of @b@.
      let proper_coals_tab :: CoalsTab
proper_coals_tab = case CoalescedKind
knd of
            CoalescedKind
InPlaceCoal -> CoalsTab
activeCoals_tab
            CoalescedKind
_ -> CoalsTab
successCoals_tab
          (VName
m_yx, LMAD
ind_yx, Names
mem_yx_al, Map VName VName
x_deps, Certs
certs') =
            case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x CoalsTab
proper_coals_tab of
              Maybe CoalsEntry
Nothing ->
                (VName
m_x, LMAD -> LMAD
alias_fn LMAD
ind_x, VName -> Names
oneName VName
m_x, Map VName VName
forall k a. Map k a
M.empty, Certs
forall a. Monoid a => a
mempty)
              Just (CoalsEntry VName
m_y LMAD
ind_y Names
y_al Map VName Coalesced
vtab Map VName VName
x_deps0 MemRefs
_ Certs
certs'') ->
                let ind :: LMAD
ind = case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName Coalesced
vtab of
                      Just (Coalesced CoalescedKind
_ (MemBlock PrimType
_ ShapeBase SubExp
_ VName
_ LMAD
ixf) FreeVarSubsts
_) ->
                        LMAD
ixf
                      Maybe Coalesced
Nothing ->
                        LMAD
ind_y
                 in (VName
m_y, LMAD -> LMAD
alias_fn LMAD
ind, VName -> Names
oneName VName
m_x Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
y_al, Map VName VName
x_deps0, Certs
certs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
certs'')
          m_b_aliased_m_yx :: Bool
m_b_aliased_m_yx = TopdownEnv rep -> VName -> [VName] -> Bool
forall rep. TopdownEnv rep -> VName -> [VName] -> Bool
areAnyAliased TopdownEnv rep
td_env VName
m_b [VName
m_yx] -- m_b \= m_yx
       in if Bool -> Bool
not Bool
m_b_aliased_m_yx Bool -> Bool -> Bool
&& TopdownEnv rep -> VName -> Bool
forall rep. TopdownEnv rep -> VName -> Bool
isInScope TopdownEnv rep
td_env VName
m_yx -- nameIn m_yx (alloc td_env)
      -- Finally update the @activeCoals@ table with a fresh
      --   binding for @m_b@; if such one exists then overwrite.
      -- Also, add all variables from the alias chain of @b@ to
      --   @vartab@, for example, in the case of a sequence:
      --   @ b0 = if cond then ... else ... @
      --   @ b1 = alias0 b0 @
      --   @ b  = alias1 b1 @
      --   @ x[j] = b @
      -- Then @b1@ and @b0@ should also be added to @vartab@ if
      --   @alias1@ and @alias0@ are invertible, otherwise fail early!
            then
              let mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
tp_b ShapeBase SubExp
shp_b VName
m_yx LMAD
ind_yx) FreeVarSubsts
forall k a. Map k a
M.empty
                  opts' :: Map VName VName
opts' =
                    if VName
m_yx VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_x
                      then Map VName VName
forall k a. Map k a
M.empty
                      else VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
x VName
m_x Map VName VName
x_deps
                  vtab :: Map VName Coalesced
vtab = VName -> Coalesced -> Map VName Coalesced
forall k a. k -> a -> Map k a
M.singleton VName
b Coalesced
mem_info
                  mvtab :: Maybe (Map VName Coalesced)
mvtab = TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
forall rep.
HasMemBlock (Aliases rep) =>
TopdownEnv rep
-> Map VName Coalesced -> VName -> Maybe (Map VName Coalesced)
addInvAliasesVarTab TopdownEnv rep
td_env Map VName Coalesced
vtab VName
b

                  is_inhibited :: Bool
is_inhibited = case VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b (InhibitTab -> Maybe Names) -> InhibitTab -> Maybe Names
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> InhibitTab
forall rep. TopdownEnv rep -> InhibitTab
inhibited TopdownEnv rep
td_env of
                    Just Names
nms -> VName
m_yx VName -> Names -> Bool
`nameIn` Names
nms
                    Maybe Names
Nothing -> Bool
False
               in case (Bool
is_inhibited, Maybe (Map VName Coalesced)
mvtab) of
                    (Bool
True, Maybe (Map VName Coalesced)
_) -> CoalsTab
acc -- fail due to inhibited
                    (Bool
_, Maybe (Map VName Coalesced)
Nothing) -> CoalsTab
acc -- fail early due to non-invertible aliasing
                    (Bool
_, Just Map VName Coalesced
vtab') ->
                      -- successfully adding a new coalesced entry
                      let coal_etry :: CoalsEntry
coal_etry =
                            VName
-> LMAD
-> Names
-> Map VName Coalesced
-> Map VName VName
-> MemRefs
-> Certs
-> CoalsEntry
CoalsEntry
                              VName
m_yx
                              LMAD
ind_yx
                              Names
mem_yx_al
                              Map VName Coalesced
vtab'
                              Map VName VName
opts'
                              MemRefs
forall a. Monoid a => a
mempty
                              (Certs
certs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
certs')
                       in VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b CoalsEntry
coal_etry CoalsTab
acc
            else CoalsTab
acc

-- | Information about a particular short-circuit point
type SSPointInfo =
  ( CoalescedKind,
    LMAD -> LMAD,
    VName,
    VName,
    LMAD,
    VName,
    VName,
    LMAD,
    PrimType,
    Shape,
    Certs
  )

-- | Given an op, return a list of potential short-circuit points
type GenSSPoint rep op =
  LUTabFun ->
  TopdownEnv rep ->
  ScopeTab rep ->
  Pat (VarAliases, LetDecMem) ->
  Certs ->
  op ->
  Maybe [SSPointInfo]

genSSPointInfoSeqMem ::
  GenSSPoint SeqMem (Op (Aliases SeqMem))
genSSPointInfoSeqMem :: InhibitTab
-> TopdownEnv SeqMem
-> ScopeTab SeqMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases SeqMem)
-> Maybe [SSPointInfo]
genSSPointInfoSeqMem InhibitTab
_ TopdownEnv SeqMem
_ ScopeTab SeqMem
_ Pat (VarAliases, LParamMem)
_ Certs
_ Op (Aliases SeqMem)
_ =
  Maybe [SSPointInfo]
forall a. Maybe a
Nothing

-- | For 'SegOp', we currently only handle 'SegMap', and only under the following
-- circumstances:
--
--  1. The 'SegMap' has only one return/pattern value, which is a 'Returns'.
--
--  2. The 'KernelBody' contains an 'Index' statement that is indexing an array using
--  only the values from the 'SegSpace'.
--
--  3. The array being indexed is last-used in that statement, is free in the
--  'SegMap', is unique or has been recently allocated (specifically, it should
--  not be a non-unique argument to the enclosing function), has elements with
--  the same bit-size as the pattern elements, and has the exact same 'LMAD' as
--  the pattern of the 'SegMap' statement.
--
-- There can be multiple candidate arrays, but the current implementation will
-- always just try the first one.
--
-- The first restriction could be relaxed by trying to match up arrays in the
-- 'KernelBody' with patterns of the 'SegMap', but the current implementation
-- should be enough to handle many common cases.
--
-- The result of the 'SegMap' is treated as the destination, while the candidate
-- array from inside the body is treated as the source.
genSSPointInfoSegOp ::
  (Coalesceable rep inner) => GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp
  InhibitTab
lutab
  TopdownEnv rep
td_env
  ScopeTab rep
scopetab
  (Pat [PatElem VName
dst (VarAliases
_, MemArray PrimType
dst_pt ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
dst_mem LMAD
dst_ixf))])
  Certs
certs
  (SegMap lvl
_ SegSpace
space [Type]
_ kernel_body :: KernelBody (Aliases rep)
kernel_body@KernelBody {kernelBodyResult :: forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult = [Returns {}]})
    | (VName
src, MemBlock PrimType
src_pt ShapeBase SubExp
shp VName
src_mem LMAD
src_ixf) : [(VName, ArrayMemBound)]
_ <-
        (Stm (Aliases rep) -> Maybe (VName, ArrayMemBound))
-> [Stm (Aliases rep)] -> [(VName, ArrayMemBound)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm (Aliases rep) -> Maybe (VName, ArrayMemBound)
getPotentialMapShortCircuit ([Stm (Aliases rep)] -> [(VName, ArrayMemBound)])
-> [Stm (Aliases rep)] -> [(VName, ArrayMemBound)]
forall a b. (a -> b) -> a -> b
$
          Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
            KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody (Aliases rep)
kernel_body =
        [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
MapCoal, LMAD -> LMAD
forall a. a -> a
id, VName
dst, VName
dst_mem, LMAD
dst_ixf, VName
src, VName
src_mem, LMAD
src_ixf, PrimType
src_pt, ShapeBase SubExp
shp, Certs
certs)]
    where
      iterators :: [VName]
iterators = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      frees :: Names
frees = KernelBody (Aliases rep) -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody (Aliases rep)
kernel_body

      getPotentialMapShortCircuit :: Stm (Aliases rep) -> Maybe (VName, ArrayMemBound)
getPotentialMapShortCircuit (Let (Pat [PatElem VName
x LetDec (Aliases rep)
_]) StmAux (ExpDec (Aliases rep))
_ (BasicOp (Index VName
src Slice SubExp
slc)))
        | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slc,
          [SubExp] -> [SubExp]
forall a. Ord a => [a] -> [a]
L.sort [SubExp]
inds [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> [SubExp]
forall a. Ord a => [a] -> [a]
L.sort ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
iterators),
          Just Names
last_uses <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab,
          VName
src VName -> Names -> Bool
`nameIn` Names
last_uses,
          Just memblock :: ArrayMemBound
memblock@(MemBlock PrimType
src_pt ShapeBase SubExp
_ VName
src_mem LMAD
src_ixf) <-
            VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
src ScopeTab rep
scopetab,
          VName
src_mem VName -> Names -> Bool
`nameIn` Names
last_uses,
          -- The 'alloc' table contains allocated memory blocks, including
          -- unique memory blocks from the enclosing function. It does _not_
          -- include non-unique memory blocks from the enclosing function.
          VName
src_mem VName -> AllocTab -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` TopdownEnv rep -> AllocTab
forall rep. TopdownEnv rep -> AllocTab
alloc TopdownEnv rep
td_env,
          VName
src VName -> Names -> Bool
`nameIn` Names
frees,
          LMAD
src_ixf LMAD -> LMAD -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD
dst_ixf,
          PrimType -> Int
primBitSize PrimType
src_pt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> Int
primBitSize PrimType
dst_pt =
            (VName, ArrayMemBound) -> Maybe (VName, ArrayMemBound)
forall a. a -> Maybe a
Just (VName
src, ArrayMemBound
memblock)
      getPotentialMapShortCircuit Stm (Aliases rep)
_ = Maybe (VName, ArrayMemBound)
forall a. Maybe a
Nothing
genSSPointInfoSegOp InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ SegOp lvl (Aliases rep)
_ =
  Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genSSPointInfoMemOp ::
  GenSSPoint rep (inner (Aliases rep)) ->
  GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp :: forall rep (inner :: * -> *).
GenSSPoint rep (inner (Aliases rep))
-> GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp GenSSPoint rep (inner (Aliases rep))
onOp InhibitTab
lutab TopdownEnv rep
td_end ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs (Inner inner (Aliases rep)
op) =
  GenSSPoint rep (inner (Aliases rep))
onOp InhibitTab
lutab TopdownEnv rep
td_end ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs inner (Aliases rep)
op
genSSPointInfoMemOp GenSSPoint rep (inner (Aliases rep))
_ InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ MemOp inner (Aliases rep)
_ = Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genSSPointInfoGPUMem ::
  GenSSPoint GPUMem (Op (Aliases GPUMem))
genSSPointInfoGPUMem :: InhibitTab
-> TopdownEnv GPUMem
-> ScopeTab GPUMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases GPUMem)
-> Maybe [SSPointInfo]
genSSPointInfoGPUMem = GenSSPoint GPUMem (HostOp NoOp (Aliases GPUMem))
-> GenSSPoint GPUMem (MemOp (HostOp NoOp) (Aliases GPUMem))
forall rep (inner :: * -> *).
GenSSPoint rep (inner (Aliases rep))
-> GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp GenSSPoint GPUMem (HostOp NoOp (Aliases GPUMem))
forall {rep} {inner :: * -> *} {op :: * -> *}.
(LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem, OpReturns inner,
 RephraseOp inner, ASTRep rep, CanBeAliased inner, AliasedOp inner,
 Pretty (inner rep), Pretty (inner (Aliases rep)),
 HasMemBlock (Aliases rep), TopDownHelper (inner (Aliases rep)),
 Rename (inner rep), Rename (inner (Aliases rep)), Show (inner rep),
 Show (inner (Aliases rep)), Ord (inner rep),
 Ord (inner (Aliases rep)), Substitute (inner rep),
 Substitute (inner (Aliases rep)), FreeIn (inner rep),
 FreeIn (inner (Aliases rep))) =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> HostOp op (Aliases rep)
-> Maybe [SSPointInfo]
f
  where
    f :: InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> HostOp op (Aliases rep)
-> Maybe [SSPointInfo]
f InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs (GPU.SegOp SegOp SegLevel (Aliases rep)
op) =
      GenSSPoint rep (SegOp SegLevel (Aliases rep))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp SegLevel (Aliases rep)
op
    f InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ HostOp op (Aliases rep)
_ = Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genSSPointInfoMCMem ::
  GenSSPoint MCMem (Op (Aliases MCMem))
genSSPointInfoMCMem :: InhibitTab
-> TopdownEnv MCMem
-> ScopeTab MCMem
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases MCMem)
-> Maybe [SSPointInfo]
genSSPointInfoMCMem = GenSSPoint MCMem (MCOp NoOp (Aliases MCMem))
-> GenSSPoint MCMem (MemOp (MCOp NoOp) (Aliases MCMem))
forall rep (inner :: * -> *).
GenSSPoint rep (inner (Aliases rep))
-> GenSSPoint rep (MemOp inner (Aliases rep))
genSSPointInfoMemOp GenSSPoint MCMem (MCOp NoOp (Aliases MCMem))
forall {rep} {inner :: * -> *} {op :: * -> *}.
(LParamInfo rep ~ LParamMem, FParamInfo rep ~ FParamMem,
 OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LetDec rep ~ LParamMem, RetType rep ~ RetTypeMem, OpReturns inner,
 RephraseOp inner, ASTRep rep, CanBeAliased inner, AliasedOp inner,
 Pretty (inner rep), Pretty (inner (Aliases rep)),
 HasMemBlock (Aliases rep), TopDownHelper (inner (Aliases rep)),
 Rename (inner rep), Rename (inner (Aliases rep)), Show (inner rep),
 Show (inner (Aliases rep)), Ord (inner rep),
 Ord (inner (Aliases rep)), Substitute (inner rep),
 Substitute (inner (Aliases rep)), FreeIn (inner rep),
 FreeIn (inner (Aliases rep))) =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> MCOp op (Aliases rep)
-> Maybe [SSPointInfo]
f
  where
    f :: InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> MCOp op (Aliases rep)
-> Maybe [SSPointInfo]
f InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs (MC.ParOp Maybe (SegOp () (Aliases rep))
Nothing SegOp () (Aliases rep)
op) =
      GenSSPoint rep (SegOp () (Aliases rep))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
GenSSPoint rep (SegOp lvl (Aliases rep))
genSSPointInfoSegOp InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
pat Certs
certs SegOp () (Aliases rep)
op
    f InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Pat (VarAliases, LParamMem)
_ Certs
_ MCOp op (Aliases rep)
_ = Maybe [SSPointInfo]
forall a. Maybe a
Nothing

genCoalStmtInfo ::
  (Coalesceable rep inner) =>
  LUTabFun ->
  TopdownEnv rep ->
  ScopeTab rep ->
  Stm (Aliases rep) ->
  ShortCircuitM rep (Maybe [SSPointInfo])
-- CASE a) @let x <- copy(b^{lu})@
genCoalStmtInfo :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Stm (Aliases rep)
-> ShortCircuitM rep (Maybe [SSPointInfo])
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (Replicate (Shape []) (Var VName
b))))
  | Pat [PatElem VName
x (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x LMAD
ind_x))] <- Pat (LetDec (Aliases rep))
pat,
    Just Names
last_uses <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab,
    Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b LMAD
ind_b) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab,
    TopdownEnv rep -> VName -> VName -> Bool
forall rep (inner :: * -> *).
Coalesceable rep inner =>
TopdownEnv rep -> VName -> VName -> Bool
sameSpace TopdownEnv rep
td_env VName
m_x VName
m_b,
    VName
b VName -> Names -> Bool
`nameIn` Names
last_uses =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
CopyCoal, LMAD -> LMAD
forall a. a -> a
id, VName
x, VName
m_x, LMAD
ind_x, VName
b, VName
m_b, LMAD
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)]
-- CASE c) @let x[i] = b^{lu}@
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (Update Safety
_ VName
x Slice SubExp
slice_x (Var VName
b))))
  | Pat [PatElem VName
x' (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x LMAD
ind_x))] <- Pat (LetDec (Aliases rep))
pat,
    Just Names
last_uses <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x' InhibitTab
lutab,
    Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b LMAD
ind_b) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab,
    TopdownEnv rep -> VName -> VName -> Bool
forall rep (inner :: * -> *).
Coalesceable rep inner =>
TopdownEnv rep -> VName -> VName -> Bool
sameSpace TopdownEnv rep
td_env VName
m_x VName
m_b,
    VName
b VName -> Names -> Bool
`nameIn` Names
last_uses =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
InPlaceCoal, (LMAD -> Slice SubExp -> LMAD
`updateIndFunSlice` Slice SubExp
slice_x), VName
x, VName
m_x, LMAD
ind_x, VName
b, VName
m_b, LMAD
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)]
  where
    updateIndFunSlice :: LMAD -> Slice SubExp -> LMAD
    updateIndFunSlice :: LMAD -> Slice SubExp -> LMAD
updateIndFunSlice LMAD
ind_fun Slice SubExp
slc_x =
      let slc_x' :: [DimIndex (TPrimExp Int64 VName)]
slc_x' = (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> DimIndex a -> DimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) ([DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)])
-> [DimIndex SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slc_x
       in LMAD -> Slice (TPrimExp Int64 VName) -> LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD
ind_fun (Slice (TPrimExp Int64 VName) -> LMAD)
-> Slice (TPrimExp Int64 VName) -> LMAD
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice [DimIndex (TPrimExp Int64 VName)]
slc_x'
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (FlatUpdate VName
x FlatSlice SubExp
slice_x VName
b)))
  | Pat [PatElem VName
x' (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x LMAD
ind_x))] <- Pat (LetDec (Aliases rep))
pat,
    Just Names
last_uses <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x' InhibitTab
lutab,
    Just (MemBlock PrimType
tpb ShapeBase SubExp
shpb VName
m_b LMAD
ind_b) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab,
    TopdownEnv rep -> VName -> VName -> Bool
forall rep (inner :: * -> *).
Coalesceable rep inner =>
TopdownEnv rep -> VName -> VName -> Bool
sameSpace TopdownEnv rep
td_env VName
m_x VName
m_b,
    VName
b VName -> Names -> Bool
`nameIn` Names
last_uses =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [(CoalescedKind
InPlaceCoal, (LMAD -> FlatSlice SubExp -> LMAD
`updateIndFunSlice` FlatSlice SubExp
slice_x), VName
x, VName
m_x, LMAD
ind_x, VName
b, VName
m_b, LMAD
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)]
  where
    updateIndFunSlice :: LMAD -> FlatSlice SubExp -> LMAD
    updateIndFunSlice :: LMAD -> FlatSlice SubExp -> LMAD
updateIndFunSlice LMAD
ind_fun (FlatSlice SubExp
offset [FlatDimIndex SubExp]
dims) =
      LMAD -> FlatSlice (TPrimExp Int64 VName) -> LMAD
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
LMAD.flatSlice LMAD
ind_fun (FlatSlice (TPrimExp Int64 VName) -> LMAD)
-> FlatSlice (TPrimExp Int64 VName) -> LMAD
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall d. d -> [FlatDimIndex d] -> FlatSlice d
FlatSlice (SubExp -> TPrimExp Int64 VName
pe64 SubExp
offset) ([FlatDimIndex (TPrimExp Int64 VName)]
 -> FlatSlice (TPrimExp Int64 VName))
-> [FlatDimIndex (TPrimExp Int64 VName)]
-> FlatSlice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName))
-> [FlatDimIndex SubExp] -> [FlatDimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TPrimExp Int64 VName)
-> FlatDimIndex SubExp -> FlatDimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> FlatDimIndex a -> FlatDimIndex b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64) [FlatDimIndex SubExp]
dims

-- CASE b) @let x = concat(a, b^{lu})@
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (BasicOp (Concat Int
concat_dim (VName
b0 :| [VName]
bs) SubExp
_)))
  | Pat [PatElem VName
x (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_x LMAD
ind_x))] <- Pat (LetDec (Aliases rep))
pat,
    Just Names
last_uses <- VName -> InhibitTab -> Maybe Names
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x InhibitTab
lutab =
      Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$
        let ([SSPointInfo]
res, TPrimExp Int64 VName
_, Bool
_) = (([SSPointInfo], TPrimExp Int64 VName, Bool)
 -> VName -> ([SSPointInfo], TPrimExp Int64 VName, Bool))
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
-> [VName]
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names
-> VName
-> VName
-> LMAD
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
-> VName
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
markConcatParts Names
last_uses VName
x VName
m_x LMAD
ind_x) ([], TPrimExp Int64 VName
zero, Bool
True) (VName
b0 VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
bs)
         in if [SSPointInfo] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SSPointInfo]
res then Maybe [SSPointInfo]
forall a. Maybe a
Nothing else [SSPointInfo] -> Maybe [SSPointInfo]
forall a. a -> Maybe a
Just [SSPointInfo]
res
  where
    zero :: TPrimExp Int64 VName
zero = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
    markConcatParts :: Names
-> VName
-> VName
-> LMAD
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
-> VName
-> ([SSPointInfo], TPrimExp Int64 VName, Bool)
markConcatParts Names
_ VName
_ VName
_ LMAD
_ acc :: ([SSPointInfo], TPrimExp Int64 VName, Bool)
acc@([SSPointInfo]
_, TPrimExp Int64 VName
_, Bool
False) VName
_ = ([SSPointInfo], TPrimExp Int64 VName, Bool)
acc
    markConcatParts Names
last_uses VName
x VName
m_x LMAD
ind_x ([SSPointInfo]
acc, TPrimExp Int64 VName
offs, Bool
True) VName
b
      | Just (MemBlock PrimType
tpb shpb :: ShapeBase SubExp
shpb@(Shape dims :: [SubExp]
dims@(SubExp
_ : [SubExp]
_)) VName
m_b LMAD
ind_b) <- VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
b ScopeTab rep
scopetab,
        Just SubExp
d <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
concat_dim [SubExp]
dims,
        TPrimExp Int64 VName
offs' <- TPrimExp Int64 VName
offs TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
d =
          if VName
b VName -> Names -> Bool
`nameIn` Names
last_uses Bool -> Bool -> Bool
&& TopdownEnv rep -> VName -> VName -> Bool
forall rep (inner :: * -> *).
Coalesceable rep inner =>
TopdownEnv rep -> VName -> VName -> Bool
sameSpace TopdownEnv rep
td_env VName
m_x VName
m_b
            then
              let slc :: Slice (TPrimExp Int64 VName)
slc =
                    [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                      (SubExp -> DimIndex (TPrimExp Int64 VName))
-> [SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
zero (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
concat_dim [SubExp]
dims)
                        [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Semigroup a => a -> a -> a
<> [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
offs (SubExp -> TPrimExp Int64 VName
pe64 SubExp
d)]
                        [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Semigroup a => a -> a -> a
<> (SubExp -> DimIndex (TPrimExp Int64 VName))
-> [SubExp] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
zero (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (SubExp -> TPrimExp Int64 VName)
-> SubExp
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TPrimExp Int64 VName
pe64) (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop (Int
concat_dim Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [SubExp]
dims)
               in ( [SSPointInfo]
acc [SSPointInfo] -> [SSPointInfo] -> [SSPointInfo]
forall a. [a] -> [a] -> [a]
++ [(CoalescedKind
ConcatCoal, (LMAD -> Slice (TPrimExp Int64 VName) -> LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
`LMAD.slice` Slice (TPrimExp Int64 VName)
slc), VName
x, VName
m_x, LMAD
ind_x, VName
b, VName
m_b, LMAD
ind_b, PrimType
tpb, ShapeBase SubExp
shpb, StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux)],
                    TPrimExp Int64 VName
offs',
                    Bool
True
                  )
            else ([SSPointInfo]
acc, TPrimExp Int64 VName
offs', Bool
True)
      | Bool
otherwise = ([SSPointInfo]
acc, TPrimExp Int64 VName
offs, Bool
False)
-- case d) short-circuit points from ops. For instance, the result of a segmap
-- can be considered a short-circuit point.
genCoalStmtInfo InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab (Let Pat (LetDec (Aliases rep))
pat StmAux (ExpDec (Aliases rep))
aux (Op Op (Aliases rep)
op)) = do
  InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> OpC rep (Aliases rep)
-> Maybe [SSPointInfo]
ss_op <- (ShortCircuitReader rep
 -> InhibitTab
 -> TopdownEnv rep
 -> ScopeTab rep
 -> Pat (VarAliases, LParamMem)
 -> Certs
 -> OpC rep (Aliases rep)
 -> Maybe [SSPointInfo])
-> ShortCircuitM
     rep
     (InhibitTab
      -> TopdownEnv rep
      -> ScopeTab rep
      -> Pat (VarAliases, LParamMem)
      -> Certs
      -> OpC rep (Aliases rep)
      -> Maybe [SSPointInfo])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> OpC rep (Aliases rep)
-> Maybe [SSPointInfo]
ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> Maybe [SSPointInfo]
forall rep.
ShortCircuitReader rep
-> InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> Op (Aliases rep)
-> Maybe [SSPointInfo]
ssPointFromOp
  Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo]))
-> Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a b. (a -> b) -> a -> b
$ InhibitTab
-> TopdownEnv rep
-> ScopeTab rep
-> Pat (VarAliases, LParamMem)
-> Certs
-> OpC rep (Aliases rep)
-> Maybe [SSPointInfo]
ss_op InhibitTab
lutab TopdownEnv rep
td_env ScopeTab rep
scopetab Pat (VarAliases, LParamMem)
Pat (LetDec (Aliases rep))
pat (StmAux (VarAliases, ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (VarAliases, ExpDec rep)
StmAux (ExpDec (Aliases rep))
aux) OpC rep (Aliases rep)
Op (Aliases rep)
op
-- CASE other than a), b), c), or d) not supported
genCoalStmtInfo InhibitTab
_ TopdownEnv rep
_ ScopeTab rep
_ Stm (Aliases rep)
_ = Maybe [SSPointInfo] -> ShortCircuitM rep (Maybe [SSPointInfo])
forall a. a -> ShortCircuitM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [SSPointInfo]
forall a. Maybe a
Nothing

sameSpace :: (Coalesceable rep inner) => TopdownEnv rep -> VName -> VName -> Bool
sameSpace :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
TopdownEnv rep -> VName -> VName -> Bool
sameSpace TopdownEnv rep
td_env VName
m_x VName
m_b
  | Just (MemMem Space
pat_space) <- NameInfo rep -> LParamMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
nameInfoToMemInfo (NameInfo rep -> LParamMem)
-> Maybe (NameInfo rep) -> Maybe LParamMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName (NameInfo rep) -> Maybe (NameInfo rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_x Map VName (NameInfo rep)
scope',
    Just (MemMem Space
return_space) <- NameInfo rep -> LParamMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
nameInfoToMemInfo (NameInfo rep -> LParamMem)
-> Maybe (NameInfo rep) -> Maybe LParamMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName (NameInfo rep) -> Maybe (NameInfo rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b Map VName (NameInfo rep)
scope' =
      Space
pat_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
return_space
  | Bool
otherwise = Bool
False
  where
    scope' :: Map VName (NameInfo rep)
scope' = Scope (Aliases rep) -> Map VName (NameInfo rep)
forall rep. Scope (Aliases rep) -> Scope rep
removeScopeAliases (Scope (Aliases rep) -> Map VName (NameInfo rep))
-> Scope (Aliases rep) -> Map VName (NameInfo rep)
forall a b. (a -> b) -> a -> b
$ TopdownEnv rep -> Scope (Aliases rep)
forall rep. TopdownEnv rep -> ScopeTab rep
scope TopdownEnv rep
td_env

data MemBodyResult = MemBodyResult
  { MemBodyResult -> VName
patMem :: VName,
    MemBodyResult -> VName
_patName :: VName,
    MemBodyResult -> VName
bodyName :: VName,
    MemBodyResult -> VName
bodyMem :: VName
  }

-- | Results in pairs of pattern-blockresult pairs of (var name, mem block)
--   for those if-patterns that are candidates for coalescing.
findMemBodyResult ::
  (HasMemBlock (Aliases rep)) =>
  CoalsTab ->
  ScopeTab rep ->
  [PatElem (VarAliases, LetDecMem)] ->
  Body (Aliases rep) ->
  [MemBodyResult]
findMemBodyResult :: forall rep.
HasMemBlock (Aliases rep) =>
CoalsTab
-> ScopeTab rep
-> [PatElem (VarAliases, LParamMem)]
-> Body (Aliases rep)
-> [MemBodyResult]
findMemBodyResult CoalsTab
activeCoals_tab ScopeTab rep
scope_env [PatElem (VarAliases, LParamMem)]
patelms Body (Aliases rep)
bdy =
  ((PatElem (VarAliases, LParamMem), SubExp) -> Maybe MemBodyResult)
-> [(PatElem (VarAliases, LParamMem), SubExp)] -> [MemBodyResult]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
    (PatElem (VarAliases, LParamMem), SubExp) -> Maybe MemBodyResult
findMemBodyResult'
    ([PatElem (VarAliases, LParamMem)]
-> [SubExp] -> [(PatElem (VarAliases, LParamMem), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (VarAliases, LParamMem)]
patelms ([SubExp] -> [(PatElem (VarAliases, LParamMem), SubExp)])
-> [SubExp] -> [(PatElem (VarAliases, LParamMem), SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [SubExp]) -> Result -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Result
forall rep. Body rep -> Result
bodyResult Body (Aliases rep)
bdy)
  where
    scope_env' :: ScopeTab rep
scope_env' = ScopeTab rep
scope_env ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
bdy)
    findMemBodyResult' :: (PatElem (VarAliases, LParamMem), SubExp) -> Maybe MemBodyResult
findMemBodyResult' (PatElem (VarAliases, LParamMem)
patel, SubExp
se_r) =
      case (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
patel, PatElem (VarAliases, LParamMem) -> (VarAliases, LParamMem)
forall dec. PatElem dec -> dec
patElemDec PatElem (VarAliases, LParamMem)
patel, SubExp
se_r) of
        (VName
b, (VarAliases
_, MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
m_b LMAD
_)), Var VName
r) ->
          case VName -> ScopeTab rep -> Maybe ArrayMemBound
forall rep.
HasMemBlock rep =>
VName -> Scope rep -> Maybe ArrayMemBound
getScopeMemInfo VName
r ScopeTab rep
scope_env' of
            Maybe ArrayMemBound
Nothing -> Maybe MemBodyResult
forall a. Maybe a
Nothing
            Just (MemBlock PrimType
_ ShapeBase SubExp
_ VName
m_r LMAD
_) ->
              case VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
activeCoals_tab of
                Maybe CoalsEntry
Nothing -> Maybe MemBodyResult
forall a. Maybe a
Nothing
                Just CoalsEntry
coal_etry ->
                  case VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
coal_etry) of
                    Maybe Coalesced
Nothing -> Maybe MemBodyResult
forall a. Maybe a
Nothing
                    Just Coalesced
_ -> MemBodyResult -> Maybe MemBodyResult
forall a. a -> Maybe a
Just (MemBodyResult -> Maybe MemBodyResult)
-> MemBodyResult -> Maybe MemBodyResult
forall a b. (a -> b) -> a -> b
$ VName -> VName -> VName -> VName -> MemBodyResult
MemBodyResult VName
m_b VName
b VName
r VName
m_r
        (VName, (VarAliases, LParamMem), SubExp)
_ -> Maybe MemBodyResult
forall a. Maybe a
Nothing

-- | transfers coalescing from if-pattern to then|else body result
--   in the active coalesced table. The transfer involves, among
--   others, inserting @(r,m_r)@ in the optimistically-dependency
--   set of @m_b@'s entry and inserting @(b,m_b)@ in the opt-deps
--   set of @m_r@'s entry. Meaning, ultimately, @m_b@ can be merged
--   if @m_r@ can be merged (and vice-versa). This is checked by a
--   fix point iteration at the function-definition level.
transferCoalsToBody ::
  M.Map VName (TPrimExp Int64 VName) -> -- (PrimExp VName)
  CoalsTab ->
  MemBodyResult ->
  CoalsTab
transferCoalsToBody :: FreeVarSubsts -> CoalsTab -> MemBodyResult -> CoalsTab
transferCoalsToBody FreeVarSubsts
exist_subs CoalsTab
activeCoals_tab (MemBodyResult VName
m_b VName
b VName
r VName
m_r)
  | -- the @Nothing@ pattern for the two lookups cannot happen
    -- because they were already cheked in @findMemBodyResult@
    Just CoalsEntry
etry <- VName -> CoalsTab -> Maybe CoalsEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m_b CoalsTab
activeCoals_tab,
    Just (Coalesced CoalescedKind
knd (MemBlock PrimType
btp ShapeBase SubExp
shp VName
_ LMAD
ind_b) FreeVarSubsts
subst_b) <- VName -> Map VName Coalesced -> Maybe Coalesced
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
b (Map VName Coalesced -> Maybe Coalesced)
-> Map VName Coalesced -> Maybe Coalesced
forall a b. (a -> b) -> a -> b
$ CoalsEntry -> Map VName Coalesced
vartab CoalsEntry
etry =
      -- by definition of if-stmt, r and b have the same basic type, shape and
      -- index function, hence, for example, do not need to rebase
      -- We will check whether it is translatable at the definition point of r.
      let ind_r :: LMAD
ind_r = FreeVarSubsts -> LMAD -> LMAD
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
LMAD.substitute FreeVarSubsts
exist_subs LMAD
ind_b
          subst_r :: FreeVarSubsts
subst_r = FreeVarSubsts -> FreeVarSubsts -> FreeVarSubsts
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union FreeVarSubsts
exist_subs FreeVarSubsts
subst_b
          mem_info :: Coalesced
mem_info = CoalescedKind -> ArrayMemBound -> FreeVarSubsts -> Coalesced
Coalesced CoalescedKind
knd (PrimType -> ShapeBase SubExp -> VName -> LMAD -> ArrayMemBound
MemBlock PrimType
btp ShapeBase SubExp
shp (CoalsEntry -> VName
dstmem CoalsEntry
etry) LMAD
ind_r) FreeVarSubsts
subst_r
       in if VName
m_r VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
m_b -- already unified, just add binding for @r@
            then
              let etry' :: CoalsEntry
etry' =
                    CoalsEntry
etry
                      { optdeps = M.insert b m_b (optdeps etry),
                        vartab = M.insert r mem_info (vartab etry)
                      }
               in VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r CoalsEntry
etry' CoalsTab
activeCoals_tab
            else -- make them both optimistically depend on each other

              let opts_x_new :: Map VName VName
opts_x_new = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
r VName
m_r (CoalsEntry -> Map VName VName
optdeps CoalsEntry
etry)
                  -- Here we should translate the @ind_b@ field of @mem_info@
                  -- across the existential introduced by the if-then-else
                  coal_etry :: CoalsEntry
coal_etry =
                    CoalsEntry
etry
                      { vartab = M.singleton r mem_info,
                        optdeps = M.insert b m_b (optdeps etry)
                      }
               in VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_b (CoalsEntry
etry {optdeps = opts_x_new}) (CoalsTab -> CoalsTab) -> CoalsTab -> CoalsTab
forall a b. (a -> b) -> a -> b
$
                    VName -> CoalsEntry -> CoalsTab -> CoalsTab
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
m_r CoalsEntry
coal_etry CoalsTab
activeCoals_tab
  | Bool
otherwise = String -> CoalsTab
forall a. HasCallStack => String -> a
error String
"Impossible"

mkSubsTab ::
  Pat (aliases, LetDecMem) ->
  [SubExp] ->
  M.Map VName (TPrimExp Int64 VName)
mkSubsTab :: forall aliases.
Pat (aliases, LParamMem) -> [SubExp] -> FreeVarSubsts
mkSubsTab Pat (aliases, LParamMem)
pat [SubExp]
res =
  let pat_elms :: [PatElem (aliases, LParamMem)]
pat_elms = Pat (aliases, LParamMem) -> [PatElem (aliases, LParamMem)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (aliases, LParamMem)
pat
   in [(VName, TPrimExp Int64 VName)] -> FreeVarSubsts
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, TPrimExp Int64 VName)] -> FreeVarSubsts)
-> [(VName, TPrimExp Int64 VName)] -> FreeVarSubsts
forall a b. (a -> b) -> a -> b
$ ((PatElem (aliases, LParamMem), SubExp)
 -> Maybe (VName, TPrimExp Int64 VName))
-> [(PatElem (aliases, LParamMem), SubExp)]
-> [(VName, TPrimExp Int64 VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (PatElem (aliases, LParamMem), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
forall {a} {d} {u} {ret}.
(PatElem (a, MemInfo d u ret), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
mki64subst ([(PatElem (aliases, LParamMem), SubExp)]
 -> [(VName, TPrimExp Int64 VName)])
-> [(PatElem (aliases, LParamMem), SubExp)]
-> [(VName, TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ [PatElem (aliases, LParamMem)]
-> [SubExp] -> [(PatElem (aliases, LParamMem), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (aliases, LParamMem)]
pat_elms [SubExp]
res
  where
    mki64subst :: (PatElem (a, MemInfo d u ret), SubExp)
-> Maybe (VName, TPrimExp Int64 VName)
mki64subst (PatElem (a, MemInfo d u ret)
a, Var VName
v)
      | (a
_, MemPrim (IntType IntType
Int64)) <- PatElem (a, MemInfo d u ret) -> (a, MemInfo d u ret)
forall dec. PatElem dec -> dec
patElemDec PatElem (a, MemInfo d u ret)
a = (VName, TPrimExp Int64 VName)
-> Maybe (VName, TPrimExp Int64 VName)
forall a. a -> Maybe a
Just (PatElem (a, MemInfo d u ret) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (a, MemInfo d u ret)
a, VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
v)
    mki64subst (PatElem (a, MemInfo d u ret)
a, se :: SubExp
se@(Constant (IntValue (Int64Value Int64
_)))) = (VName, TPrimExp Int64 VName)
-> Maybe (VName, TPrimExp Int64 VName)
forall a. a -> Maybe a
Just (PatElem (a, MemInfo d u ret) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (a, MemInfo d u ret)
a, SubExp -> TPrimExp Int64 VName
pe64 SubExp
se)
    mki64subst (PatElem (a, MemInfo d u ret), SubExp)
_ = Maybe (VName, TPrimExp Int64 VName)
forall a. Maybe a
Nothing

computeScalarTable ::
  (Coalesceable rep inner) =>
  ScopeTab rep ->
  Stm (Aliases rep) ->
  ScalarTableM rep (M.Map VName (PrimExp VName))
computeScalarTable :: forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable ScopeTab rep
scope_table (Let (Pat [PatElem (LetDec (Aliases rep))
pe]) StmAux (ExpDec (Aliases rep))
_ Exp (Aliases rep)
e)
  | Just PrimExp VName
primexp <- (VName -> Maybe (PrimExp VName))
-> Exp (Aliases rep) -> Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
forall rep.
AliasableRep rep =>
ScopeTab rep
-> Map VName (PrimExp VName) -> VName -> Maybe (PrimExp VName)
vnameToPrimExp ScopeTab rep
scope_table Map VName (PrimExp VName)
forall a. Monoid a => a
mempty) Exp (Aliases rep)
e =
      Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName (PrimExp VName)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Map VName (PrimExp VName)
forall k a. k -> a -> Map k a
M.singleton (PatElem (VarAliases, LParamMem) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (VarAliases, LParamMem)
PatElem (LetDec (Aliases rep))
pe) PrimExp VName
primexp
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Loop [(FParam (Aliases rep), SubExp)]
loop_inits LoopForm
loop_form Body (Aliases rep)
body)) =
  (Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$
        ScopeTab rep
scope_table
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> [Param FParamMem] -> ScopeTab rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam (Aliases rep), SubExp)]
[(Param FParamMem, SubExp)]
loop_inits)
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> LoopForm -> ScopeTab rep
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
loop_form
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
    )
    (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Match [SubExp]
_ [Case (Body (Aliases rep))]
cases Body (Aliases rep)
body MatchDec (BranchType (Aliases rep))
_)) = do
  Map VName (PrimExp VName)
body_tab <- (Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM (ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ ScopeTab rep
scope_table ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)) (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body)
  Map VName (PrimExp VName)
cases_tab <-
    (Case (Body (Aliases rep))
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Case (Body (Aliases rep))]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
      ( \(Case [Maybe PrimValue]
_ Body (Aliases rep)
b) ->
          (Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
            (ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ ScopeTab rep
scope_table ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
b))
            ( Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$
                Body (Aliases rep) -> Stms (Aliases rep)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases rep)
body
            )
      )
      [Case (Body (Aliases rep))]
cases
  Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName (PrimExp VName)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ Map VName (PrimExp VName)
body_tab Map VName (PrimExp VName)
-> Map VName (PrimExp VName) -> Map VName (PrimExp VName)
forall a. Semigroup a => a -> a -> a
<> Map VName (PrimExp VName)
cases_tab
computeScalarTable ScopeTab rep
scope_table (Let Pat (LetDec (Aliases rep))
_ StmAux (ExpDec (Aliases rep))
_ (Op Op (Aliases rep)
op)) = do
  ScopeTab rep
-> MemOp inner (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
on_op <- (ComputeScalarTableOnOp rep
 -> ScopeTab rep
 -> MemOp inner (Aliases rep)
 -> ReaderT
      (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
-> ReaderT
     (ComputeScalarTableOnOp rep)
     Identity
     (ScopeTab rep
      -> MemOp inner (Aliases rep)
      -> ReaderT
           (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName)))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ComputeScalarTableOnOp rep
-> ComputeScalarTable rep (Op (Aliases rep))
ComputeScalarTableOnOp rep
-> ScopeTab rep
-> MemOp inner (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall rep.
ComputeScalarTableOnOp rep
-> ComputeScalarTable rep (Op (Aliases rep))
scalarTableOnOp
  ScopeTab rep
-> MemOp inner (Aliases rep)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
on_op ScopeTab rep
scope_table Op (Aliases rep)
MemOp inner (Aliases rep)
op
computeScalarTable ScopeTab rep
_ Stm (Aliases rep)
_ = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty

computeScalarTableMemOp ::
  ComputeScalarTable rep (inner (Aliases rep)) -> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp :: forall rep (inner :: * -> *).
ComputeScalarTable rep (inner (Aliases rep))
-> ComputeScalarTable rep (MemOp inner (Aliases rep))
computeScalarTableMemOp ComputeScalarTable rep (inner (Aliases rep))
_ ScopeTab rep
_ (Alloc SubExp
_ Space
_) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp rep) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp rep) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableMemOp ComputeScalarTable rep (inner (Aliases rep))
onInner ScopeTab rep
scope_table (Inner inner (Aliases rep)
op) = ComputeScalarTable rep (inner (Aliases rep))
onInner ScopeTab rep
scope_table inner (Aliases rep)
op

computeScalarTableSegOp ::
  (Coalesceable rep inner) =>
  ComputeScalarTable rep (GPU.SegOp lvl (Aliases rep))
computeScalarTableSegOp :: forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab rep
scope_table SegOp lvl (Aliases rep)
segop = do
  (Stm (Aliases rep) -> ScalarTableM rep (Map VName (PrimExp VName)))
-> [Stm (Aliases rep)]
-> ScalarTableM rep (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    ( ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab rep
 -> Stm (Aliases rep)
 -> ScalarTableM rep (Map VName (PrimExp VName)))
-> ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$
        ScopeTab rep
scope_table
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases rep) -> ScopeTab rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms (KernelBody (Aliases rep) -> Stms (Aliases rep))
-> KernelBody (Aliases rep) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$ SegOp lvl (Aliases rep) -> KernelBody (Aliases rep)
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl (Aliases rep)
segop)
          ScopeTab rep -> ScopeTab rep -> ScopeTab rep
forall a. Semigroup a => a -> a -> a
<> SegSpace -> ScopeTab rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp lvl (Aliases rep) -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl (Aliases rep)
segop)
    )
    (Stms (Aliases rep) -> [Stm (Aliases rep)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases rep) -> [Stm (Aliases rep)])
-> Stms (Aliases rep) -> [Stm (Aliases rep)]
forall a b. (a -> b) -> a -> b
$ KernelBody (Aliases rep) -> Stms (Aliases rep)
forall rep. KernelBody rep -> Stms rep
kernelBodyStms (KernelBody (Aliases rep) -> Stms (Aliases rep))
-> KernelBody (Aliases rep) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$ SegOp lvl (Aliases rep) -> KernelBody (Aliases rep)
forall lvl rep. SegOp lvl rep -> KernelBody rep
segBody SegOp lvl (Aliases rep)
segop)

computeScalarTableGPUMem ::
  ComputeScalarTable GPUMem (GPU.HostOp NoOp (Aliases GPUMem))
computeScalarTableGPUMem :: ComputeScalarTable GPUMem (HostOp NoOp (Aliases GPUMem))
computeScalarTableGPUMem ScopeTab GPUMem
scope_table (GPU.SegOp SegOp SegLevel (Aliases GPUMem)
segop) =
  ComputeScalarTable GPUMem (SegOp SegLevel (Aliases GPUMem))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab GPUMem
scope_table SegOp SegLevel (Aliases GPUMem)
segop
computeScalarTableGPUMem ScopeTab GPUMem
_ (GPU.SizeOp SizeOp
_) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp GPUMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
_ (GPU.OtherOp NoOp (Aliases GPUMem)
NoOp) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp GPUMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableGPUMem ScopeTab GPUMem
scope_table (GPU.GPUBody [Type]
_ Body (Aliases GPUMem)
body) =
  (Stm (Aliases GPUMem)
 -> ReaderT
      (ComputeScalarTableOnOp GPUMem)
      Identity
      (Map VName (PrimExp VName)))
-> [Stm (Aliases GPUMem)]
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM
    (ScopeTab GPUMem
-> Stm (Aliases GPUMem)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall rep (inner :: * -> *).
Coalesceable rep inner =>
ScopeTab rep
-> Stm (Aliases rep)
-> ScalarTableM rep (Map VName (PrimExp VName))
computeScalarTable (ScopeTab GPUMem
 -> Stm (Aliases GPUMem)
 -> ReaderT
      (ComputeScalarTableOnOp GPUMem)
      Identity
      (Map VName (PrimExp VName)))
-> ScopeTab GPUMem
-> Stm (Aliases GPUMem)
-> ReaderT
     (ComputeScalarTableOnOp GPUMem)
     Identity
     (Map VName (PrimExp VName))
forall a b. (a -> b) -> a -> b
$ ScopeTab GPUMem
scope_table ScopeTab GPUMem -> ScopeTab GPUMem -> ScopeTab GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms (Aliases GPUMem) -> ScopeTab GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Body (Aliases GPUMem) -> Stms (Aliases GPUMem)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases GPUMem)
body))
    (Stms (Aliases GPUMem) -> [Stm (Aliases GPUMem)]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms (Aliases GPUMem) -> [Stm (Aliases GPUMem)])
-> Stms (Aliases GPUMem) -> [Stm (Aliases GPUMem)]
forall a b. (a -> b) -> a -> b
$ Body (Aliases GPUMem) -> Stms (Aliases GPUMem)
forall rep. Body rep -> Stms rep
bodyStms Body (Aliases GPUMem)
body)

computeScalarTableMCMem ::
  ComputeScalarTable MCMem (MC.MCOp NoOp (Aliases MCMem))
computeScalarTableMCMem :: ComputeScalarTable MCMem (MCOp NoOp (Aliases MCMem))
computeScalarTableMCMem ScopeTab MCMem
_ (MC.OtherOp NoOp (Aliases MCMem)
NoOp) = Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp MCMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty
computeScalarTableMCMem ScopeTab MCMem
scope_table (MC.ParOp Maybe (SegOp () (Aliases MCMem))
par_op SegOp () (Aliases MCMem)
segop) =
  Map VName (PrimExp VName)
-> Map VName (PrimExp VName) -> Map VName (PrimExp VName)
forall a. Semigroup a => a -> a -> a
(<>)
    (Map VName (PrimExp VName)
 -> Map VName (PrimExp VName) -> Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem)
     Identity
     (Map VName (PrimExp VName) -> Map VName (PrimExp VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT
  (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
-> (SegOp () (Aliases MCMem)
    -> ReaderT
         (ComputeScalarTableOnOp MCMem)
         Identity
         (Map VName (PrimExp VName)))
-> Maybe (SegOp () (Aliases MCMem))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Map VName (PrimExp VName)
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall a. a -> ReaderT (ComputeScalarTableOnOp MCMem) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName (PrimExp VName)
forall a. Monoid a => a
mempty) (ComputeScalarTable MCMem (SegOp () (Aliases MCMem))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab MCMem
scope_table) Maybe (SegOp () (Aliases MCMem))
par_op
    ReaderT
  (ComputeScalarTableOnOp MCMem)
  Identity
  (Map VName (PrimExp VName) -> Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
-> ReaderT
     (ComputeScalarTableOnOp MCMem) Identity (Map VName (PrimExp VName))
forall a b.
ReaderT (ComputeScalarTableOnOp MCMem) Identity (a -> b)
-> ReaderT (ComputeScalarTableOnOp MCMem) Identity a
-> ReaderT (ComputeScalarTableOnOp MCMem) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ComputeScalarTable MCMem (SegOp () (Aliases MCMem))
forall rep (inner :: * -> *) lvl.
Coalesceable rep inner =>
ComputeScalarTable rep (SegOp lvl (Aliases rep))
computeScalarTableSegOp ScopeTab MCMem
scope_table SegOp () (Aliases MCMem)
segop

filterMapM1 :: (Eq k, Monad m) => (v -> m Bool) -> M.Map k v -> m (M.Map k v)
filterMapM1 :: forall k (m :: * -> *) v.
(Eq k, Monad m) =>
(v -> m Bool) -> Map k v -> m (Map k v)
filterMapM1 v -> m Bool
f Map k v
m = ([(k, v)] -> Map k v) -> m [(k, v)] -> m (Map k v)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(k, v)] -> Map k v
forall k a. Eq k => [(k, a)] -> Map k a
M.fromAscList (m [(k, v)] -> m (Map k v)) -> m [(k, v)] -> m (Map k v)
forall a b. (a -> b) -> a -> b
$ ((k, v) -> m Bool) -> [(k, v)] -> m [(k, v)]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (v -> m Bool
f (v -> m Bool) -> ((k, v) -> v) -> (k, v) -> m Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (k, v) -> v
forall a b. (a, b) -> b
snd) ([(k, v)] -> m [(k, v)]) -> [(k, v)] -> m [(k, v)]
forall a b. (a -> b) -> a -> b
$ Map k v -> [(k, v)]
forall k a. Map k a -> [(k, a)]
M.toAscList Map k v
m