{-# LANGUAGE TypeFamilies #-}

-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This pass tries to rewrite loops with memory parameters.
-- Specifically, it takes loops of this form:
--
-- @
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...) -- stores A_out
--   in {..., A_out_mem, ..., A_out, ...}
-- }
-- @
--
-- and turns them into
--
-- @
-- let A_in_mem = alloc(...)
-- let A_out_mem = alloc(...)
-- let A_in = copy A -- in A_in_mem
-- loop {..., A_in_mem, A_out_mem, ..., A=A_in, ...} ... do {
--   ...
--   in {..., A_out_mem, A_mem, ..., A_out, ...}
-- }
-- @
--
-- The result is essentially "pointer swapping" between the two memory
-- initial blocks @A_mem@ and @A_out_mem@.  The invariant is that the
-- array is always stored in the "first" memory block at the beginning
-- of the loop (and also in the final result).  We do need to add an
-- extra element to the pattern, however.  The initial copy of @A@
-- could be elided if @A@ is unique (thus @A_in_mem=A_mem@).  This is
-- because only then is it safe to use @A_mem@ to store loop results.
-- We don't currently do this.
--
-- Unfortunately, not all loops fit the pattern above.  In particular,
-- a nested loop that has been transformed as such does not!
-- Therefore we also have another double buffering strategy, that
-- turns
--
-- @
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...)
--   -- A in A_out_mem
--   in {..., A_out_mem, ..., A, ...}
-- }
-- @
--
-- into
--
-- @
-- let A_res_mem = alloc(...)
-- loop {..., A_mem, ..., A, ...} ... do {
--   ...
--   let A_out_mem = alloc(...)
--   -- A in A_out_mem
--   let A' = copy A
--   -- A' in A_res_mem
--   in {..., A_res_mem, ..., A, ...}
-- }
-- @
--
-- The allocation of A_out_mem can then be hoisted out because it is
-- dead at the end of the loop.  This always works as long as
-- A_out_mem has a loop-invariant allocation size, but requires a copy
-- per iteration (and an initial one, elided above).
module Futhark.Optimise.DoubleBuffer (doubleBufferGPU, doubleBufferMC) where

import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Transform.Substitute
import Futhark.Util (mapAccumLM)

type OptimiseLoop rep =
  Pat (LetDec rep) ->
  [(FParam rep, SubExp)] ->
  Body rep ->
  DoubleBufferM
    rep
    ( Stms rep,
      Pat (LetDec rep),
      [(FParam rep, SubExp)],
      Body rep
    )

type OptimiseOp rep =
  Op rep -> DoubleBufferM rep (Op rep)

data Env rep = Env
  { forall rep. Env rep -> Scope rep
envScope :: Scope rep,
    forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop :: OptimiseLoop rep,
    forall rep. Env rep -> OptimiseOp rep
envOptimiseOp :: OptimiseOp rep
  }

newtype DoubleBufferM rep a = DoubleBufferM
  { forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM :: ReaderT (Env rep) (State VNameSource) a
  }
  deriving ((forall a b.
 (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Functor (DoubleBufferM rep)
forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
fmap :: forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
$c<$ :: forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
<$ :: forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
Functor, Functor (DoubleBufferM rep)
Functor (DoubleBufferM rep) =>
(forall a. a -> DoubleBufferM rep a)
-> (forall a b.
    DoubleBufferM rep (a -> b)
    -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> DoubleBufferM rep a
    -> DoubleBufferM rep b
    -> DoubleBufferM rep c)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Applicative (DoubleBufferM rep)
forall rep. Functor (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep a. a -> DoubleBufferM rep a
pure :: forall a. a -> DoubleBufferM rep a
$c<*> :: forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
<*> :: forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
liftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
$c*> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
*> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c<* :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
<* :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
Applicative, Applicative (DoubleBufferM rep)
Applicative (DoubleBufferM rep) =>
(forall a b.
 DoubleBufferM rep a
 -> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a. a -> DoubleBufferM rep a)
-> Monad (DoubleBufferM rep)
forall rep. Applicative (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
>>= :: forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
$c>> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
>> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$creturn :: forall rep a. a -> DoubleBufferM rep a
return :: forall a. a -> DoubleBufferM rep a
Monad, MonadReader (Env rep), Monad (DoubleBufferM rep)
DoubleBufferM rep VNameSource
Monad (DoubleBufferM rep) =>
DoubleBufferM rep VNameSource
-> (VNameSource -> DoubleBufferM rep ())
-> MonadFreshNames (DoubleBufferM rep)
VNameSource -> DoubleBufferM rep ()
forall rep. Monad (DoubleBufferM rep)
forall rep. DoubleBufferM rep VNameSource
forall rep. VNameSource -> DoubleBufferM rep ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: forall rep. DoubleBufferM rep VNameSource
getNameSource :: DoubleBufferM rep VNameSource
$cputNameSource :: forall rep. VNameSource -> DoubleBufferM rep ()
putNameSource :: VNameSource -> DoubleBufferM rep ()
MonadFreshNames)

instance (ASTRep rep) => HasScope rep (DoubleBufferM rep) where
  askScope :: DoubleBufferM rep (Scope rep)
askScope = (Env rep -> Scope rep) -> DoubleBufferM rep (Scope rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Scope rep
forall rep. Env rep -> Scope rep
envScope

instance (ASTRep rep) => LocalScope rep (DoubleBufferM rep) where
  localScope :: forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
localScope Scope rep
scope = (Env rep -> Env rep) -> DoubleBufferM rep a -> DoubleBufferM rep a
forall a.
(Env rep -> Env rep) -> DoubleBufferM rep a -> DoubleBufferM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep -> Env rep)
 -> DoubleBufferM rep a -> DoubleBufferM rep a)
-> (Env rep -> Env rep)
-> DoubleBufferM rep a
-> DoubleBufferM rep a
forall a b. (a -> b) -> a -> b
$ \Env rep
env -> Env rep
env {envScope = envScope env <> scope}

optimiseBody :: (ASTRep rep) => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody :: forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body = do
  Stms rep
stms' <- [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms Body rep
body
  Body rep -> DoubleBufferM rep (Body rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms = stms'}

optimiseStms :: (ASTRep rep) => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms :: forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [] = Stms rep -> DoubleBufferM rep (Stms rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms rep
forall a. Monoid a => a
mempty
optimiseStms (Stm rep
e : [Stm rep]
es) = do
  Stms rep
e_es <- Stm rep -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm Stm rep
e
  Stms rep
es' <- Scope rep
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ Stms rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
e_es) (DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms [Stm rep]
es
  Stms rep -> DoubleBufferM rep (Stms rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DoubleBufferM rep (Stms rep))
-> Stms rep -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep
e_es Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
es'

optimiseStm :: forall rep. (ASTRep rep) => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep (Stms rep)
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Loop [(FParam rep, SubExp)]
merge LoopForm
form Body rep
body)) = do
  Body rep
body' <-
    Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope rep
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> [FParam rep] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge)) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$
      Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body
  Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
opt_loop <- (Env rep
 -> Pat (LetDec rep)
 -> [(FParam rep, SubExp)]
 -> Body rep
 -> DoubleBufferM
      rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep))
-> DoubleBufferM
     rep
     (Pat (LetDec rep)
      -> [(FParam rep, SubExp)]
      -> Body rep
      -> DoubleBufferM
           rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep
-> Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop
  (Stms rep
stms, Pat (LetDec rep)
pat', [(FParam rep, SubExp)]
merge', Body rep
body'') <- Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Body rep
-> DoubleBufferM
     rep (Stms rep, Pat (LetDec rep), [(FParam rep, SubExp)], Body rep)
opt_loop Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Body rep
body'
  Stms rep -> DoubleBufferM rep (Stms rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep -> DoubleBufferM rep (Stms rep))
-> Stms rep -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep
stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat' StmAux (ExpDec rep)
aux (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam rep, SubExp)]
merge' LoopForm
form Body rep
body'')
optimiseStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e) = do
  Op rep -> DoubleBufferM rep (Op rep)
