{-# LANGUAGE TypeFamilies #-}

-- | This pass attempts to lower allocations as far towards the bottom of their
-- body as possible.
module Futhark.Pass.LowerAllocations
  ( lowerAllocationsSeqMem,
    lowerAllocationsGPUMem,
    lowerAllocationsMCMem,
  )
where

import Control.Monad.Reader
import Data.Function ((&))
import Data.Map qualified as M
import Data.Sequence (Seq (..))
import Data.Sequence qualified as Seq
import Futhark.IR.GPUMem
import Futhark.IR.MCMem
import Futhark.IR.SeqMem
import Futhark.Pass (Pass (..))

lowerInProg ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  (inner rep -> LowerM (inner rep) (inner rep)) ->
  Prog rep ->
  Prog rep
lowerInProg :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
(inner rep -> LowerM (inner rep) (inner rep))
-> Prog rep -> Prog rep
lowerInProg inner rep -> LowerM (inner rep) (inner rep)
onOp Prog rep
prog =
  Prog rep
prog {progFuns = fmap onFun (progFuns prog)}
  where
    onFun :: FunDef rep -> FunDef rep
onFun FunDef rep
f = FunDef rep
f {funDefBody = onBody (funDefBody f)}
    onBody :: Body rep -> Body rep
onBody Body rep
body = Reader (Env (inner rep)) (Body rep) -> Env (inner rep) -> Body rep
forall r a. Reader r a -> r -> a
runReader (Body rep -> Reader (Env (inner rep)) (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> LowerM (inner rep) (Body rep)
lowerAllocationsInBody Body rep
body) ((inner rep -> LowerM (inner rep) (inner rep)) -> Env (inner rep)
forall inner. (inner -> LowerM inner inner) -> Env inner
Env inner rep -> LowerM (inner rep) (inner rep)
onOp)

lowerAllocationsSeqMem :: Pass SeqMem SeqMem
lowerAllocationsSeqMem :: Pass SeqMem SeqMem
lowerAllocationsSeqMem =
  String
-> String
-> (Prog SeqMem -> PassM (Prog SeqMem))
-> Pass SeqMem SeqMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"lower allocations" String
"lower allocations" ((Prog SeqMem -> PassM (Prog SeqMem)) -> Pass SeqMem SeqMem)
-> (Prog SeqMem -> PassM (Prog SeqMem)) -> Pass SeqMem SeqMem
forall a b. (a -> b) -> a -> b
$
    Prog SeqMem -> PassM (Prog SeqMem)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog SeqMem -> PassM (Prog SeqMem))
-> (Prog SeqMem -> Prog SeqMem)
-> Prog SeqMem
-> PassM (Prog SeqMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (NoOp SeqMem -> LowerM (NoOp SeqMem) (NoOp SeqMem))
-> Prog SeqMem -> Prog SeqMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
(inner rep -> LowerM (inner rep) (inner rep))
-> Prog rep -> Prog rep
lowerInProg NoOp SeqMem -> LowerM (NoOp SeqMem) (NoOp SeqMem)
forall a. a -> ReaderT (Env (NoOp SeqMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

lowerAllocationsGPUMem :: Pass GPUMem GPUMem
lowerAllocationsGPUMem :: Pass GPUMem GPUMem
lowerAllocationsGPUMem =
  String
-> String
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"lower allocations gpu" String
"lower allocations gpu" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
    Prog GPUMem -> PassM (Prog GPUMem)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog GPUMem -> PassM (Prog GPUMem))
-> (Prog GPUMem -> Prog GPUMem)
-> Prog GPUMem
-> PassM (Prog GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HostOp NoOp GPUMem
 -> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem))
-> Prog GPUMem -> Prog GPUMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
(inner rep -> LowerM (inner rep) (inner rep))
-> Prog rep -> Prog rep
lowerInProg HostOp NoOp GPUMem
-> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
lowerAllocationsInHostOp

lowerAllocationsMCMem :: Pass MCMem MCMem
lowerAllocationsMCMem :: Pass MCMem MCMem
lowerAllocationsMCMem =
  String
-> String -> (Prog MCMem -> PassM (Prog MCMem)) -> Pass MCMem MCMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"lower allocations mc" String
"lower allocations mc" ((Prog MCMem -> PassM (Prog MCMem)) -> Pass MCMem MCMem)
-> (Prog MCMem -> PassM (Prog MCMem)) -> Pass MCMem MCMem
forall a b. (a -> b) -> a -> b
$
    Prog MCMem -> PassM (Prog MCMem)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog MCMem -> PassM (Prog MCMem))
-> (Prog MCMem -> Prog MCMem) -> Prog MCMem -> PassM (Prog MCMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem))
-> Prog MCMem -> Prog MCMem
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
(inner rep -> LowerM (inner rep) (inner rep))
-> Prog rep -> Prog rep
lowerInProg MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
lowerAllocationsInMCOp

