{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | The simplification engine is only willing to hoist allocations
-- out of loops if the memory block resulting from the allocation is
-- dead at the end of the loop.  If it is not, we may cause data
-- hazards.
--
-- This module rewrites loops with memory block merge parameters such
-- that each memory block is copied at the end of the iteration, thus
-- ensuring that any allocation inside the loop is dead at the end of
-- the loop.  This is only possible for allocations whose size is
-- loop-invariant, although the initial size may differ from the size
-- produced by the loop result.
--
-- Additionally, inside parallel kernels we also copy the initial
-- value.  This has the effect of making the memory block returned by
-- the array non-existential, which is important for later memory
-- expansion to work.
module Futhark.Optimise.DoubleBuffer (doubleBufferGPU, doubleBufferMC) where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.List (find)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.GPUMem as GPU
import Futhark.IR.MCMem as MC
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations (arraySizeInBytesExp)
import Futhark.Pass.ExplicitAllocations.GPU ()
import Futhark.Util (maybeHead)

-- | The pass for GPU kernels.
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU :: Pass GPUMem GPUMem
doubleBufferGPU = OptimiseOp GPUMem -> Pass GPUMem GPUMem
forall rep. Mem rep => OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp GPUMem
optimiseGPUOp

-- | The pass for multicore
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC :: Pass MCMem MCMem
doubleBufferMC = OptimiseOp MCMem -> Pass MCMem MCMem
forall rep. Mem rep => OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp MCMem
optimiseMCOp

-- | The double buffering pass definition.
doubleBuffer :: Mem rep => OptimiseOp rep -> Pass rep rep
doubleBuffer :: forall rep. Mem rep => OptimiseOp rep -> Pass rep rep
doubleBuffer OptimiseOp rep
onOp =
  Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: String
passName = String
"Double buffer",
      passDescription :: String
passDescription = String
"Perform double buffering for merge parameters of sequential loops.",
      passFunction :: Prog rep -> PassM (Prog rep)
passFunction = (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation Scope rep -> Stms rep -> PassM (Stms rep)
optimise
    }
  where
    optimise :: Scope rep -> Stms rep -> PassM (Stms rep)
optimise Scope rep
scope Stms rep
stms = (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep))
-> (VNameSource -> (Stms rep, VNameSource)) -> PassM (Stms rep)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
      let m :: ReaderT (Env rep) (State VNameSource) (Stms rep)
m =
            DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM (DoubleBufferM rep (Stms rep)
 -> ReaderT (Env rep) (State VNameSource) (Stms rep))
-> DoubleBufferM rep (Stms rep)
-> ReaderT (Env rep) (State VNameSource) (Stms rep)
forall a b. (a -> b) -> a -> b
$
              Scope rep
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope rep
scope (DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep (Stms rep) -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$
                ([Stm rep] -> Stms rep)
-> DoubleBufferM rep [Stm rep] -> DoubleBufferM rep (Stms rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList (DoubleBufferM rep [Stm rep] -> DoubleBufferM rep (Stms rep))
-> DoubleBufferM rep [Stm rep] -> DoubleBufferM rep (Stms rep)
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep [Stm rep]
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep [Stm rep]
optimiseStms ([Stm rep] -> DoubleBufferM rep [Stm rep])
-> [Stm rep] -> DoubleBufferM rep [Stm rep]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
       in State VNameSource (Stms rep)
-> VNameSource -> (Stms rep, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep) (State VNameSource) (Stms rep)
-> Env rep -> State VNameSource (Stms rep)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep) (State VNameSource) (Stms rep)
m Env rep
env) VNameSource
src

    env :: Env rep
env = Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
forall rep.
Scope rep -> OptimiseLoop rep -> OptimiseOp rep -> Env rep
Env Scope rep
forall a. Monoid a => a
mempty OptimiseLoop rep
forall {m :: * -> *} {a} {b} {c} {d}.
(Monad m, Monoid a) =>
b -> c -> d -> m (a, b, c, d)
doNotTouchLoop OptimiseOp rep
onOp
    doNotTouchLoop :: b -> c -> d -> m (a, b, c, d)
doNotTouchLoop b
ctx c
val d
body = (a, b, c, d) -> m (a, b, c, d)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
forall a. Monoid a => a
mempty, b
ctx, c
val, d
body)

type OptimiseLoop rep =
  [(FParam rep, SubExp)] ->
  [(FParam rep, SubExp)] ->
  Body rep ->
  DoubleBufferM
    rep
    ( [Stm rep],
      [(FParam rep, SubExp)],
      [(FParam rep, SubExp)],
      Body rep
    )

type OptimiseOp rep =
  Op rep -> DoubleBufferM rep (Op rep)

data Env rep = Env
  { forall rep. Env rep -> Scope rep
envScope :: Scope rep,
    forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop :: OptimiseLoop rep,
    forall rep. Env rep -> OptimiseOp rep
envOptimiseOp :: OptimiseOp rep
  }

newtype DoubleBufferM rep a = DoubleBufferM
  { forall rep a.
DoubleBufferM rep a -> ReaderT (Env rep) (State VNameSource) a
runDoubleBufferM :: ReaderT (Env rep) (State VNameSource) a
  }
  deriving ((forall a b.
 (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Functor (DoubleBufferM rep)
forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
$c<$ :: forall rep a b. a -> DoubleBufferM rep b -> DoubleBufferM rep a
fmap :: forall a b. (a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
$cfmap :: forall rep a b.
(a -> b) -> DoubleBufferM rep a -> DoubleBufferM rep b
Functor, Functor (DoubleBufferM rep)
Functor (DoubleBufferM rep)
-> (forall a. a -> DoubleBufferM rep a)
-> (forall a b.
    DoubleBufferM rep (a -> b)
    -> DoubleBufferM rep a -> DoubleBufferM rep b)
-> (forall a b c.
    (a -> b -> c)
    -> DoubleBufferM rep a
    -> DoubleBufferM rep b
    -> DoubleBufferM rep c)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a)
-> Applicative (DoubleBufferM rep)
forall rep. Functor (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
$c<* :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep a
*> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c*> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
liftA2 :: forall a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
$cliftA2 :: forall rep a b c.
(a -> b -> c)
-> DoubleBufferM rep a
-> DoubleBufferM rep b
-> DoubleBufferM rep c
<*> :: forall a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
$c<*> :: forall rep a b.
DoubleBufferM rep (a -> b)
-> DoubleBufferM rep a -> DoubleBufferM rep b
pure :: forall a. a -> DoubleBufferM rep a
$cpure :: forall rep a. a -> DoubleBufferM rep a
Applicative, Applicative (DoubleBufferM rep)
Applicative (DoubleBufferM rep)
-> (forall a b.
    DoubleBufferM rep a
    -> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b)
-> (forall a b.
    DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b)
-> (forall a. a -> DoubleBufferM rep a)
-> Monad (DoubleBufferM rep)
forall rep. Applicative (DoubleBufferM rep)
forall a. a -> DoubleBufferM rep a
forall rep a. a -> DoubleBufferM rep a
forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DoubleBufferM rep a
$creturn :: forall rep a. a -> DoubleBufferM rep a
>> :: forall a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
$c>> :: forall rep a b.
DoubleBufferM rep a -> DoubleBufferM rep b -> DoubleBufferM rep b
>>= :: forall a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
$c>>= :: forall rep a b.
DoubleBufferM rep a
-> (a -> DoubleBufferM rep b) -> DoubleBufferM rep b
Monad, MonadReader (Env rep), Monad (DoubleBufferM rep)
Applicative (DoubleBufferM rep)
DoubleBufferM rep VNameSource
Applicative (DoubleBufferM rep)
-> Monad (DoubleBufferM rep)
-> DoubleBufferM rep VNameSource
-> (VNameSource -> DoubleBufferM rep ())
-> MonadFreshNames (DoubleBufferM rep)
VNameSource -> DoubleBufferM rep ()
forall rep. Monad (DoubleBufferM rep)
forall rep. Applicative (DoubleBufferM rep)
forall rep. DoubleBufferM rep VNameSource
forall rep. VNameSource -> DoubleBufferM rep ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DoubleBufferM rep ()
$cputNameSource :: forall rep. VNameSource -> DoubleBufferM rep ()
getNameSource :: DoubleBufferM rep VNameSource
$cgetNameSource :: forall rep. DoubleBufferM rep VNameSource
MonadFreshNames)

instance ASTRep rep => HasScope rep (DoubleBufferM rep) where
  askScope :: DoubleBufferM rep (Scope rep)
askScope = (Env rep -> Scope rep) -> DoubleBufferM rep (Scope rep)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Scope rep
forall rep. Env rep -> Scope rep
envScope

instance ASTRep rep => LocalScope rep (DoubleBufferM rep) where
  localScope :: forall a. Scope rep -> DoubleBufferM rep a -> DoubleBufferM rep a
localScope Scope rep
scope = (Env rep -> Env rep) -> DoubleBufferM rep a -> DoubleBufferM rep a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep -> Env rep)
 -> DoubleBufferM rep a -> DoubleBufferM rep a)
-> (Env rep -> Env rep)
-> DoubleBufferM rep a
-> DoubleBufferM rep a
forall a b. (a -> b) -> a -> b
$ \Env rep
env -> Env rep
env {envScope :: Scope rep
envScope = Env rep -> Scope rep
forall rep. Env rep -> Scope rep
envScope Env rep
env Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> Scope rep
scope}

optimiseBody :: ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody :: forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody Body rep
body = do
  [Stm rep]
bnds' <- [Stm rep] -> DoubleBufferM rep [Stm rep]
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep [Stm rep]
optimiseStms ([Stm rep] -> DoubleBufferM rep [Stm rep])
-> [Stm rep] -> DoubleBufferM rep [Stm rep]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ Body rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms Body rep
body
  Body rep -> DoubleBufferM rep (Body rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
bnds'}

optimiseStms :: ASTRep rep => [Stm rep] -> DoubleBufferM rep [Stm rep]
optimiseStms :: forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep [Stm rep]
optimiseStms [] = [Stm rep] -> DoubleBufferM rep [Stm rep]
forall (m :: * -> *) a. Monad m => a -> m a
return []
optimiseStms (Stm rep
e : [Stm rep]
es) = do
  [Stm rep]
e_es <- Stm rep -> DoubleBufferM rep [Stm rep]
forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep [Stm rep]
optimiseStm Stm rep
e
  [Stm rep]
es' <- Scope rep
-> DoubleBufferM rep [Stm rep] -> DoubleBufferM rep [Stm rep]
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf [Stm rep]
e_es) (DoubleBufferM rep [Stm rep] -> DoubleBufferM rep [Stm rep])
-> DoubleBufferM rep [Stm rep] -> DoubleBufferM rep [Stm rep]
forall a b. (a -> b) -> a -> b
$ [Stm rep] -> DoubleBufferM rep [Stm rep]
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep [Stm rep]
optimiseStms [Stm rep]
es
  [Stm rep] -> DoubleBufferM rep [Stm rep]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm rep] -> DoubleBufferM rep [Stm rep])
-> [Stm rep] -> DoubleBufferM rep [Stm rep]
forall a b. (a -> b) -> a -> b
$ [Stm rep]
e_es [Stm rep] -> [Stm rep] -> [Stm rep]
forall a. [a] -> [a] -> [a]
++ [Stm rep]
es'

optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep [Stm rep]
optimiseStm :: forall rep. ASTRep rep => Stm rep -> DoubleBufferM rep [Stm rep]
optimiseStm (Let Pattern rep
pat StmAux (ExpDec rep)
aux (DoLoop [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
val LoopForm rep
form BodyT rep
body)) = do
  BodyT rep
body' <-
    Scope rep
-> DoubleBufferM rep (BodyT rep) -> DoubleBufferM rep (BodyT rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm rep
form Scope rep -> Scope rep -> Scope rep
forall a. Semigroup a => a -> a -> a
<> [FParam rep] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam rep, SubExp) -> FParam rep)
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> [a] -> [b]
map (FParam rep, SubExp) -> FParam rep
forall a b. (a, b) -> a
fst ([(FParam rep, SubExp)] -> [FParam rep])
-> [(FParam rep, SubExp)] -> [FParam rep]
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> [(FParam rep, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
val)) (DoubleBufferM rep (BodyT rep) -> DoubleBufferM rep (BodyT rep))
-> DoubleBufferM rep (BodyT rep) -> DoubleBufferM rep (BodyT rep)
forall a b. (a -> b) -> a -> b
$
      BodyT rep -> DoubleBufferM rep (BodyT rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody BodyT rep
body
  [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)]
-> BodyT rep
-> DoubleBufferM
     rep
     ([Stm rep], [(FParam rep, SubExp)], [(FParam rep, SubExp)],
      BodyT rep)
opt_loop <- (Env rep
 -> [(FParam rep, SubExp)]
 -> [(FParam rep, SubExp)]
 -> BodyT rep
 -> DoubleBufferM
      rep
      ([Stm rep], [(FParam rep, SubExp)], [(FParam rep, SubExp)],
       BodyT rep))
-> DoubleBufferM
     rep
     ([(FParam rep, SubExp)]
      -> [(FParam rep, SubExp)]
      -> BodyT rep
      -> DoubleBufferM
           rep
           ([Stm rep], [(FParam rep, SubExp)], [(FParam rep, SubExp)],
            BodyT rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep
-> [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)]
-> BodyT rep
-> DoubleBufferM
     rep
     ([Stm rep], [(FParam rep, SubExp)], [(FParam rep, SubExp)],
      BodyT rep)
forall rep. Env rep -> OptimiseLoop rep
envOptimiseLoop
  ([Stm rep]
bnds, [(FParam rep, SubExp)]
ctx', [(FParam rep, SubExp)]
val', BodyT rep
body'') <- [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)]
-> BodyT rep
-> DoubleBufferM
     rep
     ([Stm rep], [(FParam rep, SubExp)], [(FParam rep, SubExp)],
      BodyT rep)
opt_loop [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
val BodyT rep
body'
  [Stm rep] -> DoubleBufferM rep [Stm rep]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm rep] -> DoubleBufferM rep [Stm rep])
-> [Stm rep] -> DoubleBufferM rep [Stm rep]
forall a b. (a -> b) -> a -> b
$ [Stm rep]
bnds [Stm rep] -> [Stm rep] -> [Stm rep]
forall a. [a] -> [a] -> [a]
++ [Pattern rep -> StmAux (ExpDec rep) -> ExpT rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern rep
pat StmAux (ExpDec rep)
aux (ExpT rep -> Stm rep) -> ExpT rep -> Stm rep
forall a b. (a -> b) -> a -> b
$ [(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
forall rep.
[(FParam rep, SubExp)]
-> [(FParam rep, SubExp)] -> LoopForm rep -> BodyT rep -> ExpT rep
DoLoop [(FParam rep, SubExp)]
ctx' [(FParam rep, SubExp)]
val' LoopForm rep
form BodyT rep
body'']
optimiseStm (Let Pattern rep
pat StmAux (ExpDec rep)
aux ExpT rep
e) = do
  Op rep -> DoubleBufferM rep (Op rep)
onOp <- (Env rep -> Op rep -> DoubleBufferM rep (Op rep))
-> DoubleBufferM rep (Op rep -> DoubleBufferM rep (Op rep))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep -> Op rep -> DoubleBufferM rep (Op rep)
forall rep. Env rep -> OptimiseOp rep
envOptimiseOp
  Stm rep -> [Stm rep]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm rep -> [Stm rep])
-> (ExpT rep -> Stm rep) -> ExpT rep -> [Stm rep]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern rep -> StmAux (ExpDec rep) -> ExpT rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pattern rep
pat StmAux (ExpDec rep)
aux (ExpT rep -> [Stm rep])
-> DoubleBufferM rep (ExpT rep) -> DoubleBufferM rep [Stm rep]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper rep rep (DoubleBufferM rep)
-> ExpT rep -> DoubleBufferM rep (ExpT rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((Op rep -> DoubleBufferM rep (Op rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise Op rep -> DoubleBufferM rep (Op rep)
onOp) ExpT rep
e
  where
    optimise :: (Op rep -> DoubleBufferM rep (Op rep))
-> Mapper rep rep (DoubleBufferM rep)
optimise Op rep -> DoubleBufferM rep (Op rep)
onOp =
      Mapper rep rep (DoubleBufferM rep)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope rep -> BodyT rep -> DoubleBufferM rep (BodyT rep)
mapOnBody = \Scope rep
_ BodyT rep
x ->
            BodyT rep -> DoubleBufferM rep (BodyT rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody BodyT rep
x :: DoubleBufferM rep (Body rep),
          mapOnOp :: Op rep -> DoubleBufferM rep (Op rep)
mapOnOp = Op rep -> DoubleBufferM rep (Op rep)
onOp
        }

optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp :: OptimiseOp GPUMem
optimiseGPUOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
  (Env GPUMem -> Env GPUMem)
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env GPUMem -> Env GPUMem
forall {rep} {inner}.
(OpReturns rep, BinderOps rep, RetType rep ~ RetTypeMem,
 LetDec rep ~ LetDecMem, LParamInfo rep ~ LetDecMem,
 BranchType rep ~ BranchTypeMem, FParamInfo rep ~ FParamMem,
 ExpDec rep ~ (), Op rep ~ MemOp inner, BodyDec rep ~ ()) =>
Env rep -> Env rep
inSegOp (DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
 -> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ())))
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
-> SegOp SegLevel GPUMem
-> DoubleBufferM GPUMem (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem (DoubleBufferM GPUMem)
forall {lvl}. SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper SegOp SegLevel GPUMem
op
  where
    mapper :: SegOpMapper lvl GPUMem GPUMem (DoubleBufferM GPUMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM GPUMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> DoubleBufferM GPUMem (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> DoubleBufferM GPUMem (Lambda GPUMem)
forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda,
          mapOnSegOpBody :: KernelBody GPUMem -> DoubleBufferM GPUMem (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> DoubleBufferM GPUMem (KernelBody GPUMem)
forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop :: OptimiseLoop rep
envOptimiseLoop = OptimiseLoop rep
forall rep inner.
(Constraints rep, Op rep ~ MemOp inner, BinderOps rep) =>
OptimiseLoop rep
optimiseLoop}
optimiseGPUOp Op GPUMem
op = MemOp (HostOp GPUMem ())
-> DoubleBufferM GPUMem (MemOp (HostOp GPUMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op GPUMem
MemOp (HostOp GPUMem ())
op

optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp :: OptimiseOp MCMem
optimiseMCOp (Inner (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op)) =
  (Env MCMem -> Env MCMem)
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env MCMem -> Env MCMem
forall {rep} {inner}.
(OpReturns rep, BinderOps rep, RetType rep ~ RetTypeMem,
 LetDec rep ~ LetDecMem, LParamInfo rep ~ LetDecMem,
 BranchType rep ~ BranchTypeMem, FParamInfo rep ~ FParamMem,
 ExpDec rep ~ (), Op rep ~ MemOp inner, BodyDec rep ~ ()) =>
Env rep -> Env rep
inSegOp (DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
 -> DoubleBufferM MCMem (MemOp (MCOp MCMem ())))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall a b. (a -> b) -> a -> b
$
    MCOp MCMem () -> MemOp (MCOp MCMem ())
forall inner. inner -> MemOp inner
Inner
      (MCOp MCMem () -> MemOp (MCOp MCMem ()))
-> DoubleBufferM MCMem (MCOp MCMem ())
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp MCMem ()
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (Maybe (SegOp () MCMem) -> SegOp () MCMem -> MCOp MCMem ())
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
-> DoubleBufferM MCMem (SegOp () MCMem -> MCOp MCMem ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem))
-> Maybe (SegOp () MCMem)
-> DoubleBufferM MCMem (Maybe (SegOp () MCMem))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper) Maybe (SegOp () MCMem)
par_op DoubleBufferM MCMem (SegOp () MCMem -> MCOp MCMem ())
-> DoubleBufferM MCMem (SegOp () MCMem)
-> DoubleBufferM MCMem (MCOp MCMem ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
-> SegOp () MCMem -> DoubleBufferM MCMem (SegOp () MCMem)
forall (m :: * -> *) lvl frep trep.
(Applicative m, Monad m) =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper () MCMem MCMem (DoubleBufferM MCMem)
forall {lvl}. SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper SegOp () MCMem
op)
  where
    mapper :: SegOpMapper lvl MCMem MCMem (DoubleBufferM MCMem)
mapper =
      SegOpMapper lvl Any Any (DoubleBufferM MCMem)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
mapOnSegOpLambda = Lambda MCMem -> DoubleBufferM MCMem (Lambda MCMem)
forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda,
          mapOnSegOpBody :: KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
mapOnSegOpBody = KernelBody MCMem -> DoubleBufferM MCMem (KernelBody MCMem)
forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody
        }
    inSegOp :: Env rep -> Env rep
inSegOp Env rep
env = Env rep
env {envOptimiseLoop :: OptimiseLoop rep
envOptimiseLoop = OptimiseLoop rep
forall rep inner.
(Constraints rep, Op rep ~ MemOp inner, BinderOps rep) =>
OptimiseLoop rep
optimiseLoop}
optimiseMCOp Op MCMem
op = MemOp (MCOp MCMem ())
-> DoubleBufferM MCMem (MemOp (MCOp MCMem ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Op MCMem
MemOp (MCOp MCMem ())
op

optimiseKernelBody ::
  ASTRep rep =>
  KernelBody rep ->
  DoubleBufferM rep (KernelBody rep)
optimiseKernelBody :: forall rep.
ASTRep rep =>
KernelBody rep -> DoubleBufferM rep (KernelBody rep)
optimiseKernelBody KernelBody rep
kbody = do
  [Stm rep]
stms' <- [Stm rep] -> DoubleBufferM rep [Stm rep]
forall rep. ASTRep rep => [Stm rep] -> DoubleBufferM rep [Stm rep]
optimiseStms ([Stm rep] -> DoubleBufferM rep [Stm rep])
-> [Stm rep] -> DoubleBufferM rep [Stm rep]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms rep -> [Stm rep]) -> Stms rep -> [Stm rep]
forall a b. (a -> b) -> a -> b
$ KernelBody rep -> Stms rep
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody rep
kbody
  KernelBody rep -> DoubleBufferM rep (KernelBody rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody rep -> DoubleBufferM rep (KernelBody rep))
-> KernelBody rep -> DoubleBufferM rep (KernelBody rep)
forall a b. (a -> b) -> a -> b
$ KernelBody rep
kbody {kernelBodyStms :: Stms rep
kernelBodyStms = [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
stms'}

optimiseLambda ::
  ASTRep rep =>
  Lambda rep ->
  DoubleBufferM rep (Lambda rep)
optimiseLambda :: forall rep.
ASTRep rep =>
Lambda rep -> DoubleBufferM rep (Lambda rep)
optimiseLambda Lambda rep
lam = do
  Body rep
body <- Scope rep
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope rep -> Scope rep
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Scope rep
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda rep
lam) (DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep))
-> DoubleBufferM rep (Body rep) -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Body rep -> DoubleBufferM rep (Body rep)
forall rep. ASTRep rep => Body rep -> DoubleBufferM rep (Body rep)
optimiseBody (Body rep -> DoubleBufferM rep (Body rep))
-> Body rep -> DoubleBufferM rep (Body rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam
  Lambda rep -> DoubleBufferM rep (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = Body rep
body}

type Constraints rep =
  ( ASTRep rep,
    FParamInfo rep ~ FParamMem,
    LParamInfo rep ~ LParamMem,
    RetType rep ~ RetTypeMem,
    LetDec rep ~ LetDecMem,
    BranchType rep ~ BranchTypeMem,
    ExpDec rep ~ (),
    BodyDec rep ~ (),
    OpReturns rep
  )

optimiseLoop :: (Constraints rep, Op rep ~ MemOp inner, BinderOps rep) => OptimiseLoop rep
optimiseLoop :: forall rep inner.
(Constraints rep, Op rep ~ MemOp inner, BinderOps rep) =>
OptimiseLoop rep
optimiseLoop [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
val Body rep
body = do
  -- We start out by figuring out which of the merge variables should
  -- be double-buffered.
  [DoubleBuffer]
buffered <-
    [(Param FParamMem, SubExp)]
-> [Param FParamMem] -> Names -> DoubleBufferM rep [DoubleBuffer]
forall (m :: * -> *).
MonadFreshNames m =>
[(Param FParamMem, SubExp)]
-> [Param FParamMem] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams
      ([Param FParamMem] -> [SubExp] -> [(Param FParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
ctx) (Body rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult Body rep
body))
      (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge)
      (Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
body)
  -- Then create the allocations of the buffers and copies of the
  -- initial values.
  ([(Param FParamMem, SubExp)]
merge', [Stm rep]
allocs) <- [(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
forall rep inner.
(Constraints rep, Op rep ~ MemOp inner, BinderOps rep) =>
[(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge [DoubleBuffer]
buffered
  -- Modify the loop body to copy buffered result arrays.
  let body' :: Body rep
body' = [FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
forall rep.
Constraints rep =>
[FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
doubleBufferResult (((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge) [DoubleBuffer]
buffered Body rep
body
      ([(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val') = Int
-> [(Param FParamMem, SubExp)]
-> ([(Param FParamMem, SubExp)], [(Param FParamMem, SubExp)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(Param FParamMem, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
ctx) [(Param FParamMem, SubExp)]
merge'
  -- Modify the initial merge p
  ([Stm rep], [(Param FParamMem, SubExp)],
 [(Param FParamMem, SubExp)], Body rep)
-> DoubleBufferM
     rep
     ([Stm rep], [(Param FParamMem, SubExp)],
      [(Param FParamMem, SubExp)], Body rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm rep]
allocs, [(Param FParamMem, SubExp)]
ctx', [(Param FParamMem, SubExp)]
val', Body rep
body')
  where
    merge :: [(Param FParamMem, SubExp)]
merge = [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
val

-- | The booleans indicate whether we should also play with the
-- initial merge values.
data DoubleBuffer
  = BufferAlloc VName (PrimExp VName) Space Bool
  | -- | First name is the memory block to copy to,
    -- second is the name of the array copy.
    BufferCopy VName IxFun VName Bool
  | NoBuffer
  deriving (Int -> DoubleBuffer -> ShowS
[DoubleBuffer] -> ShowS
DoubleBuffer -> String
(Int -> DoubleBuffer -> ShowS)
-> (DoubleBuffer -> String)
-> ([DoubleBuffer] -> ShowS)
-> Show DoubleBuffer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DoubleBuffer] -> ShowS
$cshowList :: [DoubleBuffer] -> ShowS
show :: DoubleBuffer -> String
$cshow :: DoubleBuffer -> String
showsPrec :: Int -> DoubleBuffer -> ShowS
$cshowsPrec :: Int -> DoubleBuffer -> ShowS
Show)

doubleBufferMergeParams ::
  MonadFreshNames m =>
  [(Param FParamMem, SubExp)] ->
  [Param FParamMem] ->
  Names ->
  m [DoubleBuffer]
doubleBufferMergeParams :: forall (m :: * -> *).
MonadFreshNames m =>
[(Param FParamMem, SubExp)]
-> [Param FParamMem] -> Names -> m [DoubleBuffer]
doubleBufferMergeParams [(Param FParamMem, SubExp)]
ctx_and_res [Param FParamMem]
val_params Names
bound_in_loop =
  StateT (Map VName (VName, Bool)) m [DoubleBuffer]
-> Map VName (VName, Bool) -> m [DoubleBuffer]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((Param FParamMem
 -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> [Param FParamMem]
-> StateT (Map VName (VName, Bool)) m [DoubleBuffer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer [Param FParamMem]
val_params) Map VName (VName, Bool)
forall k a. Map k a
M.empty
  where
    loopVariant :: VName -> Bool
loopVariant VName
v =
      VName
v VName -> Names -> Bool
`nameIn` Names
bound_in_loop
        Bool -> Bool -> Bool
|| VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((Param FParamMem, SubExp) -> VName)
-> [(Param FParamMem, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
ctx_and_res

    loopInvariantSize :: SubExp -> Maybe (SubExp, Bool)
loopInvariantSize (Constant PrimValue
v) =
      (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
v, Bool
True)
    loopInvariantSize (Var VName
v) =
      case ((Param FParamMem, SubExp) -> Bool)
-> [(Param FParamMem, SubExp)] -> Maybe (Param FParamMem, SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
ctx_and_res of
        Just (Param FParamMem
_, Constant PrimValue
val) ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (PrimValue -> SubExp
Constant PrimValue
val, Bool
False)
        Just (Param FParamMem
_, Var VName
v')
          | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
loopVariant VName
v' ->
            (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v', Bool
False)
        Just (Param FParamMem, SubExp)
_ ->
          Maybe (SubExp, Bool)
forall a. Maybe a
Nothing
        Maybe (Param FParamMem, SubExp)
Nothing ->
          (SubExp, Bool) -> Maybe (SubExp, Bool)
forall a. a -> Maybe a
Just (VName -> SubExp
Var VName
v, Bool
True)

    sizeForMem :: VName -> Maybe (PrimExp VName, Bool)
sizeForMem VName
mem = [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a. [a] -> Maybe a
maybeHead ([(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool))
-> [(PrimExp VName, Bool)] -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (PrimExp VName, Bool))
-> [Param FParamMem] -> [(PrimExp VName, Bool)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (FParamMem -> Maybe (PrimExp VName, Bool))
-> (Param FParamMem -> FParamMem)
-> Param FParamMem
-> Maybe (PrimExp VName, Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec) [Param FParamMem]
val_params
      where
        arrayInMem :: FParamMem -> Maybe (PrimExp VName, Bool)
arrayInMem (MemArray PrimType
pt ShapeBase SubExp
shape Uniqueness
_ (ArrayIn VName
arraymem IxFun
ixfun))
          | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isDirect IxFun
ixfun,
            Just ([SubExp]
dims, [Bool]
b) <-
              (SubExp -> Maybe (SubExp, Bool))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM SubExp -> Maybe (SubExp, Bool)
loopInvariantSize ([SubExp] -> Maybe ([SubExp], [Bool]))
-> [SubExp] -> Maybe ([SubExp], [Bool])
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape,
            VName
mem VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
arraymem =
            (PrimExp VName, Bool) -> Maybe (PrimExp VName, Bool)
forall a. a -> Maybe a
Just
              ( Type -> PrimExp VName
arraySizeInBytesExp (Type -> PrimExp VName) -> Type -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
                  PrimType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) NoUniqueness
NoUniqueness,
                [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
b
              )
        arrayInMem FParamMem
_ = Maybe (PrimExp VName, Bool)
forall a. Maybe a
Nothing

    buffer :: Param FParamMem -> StateT (Map VName (VName, Bool)) m DoubleBuffer
buffer Param FParamMem
fparam = case Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam of
      Mem Space
space
        | Just (PrimExp VName
size, Bool
b) <- VName -> Maybe (PrimExp VName, Bool)
sizeForMem (VName -> Maybe (PrimExp VName, Bool))
-> VName -> Maybe (PrimExp VName, Bool)
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam -> do
          -- Let us double buffer this!
          VName
bufname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_mem"
          (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName (VName, Bool) -> Map VName (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m ())
-> (Map VName (VName, Bool) -> Map VName (VName, Bool))
-> StateT (Map VName (VName, Bool)) m ()
forall a b. (a -> b) -> a -> b
$ VName
-> (VName, Bool)
-> Map VName (VName, Bool)
-> Map VName (VName, Bool)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam) (VName
bufname, Bool
b)
          DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName -> Space -> Bool -> DoubleBuffer
BufferAlloc VName
bufname PrimExp VName
size Space
space Bool
b
      Array {}
        | MemArray PrimType
_ ShapeBase SubExp
_ Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) <- Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec Param FParamMem
fparam -> do
          Maybe (VName, Bool)
buffered <- (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName (VName, Bool) -> Maybe (VName, Bool))
 -> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool)))
-> (Map VName (VName, Bool) -> Maybe (VName, Bool))
-> StateT (Map VName (VName, Bool)) m (Maybe (VName, Bool))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VName, Bool) -> Maybe (VName, Bool)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
mem
          case Maybe (VName, Bool)
buffered of
            Just (VName
bufname, Bool
b) -> do
              VName
copyname <- m VName -> StateT (Map VName (VName, Bool)) m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> StateT (Map VName (VName, Bool)) m VName)
-> m VName -> StateT (Map VName (VName, Bool)) m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"double_buffer_array"
              DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return (DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer)
-> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> VName -> Bool -> DoubleBuffer
BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
b
            Maybe (VName, Bool)
Nothing ->
              DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer
      Type
_ -> DoubleBuffer -> StateT (Map VName (VName, Bool)) m DoubleBuffer
forall (m :: * -> *) a. Monad m => a -> m a
return DoubleBuffer
NoBuffer

allocStms ::
  (Constraints rep, Op rep ~ MemOp inner, BinderOps rep) =>
  [(FParam rep, SubExp)] ->
  [DoubleBuffer] ->
  DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms :: forall rep inner.
(Constraints rep, Op rep ~ MemOp inner, BinderOps rep) =>
[(FParam rep, SubExp)]
-> [DoubleBuffer]
-> DoubleBufferM rep ([(FParam rep, SubExp)], [Stm rep])
allocStms [(FParam rep, SubExp)]
merge = WriterT [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)]
-> DoubleBufferM rep ([(Param FParamMem, SubExp)], [Stm rep])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)]
 -> DoubleBufferM rep ([(Param FParamMem, SubExp)], [Stm rep]))
-> ([DoubleBuffer]
    -> WriterT
         [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)])
-> [DoubleBuffer]
-> DoubleBufferM rep ([(Param FParamMem, SubExp)], [Stm rep])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param FParamMem, SubExp)
 -> DoubleBuffer
 -> WriterT [Stm rep] (DoubleBufferM rep) (Param FParamMem, SubExp))
-> [(Param FParamMem, SubExp)]
-> [DoubleBuffer]
-> WriterT
     [Stm rep] (DoubleBufferM rep) [(Param FParamMem, SubExp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Param FParamMem, SubExp)
-> DoubleBuffer
-> WriterT [Stm rep] (DoubleBufferM rep) (Param FParamMem, SubExp)
forall {rep} {m :: * -> *} {t :: (* -> *) -> * -> *} {d} {ret}
       {rep} {inner}.
(HasScope rep m, OpReturns rep, AllocOp (Op rep),
 MonadFreshNames m, MonadTrans t, Typed (MemInfo d Uniqueness ret),
 MonadWriter [Stm rep] (t m), ASTRep rep, BinderOps rep,
 RetType rep ~ RetTypeMem, ExpDec rep ~ (),
 BranchType rep ~ BranchTypeMem, Op rep ~ MemOp inner,
 LParamInfo rep ~ LetDecMem, LParamInfo rep ~ LParamInfo rep,
 FParamInfo rep ~ FParamMem, FParamInfo rep ~ FParamInfo rep,
 LetDec rep ~ LetDecMem, LetDec rep ~ LetDec rep,
 LetDec rep ~ LetDecMem) =>
(Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge
  where
    allocation :: (Param (MemInfo d Uniqueness ret), SubExp)
-> DoubleBuffer -> t m (Param (MemInfo d Uniqueness ret), SubExp)
allocation m :: (Param (MemInfo d Uniqueness ret), SubExp)
m@(Param VName
pname MemInfo d Uniqueness ret
_, SubExp
_) (BufferAlloc VName
name PrimExp VName
size Space
space Bool
b) = do
      Stms rep
stms <- m (Stms rep) -> t m (Stms rep)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Stms rep) -> t m (Stms rep)) -> m (Stms rep) -> t m (Stms rep)
forall a b. (a -> b) -> a -> b
$
        Binder rep () -> m (Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Binder rep a -> m (Stms rep)
runBinder_ (Binder rep () -> m (Stms rep)) -> Binder rep () -> m (Stms rep)
forall a b. (a -> b) -> a -> b
$ do
          SubExp
size' <- String -> PrimExp VName -> BinderT rep (State VNameSource) SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"double_buffer_size" PrimExp VName
size
          [VName]
-> Exp (Rep (BinderT rep (State VNameSource))) -> Binder rep ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp (Rep (BinderT rep (State VNameSource))) -> Binder rep ())
-> Exp (Rep (BinderT rep (State VNameSource))) -> Binder rep ()
forall a b. (a -> b) -> a -> b
$ Op rep -> ExpT rep
forall rep. Op rep -> ExpT rep
Op (Op rep -> ExpT rep) -> Op rep -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp inner
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
size' Space
space
      [Stm rep] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Stm rep] -> t m ()) -> [Stm rep] -> t m ()
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
stms
      if Bool
b
        then (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall dec. VName -> dec -> Param dec
Param VName
pname (MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret))
-> MemInfo d Uniqueness ret -> Param (MemInfo d Uniqueness ret)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo d Uniqueness ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, VName -> SubExp
Var VName
name)
        else (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret), SubExp)
m
    allocation (Param (MemInfo d Uniqueness ret)
f, Var VName
v) (BufferCopy VName
mem IxFun
_ VName
_ Bool
b) | Bool
b = do
      VName
v_copy <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_double_buffer_copy"
      (VName
_v_mem, IxFun
v_ixfun) <- m (VName, IxFun) -> t m (VName, IxFun)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (VName, IxFun) -> t m (VName, IxFun))
-> m (VName, IxFun) -> t m (VName, IxFun)
forall a b. (a -> b) -> a -> b
$ VName -> m (VName, IxFun)
forall rep (m :: * -> *).
(Mem rep, HasScope rep m, Monad m) =>
VName -> m (VName, IxFun)
lookupArraySummary VName
v
      let bt :: PrimType
bt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          shape :: ShapeBase SubExp
shape = Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> ShapeBase SubExp) -> Type -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo d Uniqueness ret) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo d Uniqueness ret)
f
          bound :: LetDecMem
bound = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase SubExp
shape NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem IxFun
v_ixfun
      [Stm rep] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
        [ Pattern rep -> StmAux (ExpDec rep) -> ExpT rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
v_copy LetDecMem
bound]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT rep -> Stm rep) -> ExpT rep -> Stm rep
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
        ]
      -- It is important that we treat this as a consumption, to
      -- avoid the Copy from being hoisted out of any enclosing
      -- loops.  Since we re-use (=overwrite) memory in the loop,
      -- the copy is critical for initialisation.  See issue #816.
      let uniqueMemInfo :: MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemArray PrimType
pt ShapeBase d
pshape Uniqueness
_ ret
ret) =
            PrimType
-> ShapeBase d -> Uniqueness -> ret -> MemInfo d Uniqueness ret
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pshape Uniqueness
Unique ret
ret
          uniqueMemInfo MemInfo d Uniqueness ret
info = MemInfo d Uniqueness ret
info
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
forall {d} {ret}.
MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret
uniqueMemInfo (MemInfo d Uniqueness ret -> MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
-> Param (MemInfo d Uniqueness ret)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo d Uniqueness ret)
f, VName -> SubExp
Var VName
v_copy)
    allocation (Param (MemInfo d Uniqueness ret)
f, SubExp
se) DoubleBuffer
_ =
      (Param (MemInfo d Uniqueness ret), SubExp)
-> t m (Param (MemInfo d Uniqueness ret), SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (MemInfo d Uniqueness ret)
f, SubExp
se)

doubleBufferResult ::
  (Constraints rep) =>
  [FParam rep] ->
  [DoubleBuffer] ->
  Body rep ->
  Body rep
doubleBufferResult :: forall rep.
Constraints rep =>
[FParam rep] -> [DoubleBuffer] -> Body rep -> Body rep
doubleBufferResult [FParam rep]
valparams [DoubleBuffer]
buffered (Body BodyDec rep
_ Stms rep
bnds [SubExp]
res) =
  let ([SubExp]
ctx_res, [SubExp]
val_res) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
res Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param FParamMem]
valparams) [SubExp]
res
      ([Maybe (Stm rep)]
copybnds, [SubExp]
val_res') =
        [(Maybe (Stm rep), SubExp)] -> ([Maybe (Stm rep)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe (Stm rep), SubExp)] -> ([Maybe (Stm rep)], [SubExp]))
-> [(Maybe (Stm rep), SubExp)] -> ([Maybe (Stm rep)], [SubExp])
forall a b. (a -> b) -> a -> b
$ (Param FParamMem
 -> DoubleBuffer -> SubExp -> (Maybe (Stm rep), SubExp))
-> [Param FParamMem]
-> [DoubleBuffer]
-> [SubExp]
-> [(Maybe (Stm rep), SubExp)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param FParamMem
-> DoubleBuffer -> SubExp -> (Maybe (Stm rep), SubExp)
buffer [FParam rep]
[Param FParamMem]
valparams [DoubleBuffer]
buffered [SubExp]
val_res
   in BodyDec rep -> Stms rep -> [SubExp] -> BodyT rep
forall rep. BodyDec rep -> Stms rep -> [SubExp] -> BodyT rep
Body () (Stms rep
bnds Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Maybe (Stm rep)] -> [Stm rep]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm rep)]
copybnds)) ([SubExp] -> BodyT rep) -> [SubExp] -> BodyT rep
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_res [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
val_res'
  where
    buffer :: Param FParamMem
-> DoubleBuffer -> SubExp -> (Maybe (Stm rep), SubExp)
buffer Param FParamMem
_ (BufferAlloc VName
bufname PrimExp VName
_ Space
_ Bool
_) SubExp
_ =
      (Maybe (Stm rep)
forall a. Maybe a
Nothing, VName -> SubExp
Var VName
bufname)
    buffer Param FParamMem
fparam (BufferCopy VName
bufname IxFun
ixfun VName
copyname Bool
_) (Var VName
v) =
      -- To construct the copy we will need to figure out its type
      -- based on the type of the function parameter.
      let t :: Type
t = Type -> Type
resultType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam
          summary :: LetDecMem
summary = PrimType
-> ShapeBase SubExp -> NoUniqueness -> MemBind -> LetDecMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) NoUniqueness
NoUniqueness (MemBind -> LetDecMem) -> MemBind -> LetDecMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
bufname IxFun
ixfun
          copybnd :: Stm rep
copybnd =
            Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep.
Pattern rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT LetDecMem] -> [PatElemT LetDecMem] -> PatternT LetDecMem
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName -> LetDecMem -> PatElemT LetDecMem
forall dec. VName -> dec -> PatElemT dec
PatElem VName
copyname LetDecMem
summary]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp rep -> Stm rep) -> Exp rep -> Stm rep
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
       in (Stm rep -> Maybe (Stm rep)
forall a. a -> Maybe a
Just Stm rep
copybnd, VName -> SubExp
Var VName
copyname)
    buffer Param FParamMem
_ DoubleBuffer
_ SubExp
se =
      (Maybe (Stm rep)
forall a. Maybe a
Nothing, SubExp
se)

    parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param FParamMem -> VName) -> [Param FParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param FParamMem -> VName
forall dec. Param dec -> VName
paramName [FParam rep]
[Param FParamMem]
valparams) [SubExp]
res

    resultType :: Type -> Type
resultType Type
t = Type
t Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
substitute (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t)

    substitute :: SubExp -> SubExp
substitute (Var VName
v)
      | Just SubExp
replacement <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
replacement
    substitute SubExp
se =
      SubExp
se