{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform array short circuiting
module Futhark.Optimise.ArrayShortCircuiting
  ( optimiseSeqMem,
    optimiseGPUMem,
    optimiseMCMem,
  )
where

import Control.Monad.Reader
import Data.Function ((&))
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Futhark.Analysis.Alias qualified as AnlAls
import Futhark.IR.Aliases
import Futhark.IR.GPUMem
import Futhark.IR.MCMem
import Futhark.IR.Mem.IxFun (substituteInIxFun)
import Futhark.IR.SeqMem
import Futhark.Optimise.ArrayShortCircuiting.ArrayCoalescing
import Futhark.Optimise.ArrayShortCircuiting.DataStructs
import Futhark.Pass (Pass (..))
import Futhark.Pass qualified as Pass
import Futhark.Util

data Env inner = Env
  { forall inner. Env inner -> CoalsTab
envCoalesceTab :: CoalsTab,
    forall inner. Env inner -> inner -> UpdateM inner inner
onInner :: inner -> UpdateM inner inner,
    forall inner. Env inner -> Names
memAllocsToRemove :: Names
  }

type UpdateM inner a = Reader (Env inner) a

optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem :: Pass SeqMem SeqMem
optimiseSeqMem = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit" String
"Array Short-Circuiting" forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases SeqMem) -> m (Map Name CoalsTab)
mkCoalsTab forall (f :: * -> *) a. Applicative f => a -> f a
pure CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem :: Pass GPUMem GPUMem
optimiseGPUMem = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-gpu" String
"Array Short-Circuiting (GPU)" forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases GPUMem) -> m (Map Name CoalsTab)
mkCoalsTabGPU HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem :: Pass MCMem MCMem
optimiseMCMem = forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
"short-circuit-mc" String
"Array Short-Circuiting (MC)" forall (m :: * -> *).
MonadFreshNames m =>
Prog (Aliases MCMem) -> m (Map Name CoalsTab)
mkCoalsTabMC MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams

replaceInParams :: CoalsTab -> [Param FParamMem] -> (Names, [Param FParamMem])
replaceInParams :: CoalsTab
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParams CoalsTab
coalstab [Param (MemInfo SubExp Uniqueness MemBind)]
fparams =
  let (Names
mem_allocs_to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
fparams') =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty) [Param (MemInfo SubExp Uniqueness MemBind)]
fparams
   in (Names
mem_allocs_to_remove, forall a. [a] -> [a]
reverse [Param (MemInfo SubExp Uniqueness MemBind)]
fparams')
  where
    replaceInParam :: (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
-> Param (MemInfo SubExp Uniqueness MemBind)
-> (Names, [Param (MemInfo SubExp Uniqueness MemBind)])
replaceInParam (Names
to_remove, [Param (MemInfo SubExp Uniqueness MemBind)]
acc) (Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec) =
      case MemInfo SubExp Uniqueness MemBind
dec of
        MemMem Space
_
          | Just CoalsEntry
entry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name CoalsTab
coalstab ->
              (VName -> Names
oneName (CoalsEntry -> VName
dstmem CoalsEntry
entry) forall a. Semigroup a => a -> a -> a
<> Names
to_remove, forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs (CoalsEntry -> VName
dstmem CoalsEntry
entry) MemInfo SubExp Uniqueness MemBind
dec forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
        MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u (ArrayIn VName
m IxFun
ixf)
          | Just CoalsEntry
entry <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
m CoalsTab
coalstab ->
              (Names
to_remove, forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name (forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp Uniqueness
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn (CoalsEntry -> VName
dstmem CoalsEntry
entry) IxFun
ixf) forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)
        MemInfo SubExp Uniqueness MemBind
_ -> (Names
to_remove, forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
name MemInfo SubExp Uniqueness MemBind
dec forall a. a -> [a] -> [a]
: [Param (MemInfo SubExp Uniqueness MemBind)]
acc)

removeAllocsInStms :: Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms :: forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms = do
  Names
to_remove <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> Names
memAllocsToRemove
  forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
    forall a b. a -> (a -> b) -> b
& forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Names -> Bool
nameIn Names
to_remove forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [VName]
patNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Pat (LetDec rep)
stmPat)
    forall a b. a -> (a -> b) -> b
& forall rep. [Stm rep] -> Stms rep
stmsFromList
    forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure

pass ::
  (Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
  String ->
  String ->
  (Prog (Aliases rep) -> Pass.PassM (M.Map Name CoalsTab)) ->
  (inner rep -> UpdateM (inner rep) (inner rep)) ->
  (CoalsTab -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])) ->
  Pass rep rep
pass :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem, AliasableRep rep) =>
String
-> String
-> (Prog (Aliases rep) -> PassM (Map Name CoalsTab))
-> (inner rep -> UpdateM (inner rep) (inner rep))
-> (CoalsTab
    -> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)]))