onOp <- (Env rep -> Op rep -> DoubleBufferM rep (Op rep))
-> DoubleBufferM rep (Op rep -> DoubleBufferM rep (Op rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Op rep -> DoubleBufferM rep (Op rep)
forall rep. Env rep -> OptimiseOp rep
envOptimiseOp
  Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm (Stm rep -> Stms rep)
-> (Exp rep -> Stm rep) -> Exp rep -> Stms rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Exp rep -> Stms rep)
-> DoubleBufferM rep (Exp rep) -> DoubleBufferM rep (Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper rep rep (DoubleBufferM rep)
-> Exp rep -> DoubleBufferM rep (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((Op rep -> DoubleBufferM rep (Op rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise Op rep -> DoubleBufferM rep (Op rep)
onOp) Exp rep
e
  where
    optimise :: (Op rep -> DoubleBufferM rep (Op rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise Op rep -> DoubleBufferM rep (Op rep)
onOp =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @rep)
        { mapOnBody = \Scope rep
_ Body rep
x ->
            Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
x :: DoubleBufferM rep (Body rep),
          mapOnOp = onOp
        }

optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
  (Env GPUMem -> Env GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem)
forall a.
(Env GPUMem -> Env GPUMem)
-> DoubleBufferM GPUMem a -> DoubleBufferM GPUMem a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env GPUMem -> Env GPUMem
forall {rep} {inner :: * -> *}.
(ExpDec rep ~ (), LetDec rep ~ LetDecMem, BodyDec rep ~ (),
 LParamInfo rep ~ LetDecMem, FParamInfo rep ~ FParamMem,
 BranchType rep ~ BranchTypeMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, OpReturns (inner rep), RephraseOp inner,
 BuilderOps rep) =>
Env rep -> Env rep
inSegOp (DoubleBufferM GPUMem (Op GPUMem)
 -> DoubleBufferM GPUMem (Op GPUMem))
-> DoubleBufferM GPUMem (Op GPUMem)
-> DoubleBufferM GPUMem (Op GPUMem)
forall a b. (a -> b) -> a -> b
$ HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp NoOp) GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
-> DoubleBufferM GPUMem (MemOp (HostOp NoOp) GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
-> SegOp SegLevel GPUMem
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
forall {lvl}. SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper SegOp SegLevel GPUMem
op
  where
    mapper :: SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM GPUMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda = optimiseLambda,
          mapOnSegOpBody = optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop = optimiseLoop}
optimiseGPUOp Op GPUMem
op = MemOp (HostOp NoOp) GPUMem
-> DoubleBufferM GPUMem (MemOp (HostOp NoOp) GPUMem)
forall a. a -> DoubleBufferM GPUMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op GPUMem
MemOp (HostOp NoOp) GPUMem
op

optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp (Inner (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op)) =
  (Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem (Op MCMem) -> DoubleBufferM MCMem (Op MCMem)
forall a.
(Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem a -> DoubleBufferM MCMem a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env MCMem -> Env MCMem
forall {rep} {inner :: * -> *}.
(ExpDec rep ~ (), LetDec rep ~ LetDecMem, BodyDec rep ~ (),
 LParamInfo rep ~ LetDecMem, FParamInfo rep ~ FParamMem,
 BranchType rep ~ BranchTypeMem, RetType rep ~ RetTypeMem,
 OpC rep ~ MemOp inner, OpReturns (inner rep), RephraseOp inner,
 BuilderOps rep) =>
Env rep -> Env rep
inSegOp (DoubleBufferM MCMem (Op MCMem) -> DoubleBufferM MCMem (Op MCMem))
-> DoubleBufferM MCMem (Op MCMem) -> DoubleBufferM MCMem (Op MCMem)
forall a b. (a -> b) -> a -> b
$
    MCOp NoOp MCMem -> MemOp (MCOp NoOp) MCMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner
      (MCOp NoOp MCMem -> MemOp (MCOp NoOp) MCMem)
-> DoubleBufferM MCMem (MCOp NoOp MCMem)
-> DoubleBufferM MCMem (MemOp (MCOp NoOp) MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp NoOp MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
-> DoubleBufferM MCMem (SegOp () MCMem -> MCOp NoOp MCMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse (SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper) Maybe (SegOp () MCMem)
par_op DoubleBufferM MCMem (SegOp () MCMem -> MCOp NoOp MCMem)
-> DoubleBufferM MCMem (SegOp () MCMem)
-> DoubleBufferM MCMem (MCOp NoOp MCMem)
forall a b.
DoubleBufferM MCMem (a -> b)
-> DoubleBufferM MCMem a -> DoubleBufferM MCMem b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper SegOp () MCMem
op)
  where
    mapper :: SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM MCMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda = optimiseLambda,
          mapOnSegOpBody = optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop = optimiseLoop}
optimiseMCOp Op MCMem
op = MemOp (MCOp NoOp) MCMem
-> DoubleBufferM MCMem (MemOp (MCOp NoOp) MCMem)
forall a. a -> DoubleBufferM MCMem a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Op MCMem
MemOp (MCOp NoOp) MCMem
op

optimiseKernelBody ::
  (ASTRep rep) =>
  KernelBody rep ->
  DoubleBufferM rep (KernelBody rep)
optimiseKernelBody :: forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody KernelBody rep
kbody = do
  Stms rep
stms' <- [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
  KernelBody rep -> DoubleBufferM rep (KernelBody rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody rep -> DoubleBufferM rep (KernelBody rep))
-> KernelBody rep -> DoubleBufferM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep
kbody {kernelBodyStms = stms'}

optimiseLambda ::
  (ASTRep rep) =>
  Lambda rep ->
  DoubleBufferM rep (Lambda rep)
optimiseLambda :: forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda Lambda rep
lam = do
  Body rep
body <- Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda rep
lam) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
  Lambda rep -> DoubleBufferM rep (Lambda rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody = body}

type Constraints rep inner =
  ( Mem rep inner,
    BuilderOps rep,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    LetDec rep ~ LetDecMem
  )

extractAllocOf :: (Constraints rep inner) => Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf :: forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound VName
needle Stms rep
stms = do
  (Stm rep
stm, Stms rep
stms') <- Stms rep -> Maybe (Stm rep, Stms rep)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms rep
stms
  case Stm rep
stm of
    Let (Pat [PatElem (LetDec rep)
pe]) StmAux (ExpDec rep)
_ (Op (Alloc SubExp
size Space
_))
      | PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
PatElem LetDecMem
pe VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
needle,
        SubExp -> Bool
invariant SubExp
size ->
          (Stm rep, Stms rep) -> Maybe (Stm rep, Stms rep)
forall a. a -> Maybe a
Just (Stm rep
stm, Stms rep
stms')
    Stm rep
_ ->
      let bound' :: Names
bound' = [VName] -> Names
namesFromList (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm)) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
bound
       in (Stms rep -> Stms rep)
-> (Stm rep, Stms rep) -> (Stm rep, Stms rep)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Stm rep -> Stms rep
forall rep. Stm rep -> Stms rep
oneStm Stm rep
stm <>) ((Stm rep, Stms rep) -> (Stm rep, Stms rep))
-> Maybe (Stm rep, Stms rep) -> Maybe (Stm rep, Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound' VName
needle Stms rep
stms'
  where
    invariant :: SubExp -> Bool
invariant Constant {} = Bool
True
    invariant (Var VName
v) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound

isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn :: VName -> Param FParamMem -> Bool
isArrayIn VName
x (Param Attrs
_ VName
_ (MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
y LMAD
_))) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
isArrayIn VName
_ Param FParamMem
_ = Bool
False

doubleBufferSpace :: Space -> Bool
doubleBufferSpace :: Space -> Bool
doubleBufferSpace ScalarSpace {} = Bool
False
doubleBufferSpace Space
_ = Bool
True

optimiseLoop :: (Constraints rep inner) => OptimiseLoop rep
optimiseLoop :: forall rep (inner :: * -> *).
Constraints rep inner =>
OptimiseLoop rep
optimiseLoop (Pat [PatElem (LetDec rep)]
pes) [(FParam rep, SubExp)]
merge body :: Body rep
body@(Body BodyDec rep
_ Stms rep
body_stms Result
body_res) = do
  ((Pat LetDecMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body'), Stms rep
outer_stms) <- Builder rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     ((Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep), Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
 -> DoubleBufferM
      rep
      ((Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep), Stms rep))
-> Builder
     rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     ((Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep), Stms rep)
forall a b. (a -> b) -> a -> b
$ do
    (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes, Stms rep
body_stms'), ([[PatElem LetDecMem]]
pes', [[(Param FParamMem, SubExp)]]
merge', [Result]
body_res')) <-
      ([([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)]
 -> ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]],
     [Result]))
-> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
     Stms rep),
    [([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
-> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
     Stms rep),
    ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result]))
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)]
-> ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ((((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
   Stms rep),
  [([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
 -> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
      Stms rep),
     ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result])))
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      [([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      ([[PatElem LetDecMem]], [[(Param FParamMem, SubExp)]], [Result]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp), Stms rep)
 -> (PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)
 -> BuilderT
      rep
      (State VNameSource)
      (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
        Stms rep),
       ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
    Stms rep)
-> [(PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)]
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      [([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp), Stms rep)
-> (PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
check ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
forall a. a -> a
id, Stms rep
body_stms) ([PatElem LetDecMem]
-> [(Param FParamMem, SubExp)]
-> Result
-> [(PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem (LetDec rep)]
[PatElem LetDecMem]
pes [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge Result
body_res)
    (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> Builder
     rep (Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [PatElem LetDecMem] -> Pat LetDecMem
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem LetDecMem] -> Pat LetDecMem)
-> [PatElem LetDecMem] -> Pat LetDecMem
forall a b. (a -> b) -> a -> b
$ [[PatElem LetDecMem]] -> [PatElem LetDecMem]
forall a. Monoid a => [a] -> a
mconcat [[PatElem LetDecMem]]
pes',
        ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes ([(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)])
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ [[(Param FParamMem, SubExp)]] -> [(Param FParamMem, SubExp)]
forall a. Monoid a => [a] -> a
mconcat [[(Param FParamMem, SubExp)]]
merge',
        BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms rep
body_stms' (Result -> Body rep) -> Result -> Body rep
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall a. Monoid a => [a] -> a
mconcat [Result]
body_res'
      )
  (Stms rep, Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     (Stms rep, Pat LetDecMem, [(Param FParamMem, SubExp)], Body rep)
forall a. a -> DoubleBufferM rep a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
outer_stms, Pat LetDecMem
pat', [(Param FParamMem, SubExp)]
merge', Body rep
body')
  where
    bound_in_loop :: Names
bound_in_loop =
      [VName] -> Names
namesFromList (((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body

    findLmadOfArray :: VName -> Maybe LMAD
findLmadOfArray VName
v = [LMAD] -> Maybe LMAD
forall a. [a] -> Maybe a
listToMaybe ([LMAD] -> Maybe LMAD)
-> ([Stm rep] -> [LMAD]) -> [Stm rep] -> Maybe LMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm rep -> Maybe LMAD) -> [Stm rep] -> [LMAD]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm rep -> Maybe LMAD
onStm ([Stm rep] -> Maybe LMAD) -> [Stm rep] -> Maybe LMAD
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
body_stms
      where
        onStm :: Stm rep -> Maybe LMAD
onStm = [LMAD] -> Maybe LMAD
forall a. [a] -> Maybe a
listToMaybe ([LMAD] -> Maybe LMAD)
-> (Stm rep -> [LMAD]) -> Stm rep -> Maybe LMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem LetDecMem -> Maybe LMAD) -> [PatElem LetDecMem] -> [LMAD]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElem LetDecMem -> Maybe LMAD
onPatElem ([PatElem LetDecMem] -> [LMAD])
-> (Stm rep -> [PatElem LetDecMem]) -> Stm rep -> [LMAD]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat LetDecMem -> [PatElem LetDecMem])
-> (Stm rep -> Pat LetDecMem) -> Stm rep -> [PatElem LetDecMem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
Stm rep -> Pat LetDecMem
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
        onPatElem :: PatElem LetDecMem -> Maybe LMAD
onPatElem (PatElem VName
pe_v (MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
_ LMAD
lmad)))
          | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
pe_v,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
bound_in_loop Names -> Names -> Bool
`namesIntersect` LMAD -> Names
forall a. FreeIn a => a -> Names
freeIn LMAD
lmad =
              LMAD -> Maybe LMAD
forall a. a -> Maybe a
Just LMAD
lmad
        onPatElem PatElem LetDecMem
_ = Maybe LMAD
forall a. Maybe a
Nothing

    changeParam :: a -> (a, b) -> (a, b) -> (a, b)
changeParam a
p_needle (a, b)
new (a
p, b
p_initial) =
      if a
p a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p_needle then (a, b)
new else (a
p, b
p_initial)

    check :: ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp), Stms rep)
-> (PatElem LetDecMem, (Param FParamMem, SubExp), SubExpRes)
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
check ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes, Stms rep
body_stms') (PatElem LetDecMem
pe, (Param FParamMem
param, SubExp
arg), SubExpRes
res)
      | Mem Space
space <- Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
param,
        Space -> Bool
doubleBufferSpace Space
space,
        Var VName
arg_v <- SubExp
arg,
        -- XXX: what happens if there are multiple arrays in the same
        -- memory block?
        [((Param FParamMem
arr_param, Var VName
arr_param_initial), Var VName
arr_v)] <-
          (((Param FParamMem, SubExp), SubExp) -> Bool)
-> [((Param FParamMem, SubExp), SubExp)]
-> [((Param FParamMem, SubExp), SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter
            (VName -> Param FParamMem -> Bool
isArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) (Param FParamMem -> Bool)
-> (((Param FParamMem, SubExp), SubExp) -> Param FParamMem)
-> ((Param FParamMem, SubExp), SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ((Param FParamMem, SubExp) -> Param FParamMem)
-> (((Param FParamMem, SubExp), SubExp)
    -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp), SubExp)
-> Param FParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp), SubExp) -> (Param FParamMem, SubExp)
forall a b. (a, b) -> a
fst)
            ([(Param FParamMem, SubExp)]
-> [SubExp] -> [((Param FParamMem, SubExp), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge ([SubExp] -> [((Param FParamMem, SubExp), SubExp)])
-> [SubExp] -> [((Param FParamMem, SubExp), SubExp)]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
body_res),
        MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
_ LMAD
param_lmad) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
arr_param,
        Var VName
arr_mem_out <- SubExpRes -> SubExp
resSubExp SubExpRes
res,
        Just LMAD
arr_lmad <- VName -> Maybe LMAD
findLmadOfArray VName
arr_v,
        Just (Stm rep
arr_mem_out_alloc, Stms rep
body_stms'') <-
          Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
forall rep (inner :: * -> *).
Constraints rep inner =>
Names -> VName -> Stms rep -> Maybe (Stm rep, Stms rep)
extractAllocOf Names
bound_in_loop VName
arr_mem_out Stms rep
body_stms' = do
          -- Put the allocations outside the loop.
          SubExp
num_bytes <-
            String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_bytes" (Exp rep -> BuilderT rep (State VNameSource) SubExp)
-> BuilderT rep (State VNameSource) (Exp rep)
-> BuilderT rep (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName
-> BuilderT
     rep
     (State VNameSource)
     (Exp (Rep (BuilderT rep (State VNameSource))))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* (TPrimExp Int64 VName
1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ LMAD -> TPrimExp Int64 VName
forall num.
Pretty num =>
LMAD (TPrimExp Int64 num) -> TPrimExp Int64 num
LMAD.range LMAD
arr_lmad))
          VName
arr_mem_in <-
            String
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arg_v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_in") (Exp (Rep (BuilderT rep (State VNameSource)))
 -> BuilderT rep (State VNameSource) VName)
-> Exp (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT rep (State VNameSource)))
-> Exp (Rep (BuilderT rep (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT rep (State VNameSource)))
 -> Exp (Rep (BuilderT rep (State VNameSource))))
-> Op (Rep (BuilderT rep (State VNameSource)))
-> Exp (Rep (BuilderT rep (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner rep
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
num_bytes Space
space
          Stm (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm rep
Stm (Rep (BuilderT rep (State VNameSource)))
arr_mem_out_alloc

          -- Construct additional pattern element and parameter for
          -- the memory block that is not used afterwards.
          PatElem LetDecMem
pe_unused <-
            VName -> LetDecMem -> PatElem LetDecMem
forall dec. VName -> dec -> PatElem dec
PatElem
              (VName -> LetDecMem -> PatElem LetDecMem)
-> BuilderT rep (State VNameSource) VName
-> BuilderT
     rep (State VNameSource) (LetDecMem -> PatElem LetDecMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BuilderT rep (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_unused")
              BuilderT rep (State VNameSource) (LetDecMem -> PatElem LetDecMem)
-> BuilderT rep (State VNameSource) LetDecMem
-> BuilderT rep (State VNameSource) (PatElem LetDecMem)
forall a b.
BuilderT rep (State VNameSource) (a -> b)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LetDecMem -> BuilderT rep (State VNameSource) LetDecMem
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Space -> LetDecMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
          Param FParamMem
param_out <-
            String
-> FParamMem -> BuilderT rep (State VNameSource) (Param FParamMem)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_out") (Space -> FParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)

          -- Copy the initial array value to the input memory, with
          -- the same index function as the result.
          VName
arr_v_copy <- String -> BuilderT rep (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> BuilderT rep (State VNameSource) VName)
-> String -> BuilderT rep (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
arr_v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_db_copy"
          let arr_initial_info :: LetDecMem
arr_initial_info =
                PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
arr_mem_in LMAD
arr_lmad
              arr_initial_pe :: PatElem LetDecMem
arr_initial_pe =
                VName -> LetDecMem -> PatElem LetDecMem
forall dec. VName -> dec -> PatElem dec
PatElem VName
arr_v_copy LetDecMem
arr_initial_info
          Stm rep -> BuilderT rep (State VNameSource) ()
Stm (Rep (BuilderT rep (State VNameSource)))
-> BuilderT rep (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm rep -> BuilderT rep (State VNameSource) ())
-> (BasicOp -> Stm rep)
-> BasicOp
-> BuilderT rep (State VNameSource) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem LetDecMem] -> Pat LetDecMem
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem LetDecMem
arr_initial_pe]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> (BasicOp -> Exp rep) -> BasicOp -> Stm rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT rep (State VNameSource) ())
-> BasicOp -> BuilderT rep (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
            ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
arr_param_initial)
          -- AS a trick we must make the array parameter Unique to
          -- avoid unfortunate hoisting (see #1533) because we are
          -- invalidating the underlying memory.
          let arr_param' :: Param FParamMem
arr_param' =
                Attrs -> VName -> FParamMem -> Param FParamMem
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
arr_param) (FParamMem -> Param FParamMem) -> FParamMem -> Param FParamMem
forall a b. (a -> b) -> a -> b
$
                  PrimType -> ShapeBase SubExp -> Uniqueness -> MemBind -> FParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
Unique (VName -> LMAD -> MemBind
ArrayIn (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param) LMAD
param_lmad)

          -- We must also update the initial values of the parameters
          -- used in the index function of this array parameter, such
          -- that they match the result.
          let mkUpdate :: VName -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
mkUpdate VName
lmad_v =
                case (((Param FParamMem, SubExp), SubExpRes) -> Bool)
-> [((Param FParamMem, SubExp), SubExpRes)]
-> Maybe ((Param FParamMem, SubExp), SubExpRes)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
lmad_v) (VName -> Bool)
-> (((Param FParamMem, SubExp), SubExpRes) -> VName)
-> ((Param FParamMem, SubExp), SubExpRes)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> (((Param FParamMem, SubExp), SubExpRes) -> Param FParamMem)
-> ((Param FParamMem, SubExp), SubExpRes)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst ((Param FParamMem, SubExp) -> Param FParamMem)
-> (((Param FParamMem, SubExp), SubExpRes)
    -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp), SubExpRes)
-> Param FParamMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp), SubExpRes) -> (Param FParamMem, SubExp)
forall a b. (a, b) -> a
fst) ([((Param FParamMem, SubExp), SubExpRes)]
 -> Maybe ((Param FParamMem, SubExp), SubExpRes))
-> [((Param FParamMem, SubExp), SubExpRes)]
-> Maybe ((Param FParamMem, SubExp), SubExpRes)
forall a b. (a -> b) -> a -> b
$
                  [(Param FParamMem, SubExp)]
-> Result -> [((Param FParamMem, SubExp), SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge Result
body_res of
                  Maybe ((Param FParamMem, SubExp), SubExpRes)
Nothing -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
forall a. a -> a
id
                  Just ((Param FParamMem
p, SubExp
_), SubExpRes
p_res) -> Param FParamMem
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall {a} {b}. Eq a => a -> (a, b) -> (a, b) -> (a, b)
changeParam Param FParamMem
p (Param FParamMem
p, SubExpRes -> SubExp
resSubExp SubExpRes
p_res)
              updateLmadParam :: (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
updateLmadParam =
                (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
 -> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
 -> (Param FParamMem, SubExp)
 -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
forall a. a -> a
id ([(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
 -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall a b. (a -> b) -> a -> b
$ (VName -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> [VName]
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
mkUpdate ([VName]
 -> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)])
-> [VName]
-> [(Param FParamMem, SubExp) -> (Param FParamMem, SubExp)]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ LMAD -> Names
forall a. FreeIn a => a -> Names
freeIn LMAD
param_lmad

          (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
  Stms rep),
 ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( ( (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
updateLmadParam
                  ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall {a} {b}. Eq a => a -> (a, b) -> (a, b) -> (a, b)
changeParam Param FParamMem
arr_param (Param FParamMem
arr_param', VName -> SubExp
Var VName
arr_v_copy)
                  ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp))
-> (Param FParamMem, SubExp)
-> (Param FParamMem, SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes,
                Map VName VName -> Stms rep -> Stms rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
arr_mem_out (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param_out)) Stms rep
body_stms''
              ),
              ( [PatElem LetDecMem
pe, PatElem LetDecMem
pe_unused],
                [(Param FParamMem
param, VName -> SubExp
Var VName
arr_mem_in), (Param FParamMem
param_out, VName -> SubExp
Var VName
arr_mem_out)],
                [ SubExpRes
res {resSubExp = Var $ paramName param_out},
                  SubExp -> SubExpRes
subExpRes (SubExp -> SubExpRes) -> SubExp -> SubExpRes
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
param
                ]
              )
            )
      | Bool
otherwise =
          (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
  Stms rep),
 ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
-> BuilderT
     rep
     (State VNameSource)
     (((Param FParamMem, SubExp) -> (Param FParamMem, SubExp),
       Stms rep),
      ([PatElem LetDecMem], [(Param FParamMem, SubExp)], Result))
forall a. a -> BuilderT rep (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( ((Param FParamMem, SubExp) -> (Param FParamMem, SubExp)
param_changes, Stms rep
body_stms'),
              ([PatElem LetDecMem
pe], [(Param FParamMem
param, SubExp
arg)], [SubExpRes
res])
            )

-- | The double buffering pass definition.
doubleBuffer :: (Mem rep inner) => String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer :: forall rep (inner :: * -> *).
Mem rep inner =>
String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer String
name String
desc OptimiseOp rep
onOp =
  Pass
    { passName :: String
passName = String
name,
      passDescription :: String
passDescription = String
desc,
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope rep -> Stms rep -> PassM (Stms rep)
optimise
    }
  where
    optimise :: Scope rep -> Stms rep -> PassM (Stms rep)
optimise Scope rep
scope Stms rep
stms = (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep))
-> (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
      let m :: ReaderT (Env rep) (State VNameSource) (Stms rep)
m =
            DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM (DoubleBufferM rep (Stms rep)
 -> ReaderT (Env rep) (State VNameSource) (Stms rep))
-> DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall a b. (a -> b) -> a -> b
$ Scope rep
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope (DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep (Stms rep)
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep (Stms rep)
optimiseStms ([Stm rep] -> DoubleBufferM rep (Stms rep))
-> [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
       in State VNameSource (Stms rep)
-> VNameSource -> (Stms rep, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep) (State VNameSource) (Stms rep)
-> Env rep -> State VNameSource (Stms rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep) (State VNameSource) (Stms rep)
m Env rep
env) VNameSource
src

    env :: Env rep
env = Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
forall rep.
Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
Env Scope rep
forall a. Monoid a => a
mempty OptimiseLoop rep
Pat (LetDec rep)
-> [(Param FParamMem, SubExp)]
-> Body rep
-> DoubleBufferM
     rep
     (Stms rep, Pat (LetDec rep), [(Param FParamMem, SubExp)], Body rep)
forall {f :: * -> *} {a} {b} {c} {d}.
(Applicative f, Monoid a) =>
b -> c -> d -> f (a, b, c, d)
doNotTouchLoop OptimiseOp rep
onOp
    doNotTouchLoop :: b -> c -> d -> f (a, b, c, d)
doNotTouchLoop b
pat c
merge d
body = (a, b, c, d) -> f (a, b, c, d)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
forall a. Monoid a => a
mempty, b
pat, c
merge, d
body)

-- | The pass for GPU kernels.
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU =
  String -> String -> OptimiseOp GPUMem -> Pass GPUMem GPUMem
forall rep (inner :: * -> *).
Mem rep inner =>
String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer
    String
"Double buffer GPU"
    String
"Double buffer memory in sequential loops (GPU rep)."
    OptimiseOp GPUMem
optimiseGPUOp

-- | The pass for multicore
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC =
  String -> String -> OptimiseOp MCMem -> Pass MCMem MCMem
forall rep (inner :: * -> *).
Mem rep inner =>
String -> String -> OptimiseOp rep -> Pass rep rep
doubleBuffer
    String
"Double buffer MC"
    String
"Double buffer memory in sequential loops (MC rep)."
    OptimiseOp MCMem
optimiseMCOp