{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | This module implements an optimization that tries to statically reuse
-- kernel-level allocations. The goal is to lower the static memory usage, which
-- might allow more programs to run using intra-group parallelism.
module Futhark.Optimise.MemoryBlockMerging (optimise) where

import Control.Exception
import Control.Monad.State.Strict
import Data.Function ((&))
import Data.Map (Map, (!))
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as S
import qualified Futhark.Analysis.Interference as Interference
import Futhark.Builder.Class
import Futhark.Construct
import Futhark.IR.GPUMem
import qualified Futhark.Optimise.MemoryBlockMerging.GreedyColoring as GreedyColoring
import Futhark.Pass (Pass (..), PassM)
import qualified Futhark.Pass as Pass
import Futhark.Util (invertMap)

-- | A mapping from allocation names to their size and space.
type Allocs = Map VName (SubExp, Space)

getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm :: Stm GPUMem -> Allocs
getAllocsStm (Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
se Space
sp))) =
  forall k a. k -> a -> Map k a
M.singleton VName
name (SubExp
se, Space
sp)
getAllocsStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
getAllocsStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms) forall a b. (a -> b) -> a -> b
$ Body GPUMem
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPUMem)]
cases
getAllocsStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ Body GPUMem
body)) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
getAllocsStm Stm GPUMem
_ = forall a. Monoid a => a
mempty

getAllocsSegOp :: SegOp lvl GPUMem -> Allocs
getAllocsSegOp :: forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp (SegMap lvl
_ SegSpace
_ [Type]
_ KernelBody GPUMem
body) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegRed lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegScan lvl
_ SegSpace
_ [SegBinOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)
getAllocsSegOp (SegHist lvl
_ SegSpace
_ [HistOp GPUMem]
_ [Type]
_ KernelBody GPUMem
body) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> Allocs
getAllocsStm (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body)

setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm :: Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let (Pat [PatElem VName
name LetDec GPUMem
_]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_)))
  | Just SubExp
s <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName SubExp
m =
      Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
s}
setAllocsStm Map VName SubExp
_ stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
_ Space
_))) = Stm GPUMem
stm
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
segop)))) =
  Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m SegOp SegLevel GPUMem
segop}
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
dec)) =
  Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Body GPUMem -> Body GPUMem
onBody) [Case (Body GPUMem)]
cases) (Body GPUMem -> Body GPUMem
onBody Body GPUMem
defbody) MatchDec (BranchType GPUMem)
dec}
  where
    onBody :: Body GPUMem -> Body GPUMem
onBody (Body () Stms GPUMem
stms Result
res) = forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
stms) Result
res
setAllocsStm Map VName SubExp
m stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form Body GPUMem
body)) =
  Stm GPUMem
stm
    { stmExp :: Exp GPUMem
stmExp =
        forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form (Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body})
    }
setAllocsStm Map VName SubExp
_ Stm GPUMem
stm = Stm GPUMem
stm

setAllocsSegOp ::
  Map VName SubExp ->
  SegOp lvl GPUMem ->
  SegOp lvl GPUMem
setAllocsSegOp :: forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp Map VName SubExp
m (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body) =
  forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
segbinops [Type]
tps forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
setAllocsSegOp Map VName SubExp
m (SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
segbinops [Type]
tps KernelBody GPUMem
body) =
  forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
segbinops [Type]
tps forall a b. (a -> b) -> a -> b
$
    KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Map VName SubExp -> Stm GPUMem -> Stm GPUMem
setAllocsStm Map VName SubExp
m forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}

maxSubExp :: MonadBuilder m => Set SubExp -> m SubExp
maxSubExp :: forall (m :: * -> *). MonadBuilder m => Set SubExp -> m SubExp
maxSubExp = forall {m :: * -> *}. MonadBuilder m => [SubExp] -> m SubExp
helper forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList
  where
    helper :: [SubExp] -> m SubExp
helper (SubExp
s1 : SubExp
s2 : [SubExp]
sexps) = do
      SubExp
z <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"maxSubHelper" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
UMax IntType
Int64) SubExp
s1 SubExp
s2
      [SubExp] -> m SubExp
helper (SubExp
z forall a. a -> [a] -> [a]
: [SubExp]
sexps)
    helper [SubExp
s] =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
    helper [] = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

isKernelInvariant :: Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant :: forall space. Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant Scope GPUMem
scope (Var VName
vname, space
_) = VName
vname forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPUMem
scope
isKernelInvariant Scope GPUMem
_ (SubExp, space)
_ = Bool
True