-> Pass rep rep
pass String
flag String
desc Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk inner rep -> UpdateM (inner rep) (inner rep)
on_inner CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams =
  forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
flag String
desc forall a b. (a -> b) -> a -> b
$ \Prog rep
prog -> do
    Map Name CoalsTab
coaltabs <- Prog (Aliases rep) -> PassM (Map Name CoalsTab)
mk forall a b. (a -> b) -> a -> b
$ forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
AnlAls.aliasAnalysis Prog rep
prog
    forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
Pass.intraproceduralTransformationWithConsts forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs) Prog rep
prog
  where
    onFun :: Map Name CoalsTab -> Stms rep -> FunDef rep -> PassM (FunDef rep)
onFun Map Name CoalsTab
coaltabs Stms rep
_ FunDef rep
f = do
      let coaltab :: CoalsTab
coaltab = Map Name CoalsTab
coaltabs forall k a. Ord k => Map k a -> k -> a
M.! forall rep. FunDef rep -> Name
funDefName FunDef rep
f
      let (Names
mem_allocs_to_remove, [FParam (Aliases rep)]
new_fparams) = CoalsTab
-> [FParam (Aliases rep)] -> (Names, [FParam (Aliases rep)])
on_fparams CoalsTab
coaltab forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef rep
f
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        FunDef rep
f
          { funDefBody :: Body rep
funDefBody = CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove forall a b. (a -> b) -> a -> b
$ forall rep. FunDef rep -> Body rep
funDefBody FunDef rep
f,
            funDefParams :: [FParam rep]
funDefParams = [FParam (Aliases rep)]
new_fparams
          }

    onBody :: CoalsTab -> Names -> Body rep -> Body rep
onBody CoalsTab
coaltab Names
mem_allocs_to_remove Body rep
body =
      Body rep
body
        { bodyStms :: Stms rep
bodyStms =
            forall r a. Reader r a -> r -> a
runReader
              (forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms Body rep
body)
              (forall inner.
CoalsTab -> (inner -> UpdateM inner inner) -> Names -> Env inner
Env CoalsTab
coaltab inner rep -> UpdateM (inner rep) (inner rep)
on_inner Names
mem_allocs_to_remove),
          bodyResult :: Result
bodyResult = forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult Body rep
body
        }

replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem :: CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab SubExpRes
res =
  case forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup CoalsTab
coaltab forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExpRes -> Maybe VName
subExpResVName SubExpRes
res of
    Just CoalsEntry
entry -> SubExpRes
res {resSubExp :: SubExp
resSubExp = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ CoalsEntry -> VName
dstmem CoalsEntry
entry}
    Maybe CoalsEntry
Nothing -> SubExpRes
res

updateStms ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stms rep ->
  UpdateM (inner rep) (Stms rep)
updateStms :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms = do
  Stms rep
stms' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm Stms rep
stms
  forall rep inner. Stms rep -> UpdateM inner (Stms rep)
removeAllocsInStms Stms rep
stms'

replaceInStm ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Stm rep ->
  UpdateM (inner rep) (Stm rep)
replaceInStm :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stm rep -> UpdateM (inner rep) (Stm rep)
replaceInStm (Let (Pat [PatElem (LetDec rep)]
elems) StmAux (ExpDec rep)
d Exp rep
e) = do
  [PatElem LetDecMem]