newtype Env inner = Env
  {forall inner. Env inner -> inner -> LowerM inner inner
onInner :: inner -> LowerM inner inner}

type LowerM inner a = Reader (Env inner) a

lowerAllocationsInBody ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  Body rep ->
  LowerM (inner rep) (Body rep)
lowerAllocationsInBody :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> LowerM (inner rep) (Body rep)
lowerAllocationsInBody Body rep
body = do
  Stms rep
stms <- Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body) Map VName (Stm rep)
forall a. Monoid a => a
mempty Stms rep
forall a. Monoid a => a
mempty
  Body rep -> LowerM (inner rep) (Body rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> LowerM (inner rep) (Body rep))
-> Body rep -> LowerM (inner rep) (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms = stms}

lowerAllocationsInStms ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  -- | The input stms
  Stms rep ->
  -- | The allocations currently being lowered
  M.Map VName (Stm rep) ->
  -- | The other statements processed so far
  Stms rep ->
  LowerM (inner rep) (Stms rep)
lowerAllocationsInStms :: forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms Stms rep
Empty Map VName (Stm rep)
allocs Stms rep
acc = Stms rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> ReaderT (Env (inner rep)) Identity (Stms rep))
-> Stms rep -> ReaderT (Env (inner rep)) Identity (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep
acc Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> [Stm rep] -> Stms rep
forall a. [a] -> Seq a
Seq.fromList (Map VName (Stm rep) -> [Stm rep]
forall k a. Map k a -> [a]
M.elems Map VName (Stm rep)
allocs)
lowerAllocationsInStms (stm :: Stm rep
stm@(Let (Pat [PatElem VName
vname LetDec rep
_]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ Space
_))) :<| Stms rep
stms) Map VName (Stm rep)
allocs Stms rep
acc =
  Stms rep
-> Map VName (Stm rep)
-> Stms rep
-> ReaderT (Env (inner rep)) Identity (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms Stms rep
stms (VName -> Stm rep -> Map VName (Stm rep) -> Map VName (Stm rep)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
vname Stm rep
stm Map VName (Stm rep)
allocs) Stms rep
acc
lowerAllocationsInStms (stm0 :: Stm rep
stm0@(Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Inner inner rep
inner))) :<| Stms rep
stms) Map VName (Stm rep)
alloc0 Stms rep
acc0 = do
  inner rep -> LowerM (inner rep) (inner rep)