isScalarSpace :: (subExp, Space) -> Bool
isScalarSpace :: forall subExp. (subExp, Space) -> Bool
isScalarSpace (subExp
_, ScalarSpace [SubExp]
_ PrimType
_) = Bool
True
isScalarSpace (subExp, Space)
_ = Bool
False

onKernelBodyStms ::
  MonadBuilder m =>
  SegOp lvl GPUMem ->
  (Stms GPUMem -> m (Stms GPUMem)) ->
  m (SegOp lvl GPUMem)
onKernelBodyStms :: forall (m :: * -> *) lvl.
MonadBuilder m =>
SegOp lvl GPUMem
-> (Stms GPUMem -> m (Stms GPUMem)) -> m (SegOp lvl GPUMem)
onKernelBodyStms (SegMap lvl
lvl SegSpace
space [Type]
ts KernelBody GPUMem
body) Stms GPUMem -> m (Stms GPUMem)
f = do
  Stms GPUMem
stms <- Stms GPUMem -> m (Stms GPUMem)
f forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
space [Type]
ts forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}
onKernelBodyStms (SegRed lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Stms GPUMem -> m (Stms GPUMem)
f = do
  Stms GPUMem
stms <- Stms GPUMem -> m (Stms GPUMem)
f forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}
onKernelBodyStms (SegScan lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Stms GPUMem -> m (Stms GPUMem)
f = do
  Stms GPUMem
stms <- Stms GPUMem -> m (Stms GPUMem)
f forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
space [SegBinOp GPUMem]
binops [Type]
ts forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}
onKernelBodyStms (SegHist lvl
lvl SegSpace
space [HistOp GPUMem]
binops [Type]
ts KernelBody GPUMem
body) Stms GPUMem -> m (Stms GPUMem)
f = do
  Stms GPUMem
stms <- Stms GPUMem -> m (Stms GPUMem)
f forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
space [HistOp GPUMem]
binops [Type]
ts forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}

-- | This is the actual optimiser. Given an interference graph and a @SegOp@,
-- replace allocations and references to memory blocks inside with a (hopefully)
-- reduced number of allocations.
optimiseKernel ::
  (MonadBuilder m, Rep m ~ GPUMem) =>
  Interference.Graph VName ->
  SegOp lvl GPUMem ->
  m (SegOp lvl GPUMem)
optimiseKernel :: forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph SegOp lvl GPUMem
segop0 = do
  SegOp lvl GPUMem
segop <- forall (m :: * -> *) lvl.
MonadBuilder m =>
SegOp lvl GPUMem
-> (Stms GPUMem -> m (Stms GPUMem)) -> m (SegOp lvl GPUMem)
onKernelBodyStms SegOp lvl GPUMem
segop0 forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Stms GPUMem -> m (Stms GPUMem)
onKernels forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph
  Scope GPUMem
scope_here <- forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let allocs :: Allocs
allocs =
        forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (\(SubExp, Space)
alloc -> forall space. Scope GPUMem -> (SubExp, space) -> Bool
isKernelInvariant Scope GPUMem
scope_here (SubExp, Space)
alloc Bool -> Bool -> Bool
&& Bool -> Bool
not (forall subExp. (subExp, Space) -> Bool
isScalarSpace (SubExp, Space)
alloc)) forall a b. (a -> b) -> a -> b
$
          forall lvl. SegOp lvl GPUMem -> Allocs
getAllocsSegOp SegOp lvl GPUMem
segop
      (Map Int Space
colorspaces, Coloring VName
coloring) =
        forall a space.
(Ord a, Ord space) =>
Map a space -> Graph a -> (Map Int space, Coloring a)
GreedyColoring.colorGraph
          (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd Allocs
allocs)
          Graph VName
graph
  ([SubExp]
maxes, Stms GPUMem
maxstms) <-
    forall v k. (Ord v, Ord k) => Map k v -> Map v (Set k)
invertMap Coloring VName
coloring
      forall a b. a -> (a -> b) -> b
& forall k a. Map k a -> [a]
M.elems
      forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *). MonadBuilder m => Set SubExp -> m SubExp
maxSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Allocs
allocs forall k a. Ord k => Map k a -> k -> a
!)))
      forall a b. a -> (a -> b) -> b
& forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
  ([SubExp]
colors, Stms GPUMem
stms) <-
    forall a. HasCallStack => Bool -> a -> a
assert (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
maxes forall a. Eq a => a -> a -> Bool
== forall k a. Map k a -> Int
M.size Map Int Space
colorspaces) [SubExp]
maxes
      forall a b. a -> (a -> b) -> b
& forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..]
      forall a b. a -> (a -> b) -> b
& forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Int
i, SubExp
x) -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"color" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
x forall a b. (a -> b) -> a -> b
$ Map Int Space
colorspaces forall k a. Ord k => Map k a -> k -> a
! Int
i)
      forall a b. a -> (a -> b) -> b
& forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
  let segop' :: SegOp lvl GPUMem
segop' = forall lvl.
Map VName SubExp -> SegOp lvl GPUMem -> SegOp lvl GPUMem
setAllocsSegOp (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([SubExp]
colors forall a. [a] -> Int -> a
!!) Coloring VName
coloring) SegOp lvl GPUMem
segop
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case SegOp lvl GPUMem
segop' of
    SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody GPUMem
body ->
      forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
maxstms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
stms forall a. Semigroup a => a -> a -> a
<> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
maxstms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
stms forall a. Semigroup a => a -> a -> a
<> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp GPUMem]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
maxstms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
stms forall a. Semigroup a => a -> a -> a
<> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}
    SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
binops [Type]
tps KernelBody GPUMem
body ->
      forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp GPUMem]
binops [Type]
tps forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem
body {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
maxstms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
stms forall a. Semigroup a => a -> a -> a
<> forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body}

-- | Helper function that modifies kernels found inside some statements.
onKernels ::
  LocalScope GPUMem m =>
  (SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)) ->
  Stms GPUMem ->
  m (Stms GPUMem)
onKernels :: forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Stms GPUMem -> m (Stms GPUMem)
onKernels SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f Stms GPUMem
orig_stms = forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
orig_stms forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> m (Stm GPUMem)
helper Stms GPUMem
orig_stms
  where
    helper :: Stm GPUMem -> m (Stm GPUMem)
helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Op (Inner (SegOp SegOp SegLevel GPUMem
segop))} = do
      SegOp SegLevel GPUMem
exp' <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f SegOp SegLevel GPUMem
segop
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPUMem
exp'}
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = Match [SubExp]
c [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
dec} = do
      [Case (Body GPUMem)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Body GPUMem -> m (Body GPUMem)
onBody) [Case (Body GPUMem)]
cases
      Body GPUMem
defbody' <- Body GPUMem -> m (Body GPUMem)
onBody Body GPUMem
defbody
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
c [Case (Body GPUMem)]
cases' Body GPUMem
defbody' MatchDec (BranchType GPUMem)
dec}
      where
        onBody :: Body GPUMem -> m (Body GPUMem)
onBody (Body () Stms GPUMem
stms Result
res) =
          forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Stms GPUMem -> m (Stms GPUMem)
`onKernels` Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
    helper stm :: Stm GPUMem
stm@Let {stmExp :: forall rep. Stm rep -> Exp rep
stmExp = DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form Body GPUMem
body} = do
      Stms GPUMem
body_stms <- SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem)
f forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Stms GPUMem -> m (Stms GPUMem)
`onKernels` forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form (Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
body_stms})}
    helper Stm GPUMem
stm = forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPUMem
stm

-- | Perform the reuse-allocations optimization.
optimise :: Pass GPUMem GPUMem
optimise :: Pass GPUMem GPUMem
optimise =
  forall fromrep torep.
[Char]
-> [Char]
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass [Char]
"memory block merging" [Char]
"memory block merging allocations" forall a b. (a -> b) -> a -> b
$ \Prog GPUMem
prog ->
    let graph :: Graph VName
graph = Prog GPUMem -> Graph VName
Interference.analyseProgGPU Prog GPUMem
prog
     in forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
Pass.intraproceduralTransformation (Graph VName -> Scope GPUMem -> Stms GPUMem -> PassM (Stms GPUMem)
onStms Graph VName
graph) Prog GPUMem
prog
  where
    onStms ::
      Interference.Graph VName ->
      Scope GPUMem ->
      Stms GPUMem ->
      PassM (Stms GPUMem)
    onStms :: Graph VName -> Scope GPUMem -> Stms GPUMem -> PassM (Stms GPUMem)
onStms Graph VName
graph Scope GPUMem
scope Stms GPUMem
stms = do
      let m :: BuilderT GPUMem (StateT VNameSource Identity) (Stms GPUMem)
m = forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) lvl.
(MonadBuilder m, Rep m ~ GPUMem) =>
Graph VName -> SegOp lvl GPUMem -> m (SegOp lvl GPUMem)
optimiseKernel Graph VName
graph forall (m :: * -> *).
LocalScope GPUMem m =>
(SegOp SegLevel GPUMem -> m (SegOp SegLevel GPUMem))
-> Stms GPUMem -> m (Stms GPUMem)
`onKernels` Stms GPUMem
stms
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> (a, s)
runState (forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPUMem (StateT VNameSource Identity) (Stms GPUMem)
m forall a. Monoid a => a
mempty)