elems' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem [PatElem (LetDec rep)]
elems
  Exp rep
e' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
elems' Exp rep
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LetDecMem]
elems') StmAux (ExpDec rep)
d Exp rep
e'
  where
    replaceInPatElem :: PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
    replaceInPatElem :: forall inner.
PatElem LetDecMem -> UpdateM inner (PatElem LetDecMem)
replaceInPatElem p :: PatElem LetDecMem
p@(PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
u MemBind
_)) =
      forall a. a -> Maybe a -> a
fromMaybe PatElem LetDecMem
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname forall dec. VName -> dec -> PatElem dec
PatElem NoUniqueness
u
    replaceInPatElem PatElem LetDecMem
p = forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem LetDecMem
p

replaceInExp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  [PatElem LetDecMem] ->
  Exp rep ->
  UpdateM (inner rep) (Exp rep)
replaceInExp :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
[PatElem LetDecMem] -> Exp rep -> UpdateM (inner rep) (Exp rep)
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@(BasicOp BasicOp
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
pat_elems (Match [SubExp]
cond_ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
dec) = do
  Body rep
defbody' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
defbody
  [Case (Body rep)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Case [Maybe PrimValue]
p Body rep
b) -> forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody Body rep
b) [Case (Body rep)]
cases
  [BranchTypeMem]
case_rets <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun [PatElem LetDecMem]
pat_elems) [PatElem LetDecMem]
pat_elems forall a b. (a -> b) -> a -> b
$ forall rt. MatchDec rt -> [rt]
matchReturns MatchDec (BranchType rep)
dec
  let dec' :: MatchDec BranchTypeMem
dec' = MatchDec (BranchType rep)
dec {matchReturns :: [BranchTypeMem]
matchReturns = [BranchTypeMem]
case_rets}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond_ses [Case (Body rep)]
cases' Body rep
defbody' MatchDec BranchTypeMem
dec'
replaceInExp [PatElem LetDecMem]
_ (DoLoop [(FParam rep, SubExp)]
loop_inits LoopForm rep
loop_form (Body BodyDec rep
dec Stms rep
stms Result
res)) = do
  [Param (MemInfo SubExp Uniqueness MemBind)]
loop_inits' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
loop_inits
  Stms rep
stms' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms
  CoalsTab
coalstab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
  let res' :: Result
res' = forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coalstab) Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
loop_inits' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(FParam rep, SubExp)]
loop_inits) LoopForm rep
loop_form forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms' Result
res'
replaceInExp [PatElem LetDecMem]
_ (Op Op rep
op) =
  case Op rep
op of
    Inner inner rep
i -> do
      inner rep -> UpdateM (inner rep) (inner rep)
on_op <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> inner -> UpdateM inner inner
onInner
      forall rep. Op rep -> Exp rep
Op forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> inner rep -> UpdateM (inner rep) (inner rep)
on_op inner rep
i
    Op rep
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op Op rep
op
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@WithAcc {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e
replaceInExp [PatElem LetDecMem]
_ e :: Exp rep
e@Apply {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp rep
e

replaceInSegOp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  SegOp lvl rep ->
  UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp :: forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}
replaceInSegOp (SegHist lvl
lvl SegSpace
sp [HistOp rep]
hist_ops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp rep]
hist_ops [Type]
tps forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms :: Stms rep
kernelBodyStms = Stms rep
stms}

replaceInHostOp :: HostOp NoOp GPUMem -> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp :: HostOp NoOp GPUMem
-> UpdateM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
replaceInHostOp (SegOp SegOp SegLevel GPUMem
op) = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp SegLevel GPUMem
op
replaceInHostOp HostOp NoOp GPUMem
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp NoOp GPUMem
op

replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp :: MCOp NoOp MCMem -> UpdateM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
replaceInMCOp (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op) =
  forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp Maybe (SegOp () MCMem)
par_op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> UpdateM (inner rep) (SegOp lvl rep)
replaceInSegOp SegOp () MCMem
op
replaceInMCOp MCOp NoOp MCMem
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp NoOp MCMem
op