on_inner <- (Env (inner rep) -> inner rep -> LowerM (inner rep) (inner rep))
-> ReaderT
     (Env (inner rep))
     Identity
     (inner rep -> LowerM (inner rep) (inner rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env (inner rep) -> inner rep -> LowerM (inner rep) (inner rep)
forall inner. Env inner -> inner -> LowerM inner inner
onInner
  inner rep
inner' <- inner rep -> LowerM (inner rep) (inner rep)
on_inner inner rep
inner
  let stm :: Stm rep
stm = Stm rep
stm0 {stmExp = Op $ Inner inner'}
      (Map VName (Stm rep)
alloc, Stms rep
acc) = Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
forall rep.
Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
insertLoweredAllocs (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm0) Map VName (Stm rep)
alloc0 Stms rep
acc0
  Stms rep
-> Map VName (Stm rep)
-> Stms rep
-> ReaderT (Env (inner rep)) Identity (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms Stms rep
stms Map VName (Stm rep)
alloc (Stms rep
acc Stms rep -> Stm rep -> Stms rep
forall a. Seq a -> a -> Seq a
:|> Stm rep
stm)
lowerAllocationsInStms (stm :: Stm rep
stm@(Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Match [SubExp]
cond_ses [Case (Body rep)]
cases Body rep
body MatchDec (BranchType rep)
dec)) :<| Stms rep
stms) Map VName (Stm rep)
alloc Stms rep
acc = do
  [Case (Body rep)]
cases' <- (Case (Body rep)
 -> ReaderT (Env (inner rep)) Identity (Case (Body rep)))
-> [Case (Body rep)]
-> ReaderT (Env (inner rep)) Identity [Case (Body rep)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(Case [Maybe PrimValue]
pat Body rep
b) -> [Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
pat (Body rep -> Case (Body rep))
-> ReaderT (Env (inner rep)) Identity (Body rep)
-> ReaderT (Env (inner rep)) Identity (Case (Body rep))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body rep -> ReaderT (Env (inner rep)) Identity (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> LowerM (inner rep) (Body rep)
lowerAllocationsInBody Body rep
b) [Case (Body rep)]
cases
  Body rep
body' <- Body rep -> ReaderT (Env (inner rep)) Identity (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> LowerM (inner rep) (Body rep)
lowerAllocationsInBody Body rep
body
  let stm' :: Stm rep
stm' = Stm rep
stm {stmExp = Match cond_ses cases' body' dec}
      (Map VName (Stm rep)
alloc', Stms rep
acc') = Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
forall rep.
Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
insertLoweredAllocs (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm) Map VName (Stm rep)
alloc Stms rep
acc
  Stms rep
-> Map VName (Stm rep)
-> Stms rep
-> ReaderT (Env (inner rep)) Identity (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms Stms rep
stms Map VName (Stm rep)
alloc' (Stms rep
acc' Stms rep -> Stm rep -> Stms rep
forall a. Seq a -> a -> Seq a
:|> Stm rep
stm')
lowerAllocationsInStms (stm :: Stm rep
stm@(Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Loop [(FParam rep, SubExp)]
params LoopForm
form Body rep
body)) :<| Stms rep
stms) Map VName (Stm rep)
alloc Stms rep
acc = do
  Body rep
body' <- Body rep -> ReaderT (Env (inner rep)) Identity (Body rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Body rep -> LowerM (inner rep) (Body rep)
lowerAllocationsInBody Body rep
body
  let stm' :: Stm rep
stm' = Stm rep
stm {stmExp = Loop params form body'}
      (Map VName (Stm rep)
alloc', Stms rep
acc') = Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
forall rep.
Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
insertLoweredAllocs (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm) Map VName (Stm rep)
alloc Stms rep
acc
  Stms rep
-> Map VName (Stm rep)
-> Stms rep
-> ReaderT (Env (inner rep)) Identity (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms Stms rep
stms Map VName (Stm rep)
alloc' (Stms rep
acc' Stms rep -> Stm rep -> Stms rep
forall a. Seq a -> a -> Seq a
:|> Stm rep
stm')
lowerAllocationsInStms (Stm rep
stm :<| Stms rep
stms) Map VName (Stm rep)
alloc Stms rep
acc = do
  let (Map VName (Stm rep)
alloc', Stms rep
acc') = Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
forall rep.
Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
insertLoweredAllocs (Stm rep -> Names
forall a. FreeIn a => a -> Names
freeIn Stm rep
stm) Map VName (Stm rep)
alloc Stms rep
acc
  Stms rep
-> Map VName (Stm rep)
-> Stms rep
-> ReaderT (Env (inner rep)) Identity (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms Stms rep
stms Map VName (Stm rep)
alloc' (Stms rep
acc' Stms rep -> Stm rep -> Stms rep
forall a. Seq a -> a -> Seq a
:|> Stm rep
stm)

insertLoweredAllocs :: Names -> M.Map VName (Stm rep) -> Stms rep -> (M.Map VName (Stm rep), Stms rep)
insertLoweredAllocs :: forall rep.
Names
-> Map VName (Stm rep)
-> Stms rep
-> (Map VName (Stm rep), Stms rep)
insertLoweredAllocs Names
frees Map VName (Stm rep)
alloc Stms rep
acc =
  Names
frees
    Names -> Names -> Names
`namesIntersection` [VName] -> Names
namesFromList (Map VName (Stm rep) -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName (Stm rep)
alloc)
    Names -> (Names -> [VName]) -> [VName]
forall a b. a -> (a -> b) -> b
& Names -> [VName]
namesToList
    [VName]
-> ([VName] -> (Map VName (Stm rep), Stms rep))
-> (Map VName (Stm rep), Stms rep)
forall a b. a -> (a -> b) -> b
& ((Map VName (Stm rep), Stms rep)
 -> VName -> (Map VName (Stm rep), Stms rep))
-> (Map VName (Stm rep), Stms rep)
-> [VName]
-> (Map VName (Stm rep), Stms rep)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
      ( \(Map VName (Stm rep)
alloc', Stms rep
acc') VName
name ->
          ( VName -> Map VName (Stm rep) -> Map VName (Stm rep)
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
name Map VName (Stm rep)
alloc',
            Stms rep
acc' Stms rep -> Stm rep -> Stms rep
forall a. Seq a -> a -> Seq a
:|> Map VName (Stm rep)
alloc' Map VName (Stm rep) -> VName -> Stm rep
forall k a. Ord k => Map k a -> k -> a
M.! VName
name
          )
      )
      (Map VName (Stm rep)
alloc, Stms rep
acc)

lowerAllocationsInSegOp ::
  (Mem rep inner, LetDec rep ~ LetDecMem) =>
  SegOp lvl rep ->
  LowerM (inner rep) (SegOp lvl rep)
lowerAllocationsInSegOp :: forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
lowerAllocationsInSegOp (SegMap lvl
lvl SegSpace
sp [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms (KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) Map VName (Stm rep)
forall a. Monoid a => a
mempty Stms rep
forall a. Monoid a => a
mempty
  SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap lvl
lvl SegSpace
sp [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms = stms}
lowerAllocationsInSegOp (SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms (KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) Map VName (Stm rep)
forall a. Monoid a => a
mempty Stms rep
forall a. Monoid a => a
mempty
  SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms = stms}
lowerAllocationsInSegOp (SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms (KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) Map VName (Stm rep)
forall a. Monoid a => a
mempty Stms rep
forall a. Monoid a => a
mempty
  SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan lvl
lvl SegSpace
sp [SegBinOp rep]
binops [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms = stms}
lowerAllocationsInSegOp (SegHist lvl
lvl SegSpace
sp [HistOp rep]
histops [Type]
tps KernelBody rep
body) = do
  Stms rep
stms <- Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
forall rep (inner :: * -> *).
(Mem rep inner, LetDec rep ~ LetDecMem) =>
Stms rep
-> Map VName (Stm rep) -> Stms rep -> LowerM (inner rep) (Stms rep)
lowerAllocationsInStms (KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
body) Map VName (Stm rep)
forall a. Monoid a => a
mempty Stms rep
forall a. Monoid a => a
mempty
  SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a. a -> ReaderT (Env (inner rep)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep))
-> SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
forall a b. (a -> b) -> a -> b
$ lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist lvl
lvl SegSpace
sp [HistOp rep]
histops [Type]
tps (KernelBody rep -> SegOp lvl rep)
-> KernelBody rep -> SegOp lvl rep
forall a b. (a -> b) -> a -> b
$ KernelBody rep
body {kernelBodyStms = stms}

lowerAllocationsInHostOp :: HostOp NoOp GPUMem -> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
lowerAllocationsInHostOp :: HostOp NoOp GPUMem
-> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
lowerAllocationsInHostOp (SegOp SegOp SegLevel GPUMem
op) = SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> ReaderT
     (Env (HostOp NoOp GPUMem)) Identity (SegOp SegLevel GPUMem)
-> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOp SegLevel GPUMem
-> ReaderT
     (Env (HostOp NoOp GPUMem)) Identity (SegOp SegLevel GPUMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
lowerAllocationsInSegOp SegOp SegLevel GPUMem
op
lowerAllocationsInHostOp HostOp NoOp GPUMem
op = HostOp NoOp GPUMem
-> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem)
forall a. a -> ReaderT (Env (HostOp NoOp GPUMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HostOp NoOp GPUMem
op

lowerAllocationsInMCOp :: MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
lowerAllocationsInMCOp :: MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
lowerAllocationsInMCOp (ParOp Maybe (SegOp () MCMem)
par SegOp () MCMem
op) =
  Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem)
-> ReaderT
     (Env (MCOp NoOp MCMem)) Identity (Maybe (SegOp () MCMem))
-> ReaderT
     (Env (MCOp NoOp MCMem))
     Identity
     (SegOp () MCMem -> MCOp NoOp MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem
 -> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> ReaderT
     (Env (MCOp NoOp MCMem)) Identity (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse SegOp () MCMem
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
lowerAllocationsInSegOp Maybe (SegOp () MCMem)
par ReaderT
  (Env (MCOp NoOp MCMem))
  Identity
  (SegOp () MCMem -> MCOp NoOp MCMem)
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
-> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
forall a b.
ReaderT (Env (MCOp NoOp MCMem)) Identity (a -> b)
-> ReaderT (Env (MCOp NoOp MCMem)) Identity a
-> ReaderT (Env (MCOp NoOp MCMem)) Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOp () MCMem
-> ReaderT (Env (MCOp NoOp MCMem)) Identity (SegOp () MCMem)
forall rep (inner :: * -> *) lvl.
(Mem rep inner, LetDec rep ~ LetDecMem) =>
SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep)
lowerAllocationsInSegOp SegOp () MCMem
op
lowerAllocationsInMCOp MCOp NoOp MCMem
op = MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem)
forall a. a -> ReaderT (Env (MCOp NoOp MCMem)) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MCOp NoOp MCMem
op