{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor
import Data.Either (rights)
import Data.List (find, foldl')
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.Analysis.Alias as Alias
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Error
import Futhark.IR
import Futhark.IR.GPU.Simplify qualified as GPU
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Transform.Substitute
import Futhark.Util (mapAccumLM)
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
  String
-> String
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"expand allocations" String
"Expand allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
    \Prog GPUMem
prog -> do
      Stms GPUMem
consts' <-
        (VNameSource -> (Stms GPUMem, VNameSource)) -> PassM (Stms GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPUMem, VNameSource))
 -> PassM (Stms GPUMem))
-> (VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
          Either String (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft
            (Either String (Stms GPUMem, VNameSource)
 -> (Stms GPUMem, VNameSource))
-> (VNameSource -> Either String (Stms GPUMem, VNameSource))
-> VNameSource
-> (Stms GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Stms GPUMem)
-> VNameSource -> Either String (Stms GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Stms GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms (Prog GPUMem -> Stms GPUMem
forall rep. Prog rep -> Stms rep
progConsts Prog GPUMem
prog)) Scope GPUMem
forall a. Monoid a => a
mempty)
      [FunDef GPUMem]
funs' <- (FunDef GPUMem -> PassM (FunDef GPUMem))
-> [FunDef GPUMem] -> PassM [FunDef GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem))
-> Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') (Prog GPUMem -> [FunDef GPUMem]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog)
      Prog GPUMem -> PassM (Prog GPUMem)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog GPUMem -> PassM (Prog GPUMem))
-> Prog GPUMem -> PassM (Prog GPUMem)
forall a b. (a -> b) -> a -> b
$ Prog GPUMem
prog {progConsts = consts', progFuns = funs'}

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either String a -> a
limitationOnLeft = (String -> a) -> (a -> a) -> Either String a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> a
forall a. String -> a
compilerLimitationS a -> a
forall a. a -> a
id

transformFunDef ::
  Scope GPUMem ->
  FunDef GPUMem ->
  PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
  Body GPUMem
body' <- (VNameSource -> (Body GPUMem, VNameSource)) -> PassM (Body GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body GPUMem, VNameSource))
 -> PassM (Body GPUMem))
-> (VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Either String (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Body GPUMem, VNameSource)
 -> (Body GPUMem, VNameSource))
-> (VNameSource -> Either String (Body GPUMem, VNameSource))
-> VNameSource
-> (Body GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Body GPUMem)
-> VNameSource -> Either String (Body GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Body GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m Scope GPUMem
forall a. Monoid a => a
mempty)
  SimpleOps GPUMem
-> SymbolTable (Wise GPUMem)
-> FunDef GPUMem
-> PassM (FunDef GPUMem)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
    SimpleOps GPUMem
simpleGPUMem
    (Scope (Wise GPUMem) -> SymbolTable (Wise GPUMem)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope GPUMem -> Scope (Wise GPUMem)
forall rep. Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
    FunDef GPUMem
fundec {funDefBody = body'}
  where
    m :: ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m =
      Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
        FunDef GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
          Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
            FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPUMem
fundec

transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPUMem -> Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b.
ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (a -> b)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Result
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params [TypeBase (ShapeBase SubExp) NoUniqueness]
ret Body GPUMem
body) =
  [LParam GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPUMem
-> Lambda GPUMem
forall rep.
[LParam rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep
-> Lambda rep
Lambda [LParam GPUMem]
params [TypeBase (ShapeBase SubExp) NoUniqueness]
ret
    (Body GPUMem -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ExpandM (Lambda GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam GPUMem] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
params) (Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body)

transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms =
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ([Stms GPUMem] -> Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> [Stm GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)

transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- all versions fail do we propagate the error.
-- FIXME: this can remove safety checks if the default branch fails!
transformStm :: Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv))) = do
  let onCase :: Case (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
onCase (Case [Maybe PrimValue]
vs Body GPUMem
body) =
        (Case (Body GPUMem) -> Either String (Case (Body GPUMem))
forall a b. b -> Either a b
Right (Case (Body GPUMem) -> Either String (Case (Body GPUMem)))
-> (Body GPUMem -> Case (Body GPUMem))
-> Body GPUMem
-> Either String (Case (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe PrimValue] -> Body GPUMem -> Case (Body GPUMem)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body GPUMem -> Either String (Case (Body GPUMem)))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Either String (Case (Body GPUMem)))
-> (String
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String (Case (Body GPUMem))))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall a.
ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> (String
    -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Case (Body GPUMem))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Case (Body GPUMem))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Case (Body GPUMem))))
-> (String -> Either String (Case (Body GPUMem)))
-> String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Case (Body GPUMem))
forall a b. a -> Either a b
Left)
  [Case (Body GPUMem)]
cases' <- [Either String (Case (Body GPUMem))] -> [Case (Body GPUMem)]
forall a b. [Either a b] -> [b]
rights ([Either String (Case (Body GPUMem))] -> [Case (Body GPUMem)])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [Either String (Case (Body GPUMem))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [Case (Body GPUMem)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Case (Body GPUMem))))
-> [Case (Body GPUMem)]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [Either String (Case (Body GPUMem))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Case (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
onCase [Case (Body GPUMem)]
cases
  Either String (Body GPUMem)
defbody' <- (Body GPUMem -> Either String (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either String (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
defbody) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Either String (Body GPUMem))
-> (String
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String (Body GPUMem)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall a.
ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> (String
    -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Body GPUMem)))
-> (String -> Either String (Body GPUMem))
-> String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Body GPUMem)
forall a b. a -> Either a b
Left)
  case ([Case (Body GPUMem)]
cases', Either String (Body GPUMem)
defbody') of
    ([], Left String
e) ->
      String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
    (Case (Body GPUMem)
_ : [Case (Body GPUMem)]
_, Left String
_) ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond ([Case (Body GPUMem)] -> [Case (Body GPUMem)]
forall a. HasCallStack => [a] -> [a]
init [Case (Body GPUMem)]
cases') (Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody (Case (Body GPUMem) -> Body GPUMem)
-> Case (Body GPUMem) -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ [Case (Body GPUMem)] -> Case (Body GPUMem)
forall a. HasCallStack => [a] -> a
last [Case (Body GPUMem)]
cases') ([BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
[BranchTypeMem]
ts MatchSort
MatchEquiv)
    ([Case (Body GPUMem)]
_, Right Body GPUMem
defbody'') ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases' Body GPUMem
defbody'' ([BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
[BranchTypeMem]
ts MatchSort
MatchEquiv)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e) = do
  (Stms GPUMem
stms, Exp GPUMem
e') <- Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp (Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Exp GPUMem)
-> ExpandM (Stms GPUMem, Exp GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
-> Exp GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform Exp GPUMem
e
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem
stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e')
  where
    transform :: Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform =
      Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody = \Scope GPUMem
scope -> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> (Body GPUMem
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody
        }

transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      OpC GPUMem GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (OpC GPUMem GPUMem -> Exp GPUMem)
-> OpC GPUMem GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
  let reds' :: [SegBinOp GPUMem]
reds' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda = lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      OpC GPUMem GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (OpC GPUMem GPUMem -> Exp GPUMem)
-> OpC GPUMem GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl' SegSpace
space [SegBinOp GPUMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
  let scans' :: [SegBinOp GPUMem]
scans' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda = lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      OpC GPUMem GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (OpC GPUMem GPUMem -> Exp GPUMem)
-> OpC GPUMem GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl' SegSpace
space [SegBinOp GPUMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
  let ops' :: [HistOp GPUMem]
ops' = (HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem] -> [HistOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem
forall {rep} {rep}. HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      OpC GPUMem GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (OpC GPUMem GPUMem -> Exp GPUMem)
-> OpC GPUMem GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl' SegSpace
space [HistOp GPUMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
  where
    lams :: [Lambda GPUMem]
lams = (HistOp GPUMem -> Lambda GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
    onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp = lam}
transformExp (WithAcc [WithAccInput GPUMem]
inputs Lambda GPUMem
lam) = do
  Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
  ([Stms GPUMem]
input_alloc_stms, [WithAccInput GPUMem]
inputs') <- (WithAccInput GPUMem
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, WithAccInput GPUMem))
-> [WithAccInput GPUMem]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Stms GPUMem], [WithAccInput GPUMem])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM WithAccInput GPUMem
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, WithAccInput GPUMem)
forall {b} {b}.
(ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput [WithAccInput GPUMem]
inputs
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
      [WithAccInput GPUMem] -> Lambda GPUMem -> Exp GPUMem
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPUMem]
inputs' Lambda GPUMem
lam'
    )
  where
    onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
forall a. Maybe a
Nothing))
    onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
      Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
      let -- XXX: fake a SegLevel, which we don't have here.  We will not
          -- use it for anything, as we will not allow irregular
          -- allocations inside the update function.
          lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt Maybe KernelGrid
forall a. Maybe a
Nothing
          (Lambda GPUMem
op_lam', Extraction
lam_allocs) =
            (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside Names
forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
          variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
          variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
          (Extraction
variant_allocs, Extraction
invariant_allocs) = (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs

      case Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
        ((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
          String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
            String
"Cannot handle un-sliceable allocation size: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
prettyString SubExp
v
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside accumulator update operator."
        [] ->
          ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      let num_is :: Int
num_is = ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
          is :: [VName]
is = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take Int
num_is ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
      (Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
        ((SegLevel, [TPrimExp Int64 VName])
 -> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations ((Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> (SegLevel, [TPrimExp Int64 VName])
-> Space
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. a -> b -> a
const ((Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
 -> (SegLevel, [TPrimExp Int64 VName])
 -> Space
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> (Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> (SegLevel, [TPrimExp Int64 VName])
-> Space
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ (ShapeBase SubExp, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. a -> b -> a
const (ShapeBase SubExp
shape, (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs

      Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      let scope' :: Scope GPUMem
scope' = Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
alloc_stms
      (String
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> ((Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
     String
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   String
   (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> (OffsetM
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either
            String
            (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Scope GPUMem
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either
        String
        (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
forall (m :: * -> *) a.
MonadFreshNames m =>
Scope GPUMem -> OffsetM a -> m (Either String a)
runOffsetM Scope GPUMem
scope' (OffsetM
   (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$ do
        Lambda GPUMem
op_lam'' <- RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda RebaseMap
alloc_offsets Lambda GPUMem
op_lam'
        (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, (ShapeBase SubExp
shape, b
arrs, (Lambda GPUMem, b) -> Maybe (Lambda GPUMem, b)
forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp Exp GPUMem
e =
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, Exp GPUMem
e)

ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown SegLevel
lvl =
  case SegLevel
lvl of
    SegThread SegVirt
_ (Just KernelGrid
grid) -> (Stms GPUMem, SegLevel, KernelGrid)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, SegLevel
lvl, KernelGrid
grid)
    SegBlock SegVirt
_ (Just KernelGrid
grid) -> (Stms GPUMem, SegLevel, KernelGrid)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, SegLevel
lvl, KernelGrid
grid)
    SegThread SegVirt
virt Maybe KernelGrid
Nothing -> (Maybe KernelGrid -> SegLevel)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall {somerep} {rep} {op :: * -> *} {m :: * -> *} {b}.
(FParamInfo somerep ~ FParamInfo rep, OpC rep ~ MemOp (HostOp op),
 LetDec somerep ~ LetDec rep, LParamInfo somerep ~ LParamInfo rep,
 HasScope somerep m, MonadFreshNames m, BuilderOps rep,
 IsOp (op rep), RephraseOp op) =>
(Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid (SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt)
    SegBlock SegVirt
virt Maybe KernelGrid
Nothing -> (Maybe KernelGrid -> SegLevel)
-> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall {somerep} {rep} {op :: * -> *} {m :: * -> *} {b}.
(FParamInfo somerep ~ FParamInfo rep, OpC rep ~ MemOp (HostOp op),
 LetDec somerep ~ LetDec rep, LParamInfo somerep ~ LParamInfo rep,
 HasScope somerep m, MonadFreshNames m, BuilderOps rep,
 IsOp (op rep), RephraseOp op) =>
(Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid (SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
virt)
    SegThreadInBlock {} -> String -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
forall a. HasCallStack => String -> a
error String
"ensureGridKnown: SegThreadInBlock"
  where
    mkGrid :: (Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid Maybe KernelGrid -> b
f = do
      (KernelGrid
grid, Stms rep
stms) <-
        Builder rep KernelGrid -> m (KernelGrid, Stms rep)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder rep KernelGrid -> m (KernelGrid, Stms rep))
-> Builder rep KernelGrid -> m (KernelGrid, Stms rep)
forall a b. (a -> b) -> a -> b
$
          Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid
            (Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid)
-> BuilderT rep (State VNameSource) (Count NumBlocks SubExp)
-> BuilderT
     rep (State VNameSource) (Count BlockSize SubExp -> KernelGrid)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count NumBlocks SubExp)
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) (Count NumBlocks SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> BuilderT rep (State VNameSource) SubExp
forall {m :: * -> *} {op :: * -> *}.
(OpC (Rep m) ~ MemOp (HostOp op), MonadBuilder m,
 IsOp (op (Rep m)), RephraseOp op) =>
String -> SizeClass -> m SubExp
getSize String
"num_tblocks" SizeClass
SizeGrid)
            BuilderT
  rep (State VNameSource) (Count BlockSize SubExp -> KernelGrid)
-> BuilderT rep (State VNameSource) (Count BlockSize SubExp)
-> Builder rep KernelGrid
forall a b.
BuilderT rep (State VNameSource) (a -> b)
-> BuilderT rep (State VNameSource) a
-> BuilderT rep (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> BuilderT rep (State VNameSource) SubExp
-> BuilderT rep (State VNameSource) (Count BlockSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> BuilderT rep (State VNameSource) SubExp
forall {m :: * -> *} {op :: * -> *}.
(OpC (Rep m) ~ MemOp (HostOp op), MonadBuilder m,
 IsOp (op (Rep m)), RephraseOp op) =>
String -> SizeClass -> m SubExp
getSize String
"tblock_size" SizeClass
SizeThreadBlock)
      (Stms rep, b, KernelGrid) -> m (Stms rep, b, KernelGrid)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
stms, Maybe KernelGrid -> b
f (Maybe KernelGrid -> b) -> Maybe KernelGrid -> b
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid, KernelGrid
grid)

    getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
      Name
size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
prettyString (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
      String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
desc (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ HostOp op (Rep m) -> MemOp (HostOp op) (Rep m)
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp op (Rep m) -> MemOp (HostOp op) (Rep m))
-> HostOp op (Rep m) -> MemOp (HostOp op) (Rep m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp op (Rep m)
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp op (Rep m)) -> SizeOp -> HostOp op (Rep m)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda GPUMem] ->
  KernelBody GPUMem ->
  ExpandM (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
  Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
  let user :: (SegLevel, [TPrimExp Int64 VName])
user = (SegLevel
lvl, [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
      (KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
        (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
      ([Lambda GPUMem]
ops', [Extraction]
ops_allocs) =
        [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction]))
-> [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda GPUMem -> (Lambda GPUMem, Extraction))
-> [Lambda GPUMem] -> [(Lambda GPUMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map ((SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
      variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
      variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
      (Extraction
variant_allocs, Extraction
invariant_allocs) =
        (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc (Extraction -> (Extraction, Extraction))
-> Extraction -> (Extraction, Extraction)
forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      badVariant :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_in_kernel
      badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False

  case (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ([((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
 -> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall a b. (a -> b) -> a -> b
$ Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
    Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
      String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
        String
"Cannot handle un-sliceable allocation size: "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> String
forall a. Pretty a => a -> String
prettyString ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside parallel constructs."
    Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  case SegLevel
lvl of
    SegBlock {}
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
          String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegBlock."
    SegLevel
_ ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  if Extraction -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs Bool -> Bool -> Bool
&& Extraction -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
invariant_allocs
    then (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (SegLevel
lvl, [Lambda GPUMem]
ops, KernelBody GPUMem
kbody))
    else do
      (Stms GPUMem
lvl_stms, SegLevel
lvl', KernelGrid
grid) <- SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown SegLevel
lvl
      Extraction
-> Extraction
-> KernelGrid
-> SegSpace
-> KernelBody GPUMem
-> KernelBody GPUMem
-> (RebaseMap
    -> Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM
         (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
forall b.
Extraction
-> Extraction
-> KernelGrid
-> SegSpace
-> KernelBody GPUMem
-> KernelBody GPUMem
-> (RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs KernelGrid
grid SegSpace
space KernelBody GPUMem
kbody KernelBody GPUMem
kbody' ((RebaseMap
  -> Stms GPUMem
  -> KernelBody GPUMem
  -> OffsetM
       (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem)))
 -> ExpandM
      (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem)))
-> (RebaseMap
    -> Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM
         (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
forall a b. (a -> b) -> a -> b
$
        \RebaseMap
offsets Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
          [Lambda GPUMem]
ops'' <- [Lambda GPUMem]
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' ((Lambda GPUMem -> OffsetM (Lambda GPUMem))
 -> OffsetM [Lambda GPUMem])
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
            Scope GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda RebaseMap
offsets Lambda GPUMem
op'
          (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
-> OffsetM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
lvl_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody

boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody GPUMem -> [VName]) -> KernelBody GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope GPUMem -> [VName])
-> (KernelBody GPUMem -> Scope GPUMem)
-> KernelBody GPUMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms GPUMem -> Scope GPUMem)
-> (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem
-> Scope GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

addStmsToKernelBody :: Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem
addStmsToKernelBody :: Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem
addStmsToKernelBody Stms GPUMem
stms KernelBody GPUMem
kbody =
  KernelBody GPUMem
kbody {kernelBodyStms = stms <> kernelBodyStms kbody}

allocsForBody ::
  Extraction ->
  Extraction ->
  KernelGrid ->
  SegSpace ->
  KernelBody GPUMem ->
  KernelBody GPUMem ->
  (RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> KernelGrid
-> SegSpace
-> KernelBody GPUMem
-> KernelBody GPUMem
-> (RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs KernelGrid
grid SegSpace
space KernelBody GPUMem
kbody KernelBody GPUMem
kbody' RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
    KernelGrid
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
      KernelGrid
grid
      SegSpace
space
      (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
      Extraction
variant_allocs
      Extraction
invariant_allocs

  -- We assume that any shared memory allocations can be inserted back
  -- into kbody'. This would not work if we had SegRed/SegScan
  -- operations that performed shared memory allocations. We don't
  -- currently, and if we would in the future, we would need to be
  -- more careful about summarising the allocations in
  -- transformScanRed.
  let (Stms GPUMem
alloc_stms_dev, Stms GPUMem
alloc_stms_shared) =
        (Stm GPUMem -> Bool) -> Stms GPUMem -> (Stms GPUMem, Stms GPUMem)
forall a. (a -> Bool) -> Seq a -> (Seq a, Seq a)
Seq.partition (Bool -> Bool
not (Bool -> Bool) -> (Stm GPUMem -> Bool) -> Stm GPUMem -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> Bool
forall {rep} {inner :: * -> *}.
(OpC rep ~ MemOp inner) =>
Stm rep -> Bool
isSharedAlloc) Stms GPUMem
alloc_stms

  Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let scope' :: Scope GPUMem
scope' = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
alloc_stms
  (String -> ExpandM b)
-> (b -> ExpandM b) -> Either String b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ExpandM b
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String b -> ExpandM b)
-> (OffsetM b
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String b))
-> OffsetM b
-> ExpandM b
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Scope GPUMem
-> OffsetM b
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String b)
forall (m :: * -> *) a.
MonadFreshNames m =>
Scope GPUMem -> OffsetM a -> m (Either String a)
runOffsetM Scope GPUMem
scope' (OffsetM b -> ExpandM b) -> OffsetM b -> ExpandM b
forall a b. (a -> b) -> a -> b
$ do
    KernelBody GPUMem
kbody'' <-
      Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem
addStmsToKernelBody Stms GPUMem
alloc_stms_shared
        (KernelBody GPUMem -> KernelBody GPUMem)
-> OffsetM (KernelBody GPUMem) -> OffsetM (KernelBody GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RebaseMap -> KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody RebaseMap
alloc_offsets KernelBody GPUMem
kbody'
    RebaseMap -> Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m RebaseMap
alloc_offsets Stms GPUMem
alloc_stms_dev KernelBody GPUMem
kbody''
  where
    isSharedAlloc :: Stm rep -> Bool
isSharedAlloc (Let Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ (Op (Alloc SubExp
_ (Space String
"shared")))) = Bool
True
    isSharedAlloc Stm rep
_ = Bool
False

memoryRequirements ::
  KernelGrid ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: KernelGrid
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements KernelGrid
grid SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  (SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
    Builder GPUMem SubExp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem SubExp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SubExp, Stms GPUMem))
-> (BasicOp -> Builder GPUMem SubExp)
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> Builder GPUMem SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp GPUMem -> Builder GPUMem SubExp)
-> (BasicOp -> Exp GPUMem) -> BasicOp -> Builder GPUMem SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPUMem
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SubExp, Stms GPUMem))
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        (Count NumBlocks SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count NumBlocks SubExp -> SubExp)
-> Count NumBlocks SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count NumBlocks SubExp
gridNumBlocks KernelGrid
grid)
        (Count BlockSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount (Count BlockSize SubExp -> SubExp)
-> Count BlockSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count BlockSize SubExp
gridBlockSize KernelGrid
grid)

  (Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
        SubExp
num_threads
        (KernelGrid -> Count NumBlocks SubExp
gridNumBlocks KernelGrid
grid)
        (KernelGrid -> Count BlockSize SubExp
gridBlockSize KernelGrid
grid)
        Extraction
invariant_allocs

  (Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms GPUMem
kstms
        Extraction
variant_allocs

  (RebaseMap, Stms GPUMem) -> ExpandM (RebaseMap, Stms GPUMem)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms GPUMem
num_threads_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
    )

type Exp64 = TPrimExp Int64 VName

-- | Identifying the spot where an allocation occurs in terms of its
-- level and unique thread ID.
type User = (SegLevel, [Exp64])

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (User, SubExp, Space)

extractKernelBodyAllocations ::
  User ->
  Names ->
  Names ->
  KernelBody GPUMem ->
  ( KernelBody GPUMem,
    Extraction
  )
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
  (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (KernelBody GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms ((Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
 -> KernelBody GPUMem -> (KernelBody GPUMem, Extraction))
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms = stms}

extractBodyAllocations ::
  User ->
  Names ->
  Names ->
  Body GPUMem ->
  (Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
  (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (Body GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms ((Stms GPUMem -> Body GPUMem -> Body GPUMem)
 -> Body GPUMem -> (Body GPUMem, Extraction))
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms = stms}

extractLambdaAllocations ::
  User ->
  Names ->
  Names ->
  Lambda GPUMem ->
  (Lambda GPUMem, Extraction)
extractLambdaAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam =
  (Lambda GPUMem
lam {lambdaBody = body'}, Extraction
allocs)
  where
    (Body GPUMem
body', Extraction
allocs) =
      (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Body GPUMem -> (Body GPUMem, Extraction))
-> Body GPUMem -> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
        Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam

extractGenericBodyAllocations ::
  User ->
  Names ->
  Names ->
  (body -> Stms GPUMem) ->
  (Stms GPUMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
  let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Names
forall rep. Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
      ([Stm GPUMem]
stms, Extraction
allocs) =
        Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction))
-> (WriterT Extraction Identity [Maybe (Stm GPUMem)]
    -> Writer Extraction [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> ([Stm GPUMem], Extraction)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Maybe (Stm GPUMem)] -> [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall a b.
(a -> b)
-> WriterT Extraction Identity a -> WriterT Extraction Identity b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm GPUMem)] -> [Stm GPUMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm GPUMem)]
 -> ([Stm GPUMem], Extraction))
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> ([Stm GPUMem], Extraction)
forall a b. (a -> b) -> a -> b
$
          (Stm GPUMem -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') ([Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)])
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall a b. (a -> b) -> a -> b
$
            Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (body -> Stms GPUMem
get_stms body
body)
   in (Stms GPUMem -> body -> body
set_stms ([Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)

expandable :: User -> Space -> Bool
expandable :: (SegLevel, [TPrimExp Int64 VName]) -> Space -> Bool
expandable (SegBlock {}, [TPrimExp Int64 VName]
_) (Space String
"shared") = Bool
False
expandable (SegLevel, [TPrimExp Int64 VName])
_ ScalarSpace {} = Bool
False
expandable (SegLevel, [TPrimExp Int64 VName])
_ Space
_ = Bool
True

notScalar :: Space -> Bool
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True

extractStmAllocations ::
  User ->
  Names ->
  Names ->
  Stm GPUMem ->
  Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pat [PatElem (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
  | (SegLevel, [TPrimExp Int64 VName]) -> Space -> Bool
expandable (SegLevel, [TPrimExp Int64 VName])
user Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
      -- FIXME: the '&& notScalar space' part is a hack because we
      -- don't otherwise hoist the sizes out far enough, and we
      -- promise to be super-duper-careful about not having variant
      -- scalar allocations.
      Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName
-> ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
-> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem (MemInfo SubExp NoUniqueness MemBind)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
      Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stm GPUMem)
forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
  Exp GPUMem
e <- Mapper GPUMem GPUMem (WriterT Extraction Identity)
-> Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) (Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem))
-> Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm
  Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stm GPUMem)
 -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Maybe (Stm GPUMem)
forall a. a -> Maybe a
Just (Stm GPUMem -> Maybe (Stm GPUMem))
-> Stm GPUMem -> Maybe (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp = e}
  where
    expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @GPUMem)
        { mapOnBody = const $ onBody user',
          mapOnOp = onOp user'
        }

    onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
      let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPUMem
body'

    onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp NoOp) GPUMem
-> WriterT Extraction Identity (MemOp (HostOp NoOp) GPUMem)
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp NoOp) GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
-> WriterT Extraction Identity (MemOp (HostOp NoOp) GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
-> SegOp SegLevel GPUMem
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
      where
        user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
          (SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
    onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp NoOp) GPUMem
op = MemOp (HostOp NoOp) GPUMem
-> WriterT Extraction Identity (MemOp (HostOp NoOp) GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp NoOp) GPUMem
op

    opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda = onLambda user',
          mapOnSegOpBody = onKernelBody user'
        }

    onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
      let (KernelBody GPUMem
body', Extraction
allocs) =
            (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
body'

    onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
      Body GPUMem
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
      Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
forall a. a -> WriterT Extraction Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda GPUMem
lam {lambdaBody = body}

genericExpandedInvariantAllocations ::
  (User -> Space -> (Shape, [Exp64])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
 -> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the number of kernel threads.
  ([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- Builder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([RebaseMap], Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem [RebaseMap]
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      ([RebaseMap], Stms GPUMem))
-> Builder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([RebaseMap], Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ ((VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
 -> BuilderT GPUMem (State VNameSource) RebaseMap)
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
 -> Builder GPUMem [RebaseMap])
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap]
forall a b. (a -> b) -> a -> b
$ Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
      let num_users :: ShapeBase SubExp
num_users = (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a, b) -> a
fst ((ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp)
-> (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space
          allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      VName
total_size <-
        String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"total_size" (Exp GPUMem -> BuilderT GPUMem (State VNameSource) VName)
-> ([TPrimExp Int64 VName]
    -> BuilderT GPUMem (State VNameSource) (Exp GPUMem))
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName
-> BuilderT
     GPUMem
     (State VNameSource)
     (Exp (Rep (BuilderT GPUMem (State VNameSource))))
TPrimExp Int64 VName
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName
 -> BuilderT GPUMem (State VNameSource) (Exp GPUMem))
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName]
 -> BuilderT GPUMem (State VNameSource) VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
      Pat (LetDec (Rep (BuilderT GPUMem (State VNameSource))))
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPUMem (State VNameSource))))
Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (Exp (Rep (BuilderT GPUMem (State VNameSource)))
 -> BuilderT GPUMem (State VNameSource) ())
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPUMem (State VNameSource)))
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPUMem (State VNameSource)))
 -> Exp (Rep (BuilderT GPUMem (State VNameSource))))
-> Op (Rep (BuilderT GPUMem (State VNameSource)))
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc (VName -> SubExp
Var VName
total_size) Space
space
      RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap
forall a. a -> BuilderT GPUMem (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap)
-> RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap
forall a b. (a -> b) -> a -> b
$ VName
-> ([TPrimExp Int64 VName]
    -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem (([TPrimExp Int64 VName]
  -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
 -> RebaseMap)
-> ([TPrimExp Int64 VName]
    -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase (SegLevel, [TPrimExp Int64 VName])
user Space
space

    newBaseThread :: (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user Space
space [TPrimExp Int64 VName]
_old_shape =
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space
          dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape)
       in ( [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName]
user_ids,
            [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims
          )

    newBase :: (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThreadInBlock {}, [TPrimExp Int64 VName]
_) Space
space = (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user Space
space
    newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) Space
space = (SegLevel, [TPrimExp Int64 VName])
-> Space
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user Space
space
    newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegBlock {}, [TPrimExp Int64 VName]
_) Space
space = \[TPrimExp Int64 VName]
_old_shape ->
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space
          dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape)
       in ( [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims [TPrimExp Int64 VName]
user_ids,
            [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims
          )

expandedInvariantAllocations ::
  SubExp ->
  Count NumBlocks SubExp ->
  Count BlockSize SubExp ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumBlocks SubExp
-> Count BlockSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_tblocks) (Count SubExp
tblock_size) =
  ((SegLevel, [TPrimExp Int64 VName])
 -> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
  where
    getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> Space -> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_tblocks, SubExp
tblock_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegThreadInBlock {}, [TPrimExp Int64 VName
gtid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThreadInBlock {}, [TPrimExp Int64 VName
_gid, TPrimExp Int64 VName
ltid]) (Space String
"shared") =
      ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
tblock_size], [TPrimExp Int64 VName
ltid])
    getNumUsers (SegThreadInBlock {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) (Space String
"device") =
      ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_tblocks, SubExp
tblock_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegBlock {}, [TPrimExp Int64 VName
gid]) Space
_ = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_tblocks], [TPrimExp Int64 VName
gid])
    getNumUsers (SegLevel, [TPrimExp Int64 VName])
user Space
space = String -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a. HasCallStack => String -> a
error (String -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> String -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ String
"getNumUsers: unhandled " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), Space) -> String
forall a. Show a => a -> String
show ((SegLevel, [TPrimExp Int64 VName])
user, Space
space)

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
  | Extraction -> Bool
forall a. Map VName a -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: [SubExp]
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
variant_sizes SegSpace
kspace Stms GPUMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  Stms GPUMem
slice_stms_tmp <- Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (Stms GPUMem)
simplifyStms (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
  Stms GPUMem
slice_stms' <- Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
          ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks) ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm GPUMem]
alloc_stms, [RebaseMap]
rebases) <- ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stm GPUMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Stm GPUMem], [RebaseMap])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
slice_stms' Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
_offset, SubExp
total_size, Space
space)) = do
      let allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      (Stm GPUMem, RebaseMap)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ OpC GPUMem GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (OpC GPUMem GPUMem -> Exp GPUMem)
-> OpC GPUMem GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. SubExp -> Space -> MemOp inner rep
Alloc SubExp
total_size Space
space,
          VName
-> ([TPrimExp Int64 VName]
    -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase
        )

    num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
    gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
newBase [TPrimExp Int64 VName]
_old_shape =
      (TPrimExp Int64 VName
gtid, TPrimExp Int64 VName
num_threads')

type Expansion = (Exp64, Exp64)

-- | A map from memory block names to index function embeddings..
type RebaseMap = M.Map VName ([Exp64] -> Expansion)

--- Modifying the index functions of code.

newtype OffsetM a
  = OffsetM (BuilderT GPUMem (StateT VNameSource (Either String)) a)
  deriving
    ( Functor OffsetM
Functor OffsetM =>
(forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> OffsetM a
pure :: forall a. a -> OffsetM a
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
Applicative,
      (forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
<$ :: forall a b. a -> OffsetM b -> OffsetM a
Functor,
      Applicative OffsetM
Applicative OffsetM =>
(forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$creturn :: forall a. a -> OffsetM a
return :: forall a. a -> OffsetM a
Monad,
      HasScope GPUMem,
      LocalScope GPUMem,
      MonadError String,
      Monad OffsetM
OffsetM VNameSource
Monad OffsetM =>
OffsetM VNameSource
-> (VNameSource -> OffsetM ()) -> MonadFreshNames OffsetM
VNameSource -> OffsetM ()
forall (m :: * -> *).
Monad m =>
m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: OffsetM VNameSource
getNameSource :: OffsetM VNameSource
$cputNameSource :: VNameSource -> OffsetM ()
putNameSource :: VNameSource -> OffsetM ()
MonadFreshNames
    )

instance MonadBuilder OffsetM where
  type Rep OffsetM = GPUMem
  mkExpDecM :: Pat (LetDec (Rep OffsetM))
-> Exp (Rep OffsetM) -> OffsetM (ExpDec (Rep OffsetM))
mkExpDecM Pat (LetDec (Rep OffsetM))
pat Exp (Rep OffsetM)
e = BuilderT
  GPUMem (StateT VNameSource (Either String)) (ExpDec (Rep OffsetM))
-> OffsetM (ExpDec (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
   GPUMem (StateT VNameSource (Either String)) (ExpDec (Rep OffsetM))
 -> OffsetM (ExpDec (Rep OffsetM)))
-> BuilderT
     GPUMem (StateT VNameSource (Either String)) (ExpDec (Rep OffsetM))
-> OffsetM (ExpDec (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ Pat
  (LetDec
     (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
-> Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> BuilderT
     GPUMem
     (StateT VNameSource (Either String))
     (ExpDec
        (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat
  (LetDec
     (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
Pat (LetDec (Rep OffsetM))
pat Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
Exp (Rep OffsetM)
e
  mkBodyM :: Stms (Rep OffsetM) -> Result -> OffsetM (Body (Rep OffsetM))
mkBodyM Stms (Rep OffsetM)
stms Result
res = BuilderT
  GPUMem (StateT VNameSource (Either String)) (Body (Rep OffsetM))
-> OffsetM (Body (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
   GPUMem (StateT VNameSource (Either String)) (Body (Rep OffsetM))
 -> OffsetM (Body (Rep OffsetM)))
-> BuilderT
     GPUMem (StateT VNameSource (Either String)) (Body (Rep OffsetM))
-> OffsetM (Body (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> Result
-> BuilderT
     GPUMem
     (StateT VNameSource (Either String))
     (Body (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
Stms (Rep OffsetM)
stms Result
res
  mkLetNamesM :: [VName] -> Exp (Rep OffsetM) -> OffsetM (Stm (Rep OffsetM))
mkLetNamesM [VName]
pat Exp (Rep OffsetM)
e = BuilderT
  GPUMem (StateT VNameSource (Either String)) (Stm (Rep OffsetM))
-> OffsetM (Stm (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
   GPUMem (StateT VNameSource (Either String)) (Stm (Rep OffsetM))
 -> OffsetM (Stm (Rep OffsetM)))
-> BuilderT
     GPUMem (StateT VNameSource (Either String)) (Stm (Rep OffsetM))
-> OffsetM (Stm (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> BuilderT
     GPUMem
     (StateT VNameSource (Either String))
     (Stm (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
Exp (Rep OffsetM)
e

  addStms :: Stms (Rep OffsetM) -> OffsetM ()
addStms = BuilderT GPUMem (StateT VNameSource (Either String)) ()
-> OffsetM ()
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT GPUMem (StateT VNameSource (Either String)) ()
 -> OffsetM ())
-> (Stms GPUMem
    -> BuilderT GPUMem (StateT VNameSource (Either String)) ())
-> Stms GPUMem
-> OffsetM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String))))
-> BuilderT GPUMem (StateT VNameSource (Either String)) ()
Stms GPUMem
-> BuilderT GPUMem (StateT VNameSource (Either String)) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
  collectStms :: forall a. OffsetM a -> OffsetM (a, Stms (Rep OffsetM))
collectStms (OffsetM BuilderT GPUMem (StateT VNameSource (Either String)) a
m) = BuilderT
  GPUMem (StateT VNameSource (Either String)) (a, Stms (Rep OffsetM))
-> OffsetM (a, Stms (Rep OffsetM))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a -> OffsetM a
OffsetM (BuilderT
   GPUMem (StateT VNameSource (Either String)) (a, Stms (Rep OffsetM))
 -> OffsetM (a, Stms (Rep OffsetM)))
-> BuilderT
     GPUMem (StateT VNameSource (Either String)) (a, Stms (Rep OffsetM))
-> OffsetM (a, Stms (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ BuilderT GPUMem (StateT VNameSource (Either String)) a
-> BuilderT
     GPUMem
     (StateT VNameSource (Either String))
     (a,
      Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall a.
BuilderT GPUMem (StateT VNameSource (Either String)) a
-> BuilderT
     GPUMem
     (StateT VNameSource (Either String))
     (a,
      Stms (Rep (BuilderT GPUMem (StateT VNameSource (Either String)))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT GPUMem (StateT VNameSource (Either String)) a
m

runOffsetM ::
  (MonadFreshNames m) =>
  Scope GPUMem ->
  OffsetM a ->
  m (Either String a)
runOffsetM :: forall (m :: * -> *) a.
MonadFreshNames m =>
Scope GPUMem -> OffsetM a -> m (Either String a)
runOffsetM Scope GPUMem
scope (OffsetM BuilderT GPUMem (StateT VNameSource (Either String)) a
m) = (VNameSource -> (Either String a, VNameSource))
-> m (Either String a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Either String a, VNameSource))
 -> m (Either String a))
-> (VNameSource -> (Either String a, VNameSource))
-> m (Either String a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  case StateT VNameSource (Either String) (a, Stms GPUMem)
-> VNameSource -> Either String ((a, Stms GPUMem), VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (BuilderT GPUMem (StateT VNameSource (Either String)) a
-> Scope GPUMem
-> StateT VNameSource (Either String) (a, Stms GPUMem)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT GPUMem (StateT VNameSource (Either String)) a
m Scope GPUMem
scope) VNameSource
src of
    Left String
e -> (String -> Either String a
forall a b. a -> Either a b
Left String
e, VNameSource
src)
    Right ((a, Stms GPUMem)
x, VNameSource
src') -> (a -> Either String a
forall a b. b -> Either a b
Right ((a, Stms GPUMem) -> a
forall a b. (a, b) -> a
fst (a, Stms GPUMem)
x), VNameSource
src')

lookupNewBase :: VName -> [Exp64] -> RebaseMap -> Maybe Expansion
lookupNewBase :: VName
-> [TPrimExp Int64 VName]
-> RebaseMap
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
lookupNewBase VName
name [TPrimExp Int64 VName]
x RebaseMap
offsets =
  (([TPrimExp Int64 VName]
 -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> [TPrimExp Int64 VName]
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName]
x) (([TPrimExp Int64 VName]
  -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
 -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> Maybe
     ([TPrimExp Int64 VName]
      -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
     ([TPrimExp Int64 VName]
      -> (TPrimExp Int64 VName, TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: RebaseMap -> KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: RebaseMap -> KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody RebaseMap
offsets KernelBody GPUMem
kbody = do
  Stms GPUMem
stms' <-
    OffsetM () -> OffsetM (Stms (Rep OffsetM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (OffsetM () -> OffsetM (Stms (Rep OffsetM)))
-> OffsetM () -> OffsetM (Stms (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$
      (Stm GPUMem -> OffsetM ()) -> Stms GPUMem -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Stm (Rep OffsetM) -> OffsetM ()
Stm GPUMem -> OffsetM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPUMem -> OffsetM ())
-> (Stm GPUMem -> OffsetM (Stm GPUMem)) -> Stm GPUMem -> OffsetM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets) (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
  KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
kbody {kernelBodyStms = stms'}

offsetMemoryInBody :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody RebaseMap
offsets (Body BodyDec GPUMem
_ Stms GPUMem
stms Result
res) = do
  OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (OffsetM Result -> OffsetM (Body (Rep OffsetM)))
-> OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ do
    (Stm GPUMem -> OffsetM ()) -> Stms GPUMem -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Stm (Rep OffsetM) -> OffsetM ()
Stm GPUMem -> OffsetM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPUMem -> OffsetM ())
-> (Stm GPUMem -> OffsetM (Stm GPUMem)) -> Stm GPUMem -> OffsetM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets) Stms GPUMem
stms
    Result -> OffsetM Result
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

argsContext :: [SubExp] -> OffsetM [SubExp]
argsContext :: [SubExp] -> OffsetM [SubExp]
argsContext = ([[SubExp]] -> [SubExp]) -> OffsetM [[SubExp]] -> OffsetM [SubExp]
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (OffsetM [[SubExp]] -> OffsetM [SubExp])
-> ([SubExp] -> OffsetM [[SubExp]]) -> [SubExp] -> OffsetM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> OffsetM [SubExp]) -> [SubExp] -> OffsetM [[SubExp]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> OffsetM [SubExp]
forall {m :: * -> *} {inner :: * -> *}.
(OpC (Rep m) ~ MemOp inner,
 LParamInfo (Rep m) ~ MemInfo SubExp NoUniqueness MemBind,
 RetType (Rep m) ~ RetTypeMem,
 FParamInfo (Rep m) ~ MemInfo SubExp Uniqueness MemBind,
 BranchType (Rep m) ~ BranchTypeMem, MonadBuilder m,
 RephraseOp inner, HasLetDecMem (LetDec (Rep m)),
 OpReturns (inner (Rep m))) =>
SubExp -> m [SubExp]
resCtx
  where
    resCtx :: SubExp -> m [SubExp]
resCtx SubExp
se = do
      MemInfo SubExp NoUniqueness MemBind
v_t <- SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
forall rep (m :: * -> *) (inner :: * -> *).
(HasScope rep m, Mem rep inner) =>
SubExp -> m (MemInfo SubExp NoUniqueness MemBind)
subExpMemInfo SubExp
se
      case MemInfo SubExp NoUniqueness MemBind
v_t of
        MemArray PrimType
_ ShapeBase SubExp
_ NoUniqueness
_ (ArrayIn VName
mem LMAD
lmad) -> do
          [SubExp]
ctxs <- (TPrimExp Int64 VName -> m SubExp)
-> [TPrimExp Int64 VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ctx" (Exp (Rep m) -> m SubExp)
-> (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName
-> m SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp) (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD
lmad)
          [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> m [SubExp]) -> [SubExp] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
mem SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ctxs
        MemInfo SubExp NoUniqueness MemBind
_ -> [SubExp] -> m [SubExp]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

offsetMemoryInBodyReturnCtx :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx :: RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets (Body BodyDec GPUMem
_ Stms GPUMem
stms Result
res) = do
  OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (OffsetM Result -> OffsetM (Body (Rep OffsetM)))
-> OffsetM Result -> OffsetM (Body (Rep OffsetM))
forall a b. (a -> b) -> a -> b
$ do
    (Stm GPUMem -> OffsetM ()) -> Stms GPUMem -> OffsetM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Stm (Rep OffsetM) -> OffsetM ()
Stm GPUMem -> OffsetM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm GPUMem -> OffsetM ())
-> (Stm GPUMem -> OffsetM (Stm GPUMem)) -> Stm GPUMem -> OffsetM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets) Stms GPUMem
stms
    [SubExp]
ctx <- [SubExp] -> OffsetM [SubExp]
argsContext ([SubExp] -> OffsetM [SubExp]) -> [SubExp] -> OffsetM [SubExp]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
    Result -> OffsetM Result
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> OffsetM Result) -> Result -> OffsetM Result
forall a b. (a -> b) -> a -> b
$ Result
res Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Result
subExpsRes [SubExp]
ctx

lmadFrom :: LMAD.Shape num -> [num] -> LMAD.LMAD num
lmadFrom :: forall num. Shape num -> Shape num -> LMAD num
lmadFrom Shape num
shape Shape num
xs =
  num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD.LMAD (Shape num -> num
forall a. HasCallStack => [a] -> a
head Shape num
xs) ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> LMADDim num)
-> Shape num -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMAD.LMADDim (Int -> Shape num -> Shape num
forall a. Int -> [a] -> [a]
drop Int
1 Shape num
xs) Shape num
shape

-- | Append pattern elements corresponding to memory and index
-- function components for every array bound in the pattern.
addPatternContext :: Pat LetDecMem -> OffsetM (Pat LetDecMem)
addPatternContext :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
addPatternContext (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) = Scope GPUMem
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Pat (MemInfo SubExp NoUniqueness MemBind) -> Scope GPUMem
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes)) (OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
 -> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind)))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ do
  ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes_ctx, [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes') <- ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> OffsetM
      ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
       PatElem (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM
     ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
      [PatElem (MemInfo SubExp NoUniqueness MemBind)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM
     ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
      PatElem (MemInfo SubExp NoUniqueness MemBind))
forall {rep} {inner :: * -> *} {m :: * -> *} {d} {u} {ret} {d} {u}.
(OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LParamInfo rep ~ MemInfo SubExp NoUniqueness MemBind,
 FParamInfo rep ~ MemInfo SubExp Uniqueness MemBind,
 RetType rep ~ RetTypeMem, HasLetDecMem (LetDec rep), ASTRep rep,
 OpReturns (inner rep), RephraseOp inner, HasScope rep m,
 MonadFreshNames m) =>
[PatElem (MemInfo d u ret)]
-> PatElem (MemInfo d u MemBind)
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
onType [] [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes
  Pat (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (MemInfo SubExp NoUniqueness MemBind)
 -> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind)))
-> Pat (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$ [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes' [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes_ctx
  where
    onType :: [PatElem (MemInfo d u ret)]
-> PatElem (MemInfo d u MemBind)
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
onType
      [PatElem (MemInfo d u ret)]
acc
      (PatElem VName
pe_v (MemArray PrimType
pt ShapeBase d
pe_shape u
pe_u (ArrayIn VName
pe_mem LMAD
lmad))) = do
        Space
space <- VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
pe_mem
        VName
pe_mem' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
pe_mem String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ext"
        let num_exts :: Int
num_exts = [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD
lmad)
        [PatElem (MemInfo d u ret)]
lmad_exts <-
          Int
-> m (PatElem (MemInfo d u ret)) -> m [PatElem (MemInfo d u ret)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_exts (m (PatElem (MemInfo d u ret)) -> m [PatElem (MemInfo d u ret)])
-> m (PatElem (MemInfo d u ret)) -> m [PatElem (MemInfo d u ret)]
forall a b. (a -> b) -> a -> b
$
            VName -> MemInfo d u ret -> PatElem (MemInfo d u ret)
forall dec. VName -> dec -> PatElem dec
PatElem (VName -> MemInfo d u ret -> PatElem (MemInfo d u ret))
-> m VName -> m (MemInfo d u ret -> PatElem (MemInfo d u ret))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ext" m (MemInfo d u ret -> PatElem (MemInfo d u ret))
-> m (MemInfo d u ret) -> m (PatElem (MemInfo d u ret))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MemInfo d u ret -> m (MemInfo d u ret)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
        let pe_lmad' :: LMAD
pe_lmad' = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> LMAD
forall num. Shape num -> Shape num -> LMAD num
lmadFrom (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad) ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo d u ret) -> TPrimExp Int64 VName)
-> [PatElem (MemInfo d u ret)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (PatElem (MemInfo d u ret) -> VName)
-> PatElem (MemInfo d u ret)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (MemInfo d u ret) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (MemInfo d u ret)]
lmad_exts
        ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( [PatElem (MemInfo d u ret)]
acc [PatElem (MemInfo d u ret)]
-> [PatElem (MemInfo d u ret)] -> [PatElem (MemInfo d u ret)]
forall a. [a] -> [a] -> [a]
++ VName -> MemInfo d u ret -> PatElem (MemInfo d u ret)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_mem' (Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space) PatElem (MemInfo d u ret)
-> [PatElem (MemInfo d u ret)] -> [PatElem (MemInfo d u ret)]
forall a. a -> [a] -> [a]
: [PatElem (MemInfo d u ret)]
lmad_exts,
            VName -> MemInfo d u MemBind -> PatElem (MemInfo d u MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_v (MemInfo d u MemBind -> PatElem (MemInfo d u MemBind))
-> MemInfo d u MemBind -> PatElem (MemInfo d u MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> MemBind -> MemInfo d u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
pe_shape u
pe_u (MemBind -> MemInfo d u MemBind) -> MemBind -> MemInfo d u MemBind
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
pe_mem' LMAD
pe_lmad'
          )
    onType [PatElem (MemInfo d u ret)]
acc PatElem (MemInfo d u MemBind)
t = ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
-> m ([PatElem (MemInfo d u ret)], PatElem (MemInfo d u MemBind))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (MemInfo d u ret)]
acc, PatElem (MemInfo d u MemBind)
t)

-- | Append pattern elements corresponding to memory and index
-- function components for every array bound in the parameters.
addParamsContext :: [Param FParamMem] -> OffsetM [Param FParamMem]
addParamsContext :: [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
addParamsContext [Param (MemInfo SubExp Uniqueness MemBind)]
ps = Scope GPUMem
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (MemInfo SubExp Uniqueness MemBind)]
ps) (OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
 -> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)])
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ do
  ([Param (MemInfo SubExp Uniqueness MemBind)]
ps_ctx, [Param (MemInfo SubExp Uniqueness MemBind)]
ps') <- ([Param (MemInfo SubExp Uniqueness MemBind)]
 -> Param (MemInfo SubExp Uniqueness MemBind)
 -> OffsetM
      ([Param (MemInfo SubExp Uniqueness MemBind)],
       Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM
     ([Param (MemInfo SubExp Uniqueness MemBind)],
      [Param (MemInfo SubExp Uniqueness MemBind)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [Param (MemInfo SubExp Uniqueness MemBind)]
-> Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM
     ([Param (MemInfo SubExp Uniqueness MemBind)],
      Param (MemInfo SubExp Uniqueness MemBind))
forall {rep} {inner :: * -> *} {m :: * -> *} {d} {u} {ret} {d} {u}.
(OpC rep ~ MemOp inner, BranchType rep ~ BranchTypeMem,
 LParamInfo rep ~ MemInfo SubExp NoUniqueness MemBind,
 FParamInfo rep ~ MemInfo SubExp Uniqueness MemBind,
 RetType rep ~ RetTypeMem, HasLetDecMem (LetDec rep), ASTRep rep,
 OpReturns (inner rep), RephraseOp inner, HasScope rep m,
 MonadFreshNames m) =>
[Param (MemInfo d u ret)]
-> Param (MemInfo d u MemBind)
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
onType [] [Param (MemInfo SubExp Uniqueness MemBind)]
ps
  [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param (MemInfo SubExp Uniqueness MemBind)]
 -> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)])
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp Uniqueness MemBind)]
ps' [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [Param (MemInfo SubExp Uniqueness MemBind)]
ps_ctx
  where
    onType :: [Param (MemInfo d u ret)]
-> Param (MemInfo d u MemBind)
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
onType [Param (MemInfo d u ret)]
acc (Param Attrs
attr VName
v (MemArray PrimType
pt ShapeBase d
shape u
u (ArrayIn VName
mem LMAD
lmad))) = do
      Space
space <- VName -> m Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
      VName
mem' <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
mem String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ext"
      let num_exts :: Int
num_exts = [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.existentialized LMAD
lmad)
      [Param (MemInfo d u ret)]
lmad_exts <-
        Int -> m (Param (MemInfo d u ret)) -> m [Param (MemInfo d u ret)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_exts (m (Param (MemInfo d u ret)) -> m [Param (MemInfo d u ret)])
-> m (Param (MemInfo d u ret)) -> m [Param (MemInfo d u ret)]
forall a b. (a -> b) -> a -> b
$
          Attrs -> VName -> MemInfo d u ret -> Param (MemInfo d u ret)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty (VName -> MemInfo d u ret -> Param (MemInfo d u ret))
-> m VName -> m (MemInfo d u ret -> Param (MemInfo d u ret))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ext" m (MemInfo d u ret -> Param (MemInfo d u ret))
-> m (MemInfo d u ret) -> m (Param (MemInfo d u ret))
forall a b. m (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MemInfo d u ret -> m (MemInfo d u ret)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> MemInfo d u ret
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
      let lmad' :: LMAD
lmad' = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> LMAD
forall num. Shape num -> Shape num -> LMAD num
lmadFrom (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad) ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo d u ret) -> TPrimExp Int64 VName)
-> [Param (MemInfo d u ret)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (Param (MemInfo d u ret) -> VName)
-> Param (MemInfo d u ret)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName) [Param (MemInfo d u ret)]
lmad_exts
      ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( [Param (MemInfo d u ret)]
acc [Param (MemInfo d u ret)]
-> [Param (MemInfo d u ret)] -> [Param (MemInfo d u ret)]
forall a. [a] -> [a] -> [a]
++ Attrs -> VName -> MemInfo d u ret -> Param (MemInfo d u ret)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
mem' (Space -> MemInfo d u ret
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space) Param (MemInfo d u ret)
-> [Param (MemInfo d u ret)] -> [Param (MemInfo d u ret)]
forall a. a -> [a] -> [a]
: [Param (MemInfo d u ret)]
lmad_exts,
          Attrs
-> VName -> MemInfo d u MemBind -> Param (MemInfo d u MemBind)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attr VName
v (MemInfo d u MemBind -> Param (MemInfo d u MemBind))
-> MemInfo d u MemBind -> Param (MemInfo d u MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase d -> u -> MemBind -> MemInfo d u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
shape u
u (MemBind -> MemInfo d u MemBind) -> MemBind -> MemInfo d u MemBind
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
mem' LMAD
lmad'
        )
    onType [Param (MemInfo d u ret)]
acc Param (MemInfo d u MemBind)
t = ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
-> m ([Param (MemInfo d u ret)], Param (MemInfo d u MemBind))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param (MemInfo d u ret)]
acc, Param (MemInfo d u MemBind)
t)

offsetBranch ::
  Pat LetDecMem ->
  [BranchTypeMem] ->
  OffsetM (Pat LetDecMem, [BranchTypeMem])
offsetBranch :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> [BranchTypeMem]
-> OffsetM
     (Pat (MemInfo SubExp NoUniqueness MemBind), [BranchTypeMem])
offsetBranch (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [BranchTypeMem]
ts = do
  (([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes_ctx, [BranchTypeMem]
ts_ctx), ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes', [BranchTypeMem]
ts')) <-
    ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
 -> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     [BranchTypeMem]))
-> ([(PatElem (MemInfo SubExp NoUniqueness MemBind),
      BranchTypeMem)]
    -> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
        [BranchTypeMem]))
-> ([(PatElem (MemInfo SubExp NoUniqueness MemBind),
      BranchTypeMem)],
    [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
-> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
     [BranchTypeMem]),
    ([PatElem (MemInfo SubExp NoUniqueness MemBind)], [BranchTypeMem]))
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
    [BranchTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
    [BranchTypeMem])
forall a b. [(a, b)] -> ([a], [b])
unzip (([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
  [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
 -> (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
      [BranchTypeMem]),
     ([PatElem (MemInfo SubExp NoUniqueness MemBind)],
      [BranchTypeMem])))
-> OffsetM
     ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
      [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
-> OffsetM
     (([PatElem (MemInfo SubExp NoUniqueness MemBind)],
       [BranchTypeMem]),
      ([PatElem (MemInfo SubExp NoUniqueness MemBind)], [BranchTypeMem]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
 -> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
 -> OffsetM
      ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
       (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)))
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> OffsetM
     ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
      [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
-> OffsetM
     ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
      (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
onType [] ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [BranchTypeMem]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [BranchTypeMem]
ts)
  (Pat (MemInfo SubExp NoUniqueness MemBind), [BranchTypeMem])
-> OffsetM
     (Pat (MemInfo SubExp NoUniqueness MemBind), [BranchTypeMem])
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes' [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a. Semigroup a => a -> a -> a
<> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes_ctx, [BranchTypeMem]
ts' [BranchTypeMem] -> [BranchTypeMem] -> [BranchTypeMem]
forall a. Semigroup a => a -> a -> a
<> [BranchTypeMem]
ts_ctx)
  where
    onType :: [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
-> OffsetM
     ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
      (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
onType
      [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc
      ( PatElem VName
pe_v (MemArray PrimType
_ ShapeBase SubExp
pe_shape NoUniqueness
pe_u (ArrayIn VName
pe_mem LMAD
pe_lmad)),
        MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u MemReturn
meminfo
        ) = do
        (Space
space, ExtLMAD
lmad) <- case MemReturn
meminfo of
          ReturnsInBlock VName
mem ExtLMAD
lmad -> do
            Space
space <- VName -> OffsetM Space
forall rep (inner :: * -> *) (m :: * -> *).
(Mem rep inner, HasScope rep m, Monad m) =>
VName -> m Space
lookupMemSpace VName
mem
            (Space, ExtLMAD) -> OffsetM (Space, ExtLMAD)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Space
space, ExtLMAD
lmad)
          ReturnsNewBlock Space
space Int
_ ExtLMAD
lmad ->
            (Space, ExtLMAD) -> OffsetM (Space, ExtLMAD)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Space
space, ExtLMAD
lmad)
        VName
pe_mem' <- String -> OffsetM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> OffsetM VName) -> String -> OffsetM VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
pe_mem String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ext"
        let start :: Int
start = [BranchTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchTypeMem]
ts Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc
            num_exts :: Int
num_exts = [TPrimExp Int64 (Ext VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ExtLMAD -> [TPrimExp Int64 (Ext VName)]
forall a. LMAD a -> [a]
LMAD.existentialized ExtLMAD
lmad)
            ext :: ExtSize -> TPrimExp Int64 (Ext VName)
ext (Free SubExp
se) = VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> TPrimExp Int64 VName
pe64 SubExp
se
            ext (Ext Int
i) = Ext VName -> TPrimExp Int64 (Ext VName)
forall a. a -> TPrimExp Int64 a
le64 (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i)
        [PatElem (MemInfo SubExp NoUniqueness MemBind)]
lmad_exts <-
          Int
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_exts (OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
 -> OffsetM [PatElem (MemInfo SubExp NoUniqueness MemBind)])
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$
            VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem (VName
 -> MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM VName
-> OffsetM
     (MemInfo SubExp NoUniqueness MemBind
      -> PatElem (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> OffsetM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ext" OffsetM
  (MemInfo SubExp NoUniqueness MemBind
   -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MemInfo SubExp NoUniqueness MemBind
-> OffsetM (MemInfo SubExp NoUniqueness MemBind)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64)
        let pe_lmad' :: LMAD
pe_lmad' = [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> LMAD
forall num. Shape num -> Shape num -> LMAD num
lmadFrom (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
pe_lmad) ([TPrimExp Int64 VName] -> LMAD) -> [TPrimExp Int64 VName] -> LMAD
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> TPrimExp Int64 VName)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
lmad_exts
        ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
 (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
-> OffsetM
     ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
      (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc
              [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
forall a. [a] -> [a] -> [a]
++ (VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_mem' (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space, Space -> BranchTypeMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space)
              (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
forall a. a -> [a] -> [a]
: (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
forall a b. (a -> b) -> [a] -> [b]
map (,PrimType -> BranchTypeMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
int64) [PatElem (MemInfo SubExp NoUniqueness MemBind)]
lmad_exts,
            ( VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
pe_v (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
pe_shape NoUniqueness
pe_u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
pe_mem' LMAD
pe_lmad',
              PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem)
-> (LMAD ExtSize -> MemReturn) -> LMAD ExtSize -> BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Int -> ExtLMAD -> MemReturn
ReturnsNewBlock Space
space Int
start (ExtLMAD -> MemReturn)
-> (LMAD ExtSize -> ExtLMAD) -> LMAD ExtSize -> MemReturn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ExtSize -> TPrimExp Int64 (Ext VName)) -> LMAD ExtSize -> ExtLMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ExtSize -> TPrimExp Int64 (Ext VName)
ext (LMAD ExtSize -> BranchTypeMem) -> LMAD ExtSize -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
                Shape ExtSize -> Int -> LMAD ExtSize
forall a. Shape (Ext a) -> Int -> LMAD (Ext a)
LMAD.mkExistential (ShapeBase ExtSize -> Shape ExtSize
forall d. ShapeBase d -> [d]
shapeDims ShapeBase ExtSize
shape) (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
start)
            )
          )
    onType [(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
t = ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
 (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
-> OffsetM
     ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)],
      (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem))
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)]
acc, (PatElem (MemInfo SubExp NoUniqueness MemBind), BranchTypeMem)
t)

offsetMemoryInPat :: RebaseMap -> Pat LetDecMem -> [ExpReturns] -> Pat LetDecMem
offsetMemoryInPat :: RebaseMap
-> Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
offsetMemoryInPat RebaseMap
offsets (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [ExpReturns]
rets = do
  [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
onPE [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [ExpReturns]
rets
  where
    onPE :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
onPE
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
mem LMAD
_)))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ Maybe MemReturn
info)
        | Just ExtLMAD
lmad <- Maybe MemReturn -> Maybe ExtLMAD
getLMAD Maybe MemReturn
info =
            VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> (LMAD -> MemInfo SubExp NoUniqueness MemBind)
-> LMAD
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> (LMAD -> MemBind) -> LMAD -> MemInfo SubExp NoUniqueness MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> LMAD -> MemBind
ArrayIn VName
mem (LMAD -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> LMAD -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$
              (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName)
-> ExtLMAD -> LMAD
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Ext VName -> VName)
-> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
unExt) ExtLMAD
lmad
    onPE PatElem (MemInfo SubExp NoUniqueness MemBind)
pe ExpReturns
_ =
      RebaseMap
-> MemInfo SubExp NoUniqueness MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall u. RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound RebaseMap
offsets (MemInfo SubExp NoUniqueness MemBind
 -> MemInfo SubExp NoUniqueness MemBind)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatElem (MemInfo SubExp NoUniqueness MemBind)
pe
    unExt :: Ext VName -> VName
unExt (Ext Int
i) = PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Int -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a. HasCallStack => [a] -> Int -> a
!! Int
i)
    unExt (Free VName
v) = VName
v
    getLMAD :: Maybe MemReturn -> Maybe ExtLMAD
getLMAD (Just (ReturnsNewBlock Space
_ Int
_ ExtLMAD
lmad)) = ExtLMAD -> Maybe ExtLMAD
forall a. a -> Maybe a
Just ExtLMAD
lmad
    getLMAD (Just (ReturnsInBlock VName
_ ExtLMAD
lmad)) = ExtLMAD -> Maybe ExtLMAD
forall a. a -> Maybe a
Just ExtLMAD
lmad
    getLMAD Maybe MemReturn
_ = Maybe ExtLMAD
forall a. Maybe a
Nothing

offsetMemoryInParam :: RebaseMap -> Param (MemBound u) -> Param (MemBound u)
offsetMemoryInParam :: forall u. RebaseMap -> Param (MemBound u) -> Param (MemBound u)
offsetMemoryInParam RebaseMap
offsets = (MemBound u -> MemBound u)
-> Param (MemBound u) -> Param (MemBound u)
forall a b. (a -> b) -> Param a -> Param b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((MemBound u -> MemBound u)
 -> Param (MemBound u) -> Param (MemBound u))
-> (MemBound u -> MemBound u)
-> Param (MemBound u)
-> Param (MemBound u)
forall a b. (a -> b) -> a -> b
$ RebaseMap -> MemBound u -> MemBound u
forall u. RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound RebaseMap
offsets

offsetMemoryInMemBound :: RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound :: forall u. RebaseMap -> MemBound u -> MemBound u
offsetMemoryInMemBound RebaseMap
offsets (MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem LMAD
lmad))
  | Just (TPrimExp Int64 VName
o, TPrimExp Int64 VName
p) <- VName
-> [TPrimExp Int64 VName]
-> RebaseMap
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
lookupNewBase VName
mem (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad) RebaseMap
offsets =
      PrimType
-> ShapeBase SubExp -> u -> MemBind -> MemInfo SubExp u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemInfo SubExp u MemBind)
-> MemBind -> MemInfo SubExp u MemBind
forall a b. (a -> b) -> a -> b
$ VName -> LMAD -> MemBind
ArrayIn VName
mem (LMAD -> MemBind) -> LMAD -> MemBind
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TPrimExp Int64 VName -> LMAD -> LMAD
forall num. IntegralExp num => num -> num -> LMAD num -> LMAD num
LMAD.expand TPrimExp Int64 VName
o TPrimExp Int64 VName
p LMAD
lmad
offsetMemoryInMemBound RebaseMap
_ MemInfo SubExp u MemBind
info = MemInfo SubExp u MemBind
info

offsetMemoryInBodyReturns :: RebaseMap -> BodyReturns -> BodyReturns
offsetMemoryInBodyReturns :: RebaseMap -> BranchTypeMem -> BranchTypeMem
offsetMemoryInBodyReturns RebaseMap
offsets (MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtLMAD
lmad))
  | Just LMAD
lmad' <- ExtLMAD -> Maybe LMAD
isStaticLMAD ExtLMAD
lmad,
    Just (TPrimExp Int64 VName
o, TPrimExp Int64 VName
p) <- VName
-> [TPrimExp Int64 VName]
-> RebaseMap
-> Maybe (TPrimExp Int64 VName, TPrimExp Int64 VName)
lookupNewBase VName
mem (LMAD -> [TPrimExp Int64 VName]
forall a. LMAD a -> [a]
LMAD.shape LMAD
lmad') RebaseMap
offsets =
      PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem) -> MemReturn -> BranchTypeMem
forall a b. (a -> b) -> a -> b
$
        VName -> ExtLMAD -> MemReturn
ReturnsInBlock VName
mem (ExtLMAD -> MemReturn) -> ExtLMAD -> MemReturn
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int64 (Ext VName)
-> TPrimExp Int64 (Ext VName) -> ExtLMAD -> ExtLMAD
forall num. IntegralExp num => num -> num -> LMAD num -> LMAD num
LMAD.expand (VName -> Ext VName
forall a. a -> Ext a
Free (VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TPrimExp Int64 VName
o) ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall a b. (a -> b) -> TPrimExp Int64 a -> TPrimExp Int64 b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free TPrimExp Int64 VName
p) ExtLMAD
lmad
offsetMemoryInBodyReturns RebaseMap
_ BranchTypeMem
br = BranchTypeMem
br

offsetMemoryInLambda :: RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: RebaseMap -> Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda RebaseMap
offsets Lambda GPUMem
lam = do
  Body GPUMem
body <- Lambda GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody RebaseMap
offsets (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
  let params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
params = (Param (MemInfo SubExp NoUniqueness MemBind)
 -> Param (MemInfo SubExp NoUniqueness MemBind))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (RebaseMap
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
forall u. RebaseMap -> Param (MemBound u) -> Param (MemBound u)
offsetMemoryInParam RebaseMap
offsets) ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
  Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody = body, lambdaParams = params}

-- A loop may have memory parameters, and those memory blocks may
-- be expanded.  We assume (but do not check - FIXME) that if the
-- initial value of a loop parameter is an expanded memory block,
-- then so will the result be.
offsetMemoryInLoopParams ::
  RebaseMap ->
  [(FParam GPUMem, SubExp)] ->
  (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a) ->
  OffsetM a
offsetMemoryInLoopParams :: forall a.
RebaseMap
-> [(FParam GPUMem, SubExp)]
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a)
-> OffsetM a
offsetMemoryInLoopParams RebaseMap
offsets [(FParam GPUMem, SubExp)]
merge RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a
f = do
  let ([Param (MemInfo SubExp Uniqueness MemBind)]
params, [SubExp]
args) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
  [Param (MemInfo SubExp Uniqueness MemBind)]
params' <- [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
addParamsContext [Param (MemInfo SubExp Uniqueness MemBind)]
params
  [SubExp]
args' <- ([SubExp]
args <>) ([SubExp] -> [SubExp]) -> OffsetM [SubExp] -> OffsetM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> OffsetM [SubExp]
argsContext [SubExp]
args
  RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a
f RebaseMap
offsets' ([(FParam GPUMem, SubExp)] -> OffsetM a)
-> [(FParam GPUMem, SubExp)] -> OffsetM a
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp Uniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
params' [SubExp]
args'
  where
    offsets' :: RebaseMap
offsets' = RebaseMap -> RebaseMap
extend RebaseMap
offsets
    extend :: RebaseMap -> RebaseMap
extend RebaseMap
rm = (RebaseMap
 -> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> RebaseMap)
-> RebaseMap
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> RebaseMap
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' RebaseMap
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> RebaseMap
forall {a} {dec}. Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg RebaseMap
rm [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
    onParamArg :: Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg Map VName a
rm (Param dec
param, Var VName
arg)
      | Just a
x <- VName -> Map VName a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arg Map VName a
rm =
          VName -> a -> Map VName a -> Map VName a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) a
x Map VName a
rm
    onParamArg Map VName a
rm (Param dec, SubExp)
_ = Map VName a
rm

-- | Handles only the expressions where we do not change the number of
-- results; meaning anything except Loop, Match, and nonscalar Apply.
offsetMemoryInExp :: RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp RebaseMap
offsets = Mapper GPUMem GPUMem OffsetM -> Exp GPUMem -> OffsetM (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse
  where
    recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @GPUMem)
        { mapOnBody = \Scope GPUMem
bscope -> Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem
-> OffsetM (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody RebaseMap
offsets,
          mapOnBranchType = pure . offsetMemoryInBodyReturns offsets,
          mapOnOp = onOp
        }
    onOp :: MemOp (HostOp NoOp) GPUMem -> OffsetM (MemOp (HostOp NoOp) GPUMem)
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem
forall (inner :: * -> *) rep. inner rep -> MemOp inner rep
Inner (HostOp NoOp GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> (SegOp SegLevel GPUMem -> HostOp NoOp GPUMem)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp NoOp) GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp NoOp GPUMem
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp
        (SegOp SegLevel GPUMem -> MemOp (HostOp NoOp) GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (MemOp (HostOp NoOp) GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (SegOpMapper SegLevel GPUMem GPUMem OffsetM
-> SegOp SegLevel GPUMem -> OffsetM (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
      where
        segOpMapper :: SegOpMapper SegLevel GPUMem GPUMem OffsetM
segOpMapper =
          SegOpMapper SegLevel Any Any OffsetM
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpBody = offsetMemoryInKernelBody offsets,
              mapOnSegOpLambda = offsetMemoryInLambda offsets
            }
    onOp MemOp (HostOp NoOp) GPUMem
op = MemOp (HostOp NoOp) GPUMem -> OffsetM (MemOp (HostOp NoOp) GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp NoOp) GPUMem
op

offsetMemoryInStm :: RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm :: RebaseMap -> Stm GPUMem -> OffsetM (Stm GPUMem)
offsetMemoryInStm RebaseMap
offsets (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
kind))) = do
  [Case (Body GPUMem)]
cases' <- [Case (Body GPUMem)]
-> (Case (Body GPUMem) -> OffsetM (Case (Body GPUMem)))
-> OffsetM [Case (Body GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Case (Body GPUMem)]
cases ((Case (Body GPUMem) -> OffsetM (Case (Body GPUMem)))
 -> OffsetM [Case (Body GPUMem)])
-> (Case (Body GPUMem) -> OffsetM (Case (Body GPUMem)))
-> OffsetM [Case (Body GPUMem)]
forall a b. (a -> b) -> a -> b
$ \(Case [Maybe PrimValue]
vs Body GPUMem
body) ->
    [Maybe PrimValue] -> Body GPUMem -> Case (Body GPUMem)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body GPUMem -> Case (Body GPUMem))
-> OffsetM (Body GPUMem) -> OffsetM (Case (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets Body GPUMem
body
  Body GPUMem
defbody' <- RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets Body GPUMem
defbody
  (Pat (MemInfo SubExp NoUniqueness MemBind)
pat', [BranchTypeMem]
ts') <- Pat (MemInfo SubExp NoUniqueness MemBind)
-> [BranchTypeMem]
-> OffsetM
     (Pat (MemInfo SubExp NoUniqueness MemBind), [BranchTypeMem])
offsetBranch Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat [BranchType GPUMem]
[BranchTypeMem]
ts
  Stm GPUMem -> OffsetM (Stm GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPUMem -> OffsetM (Stm GPUMem))
-> Stm GPUMem -> OffsetM (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat' StmAux (ExpDec GPUMem)
dec (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases' Body GPUMem
defbody' (MatchDec (BranchType GPUMem) -> Exp GPUMem)
-> MatchDec (BranchType GPUMem) -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ [BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchTypeMem]
ts' MatchSort
kind
offsetMemoryInStm RebaseMap
offsets (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec (Loop [(FParam GPUMem, SubExp)]
merge LoopForm
form Body GPUMem
body)) = do
  Exp GPUMem
loop' <-
    RebaseMap
-> [(FParam GPUMem, SubExp)]
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem)
forall a.
RebaseMap
-> [(FParam GPUMem, SubExp)]
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM a)
-> OffsetM a
offsetMemoryInLoopParams RebaseMap
offsets [(FParam GPUMem, SubExp)]
merge ((RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
 -> OffsetM (Exp GPUMem))
-> (RebaseMap -> [(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ \RebaseMap
offsets' [(FParam GPUMem, SubExp)]
merge' -> do
      Body GPUMem
body' <-
        Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall a. Scope GPUMem -> OffsetM a -> OffsetM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
          ([FParam GPUMem] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((FParam GPUMem, SubExp) -> FParam GPUMem)
-> [(FParam GPUMem, SubExp)] -> [FParam GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map (FParam GPUMem, SubExp) -> FParam GPUMem
forall a b. (a, b) -> a
fst [(FParam GPUMem, SubExp)]
merge') Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> LoopForm -> Scope GPUMem
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form)
          (RebaseMap -> Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBodyReturnCtx RebaseMap
offsets' Body GPUMem
body)
      Exp GPUMem -> OffsetM (Exp GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp GPUMem -> OffsetM (Exp GPUMem))
-> Exp GPUMem -> OffsetM (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ [(FParam GPUMem, SubExp)] -> LoopForm -> Body GPUMem -> Exp GPUMem
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam GPUMem, SubExp)]
merge' LoopForm
form Body GPUMem
body'
  Pat (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pat (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
addPatternContext Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat
  Stm GPUMem -> OffsetM (Stm GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPUMem -> OffsetM (Stm GPUMem))
-> Stm GPUMem -> OffsetM (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat' StmAux (ExpDec GPUMem)
dec Exp GPUMem
loop'
offsetMemoryInStm RebaseMap
offsets (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) = do
  Exp GPUMem
e' <- RebaseMap -> Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp RebaseMap
offsets Exp GPUMem
e
  Pat (MemInfo SubExp NoUniqueness MemBind)
pat' <-
    RebaseMap
-> Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
offsetMemoryInPat RebaseMap
offsets Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat
      ([ExpReturns] -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( OffsetM [ExpReturns]
-> ([ExpReturns] -> OffsetM [ExpReturns])
-> Maybe [ExpReturns]
-> OffsetM [ExpReturns]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> OffsetM [ExpReturns]
forall a. String -> OffsetM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"offsetMemoryInStm: ill-typed") [ExpReturns] -> OffsetM [ExpReturns]
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
              (Maybe [ExpReturns] -> OffsetM [ExpReturns])
-> OffsetM (Maybe [ExpReturns]) -> OffsetM [ExpReturns]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp GPUMem -> OffsetM (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp GPUMem
e'
          )
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <-
    OffsetM [ExpReturns]
-> ([ExpReturns] -> OffsetM [ExpReturns])
-> Maybe [ExpReturns]
-> OffsetM [ExpReturns]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> OffsetM [ExpReturns]
forall a. String -> OffsetM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"offsetMemoryInStm: ill-typed") [ExpReturns] -> OffsetM [ExpReturns]
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [ExpReturns] -> OffsetM [ExpReturns])
-> Maybe [ExpReturns] -> OffsetM [ExpReturns]
forall a b. (a -> b) -> a -> b
$
      Reader (Scope GPUMem) (Maybe [ExpReturns])
-> Scope GPUMem -> Maybe [ExpReturns]
forall r a. Reader r a -> r -> a
runReader (Exp GPUMem -> Reader (Scope GPUMem) (Maybe [ExpReturns])
forall rep (m :: * -> *) (inner :: * -> *).
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m (Maybe [ExpReturns])
expReturns Exp GPUMem
e') Scope GPUMem
scope
  let pat'' :: Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall {d} {u} {d} {u}.
PatElem (MemInfo d u MemBind)
-> MemInfo d u (Maybe MemReturn) -> PatElem (MemInfo d u MemBind)
pick (Pat (MemInfo SubExp NoUniqueness MemBind)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts
  Stm GPUMem -> OffsetM (Stm GPUMem)
forall a. a -> OffsetM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPUMem -> OffsetM (Stm GPUMem))
-> Stm GPUMem -> OffsetM (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec GPUMem)
dec Exp GPUMem
e'
  where
    pick :: PatElem (MemInfo d u MemBind)
-> MemInfo d u (Maybe MemReturn) -> PatElem (MemInfo d u MemBind)
pick
      (PatElem VName
name (MemArray PrimType
pt ShapeBase d
s u
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase d
_ u
_ (Just (ReturnsInBlock VName
m ExtLMAD
extlmad)))
        | Just LMAD
lmad <- ExtLMAD -> Maybe LMAD
instantiateLMAD ExtLMAD
extlmad =
            VName -> MemInfo d u MemBind -> PatElem (MemInfo d u MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (PrimType -> ShapeBase d -> u -> MemBind -> MemInfo d u MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase d
s u
u (VName -> LMAD -> MemBind
ArrayIn VName
m LMAD
lmad))
    pick PatElem (MemInfo d u MemBind)
p MemInfo d u (Maybe MemReturn)
_ = PatElem (MemInfo d u MemBind)
p

    instantiateLMAD :: ExtLMAD -> Maybe LMAD
    instantiateLMAD :: ExtLMAD -> Maybe LMAD
instantiateLMAD = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtLMAD -> Maybe LMAD
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TPrimExp Int64 a -> f (TPrimExp Int64 b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
        inst (Free a
x) = a -> Maybe a
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

---- Slicing allocation sizes out of a kernel.

unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
False
  where
    unAllocBody :: Body GPUMem -> Either String (Body GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
      BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> Result -> Body GPU)
-> Either String (Stms GPU) -> Either String (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String (Result -> Body GPU)
-> Either String Result -> Either String (Body GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either String Result
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> Either String (Stms GPU)
-> Either String ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String ([KernelResult] -> KernelBody GPU)
-> Either String [KernelResult] -> Either String (KernelBody GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either String [KernelResult]
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
nested = (Stm GPUMem -> Either String (Stm GPU))
-> Stms GPUMem -> Either String (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM (Bool -> Stm GPUMem -> Either String (Stm GPU)
unAllocStm Bool
nested)

    unAllocStm :: Bool -> Stm GPUMem -> Either String (Stm GPU)
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec (Op Alloc {}))
      | Bool
nested =
          String -> Either String (Stm GPU)
forall a. String -> Either String a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String (Stm GPU))
-> String -> Either String (Stm GPU)
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> String
forall a. Pretty a => a -> String
prettyString Stm GPUMem
stm
      | Bool
otherwise =
          Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU
Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let
            (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> StmAux () -> Exp GPU -> Stm GPU)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either String (StmAux () -> Exp GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat
            Either String (StmAux () -> Exp GPU -> Stm GPU)
-> Either String (StmAux ()) -> Either String (Exp GPU -> Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec
            Either String (Exp GPU -> Stm GPU)
-> Either String (Exp GPU) -> Either String (Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp GPU -> Either String (Exp GPU)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
UnitValue))
    unAllocStm Bool
_ (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) =
      Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU
Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> StmAux () -> Exp GPU -> Stm GPU)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either String (StmAux () -> Exp GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat Either String (StmAux () -> Exp GPU -> Stm GPU)
-> Either String (StmAux ()) -> Either String (Exp GPU -> Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec Either String (Exp GPU -> Stm GPU)
-> Either String (Exp GPU) -> Either String (Stm GPU)
forall a b.
Either String (a -> b) -> Either String a -> Either String b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper GPUMem GPU (Either String)
-> Exp GPUMem -> Either String (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either String)
unAlloc' Exp GPUMem
e

    unAllocLambda :: Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params [TypeBase (ShapeBase SubExp) NoUniqueness]
ret Body GPUMem
body) =
      [LParam GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU
-> Lambda GPU
forall rep.
[LParam rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep
-> Lambda rep
Lambda ((Param (MemInfo SubExp NoUniqueness MemBind)
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) [TypeBase (ShapeBase SubExp) NoUniqueness]
ret (Body GPU -> Lambda GPU)
-> Either String (Body GPU) -> Either String (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either String (Body GPU)
unAllocBody Body GPUMem
body

    unAllocPat :: Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat (Pat [PatElem (MemInfo d u ret)]
pes) =
      [PatElem (TypeBase (ShapeBase d) u)]
-> Pat (TypeBase (ShapeBase d) u)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase d) u)]
 -> Pat (TypeBase (ShapeBase d) u))
-> Either a [PatElem (TypeBase (ShapeBase d) u)]
-> Either a (Pat (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (MemInfo d u ret)
 -> Either a (PatElem (TypeBase (ShapeBase d) u)))
-> [PatElem (MemInfo d u ret)]
-> Either a [PatElem (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u))
-> PatElem (MemInfo d u ret)
-> Either a (PatElem (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u)
forall a b. b -> Either a b
Right (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u))
-> (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> MemInfo d u ret
-> Either a (TypeBase (ShapeBase d) u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem)) [PatElem (MemInfo d u ret)]
pes

    unAllocOp :: MemOp (HostOp NoOp) GPUMem -> Either String (HostOp SOAC GPU)
unAllocOp Alloc {} = String -> Either String (HostOp SOAC GPU)
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = String -> Either String (HostOp SOAC GPU)
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner GPUBody {}) = String -> Either String (HostOp SOAC GPU)
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled GPUBody"
    unAllocOp (Inner (SizeOp SizeOp
op)) = HostOp SOAC GPU -> Either String (HostOp SOAC GPU)
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp SOAC GPU -> Either String (HostOp SOAC GPU))
-> HostOp SOAC GPU -> Either String (HostOp SOAC GPU)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> Either String (SegOp SegLevel GPU)
-> Either String (HostOp SOAC GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPU (Either String)
-> SegOp SegLevel GPUMem -> Either String (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either String)
mapper SegOp SegLevel GPUMem
op
      where
        mapper :: SegOpMapper SegLevel GPUMem GPU (Either String)
mapper =
          SegOpMapper SegLevel Any Any (Either String)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpLambda = unAllocLambda,
              mapOnSegOpBody = unAllocKernelBody
            }

    unParam :: Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam = (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
forall a b. (a -> b) -> Param a -> Param b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem

    unT :: MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT = TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u)
forall a b. b -> Either a b
Right (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u))
-> (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> MemInfo d u ret
-> Either a (TypeBase (ShapeBase d) u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem

    unAlloc' :: Mapper GPUMem GPU (Either String)
unAlloc' =
      Mapper
        { mapOnBody :: Scope GPU -> Body GPUMem -> Either String (Body GPU)
mapOnBody = (Body GPUMem -> Either String (Body GPU))
-> Scope GPU -> Body GPUMem -> Either String (Body GPU)
forall a b. a -> b -> a
const Body GPUMem -> Either String (Body GPU)
unAllocBody,
          mapOnRetType :: RetType GPUMem -> Either String (RetType GPU)
mapOnRetType = RetType GPUMem -> Either String (RetType GPU)
RetTypeMem -> Either String DeclExtType
forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType GPUMem -> Either String (BranchType GPU)
mapOnBranchType = BranchType GPUMem -> Either String (BranchType GPU)
BranchTypeMem -> Either String ExtType
forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam GPUMem -> Either String (FParam GPU)
mapOnFParam = Param (TypeBase (ShapeBase SubExp) Uniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness))
forall a b. b -> Either a b
Right (Param (TypeBase (ShapeBase SubExp) Uniqueness)
 -> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness)))
-> (Param (MemInfo SubExp Uniqueness MemBind)
    -> Param (TypeBase (ShapeBase SubExp) Uniqueness))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) Uniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
          mapOnLParam :: LParam GPUMem -> Either String (LParam GPU)
mapOnLParam = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall a b. b -> Either a b
Right (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> Either
      String (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
          mapOnOp :: OpC GPUMem GPUMem -> Either String (Op GPU)
mapOnOp = OpC GPUMem GPUMem -> Either String (Op GPU)
MemOp (HostOp NoOp) GPUMem -> Either String (HostOp SOAC GPU)
unAllocOp,
          mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = SubExp -> Either String SubExp
forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either String VName
mapOnVName = VName -> Either String VName
forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem :: forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem (MemPrim PrimType
pt) = PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase (ShapeBase d) u
forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit

unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = (NameInfo GPUMem -> NameInfo GPU) -> Scope GPUMem -> Scope GPU
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo GPUMem -> NameInfo GPU
forall {rep} {d} {u} {rep} {ret} {d} {u} {ret} {d} {u} {ret}.
(LetDec rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ MemInfo d u ret,
 FParamInfo rep ~ TypeBase (ShapeBase d) u,
 FParamInfo rep ~ MemInfo d u ret,
 LParamInfo rep ~ TypeBase (ShapeBase d) u,
 LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> NameInfo rep
unInfo
  where
    unInfo :: NameInfo rep -> NameInfo rep
unInfo (LetName LetDec rep
dec) = LetDec rep -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName (LetDec rep -> NameInfo rep) -> LetDec rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LetDec rep
MemInfo d u ret
dec
    unInfo (FParamName FParamInfo rep
dec) = FParamInfo rep -> NameInfo rep
forall rep. FParamInfo rep -> NameInfo rep
FParamName (FParamInfo rep -> NameInfo rep) -> FParamInfo rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem FParamInfo rep
MemInfo d u ret
dec
    unInfo (LParamName LParamInfo rep
dec) = LParamInfo rep -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName (LParamInfo rep -> NameInfo rep) -> LParamInfo rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LParamInfo rep
MemInfo d u ret
dec
    unInfo (IndexName IntType
it) = IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it

removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
 -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)]
forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
 -> Map SubExp [(VName, Space)])
-> (Extraction
    -> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

copyConsumed :: (MonadBuilder m, AliasableRep (Rep m)) => Stms (Rep m) -> m (Stms (Rep m))
copyConsumed :: forall (m :: * -> *).
(MonadBuilder m, AliasableRep (Rep m)) =>
Stms (Rep m) -> m (Stms (Rep m))
copyConsumed Stms (Rep m)
stms = do
  let consumed :: [VName]
consumed = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ (AliasTable, Names) -> Names
forall a b. (a, b) -> b
snd ((AliasTable, Names) -> Names) -> (AliasTable, Names) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases (Rep m)), (AliasTable, Names))
-> (AliasTable, Names)
forall a b. (a, b) -> b
snd ((Stms (Aliases (Rep m)), (AliasTable, Names))
 -> (AliasTable, Names))
-> (Stms (Aliases (Rep m)), (AliasTable, Names))
-> (AliasTable, Names)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms (Rep m) -> (Stms (Aliases (Rep m)), (AliasTable, Names))
forall rep.
AliasableRep rep =>
AliasTable -> Stms rep -> (Stms (Aliases rep), (AliasTable, Names))
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms (Rep m)
stms
  m () -> m (Stms (Rep m))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (m () -> m (Stms (Rep m))) -> m () -> m (Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
    [VName]
consumed' <- (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m VName
forall {m :: * -> *}. MonadBuilder m => VName -> m VName
copy [VName]
consumed
    let substs :: Map VName VName
substs = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
consumed [VName]
consumed')
    Stms (Rep m) -> m ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep m) -> m ()) -> Stms (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms (Rep m) -> Stms (Rep m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms (Rep m)
stms
  where
    copy :: VName -> m VName
copy VName
v = String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v

-- Important for edge cases (#1838) that the Stms here still have the
-- Allocs we are actually trying to get rid of.
sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms GPUMem ->
  ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
sizes SegSpace
space Stms GPUMem
kstms = do
  Stms GPU
kstms' <- (String
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> (Stms GPU
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a.
String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Stms GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
  let num_sizes :: Int
num_sizes = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
sizes
      i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate Int
num_sizes (TypeBase (ShapeBase SubExp) NoUniqueness
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope GPU
kernels_scope <- (Scope GPUMem -> Scope GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPU)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope

  (Lambda GPU
max_lam, Stms GPU
_) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"y" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall a.
Scope GPU
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Result, Stms GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall a.
BuilderT
  GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (a,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Rep
            (BuilderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$
        [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
  Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
   Param (TypeBase (ShapeBase SubExp) NoUniqueness))
  -> BuilderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
       SubExpRes)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
          (SubExp -> SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall a b.
(a -> b)
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> SubExpRes
subExpRes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   SubExp
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExpRes)
-> (BasicOp
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExp)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"z" (Exp GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExp)
-> (BasicOp -> Exp GPU)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExpRes)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
    Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a.
a
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ [LParam GPU]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body GPU
-> Lambda GPU
forall rep.
[LParam rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Body rep
-> Lambda rep
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s (Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs)

  Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flat_gtid" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))

  Lambda GPU
size_lam' <- Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
forall a.
Scope GPUMem
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU))
-> (BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU)
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Lambda GPU, Stms GPU) -> Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
forall a b.
(a -> b)
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Lambda GPU, Stms GPU) -> Lambda GPU
forall a b. (a, b) -> a
fst (ReaderT
   (Scope GPUMem)
   (StateT VNameSource (Either String))
   (Lambda GPU, Stms GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU))
-> (BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU)
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Lambda GPU, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Lambda GPU)
forall a b. (a -> b) -> a -> b
$
    Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda (Lambda GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> (BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         (Lambda GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [LParam
   (Rep
      (BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam
  (Rep
     (BuilderT
        GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
flat_gtid_lparam] (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ do
      -- Even though this SegRed is one-dimensional, we need to
      -- provide indexes corresponding to the original potentially
      -- multi-dimensional construct.
      let ([VName]
kspace_gtids, [SubExp]
kspace_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
          new_inds :: [TPrimExp Int64 VName]
new_inds =
            [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
              ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
kspace_dims)
              (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
      ([VName]
 -> Exp GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> [[VName]]
-> [Exp GPU]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
[VName]
-> Exp GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([Exp GPU]
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Exp GPU]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Exp GPU))
-> [TPrimExp Int64 VName]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Exp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TPrimExp Int64 VName
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
TPrimExp Int64 VName
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp GPU)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds
      (Stm GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm
  (Rep
     (BuilderT
        GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
Stm GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms
  (Rep
     (BuilderT
        GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *).
(MonadBuilder m, AliasableRep (Rep m)) =>
Stms (Rep m) -> m (Stms (Rep m))
copyConsumed Stms
  (Rep
     (BuilderT
        GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
Stms GPU
kstms'
      Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a.
a
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result)
-> Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
sizes

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  ([VName], [VName])
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
      [Ident] -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPat ([Ident] -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Ident]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent String
"max_per_thread" (TypeBase (ShapeBase SubExp) NoUniqueness
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Ident)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"size_slice_w"
        (Exp GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExp)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"thread_space_iota" (Exp
   (Rep
      (BuilderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> Exp
      (Rep
         (BuilderT
            GPU
            (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
-> BasicOp
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
    let red_op :: SegBinOp GPU
red_op =
          Commutativity
-> Lambda GPU -> [SubExp] -> ShapeBase SubExp -> SegBinOp GPU
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
            Commutativity
Commutative
            Lambda GPU
max_lam
            (Int -> SubExp -> [SubExp]
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            ShapeBase SubExp
forall a. Monoid a => a
mempty
    SegLevel
lvl <- String
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SegLevel
forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> m SegLevel
segThread String
"segred"

    Stms
  (Rep
     (BuilderT
        GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
      (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Stm GPU))
-> Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stm GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         VName)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) ((VName
  -> BuilderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
       VName)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [VName])
-> (VName
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         VName)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"size_sum" (Exp
   (Rep
      (BuilderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> Exp
      (Rep
         (BuilderT
            GPU
            (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
-> BasicOp
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
forall a b. (a -> b) -> a -> b
$
          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads

    ([VName], [VName])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
forall a.
a
-> BuilderT
     GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)

  (Stms GPU, [VName], [VName])
-> ExpandM (Stms GPU, [VName], [VName])
forall a.
a -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)