generalizeIxfun :: [PatElem dec] -> PatElem LetDecMem -> BodyReturns -> UpdateM inner BodyReturns
generalizeIxfun :: forall dec inner.
[PatElem dec]
-> PatElem LetDecMem
-> BranchTypeMem
-> UpdateM inner BranchTypeMem
generalizeIxfun
  [PatElem dec]
pat_elems
  (PatElem VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem IxFun
ixf)))
  m :: BranchTypeMem
m@(MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u MemReturn
_) = do
    CoalsTab
coaltab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
    if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall k a. Ord k => k -> Map k a -> Bool
M.member VName
vname forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoalsEntry -> Map VName Coalesced
vartab) CoalsTab
coaltab
      then
        [VName] -> IxFun -> ExtIxFun
existentialiseIxFun (forall a b. (a -> b) -> [a] -> [b]
map forall dec. PatElem dec -> VName
patElemName [PatElem dec]
pat_elems) IxFun
ixf
          forall a b. a -> (a -> b) -> b
& VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem
          forall a b. a -> (a -> b) -> b
& forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shp NoUniqueness
u
          forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
      else forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m
generalizeIxfun [PatElem dec]
_ PatElem LetDecMem
_ BranchTypeMem
m = forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
m

replaceInIfBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> UpdateM (inner rep) (Body rep)
replaceInIfBody b :: Body rep
b@(Body BodyDec rep
_ Stms rep
stms Result
res) = do
  CoalsTab
coaltab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
  Stms rep
stms' <- forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep -> UpdateM (inner rep) (Stms rep)
updateStms Stms rep
stms
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body rep
b {bodyStms :: Stms rep
bodyStms = Stms rep
stms', bodyResult :: Result
bodyResult = forall a b. (a -> b) -> [a] -> [b]
map (CoalsTab -> SubExpRes -> SubExpRes
replaceResMem CoalsTab
coaltab) Result
res}

replaceInFParam :: Param FParamMem -> UpdateM inner (Param FParamMem)
replaceInFParam :: forall inner.
Param (MemInfo SubExp Uniqueness MemBind)
-> UpdateM inner (Param (MemInfo SubExp Uniqueness MemBind))
replaceInFParam p :: Param (MemInfo SubExp Uniqueness MemBind)
p@(Param Attrs
_ VName
vname (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
u MemBind
_)) = do
  forall a. a -> Maybe a -> a
fromMaybe Param (MemInfo SubExp Uniqueness MemBind)
p forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname (forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty) Uniqueness
u
replaceInFParam Param (MemInfo SubExp Uniqueness MemBind)
p = forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemInfo SubExp Uniqueness MemBind)
p

lookupAndReplace ::
  VName ->
  (VName -> MemBound u -> a) ->
  u ->
  UpdateM inner (Maybe a)
lookupAndReplace :: forall u a inner.
VName -> (VName -> MemBound u -> a) -> u -> UpdateM inner (Maybe a)
lookupAndReplace VName
vname VName -> MemBound u -> a
f u
u = do
  CoalsTab
coaltab <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall inner. Env inner -> CoalsTab
envCoalesceTab
  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
vname forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CoalsEntry -> Map VName Coalesced
vartab CoalsTab
coaltab of
    Just (Coalesced CoalescedKind
_ (MemBlock PrimType
pt ShapeBase SubExp
shp VName
mem IxFun
ixf) FreeVarSubsts
subs) ->
      IxFun
ixf
        forall a b. a -> (a -> b) -> b
& forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun FreeVarSubsts
subs)
        forall a b. a -> (a -> b) -> b
& VName -> IxFun -> MemBind
ArrayIn VName
mem
        forall a b. a -> (a -> b) -> b
& forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shp u
u
        forall a b. a -> (a -> b) -> b
& VName -> MemBound u -> a
f VName
vname
        forall a b. a -> (a -> b) -> b
& forall a. a -> Maybe a
Just
        forall a b. a -> (a -> b) -> b
& forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Maybe Coalesced
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing