{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Either (rights)
import Data.List (find, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Error
import Futhark.IR
import qualified Futhark.IR.GPU.Simplify as GPU
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
  String
-> String
-> (Prog GPUMem -> PassM (Prog GPUMem))
-> Pass GPUMem GPUMem
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"expand allocations" String
"Expand allocations" ((Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem)
-> (Prog GPUMem -> PassM (Prog GPUMem)) -> Pass GPUMem GPUMem
forall a b. (a -> b) -> a -> b
$
    \Prog GPUMem
prog -> do
      Stms GPUMem
consts' <-
        (VNameSource -> (Stms GPUMem, VNameSource)) -> PassM (Stms GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Stms GPUMem, VNameSource))
 -> PassM (Stms GPUMem))
-> (VNameSource -> (Stms GPUMem, VNameSource))
-> PassM (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
          Either String (Stms GPUMem, VNameSource)
-> (Stms GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft
            (Either String (Stms GPUMem, VNameSource)
 -> (Stms GPUMem, VNameSource))
-> (VNameSource -> Either String (Stms GPUMem, VNameSource))
-> VNameSource
-> (Stms GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Stms GPUMem)
-> VNameSource -> Either String (Stms GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Stms GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms (Prog GPUMem -> Stms GPUMem
forall rep. Prog rep -> Stms rep
progConsts Prog GPUMem
prog)) Scope GPUMem
forall a. Monoid a => a
mempty)
      [FunDef GPUMem]
funs' <- (FunDef GPUMem -> PassM (FunDef GPUMem))
-> [FunDef GPUMem] -> PassM [FunDef GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem))
-> Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') (Prog GPUMem -> [FunDef GPUMem]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog)
      Prog GPUMem -> PassM (Prog GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog GPUMem -> PassM (Prog GPUMem))
-> Prog GPUMem -> PassM (Prog GPUMem)
forall a b. (a -> b) -> a -> b
$ Prog GPUMem
prog {progConsts :: Stms GPUMem
progConsts = Stms GPUMem
consts', progFuns :: [FunDef GPUMem]
progFuns = [FunDef GPUMem]
funs'}

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either String a -> a
limitationOnLeft = (String -> a) -> (a -> a) -> Either String a -> a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> a
forall a. String -> a
compilerLimitationS a -> a
forall a. a -> a
id

transformFunDef ::
  Scope GPUMem ->
  FunDef GPUMem ->
  PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
  Body GPUMem
body' <- (VNameSource -> (Body GPUMem, VNameSource)) -> PassM (Body GPUMem)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Body GPUMem, VNameSource))
 -> PassM (Body GPUMem))
-> (VNameSource -> (Body GPUMem, VNameSource))
-> PassM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Either String (Body GPUMem, VNameSource)
-> (Body GPUMem, VNameSource)
forall a. Either String a -> a
limitationOnLeft (Either String (Body GPUMem, VNameSource)
 -> (Body GPUMem, VNameSource))
-> (VNameSource -> Either String (Body GPUMem, VNameSource))
-> VNameSource
-> (Body GPUMem, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT VNameSource (Either String) (Body GPUMem)
-> VNameSource -> Either String (Body GPUMem, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> Scope GPUMem -> StateT VNameSource (Either String) (Body GPUMem)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m Scope GPUMem
forall a. Monoid a => a
mempty)
  SimpleOps GPUMem
-> SymbolTable (Wise GPUMem)
-> FunDef GPUMem
-> PassM (FunDef GPUMem)
forall (m :: * -> *) rep.
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
    SimpleOps GPUMem
simpleGPUMem
    (Scope (Wise GPUMem) -> SymbolTable (Wise GPUMem)
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (Scope GPUMem -> Scope (Wise GPUMem)
forall rep. Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
    FunDef GPUMem
fundec {funDefBody :: Body GPUMem
funDefBody = Body GPUMem
body'}
  where
    m :: ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m =
      Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
        FunDef GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
          Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall a b. (a -> b) -> a -> b
$
            FunDef GPUMem -> Body GPUMem
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPUMem
fundec

transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPUMem -> Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Result -> Body GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Result -> Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
  [LParam GPUMem]
-> Body GPUMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPUMem
forall rep.
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda [LParam GPUMem]
params
    (Body GPUMem
 -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Scope GPUMem
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body)
    ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ExpandM (Lambda GPUMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
stms =
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat ([Stms GPUMem] -> Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> [Stm GPUMem]
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) [Stms GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)

transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- all versions fail do we propagate the error.
-- FIXME: this can remove safety checks if the default branch fails!
transformStm :: Stm GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv))) = do
  let onCase :: Case (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
onCase (Case [Maybe PrimValue]
vs Body GPUMem
body) =
        (Case (Body GPUMem) -> Either String (Case (Body GPUMem))
forall a b. b -> Either a b
Right (Case (Body GPUMem) -> Either String (Case (Body GPUMem)))
-> (Body GPUMem -> Case (Body GPUMem))
-> Body GPUMem
-> Either String (Case (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe PrimValue] -> Body GPUMem -> Case (Body GPUMem)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body GPUMem -> Either String (Case (Body GPUMem)))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Either String (Case (Body GPUMem)))
-> (String
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String (Case (Body GPUMem))))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Case (Body GPUMem))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Case (Body GPUMem))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Case (Body GPUMem))))
-> (String -> Either String (Case (Body GPUMem)))
-> String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Case (Body GPUMem))
forall a b. a -> Either a b
Left)
  [Case (Body GPUMem)]
cases' <- [Either String (Case (Body GPUMem))] -> [Case (Body GPUMem)]
forall a b. [Either a b] -> [b]
rights ([Either String (Case (Body GPUMem))] -> [Case (Body GPUMem)])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [Either String (Case (Body GPUMem))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [Case (Body GPUMem)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Case (Body GPUMem))))
-> [Case (Body GPUMem)]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [Either String (Case (Body GPUMem))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
onCase [Case (Body GPUMem)]
cases
  Either String (Body GPUMem)
defbody' <- (Body GPUMem -> Either String (Body GPUMem)
forall a b. b -> Either a b
Right (Body GPUMem -> Either String (Body GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
defbody) ReaderT
  (Scope GPUMem)
  (StateT VNameSource (Either String))
  (Either String (Body GPUMem))
-> (String
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Either String (Body GPUMem)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (Either String (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Body GPUMem)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Either String (Body GPUMem)))
-> (String -> Either String (Body GPUMem))
-> String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Body GPUMem))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Either String (Body GPUMem)
forall a b. a -> Either a b
Left)
  case ([Case (Body GPUMem)]
cases', Either String (Body GPUMem)
defbody') of
    ([], Left String
e) ->
      String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
    (Case (Body GPUMem)
_ : [Case (Body GPUMem)]
_, Left String
_) ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond ([Case (Body GPUMem)] -> [Case (Body GPUMem)]
forall a. [a] -> [a]
init [Case (Body GPUMem)]
cases') (Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody (Case (Body GPUMem) -> Body GPUMem)
-> Case (Body GPUMem) -> Body GPUMem
forall a b. (a -> b) -> a -> b
$ [Case (Body GPUMem)] -> Case (Body GPUMem)
forall a. [a] -> a
last [Case (Body GPUMem)]
cases') ([BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
[BranchTypeMem]
ts MatchSort
MatchEquiv)
    ([Case (Body GPUMem)]
_, Right Body GPUMem
defbody'') ->
      Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Stm GPUMem -> Stms GPUMem) -> Stm GPUMem -> Stms GPUMem
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases' Body GPUMem
defbody'' ([BranchTypeMem] -> MatchSort -> MatchDec BranchTypeMem
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
[BranchTypeMem]
ts MatchSort
MatchEquiv)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e) = do
  (Stms GPUMem
stms, Exp GPUMem
e') <- Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp (Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Exp GPUMem)
-> ExpandM (Stms GPUMem, Exp GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
-> Exp GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform Exp GPUMem
e
  Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem
stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stm GPUMem -> Stms GPUMem
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e')
  where
    transform :: Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform =
      Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
mapOnBody = \Scope GPUMem
scope -> Scope GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope (ReaderT
   (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> (Body GPUMem
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem))
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody
        }

transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
  let reds' :: [SegBinOp GPUMem]
reds' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space ((SegBinOp GPUMem -> Lambda GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
  let scans' :: [SegBinOp GPUMem]
scans' = (SegBinOp GPUMem -> Lambda GPUMem -> SegBinOp GPUMem)
-> [SegBinOp GPUMem] -> [Lambda GPUMem] -> [SegBinOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
  let ops' :: [HistOp GPUMem]
ops' = (HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem] -> [HistOp GPUMem]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith HistOp GPUMem -> Lambda GPUMem -> HistOp GPUMem
forall {rep} {rep}. HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody GPUMem
-> SegOp SegLevel GPUMem
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
  where
    lams :: [Lambda GPUMem]
lams = (HistOp GPUMem -> Lambda GPUMem)
-> [HistOp GPUMem] -> [Lambda GPUMem]
forall a b. (a -> b) -> [a] -> [b]
map HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
    onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp :: Lambda rep
histOp = Lambda rep
lam}
transformExp (WithAcc [WithAccInput GPUMem]
inputs Lambda GPUMem
lam) = do
  Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
  ([Stms GPUMem]
input_alloc_stms, [WithAccInput GPUMem]
inputs') <- [(Stms GPUMem, WithAccInput GPUMem)]
-> ([Stms GPUMem], [WithAccInput GPUMem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms GPUMem, WithAccInput GPUMem)]
 -> ([Stms GPUMem], [WithAccInput GPUMem]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stms GPUMem, WithAccInput GPUMem)]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Stms GPUMem], [WithAccInput GPUMem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput GPUMem
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, WithAccInput GPUMem))
-> [WithAccInput GPUMem]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stms GPUMem, WithAccInput GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM WithAccInput GPUMem
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, WithAccInput GPUMem)
forall {b} {b}.
(ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput [WithAccInput GPUMem]
inputs
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Stms GPUMem] -> Stms GPUMem
forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
      [WithAccInput GPUMem] -> Lambda GPUMem -> Exp GPUMem
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPUMem]
inputs' Lambda GPUMem
lam'
    )
  where
    onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
forall a. Maybe a
Nothing))
    onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
      Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
      let -- XXX: fake a SegLevel, which we don't have here.  We will not
          -- use it for anything, as we will not allow irregular
          -- allocations inside the update function.
          lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> SubExp -> Count NumGroups SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SegVirt
SegNoVirt
          (Lambda GPUMem
op_lam', Extraction
lam_allocs) =
            (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside Names
forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
          variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
          variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
          (Extraction
variant_allocs, Extraction
invariant_allocs) = (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs

      case Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
        ((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
          String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
            String
"Cannot handle un-sliceable allocation size: "
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ SubExp -> String
forall a. Pretty a => a -> String
pretty SubExp
v
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside accumulator update operator."
        [] ->
          ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      let num_is :: Int
num_is = ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
          is :: [VName]
is = (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName])
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
take Int
num_is ([Param (MemInfo SubExp NoUniqueness MemBind)]
 -> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
      (Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
        ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations ((ShapeBase SubExp, [TPrimExp Int64 VName])
-> (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. a -> b -> a
const (ShapeBase SubExp
shape, (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs

      Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      let scope' :: Scope GPUMem
scope' = Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
      (String
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> ((Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
    -> ReaderT
         (Scope GPUMem)
         (StateT VNameSource (Either String))
         (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
     String
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either
   String
   (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> Either
     String
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$
        Scope GPUMem
-> RebaseMap
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
     String
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM
   (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
 -> Either
      String
      (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> Either
     String
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall a b. (a -> b) -> a -> b
$ do
          Lambda GPUMem
op_lam'' <- Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op_lam'
          (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
-> OffsetM
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, (ShapeBase SubExp
shape, b
arrs, (Lambda GPUMem, b) -> Maybe (Lambda GPUMem, b)
forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp Exp GPUMem
e =
  (Stms GPUMem, Exp GPUMem) -> ExpandM (Stms GPUMem, Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, Exp GPUMem
e)

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda GPUMem] ->
  KernelBody GPUMem ->
  ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
  Names
bound_outside <- (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Scope GPUMem -> Names)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) Names)
-> (Scope GPUMem -> Names)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) Names
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names)
-> (Scope GPUMem -> [VName]) -> Scope GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys
  let user :: (SegLevel, [TPrimExp Int64 VName])
user = (SegLevel
lvl, [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
      (KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
        (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
      ([Lambda GPUMem]
ops', [Extraction]
ops_allocs) = [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction]))
-> [(Lambda GPUMem, Extraction)] -> ([Lambda GPUMem], [Extraction])
forall a b. (a -> b) -> a -> b
$ (Lambda GPUMem -> (Lambda GPUMem, Extraction))
-> [Lambda GPUMem] -> [(Lambda GPUMem, Extraction)]
forall a b. (a -> b) -> [a] -> [b]
map ((SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
      variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
      variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
      (Extraction
variant_allocs, Extraction
invariant_allocs) =
        (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> Extraction -> (Extraction, Extraction)
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc (Extraction -> (Extraction, Extraction))
-> Extraction -> (Extraction, Extraction)
forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs Extraction -> Extraction -> Extraction
forall a. Semigroup a => a -> a -> a
<> [Extraction] -> Extraction
forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      badVariant :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_in_kernel
      badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False

  case (((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool)
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ([((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
 -> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
-> Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
forall a b. (a -> b) -> a -> b
$ Extraction -> [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
    Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
      String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String
 -> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ())
-> String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall a b. (a -> b) -> a -> b
$
        String
"Cannot handle un-sliceable allocation size: "
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> String
forall a. Pretty a => a -> String
pretty ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
          String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside parallel constructs."
    Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  case SegLevel
lvl of
    SegGroup {}
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
          String
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegGroup."
    SegLevel
_ ->
      ()
-> ReaderT (Scope GPUMem) (StateT VNameSource (Either String)) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' ((Stms GPUMem
  -> KernelBody GPUMem
  -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
 -> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> (Stms GPUMem
    -> KernelBody GPUMem
    -> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem)))
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall a b. (a -> b) -> a -> b
$ \Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
    [Lambda GPUMem]
ops'' <- [Lambda GPUMem]
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' ((Lambda GPUMem -> OffsetM (Lambda GPUMem))
 -> OffsetM [Lambda GPUMem])
-> (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> OffsetM [Lambda GPUMem]
forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
      Scope GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Lambda GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op'
    (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
-> OffsetM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList (Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> Map VName (NameInfo Any) -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> Map VName (NameInfo Any)
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody

boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList ([VName] -> Names)
-> (KernelBody GPUMem -> [VName]) -> KernelBody GPUMem -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPUMem -> [VName]
forall k a. Map k a -> [k]
M.keys (Scope GPUMem -> [VName])
-> (KernelBody GPUMem -> Scope GPUMem)
-> KernelBody GPUMem
-> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf (Stms GPUMem -> Scope GPUMem)
-> (KernelBody GPUMem -> Stms GPUMem)
-> KernelBody GPUMem
-> Scope GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms

allocsForBody ::
  Extraction ->
  Extraction ->
  SegLevel ->
  SegSpace ->
  KernelBody GPUMem ->
  (Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
    SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
      SegLevel
lvl
      SegSpace
space
      (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody')
      Extraction
variant_allocs
      Extraction
invariant_allocs

  Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let scope' :: Scope GPUMem
scope' = SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  (String -> ExpandM b)
-> (b -> ExpandM b) -> Either String b -> ExpandM b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ExpandM b
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError b -> ExpandM b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String b -> ExpandM b) -> Either String b -> ExpandM b
forall a b. (a -> b) -> a -> b
$
    Scope GPUMem -> RebaseMap -> OffsetM b -> Either String b
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets (OffsetM b -> Either String b) -> OffsetM b -> Either String b
forall a b. (a -> b) -> a -> b
$ do
      KernelBody GPUMem
kbody'' <- KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody'
      Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m Stms GPUMem
alloc_stms KernelBody GPUMem
kbody''

memoryRequirements ::
  SegLevel ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  (SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
    Builder GPUMem SubExp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem SubExp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SubExp, Stms GPUMem))
-> (BasicOp -> Builder GPUMem SubExp)
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> Builder GPUMem SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" (Exp GPUMem -> Builder GPUMem SubExp)
-> (BasicOp -> Exp GPUMem) -> BasicOp -> Builder GPUMem SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPUMem
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (SubExp, Stms GPUMem))
-> BasicOp
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (SubExp, Stms GPUMem)
forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
        (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)

  (Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
        SubExp
num_threads
        (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
        (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
        Extraction
invariant_allocs

  (Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    Stms GPUMem
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms (ExpandM (Stms GPUMem, RebaseMap)
 -> ExpandM (Stms GPUMem, RebaseMap))
-> ExpandM (Stms GPUMem, RebaseMap)
-> ExpandM (Stms GPUMem, RebaseMap)
forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms GPUMem
kstms
        Extraction
variant_allocs

  (RebaseMap, Stms GPUMem) -> ExpandM (RebaseMap, Stms GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( RebaseMap
invariant_alloc_offsets RebaseMap -> RebaseMap -> RebaseMap
forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms GPUMem
num_threads_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
    )

-- | Identifying the spot where an allocation occurs in terms of its
-- level and unique thread ID.
type User = (SegLevel, [TPrimExp Int64 VName])

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (User, SubExp, Space)

extractKernelBodyAllocations ::
  User ->
  Names ->
  Names ->
  KernelBody GPUMem ->
  ( KernelBody GPUMem,
    Extraction
  )
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
  (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (KernelBody GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms ((Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
 -> KernelBody GPUMem -> (KernelBody GPUMem, Extraction))
-> (Stms GPUMem -> KernelBody GPUMem -> KernelBody GPUMem)
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}

extractBodyAllocations ::
  User ->
  Names ->
  Names ->
  Body GPUMem ->
  (Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
  (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (Body GPUMem -> Stms GPUMem)
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms ((Stms GPUMem -> Body GPUMem -> Body GPUMem)
 -> Body GPUMem -> (Body GPUMem, Extraction))
-> (Stms GPUMem -> Body GPUMem -> Body GPUMem)
-> Body GPUMem
-> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
stms}

extractLambdaAllocations ::
  User ->
  Names ->
  Names ->
  Lambda GPUMem ->
  (Lambda GPUMem, Extraction)
extractLambdaAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam = (Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body'}, Extraction
allocs)
  where
    (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Body GPUMem -> (Body GPUMem, Extraction))
-> Body GPUMem -> (Body GPUMem, Extraction)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam

extractGenericBodyAllocations ::
  User ->
  Names ->
  Names ->
  (body -> Stms GPUMem) ->
  (Stms GPUMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
  let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> Names
forall rep. Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
      ([Stm GPUMem]
stms, Extraction
allocs) =
        Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall w a. Writer w a -> (a, w)
runWriter (Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction))
-> Writer Extraction [Stm GPUMem] -> ([Stm GPUMem], Extraction)
forall a b. (a -> b) -> a -> b
$
          ([Maybe (Stm GPUMem)] -> [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe (Stm GPUMem)] -> [Stm GPUMem]
forall a. [Maybe a] -> [a]
catMaybes (WriterT Extraction Identity [Maybe (Stm GPUMem)]
 -> Writer Extraction [Stm GPUMem])
-> WriterT Extraction Identity [Maybe (Stm GPUMem)]
-> Writer Extraction [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$
            (Stm GPUMem -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') ([Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)])
-> [Stm GPUMem] -> WriterT Extraction Identity [Maybe (Stm GPUMem)]
forall a b. (a -> b) -> a -> b
$
              Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$
                body -> Stms GPUMem
get_stms body
body
   in (Stms GPUMem -> body -> body
set_stms ([Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)

expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space String
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True

extractStmAllocations ::
  User ->
  Names ->
  Names ->
  Stm GPUMem ->
  Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pat [PatElem (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
  | Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
      -- FIXME: the '&& notScalar space' part is a hack because we
      -- don't otherwise hoist the sizes out far enough, and we
      -- promise to be super-duper-careful about not having variant
      -- scalar allocations.
      Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Extraction -> WriterT Extraction Identity ())
-> Extraction -> WriterT Extraction Identity ()
forall a b. (a -> b) -> a -> b
$ VName
-> ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
-> Extraction
forall k a. k -> a -> Map k a
M.singleton (PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem (MemInfo SubExp NoUniqueness MemBind)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
      Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stm GPUMem)
forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
  Exp GPUMem
e <- Mapper GPUMem GPUMem (WriterT Extraction Identity)
-> Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) (Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem))
-> Exp GPUMem -> WriterT Extraction Identity (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp Stm GPUMem
stm
  Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Stm GPUMem)
 -> WriterT Extraction Identity (Maybe (Stm GPUMem)))
-> Maybe (Stm GPUMem)
-> WriterT Extraction Identity (Maybe (Stm GPUMem))
forall a b. (a -> b) -> a -> b
$ Stm GPUMem -> Maybe (Stm GPUMem)
forall a. a -> Maybe a
Just (Stm GPUMem -> Maybe (Stm GPUMem))
-> Stm GPUMem -> Maybe (Stm GPUMem)
forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = Exp GPUMem
e}
  where
    expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      Mapper GPUMem GPUMem (WriterT Extraction Identity)
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
mapOnBody = (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. a -> b -> a
const ((Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
 -> Scope GPUMem
 -> Body GPUMem
 -> WriterT Extraction Identity (Body GPUMem))
-> (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Scope GPUMem
-> Body GPUMem
-> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user',
          mapOnOp :: Op GPUMem -> WriterT Extraction Identity (Op GPUMem)
mapOnOp = (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel, [TPrimExp Int64 VName])
user'
        }

    onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
      let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPUMem
body'

    onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp GPUMem () -> MemOp (HostOp GPUMem ())
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem () -> MemOp (HostOp GPUMem ()))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem ())
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem ()
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem ()))
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
-> SegOp SegLevel GPUMem
-> WriterT Extraction Identity (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
      where
        user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
          (SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
    onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp GPUMem ())
op = MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp GPUMem ())
op

    opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      SegOpMapper SegLevel Any Any (WriterT Extraction Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
mapOnSegOpLambda = (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user',
          mapOnSegOpBody :: KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
mapOnSegOpBody = (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user'
        }

    onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
      let (KernelBody GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
      Extraction -> WriterT Extraction Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
body'

    onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
      Body GPUMem
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' (Body GPUMem -> WriterT Extraction Identity (Body GPUMem))
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
      Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

genericExpandedInvariantAllocations ::
  (User -> (Shape, [TPrimExp Int64 VName])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the number of kernel threads.
  ([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- Builder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([RebaseMap], Stms GPUMem)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPUMem [RebaseMap]
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      ([RebaseMap], Stms GPUMem))
-> Builder GPUMem [RebaseMap]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([RebaseMap], Stms GPUMem)
forall a b. (a -> b) -> a -> b
$ ((VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
 -> BuilderT GPUMem (State VNameSource) RebaseMap)
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
 -> Builder GPUMem [RebaseMap])
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Builder GPUMem [RebaseMap]
forall a b. (a -> b) -> a -> b
$ Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
      let num_users :: ShapeBase SubExp
num_users = (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a, b) -> a
fst ((ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp)
-> (ShapeBase SubExp, [TPrimExp Int64 VName]) -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      VName
total_size <-
        String
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"total_size" (Exp GPUMem -> BuilderT GPUMem (State VNameSource) VName)
-> ([TPrimExp Int64 VName]
    -> BuilderT GPUMem (State VNameSource) (Exp GPUMem))
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TPrimExp Int64 VName
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName
 -> BuilderT GPUMem (State VNameSource) (Exp GPUMem))
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) (Exp GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName]
 -> BuilderT GPUMem (State VNameSource) VName)
-> [TPrimExp Int64 VName]
-> BuilderT GPUMem (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size TPrimExp Int64 VName
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. a -> [a] -> [a]
: (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
      Pat (LetDec (Rep (BuilderT GPUMem (State VNameSource))))
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPUMem (State VNameSource))))
Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (Exp (Rep (BuilderT GPUMem (State VNameSource)))
 -> BuilderT GPUMem (State VNameSource) ())
-> Exp (Rep (BuilderT GPUMem (State VNameSource)))
-> BuilderT GPUMem (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
      RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure (RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap)
-> RebaseMap -> BuilderT GPUMem (State VNameSource) RebaseMap
forall a b. (a -> b) -> a -> b
$ VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase (SegLevel, [TPrimExp Int64 VName])
user

    untouched :: d -> DimIndex d
untouched d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1

    newBase :: (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          num_dims :: Int
num_dims = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
          perm :: [Int]
perm = [Int
num_dims .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
users_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName]
old_shape [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape))
          permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = IxFun (TPrimExp Int64 VName)
-> [Int] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun
    newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegGroup {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
old_shape
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> ([DimIndex (TPrimExp Int64 VName)]
    -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> IxFun (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun

expandedInvariantAllocations ::
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_groups) (Count SubExp
group_size) =
  ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
  where
    getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegGroup {}, [TPrimExp Int64 VName
gid]) = ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups], [TPrimExp Int64 VName
gid])
    getNumUsers (SegLevel, [TPrimExp Int64 VName])
user = String -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a. HasCallStack => String -> a
error (String -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> String -> (ShapeBase SubExp, [TPrimExp Int64 VName])
forall a b. (a -> b) -> a -> b
$ String
"getNumUsers: unhandled " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (SegLevel, [TPrimExp Int64 VName]) -> String
forall a. Show a => a -> String
show (SegLevel, [TPrimExp Int64 VName])
user

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
  | Extraction -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
forall a. Monoid a => a
mempty, RebaseMap
forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: [SubExp]
variant_sizes = ((SubExp, [(VName, Space)]) -> SubExp)
-> [(SubExp, [(VName, Space)])] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
variant_sizes SegSpace
kspace Stms GPUMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  Stms GPUMem
slice_stms_tmp <- Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (Stms GPUMem)
simplifyStms (Stms GPUMem
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem))
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
  Stms GPUMem
slice_stms' <- Stms GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(VName, (SubExp, SubExp, Space))]]
 -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, (SubExp, SubExp, Space))]]
-> [(VName, (SubExp, SubExp, Space))]
forall a b. (a -> b) -> a -> b
$
          ([(VName, Space)]
 -> (VName, VName) -> [(VName, (SubExp, SubExp, Space))])
-> [[(VName, Space)]]
-> [(VName, VName)]
-> [[(VName, (SubExp, SubExp, Space))]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            [(VName, Space)]
-> (VName, VName) -> [(VName, (SubExp, SubExp, Space))]
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
            (((SubExp, [(VName, Space)]) -> [(VName, Space)])
-> [(SubExp, [(VName, Space)])] -> [[(VName, Space)]]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, [(VName, Space)]) -> [(VName, Space)]
forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
            ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm GPUMem]
alloc_stms, [RebaseMap]
rebases) <- [(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stm GPUMem, RebaseMap)] -> ([Stm GPUMem], [RebaseMap]))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stm GPUMem, RebaseMap)]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     ([Stm GPUMem], [RebaseMap])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((VName, (SubExp, SubExp, Space))
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Stm GPUMem, RebaseMap))
-> [(VName, (SubExp, SubExp, Space))]
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     [(Stm GPUMem, RebaseMap)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  (Stms GPUMem, RebaseMap) -> ExpandM (Stms GPUMem, RebaseMap)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
slice_stms' Stms GPUMem -> Stms GPUMem -> Stms GPUMem
forall a. Semigroup a => a -> a -> a
<> [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_stms, [RebaseMap] -> RebaseMap
forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
      let allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
mem (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      (Stm GPUMem, RebaseMap)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp GPUMem -> Stm GPUMem) -> Exp GPUMem -> Stm GPUMem
forall a b. (a -> b) -> a -> b
$ Op GPUMem -> Exp GPUMem
forall rep. Op rep -> Exp rep
Op (Op GPUMem -> Exp GPUMem) -> Op GPUMem -> Exp GPUMem
forall a b. (a -> b) -> a -> b
$ SubExp -> Space -> MemOp (HostOp GPUMem ())
forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
          VName
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall k a. k -> a -> Map k a
M.singleton VName
mem ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> RebaseMap)
-> (([TPrimExp Int64 VName], PrimType)
    -> IxFun (TPrimExp Int64 VName))
-> RebaseMap
forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
        )

    num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
    gtid :: TPrimExp Int64 VName
gtid = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName) -> VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
      let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
            SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> ([DimIndex (TPrimExp Int64 VName)]
    -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> IxFun (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              [TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1, TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]
       in if [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
            then IxFun (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.coerce IxFun (TPrimExp Int64 VName)
offset_ixfun [TPrimExp Int64 VName]
old_shape
            else IxFun (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [TPrimExp Int64 VName]
old_shape

-- | A map from memory block names to new index function bases.
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)

newtype OffsetM a
  = OffsetM
      ( ReaderT
          (Scope GPUMem)
          (ReaderT RebaseMap (Either String))
          a
      )
  deriving
    ( Functor OffsetM
Functor OffsetM
-> (forall a. a -> OffsetM a)
-> (forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b c.
    (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM a)
-> Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: forall a. a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
Applicative,
      (forall a b. (a -> b) -> OffsetM a -> OffsetM b)
-> (forall a b. a -> OffsetM b -> OffsetM a) -> Functor OffsetM
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
      Applicative OffsetM
Applicative OffsetM
-> (forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b)
-> (forall a b. OffsetM a -> OffsetM b -> OffsetM b)
-> (forall a. a -> OffsetM a)
-> Monad OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
Monad,
      HasScope GPUMem,
      LocalScope GPUMem,
      MonadError String
    )

runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) =
  ReaderT RebaseMap (Either String) a -> RebaseMap -> Either String a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> Scope GPUMem -> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = ReaderT
  (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
-> OffsetM RebaseMap
forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM (ReaderT
   (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
 -> OffsetM RebaseMap)
-> ReaderT
     (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
-> OffsetM RebaseMap
forall a b. (a -> b) -> a -> b
$ ReaderT RebaseMap (Either String) RebaseMap
-> ReaderT
     (Scope GPUMem) (ReaderT RebaseMap (Either String)) RebaseMap
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT RebaseMap (Either String) RebaseMap
forall r (m :: * -> *). MonadReader r m => m r
ask

localRebaseMap :: (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap :: forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
f (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) = ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM (ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
 -> OffsetM a)
-> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
forall a b. (a -> b) -> a -> b
$ do
  Scope GPUMem
scope <- ReaderT
  (Scope GPUMem) (ReaderT RebaseMap (Either String)) (Scope GPUMem)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ReaderT RebaseMap (Either String) a
-> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT RebaseMap (Either String) a
 -> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a)
-> ReaderT RebaseMap (Either String) a
-> ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
forall a b. (a -> b) -> a -> b
$ (RebaseMap -> RebaseMap)
-> ReaderT RebaseMap (Either String) a
-> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local RebaseMap -> RebaseMap
f (ReaderT RebaseMap (Either String) a
 -> ReaderT RebaseMap (Either String) a)
-> ReaderT RebaseMap (Either String) a
-> ReaderT RebaseMap (Either String) a
forall a b. (a -> b) -> a -> b
$ ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> Scope GPUMem -> ReaderT RebaseMap (Either String) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope

lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (IxFun (TPrimExp Int64 VName))
 -> OffsetM (Maybe (IxFun (TPrimExp Int64 VName))))
-> Maybe (IxFun (TPrimExp Int64 VName))
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
forall a b. (a -> b) -> a -> b
$ ((([TPrimExp Int64 VName], PrimType)
 -> IxFun (TPrimExp Int64 VName))
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) ((([TPrimExp Int64 VName], PrimType)
  -> IxFun (TPrimExp Int64 VName))
 -> IxFun (TPrimExp Int64 VName))
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
-> Maybe (IxFun (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> RebaseMap
-> Maybe
     (([TPrimExp Int64 VName], PrimType)
      -> IxFun (TPrimExp Int64 VName))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody = do
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
      ((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
 -> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList (Stms GPUMem -> [Stm GPUMem]) -> Stms GPUMem -> [Stm GPUMem]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
  KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms'}

offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) = do
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    [Stm GPUMem] -> Stms GPUMem
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPUMem] -> Stms GPUMem)
-> ((Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem])
-> (Scope GPUMem, [Stm GPUMem])
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scope GPUMem, [Stm GPUMem]) -> [Stm GPUMem]
forall a b. (a, b) -> b
snd
      ((Scope GPUMem, [Stm GPUMem]) -> Stms GPUMem)
-> OffsetM (Scope GPUMem, [Stm GPUMem]) -> OffsetM (Stms GPUMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scope GPUMem -> Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Scope GPUMem
-> [Stm GPUMem]
-> OffsetM (Scope GPUMem, [Stm GPUMem])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> Scope GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' (OffsetM (Scope GPUMem, Stm GPUMem)
 -> OffsetM (Scope GPUMem, Stm GPUMem))
-> (Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem))
-> Stm GPUMem
-> OffsetM (Scope GPUMem, Stm GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
  Body GPUMem -> OffsetM (Body GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ BodyDec GPUMem -> Stms GPUMem -> Result -> Body GPUMem
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPUMem
dec Stms GPUMem
stms' Result
res

offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) = do
  Exp GPUMem
e' <- Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp Exp GPUMem
e
  Pat (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
offsetMemoryInPat Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat ([ExpReturns]
 -> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind)))
-> OffsetM [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp GPUMem -> OffsetM [ExpReturns]
forall rep (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp GPUMem
e'
  Scope GPUMem
scope <- OffsetM (Scope GPUMem)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- ReaderT (Scope GPUMem) OffsetM [ExpReturns]
-> Scope GPUMem -> OffsetM [ExpReturns]
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Exp GPUMem -> ReaderT (Scope GPUMem) OffsetM [ExpReturns]
forall rep (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp GPUMem
e') Scope GPUMem
scope
  let pat'' :: Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' = [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
pick (Pat (MemInfo SubExp NoUniqueness MemBind)
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts
      stm :: Stm GPUMem
stm = Pat (LetDec GPUMem)
-> StmAux (ExpDec GPUMem) -> Exp GPUMem -> Stm GPUMem
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec GPUMem)
dec Exp GPUMem
e'
  let scope' :: Scope GPUMem
scope' = Stm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stm GPUMem
stm Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  (Scope GPUMem, Stm GPUMem) -> OffsetM (Scope GPUMem, Stm GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scope GPUMem
scope', Stm GPUMem
stm)
  where
    pick ::
      PatElem (MemInfo SubExp NoUniqueness MemBind) ->
      ExpReturns ->
      PatElem (MemInfo SubExp NoUniqueness MemBind)
    pick :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
pick
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
        | Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
            VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
    pick PatElem (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElem (MemInfo SubExp NoUniqueness MemBind)
p

    instantiateIxFun :: ExtIxFun -> Maybe IxFun
    instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = (TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName))
-> ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Ext VName -> Maybe VName)
-> TPrimExp Int64 (Ext VName) -> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ext VName -> Maybe VName
forall {a}. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = Maybe a
forall a. Maybe a
Nothing
        inst (Free a
x) = a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

offsetMemoryInPat :: Pat LetDecMem -> [ExpReturns] -> OffsetM (Pat LetDecMem)
offsetMemoryInPat :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
offsetMemoryInPat (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [ExpReturns]
rets = do
  [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Pat (MemInfo SubExp NoUniqueness MemBind)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
 -> Pat (MemInfo SubExp NoUniqueness MemBind))
-> OffsetM [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> ExpReturns
 -> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind)))
-> [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> [ExpReturns]
-> OffsetM [PatElem (MemInfo SubExp NoUniqueness MemBind)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
onPE [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [ExpReturns]
rets
  where
    onPE :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
onPE
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
ixfun))) =
        PatElem (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatElem (MemInfo SubExp NoUniqueness MemBind)
 -> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind)))
-> (IxFun (TPrimExp Int64 VName)
    -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> IxFun (TPrimExp Int64 VName)
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> MemInfo SubExp NoUniqueness MemBind
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (MemInfo SubExp NoUniqueness MemBind
 -> PatElem (MemInfo SubExp NoUniqueness MemBind))
-> (IxFun (TPrimExp Int64 VName)
    -> MemInfo SubExp NoUniqueness MemBind)
-> IxFun (TPrimExp Int64 VName)
-> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName)
-> MemInfo SubExp NoUniqueness MemBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName)
 -> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind)))
-> IxFun (TPrimExp Int64 VName)
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
forall a b. (a -> b) -> a -> b
$
          (TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName)
-> ExtIxFun -> IxFun (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Ext VName -> VName)
-> TPrimExp Int64 (Ext VName) -> TPrimExp Int64 VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
unExt) ExtIxFun
ixfun
    onPE PatElem (MemInfo SubExp NoUniqueness MemBind)
pe ExpReturns
_ = do
      MemInfo SubExp NoUniqueness MemBind
new_dec <- MemInfo SubExp NoUniqueness MemBind
-> OffsetM (MemInfo SubExp NoUniqueness MemBind)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemInfo SubExp NoUniqueness MemBind
 -> OffsetM (MemInfo SubExp NoUniqueness MemBind))
-> MemInfo SubExp NoUniqueness MemBind
-> OffsetM (MemInfo SubExp NoUniqueness MemBind)
forall a b. (a -> b) -> a -> b
$ PatElem (MemInfo SubExp NoUniqueness MemBind)
-> MemInfo SubExp NoUniqueness MemBind
forall dec. PatElem dec -> dec
patElemDec PatElem (MemInfo SubExp NoUniqueness MemBind)
pe
      PatElem (MemInfo SubExp NoUniqueness MemBind)
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem (MemInfo SubExp NoUniqueness MemBind)
pe {patElemDec :: MemInfo SubExp NoUniqueness MemBind
patElemDec = MemInfo SubExp NoUniqueness MemBind
new_dec}
    unExt :: Ext VName -> VName
unExt (Ext Int
i) = PatElem (MemInfo SubExp NoUniqueness MemBind) -> VName
forall dec. PatElem dec -> VName
patElemName ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [PatElem (MemInfo SubExp NoUniqueness MemBind)]
-> Int -> PatElem (MemInfo SubExp NoUniqueness MemBind)
forall a. [a] -> Int -> a
!! Int
i)
    unExt (Free VName
v) = VName
v

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- MemBound u -> OffsetM (MemBound u)
forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound (MemBound u -> OffsetM (MemBound u))
-> MemBound u -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ Param (MemBound u) -> MemBound u
forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
  Param (MemBound u) -> OffsetM (Param (MemBound u))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
  Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
  MemBound u -> OffsetM (MemBound u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemBound u -> OffsetM (MemBound u))
-> (Maybe (MemBound u) -> MemBound u)
-> Maybe (MemBound u)
-> OffsetM (MemBound u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBound u -> Maybe (MemBound u) -> MemBound u
forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary (Maybe (MemBound u) -> OffsetM (MemBound u))
-> Maybe (MemBound u) -> OffsetM (MemBound u)
forall a b. (a -> b) -> a -> b
$ do
    IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
    MemBound u -> Maybe (MemBound u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemBound u -> Maybe (MemBound u))
-> MemBound u -> Maybe (MemBound u)
forall a b. (a -> b) -> a -> b
$ PrimType -> ShapeBase SubExp -> u -> MemBind -> MemBound u
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u (MemBind -> MemBound u) -> MemBind -> MemBound u
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem (IxFun (TPrimExp Int64 VName) -> MemBind)
-> IxFun (TPrimExp Int64 VName) -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName)
-> IxFun (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = MemBound u -> OffsetM (MemBound u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
      Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (IxFun (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
      BranchTypeMem -> OffsetM BranchTypeMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BranchTypeMem -> OffsetM BranchTypeMem)
-> (Maybe BranchTypeMem -> BranchTypeMem)
-> Maybe BranchTypeMem
-> OffsetM BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BranchTypeMem -> Maybe BranchTypeMem -> BranchTypeMem
forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br (Maybe BranchTypeMem -> OffsetM BranchTypeMem)
-> Maybe BranchTypeMem -> OffsetM BranchTypeMem
forall a b. (a -> b) -> a -> b
$ do
        IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
        BranchTypeMem -> Maybe BranchTypeMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BranchTypeMem -> Maybe BranchTypeMem)
-> (ExtIxFun -> BranchTypeMem) -> ExtIxFun -> Maybe BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType
-> ShapeBase ExtSize -> NoUniqueness -> MemReturn -> BranchTypeMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (MemReturn -> BranchTypeMem)
-> (ExtIxFun -> MemReturn) -> ExtIxFun -> BranchTypeMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem (ExtIxFun -> Maybe BranchTypeMem)
-> ExtIxFun -> Maybe BranchTypeMem
forall a b. (a -> b) -> a -> b
$
          ExtIxFun -> ExtIxFun -> ExtIxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase ((TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName))
-> IxFun (TPrimExp Int64 VName) -> ExtIxFun
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> Ext VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 (Ext VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Ext VName
forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = BranchTypeMem -> OffsetM BranchTypeMem
forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
br

offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
lam = Lambda GPUMem -> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam (OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem))
-> OffsetM (Lambda GPUMem) -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ do
  Body GPUMem
body <- Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem -> OffsetM (Body GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
  Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPUMem -> OffsetM (Lambda GPUMem))
-> Lambda GPUMem -> OffsetM (Lambda GPUMem)
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

-- A loop may have memory parameters, and those memory blocks may
-- be expanded.  We assume (but do not check - FIXME) that if the
-- initial value of a loop parameter is an expanded memory block,
-- then so will the result be.
offsetMemoryInLoopParams ::
  [(FParam GPUMem, SubExp)] ->
  ([(FParam GPUMem, SubExp)] -> OffsetM a) ->
  OffsetM a
offsetMemoryInLoopParams :: forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge [(FParam GPUMem, SubExp)] -> OffsetM a
f = do
  let ([Param (MemInfo SubExp Uniqueness MemBind)]
params, [SubExp]
args) = [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> ([Param (MemInfo SubExp Uniqueness MemBind)], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
  (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
extend (OffsetM a -> OffsetM a) -> OffsetM a -> OffsetM a
forall a b. (a -> b) -> a -> b
$ do
    [Param (MemInfo SubExp Uniqueness MemBind)]
params' <- (Param (MemInfo SubExp Uniqueness MemBind)
 -> OffsetM (Param (MemInfo SubExp Uniqueness MemBind)))
-> [Param (MemInfo SubExp Uniqueness MemBind)]
-> OffsetM [Param (MemInfo SubExp Uniqueness MemBind)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo SubExp Uniqueness MemBind)
-> OffsetM (Param (MemInfo SubExp Uniqueness MemBind))
forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
params
    [(FParam GPUMem, SubExp)] -> OffsetM a
f ([(FParam GPUMem, SubExp)] -> OffsetM a)
-> [(FParam GPUMem, SubExp)] -> OffsetM a
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp Uniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
params' [SubExp]
args
  where
    extend :: RebaseMap -> RebaseMap
extend RebaseMap
rm = (RebaseMap
 -> (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> RebaseMap)
-> RebaseMap
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> RebaseMap
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' RebaseMap
-> (Param (MemInfo SubExp Uniqueness MemBind), SubExp) -> RebaseMap
forall {a} {dec}. Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg RebaseMap
rm [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge
    onParamArg :: Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg Map VName a
rm (Param dec
param, Var VName
arg)
      | Just a
x <- VName -> Map VName a -> Maybe a
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arg Map VName a
rm =
          VName -> a -> Map VName a -> Map VName a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) a
x Map VName a
rm
    onParamArg Map VName a
rm (Param dec, SubExp)
_ = Map VName a
rm

offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp (DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form Body GPUMem
body) = do
  [(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem)
forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge (([(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
 -> OffsetM (Exp GPUMem))
-> ([(FParam GPUMem, SubExp)] -> OffsetM (Exp GPUMem))
-> OffsetM (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ \[(FParam GPUMem, SubExp)]
merge' -> do
    Body GPUMem
body' <-
      Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
        ([Param (MemInfo SubExp Uniqueness MemBind)] -> Scope GPUMem
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (MemInfo SubExp Uniqueness MemBind), SubExp)
 -> Param (MemInfo SubExp Uniqueness MemBind))
-> [(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
-> [Param (MemInfo SubExp Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo SubExp Uniqueness MemBind), SubExp)
-> Param (MemInfo SubExp Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(FParam GPUMem, SubExp)]
[(Param (MemInfo SubExp Uniqueness MemBind), SubExp)]
merge') Scope GPUMem -> Scope GPUMem -> Scope GPUMem
forall a. Semigroup a => a -> a -> a
<> LoopForm GPUMem -> Scope GPUMem
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPUMem
form)
        (Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody Body GPUMem
body)
    Exp GPUMem -> OffsetM (Exp GPUMem)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp GPUMem -> OffsetM (Exp GPUMem))
-> Exp GPUMem -> OffsetM (Exp GPUMem)
forall a b. (a -> b) -> a -> b
$ [(FParam GPUMem, SubExp)]
-> LoopForm GPUMem -> Body GPUMem -> Exp GPUMem
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam GPUMem, SubExp)]
merge' LoopForm GPUMem
form Body GPUMem
body'
offsetMemoryInExp Exp GPUMem
e = Mapper GPUMem GPUMem OffsetM -> Exp GPUMem -> OffsetM (Exp GPUMem)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse Exp GPUMem
e
  where
    recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
      Mapper GPUMem GPUMem OffsetM
forall (m :: * -> *) rep. Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem -> Body GPUMem -> OffsetM (Body GPUMem)
mapOnBody = \Scope GPUMem
bscope -> Scope GPUMem -> OffsetM (Body GPUMem) -> OffsetM (Body GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope (OffsetM (Body GPUMem) -> OffsetM (Body GPUMem))
-> (Body GPUMem -> OffsetM (Body GPUMem))
-> Body GPUMem
-> OffsetM (Body GPUMem)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody,
          mapOnBranchType :: BranchType GPUMem -> OffsetM (BranchType GPUMem)
mapOnBranchType = BranchType GPUMem -> OffsetM (BranchType GPUMem)
BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
          mapOnOp :: Op GPUMem -> OffsetM (Op GPUMem)
mapOnOp = Op GPUMem -> OffsetM (Op GPUMem)
forall {op}.
MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp
        }
    onOp :: MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      HostOp GPUMem op -> MemOp (HostOp GPUMem op)
forall inner. inner -> MemOp inner
Inner (HostOp GPUMem op -> MemOp (HostOp GPUMem op))
-> (SegOp SegLevel GPUMem -> HostOp GPUMem op)
-> SegOp SegLevel GPUMem
-> MemOp (HostOp GPUMem op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegOp SegLevel GPUMem -> HostOp GPUMem op
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp
        (SegOp SegLevel GPUMem -> MemOp (HostOp GPUMem op))
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (MemOp (HostOp GPUMem op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope GPUMem
-> OffsetM (SegOp SegLevel GPUMem)
-> OffsetM (SegOp SegLevel GPUMem)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPUMem
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (SegOpMapper SegLevel GPUMem GPUMem OffsetM
-> SegOp SegLevel GPUMem -> OffsetM (SegOp SegLevel GPUMem)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPUMem OffsetM
forall {lvl}. SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
      where
        segOpMapper :: SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper =
          SegOpMapper lvl Any Any OffsetM
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody,
              mapOnSegOpLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda
            }
    onOp MemOp (HostOp GPUMem op)
op = MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp GPUMem op)
op

---- Slicing allocation sizes out of a kernel.

unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
False
  where
    unAllocBody :: Body GPUMem -> Either String (Body GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
      BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> Result -> Body GPU)
-> Either String (Stms GPU) -> Either String (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String (Result -> Body GPU)
-> Either String Result -> Either String (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> Either String Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPU
BodyDec GPUMem
dec (Stms GPU -> [KernelResult] -> KernelBody GPU)
-> Either String (Stms GPU)
-> Either String ([KernelResult] -> KernelBody GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms Either String ([KernelResult] -> KernelBody GPU)
-> Either String [KernelResult] -> Either String (KernelBody GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> Either String [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
nested =
      ([Maybe (Stm GPU)] -> Stms GPU)
-> Either String [Maybe (Stm GPU)] -> Either String (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Stm GPU] -> Stms GPU
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm GPU] -> Stms GPU)
-> ([Maybe (Stm GPU)] -> [Stm GPU])
-> [Maybe (Stm GPU)]
-> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Stm GPU)] -> [Stm GPU]
forall a. [Maybe a] -> [a]
catMaybes) (Either String [Maybe (Stm GPU)] -> Either String (Stms GPU))
-> (Stms GPUMem -> Either String [Maybe (Stm GPU)])
-> Stms GPUMem
-> Either String (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm GPUMem -> Either String (Maybe (Stm GPU)))
-> [Stm GPUMem] -> Either String [Maybe (Stm GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested) ([Stm GPUMem] -> Either String [Maybe (Stm GPU)])
-> (Stms GPUMem -> [Stm GPUMem])
-> Stms GPUMem
-> Either String [Maybe (Stm GPU)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms GPUMem -> [Stm GPUMem]
forall rep. Stms rep -> [Stm rep]
stmsToList

    unAllocStm :: Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op Alloc {}))
      | Bool
nested = String -> Either String (Maybe (Stm GPU))
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> Either String (Maybe (Stm GPU)))
-> String -> Either String (Maybe (Stm GPU))
forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm GPUMem -> String
forall a. Pretty a => a -> String
pretty Stm GPUMem
stm
      | Bool
otherwise = Maybe (Stm GPU) -> Either String (Maybe (Stm GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Stm GPU)
forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) =
      Stm GPU -> Maybe (Stm GPU)
forall a. a -> Maybe a
Just (Stm GPU -> Maybe (Stm GPU))
-> Either String (Stm GPU) -> Either String (Maybe (Stm GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> StmAux () -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> StmAux () -> Exp GPU -> Stm GPU)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Either String (StmAux () -> Exp GPU -> Stm GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
Pat (MemInfo SubExp NoUniqueness MemBind)
pat Either String (StmAux () -> Exp GPU -> Stm GPU)
-> Either String (StmAux ()) -> Either String (Exp GPU -> Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StmAux () -> Either String (StmAux ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux ()
StmAux (ExpDec GPUMem)
dec Either String (Exp GPU -> Stm GPU)
-> Either String (Exp GPU) -> Either String (Stm GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mapper GPUMem GPU (Either String)
-> Exp GPUMem -> Either String (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either String)
unAlloc' Exp GPUMem
e)

    unAllocLambda :: Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
      [LParam GPU]
-> Body GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda ((Param (MemInfo SubExp NoUniqueness MemBind)
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam [LParam GPUMem]
[Param (MemInfo SubExp NoUniqueness MemBind)]
params) (Body GPU
 -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
-> Either String (Body GPU)
-> Either
     String ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either String (Body GPU)
unAllocBody Body GPUMem
body Either
  String ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Lambda GPU)
-> Either String [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either String (Lambda GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Either String [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

    unAllocPat :: Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat (Pat [PatElem (MemInfo d u ret)]
pes) =
      [PatElem (TypeBase (ShapeBase d) u)]
-> Pat (TypeBase (ShapeBase d) u)
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem (TypeBase (ShapeBase d) u)]
 -> Pat (TypeBase (ShapeBase d) u))
-> Either a [PatElem (TypeBase (ShapeBase d) u)]
-> Either a (Pat (TypeBase (ShapeBase d) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem (MemInfo d u ret)
 -> Either a (PatElem (TypeBase (ShapeBase d) u)))
-> [PatElem (MemInfo d u ret)]
-> Either a [PatElem (TypeBase (ShapeBase d) u)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u))
-> PatElem (MemInfo d u ret)
-> Either a (PatElem (TypeBase (ShapeBase d) u))
forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u)
forall a b. b -> Either a b
Right (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u))
-> (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> MemInfo d u ret
-> Either a (TypeBase (ShapeBase d) u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem)) [PatElem (MemInfo d u ret)]
pes

    unAllocOp :: MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp Alloc {} = String -> Either String (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = String -> Either String (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner GPUBody {}) = String -> Either String (HostOp GPU (SOAC GPU))
forall a b. a -> Either a b
Left String
"unAllocOp: unhandled GPUBody"
    unAllocOp (Inner (SizeOp SizeOp
op)) = HostOp GPU (SOAC GPU) -> Either String (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOp GPU (SOAC GPU) -> Either String (HostOp GPU (SOAC GPU)))
-> HostOp GPU (SOAC GPU) -> Either String (HostOp GPU (SOAC GPU))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp GPU (SOAC GPU)
forall rep op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = SegOp SegLevel GPU -> HostOp GPU (SOAC GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp (SegOp SegLevel GPU -> HostOp GPU (SOAC GPU))
-> Either String (SegOp SegLevel GPU)
-> Either String (HostOp GPU (SOAC GPU))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegOpMapper SegLevel GPUMem GPU (Either String)
-> SegOp SegLevel GPUMem -> Either String (SegOp SegLevel GPU)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either String)
mapper SegOp SegLevel GPUMem
op
      where
        mapper :: SegOpMapper SegLevel GPUMem GPU (Either String)
mapper =
          SegOpMapper SegLevel Any Any (Either String)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda GPUMem -> Either String (Lambda GPU)
mapOnSegOpLambda = Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda,
              mapOnSegOpBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
mapOnSegOpBody = KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody
            }

    unParam :: Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam = (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem

    unT :: MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT = TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u)
forall a b. b -> Either a b
Right (TypeBase (ShapeBase d) u -> Either a (TypeBase (ShapeBase d) u))
-> (MemInfo d u ret -> TypeBase (ShapeBase d) u)
-> MemInfo d u ret
-> Either a (TypeBase (ShapeBase d) u)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem

    unAlloc' :: Mapper GPUMem GPU (Either String)
unAlloc' =
      Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
        { mapOnBody :: Scope GPU -> Body GPUMem -> Either String (Body GPU)
mapOnBody = (Body GPUMem -> Either String (Body GPU))
-> Scope GPU -> Body GPUMem -> Either String (Body GPU)
forall a b. a -> b -> a
const Body GPUMem -> Either String (Body GPU)
unAllocBody,
          mapOnRetType :: RetType GPUMem -> Either String (RetType GPU)
mapOnRetType = RetType GPUMem -> Either String (RetType GPU)
forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType GPUMem -> Either String (BranchType GPU)
mapOnBranchType = BranchType GPUMem -> Either String (BranchType GPU)
forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam GPUMem -> Either String (FParam GPU)
mapOnFParam = Param (TypeBase (ShapeBase SubExp) Uniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness))
forall a b. b -> Either a b
Right (Param (TypeBase (ShapeBase SubExp) Uniqueness)
 -> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness)))
-> (Param (MemInfo SubExp Uniqueness MemBind)
    -> Param (TypeBase (ShapeBase SubExp) Uniqueness))
-> Param (MemInfo SubExp Uniqueness MemBind)
-> Either String (Param (TypeBase (ShapeBase SubExp) Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp Uniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) Uniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
          mapOnLParam :: LParam GPUMem -> Either String (LParam GPU)
mapOnLParam = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Either String (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall a b. b -> Either a b
Right (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> Either
      String (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> (Param (MemInfo SubExp NoUniqueness MemBind)
    -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> Either String (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
          mapOnOp :: Op GPUMem -> Either String (Op GPU)
mapOnOp = Op GPUMem -> Either String (Op GPU)
MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp,
          mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = SubExp -> Either String SubExp
forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either String VName
mapOnVName = VName -> Either String VName
forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem :: forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem (MemPrim PrimType
pt) = PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = PrimType -> ShapeBase d -> u -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase (ShapeBase d) u
forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = PrimType -> TypeBase (ShapeBase d) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit

unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = (NameInfo GPUMem -> NameInfo GPU) -> Scope GPUMem -> Scope GPU
forall a b k. (a -> b) -> Map k a -> Map k b
M.map NameInfo GPUMem -> NameInfo GPU
forall {rep} {d} {u} {ret} {rep} {d} {u} {d} {u} {ret} {ret}.
(FParamInfo rep ~ MemInfo d u ret,
 LParamInfo rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ TypeBase (ShapeBase d) u,
 FParamInfo rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ MemInfo d u ret, LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> NameInfo rep
unInfo
  where
    unInfo :: NameInfo rep -> NameInfo rep
unInfo (LetName LetDec rep
dec) = LetDec rep -> NameInfo rep
forall rep. LetDec rep -> NameInfo rep
LetName (LetDec rep -> NameInfo rep) -> LetDec rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LetDec rep
MemInfo d u ret
dec
    unInfo (FParamName FParamInfo rep
dec) = FParamInfo rep -> NameInfo rep
forall rep. FParamInfo rep -> NameInfo rep
FParamName (FParamInfo rep -> NameInfo rep) -> FParamInfo rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem FParamInfo rep
MemInfo d u ret
dec
    unInfo (LParamName LParamInfo rep
dec) = LParamInfo rep -> NameInfo rep
forall rep. LParamInfo rep -> NameInfo rep
LParamName (LParamInfo rep -> NameInfo rep) -> LParamInfo rep -> NameInfo rep
forall a b. (a -> b) -> a -> b
$ MemInfo d u ret -> TypeBase (ShapeBase d) u
forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LParamInfo rep
MemInfo d u ret
dec
    unInfo (IndexName IntType
it) = IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it

removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])]
forall k a. Map k a -> [(k, a)]
M.toList (Map SubExp [(VName, Space)] -> [(SubExp, [(VName, Space)])])
-> (Extraction -> Map SubExp [(VName, Space)])
-> Extraction
-> [(SubExp, [(VName, Space)])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map SubExp [(VName, Space)]
 -> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
 -> Map SubExp [(VName, Space)])
-> Map SubExp [(VName, Space)]
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
-> Map SubExp [(VName, Space)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Map SubExp [(VName, Space)]
-> (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> Map SubExp [(VName, Space)]
forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map SubExp [(VName, Space)]
forall a. Monoid a => a
mempty ([(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
 -> Map SubExp [(VName, Space)])
-> (Extraction
    -> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))])
-> Extraction
-> Map SubExp [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Extraction
-> [(VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = ([(a, b)] -> [(a, b)] -> [(a, b)])
-> k -> [(a, b)] -> Map k [(a, b)] -> Map k [(a, b)]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [(a, b)] -> [(a, b)] -> [(a, b)]
forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms GPUMem ->
  ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
sizes SegSpace
space Stms GPUMem
kstms = do
  Stms GPU
kstms' <- (String
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> (Stms GPU
    -> ReaderT
         (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Stms GPU
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String (Stms GPU)
 -> ReaderT
      (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU))
-> Either String (Stms GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
  let num_sizes :: Int
num_sizes = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
sizes
      i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate Int
num_sizes (TypeBase (ShapeBase SubExp) NoUniqueness
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope GPU
kernels_scope <- (Scope GPUMem -> Scope GPU)
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Scope GPU)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope

  (Lambda GPU
max_lam, Stms GPU
_) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"y" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Result, Stms GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Rep
            (BuilderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$
        [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
  Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
   Param (TypeBase (ShapeBase SubExp) NoUniqueness))
  -> BuilderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
       SubExpRes)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness),
     Param (TypeBase (ShapeBase SubExp) NoUniqueness))
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
          (SubExp -> SubExpRes)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> SubExpRes
subExpRes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   SubExp
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExpRes)
-> (BasicOp
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         SubExp)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"z" (Exp GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExp)
-> (BasicOp -> Exp GPU)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExpRes)
-> BasicOp
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExpRes
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
    Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$ [LParam GPU]
-> Body GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s

  Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flat_gtid" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))

  (Lambda GPU
size_lam', Stms GPU
_) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  (Lambda GPU)
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (Lambda GPU, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Lambda GPU, Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params <- Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
      ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam])
      (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Result, Stms GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result, Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result, Stms GPU)
forall a b. (a -> b) -> a -> b
$ BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
      (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Result,
       Stms
         (Rep
            (BuilderT
               GPU
               (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Result,
      Stms
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall a b. (a -> b) -> a -> b
$ do
        -- Even though this SegRed is one-dimensional, we need to
        -- provide indexes corresponding to the original potentially
        -- multi-dimensional construct.
        let ([VName]
kspace_gtids, [SubExp]
kspace_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
            new_inds :: [TPrimExp Int64 VName]
new_inds =
              [TPrimExp Int64 VName]
-> TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
kspace_dims)
                (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
        ([VName]
 -> Exp GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> [[VName]]
-> [Exp GPU]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ [VName]
-> Exp GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames ((VName -> [VName]) -> [VName] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map VName -> [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) ([Exp GPU]
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Exp GPU]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (TPrimExp Int64 VName
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Exp GPU))
-> [TPrimExp Int64 VName]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Exp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp Int64 VName
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp GPU)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds

        (Stm GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stms GPU
kstms'
        Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Result)
-> Result
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
sizes

    Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   (Lambda GPU)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Lambda GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall a b. (a -> b) -> a -> b
$
      Lambda GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Lambda GPU)
forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda ([LParam GPU]
-> Body GPU
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda GPU
forall rep.
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam GPU
flat_gtid_lparam] (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s)

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> Scope GPU
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms GPU))
-> Scope GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT
  GPU
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
  ([VName], [VName])
-> Scope GPU
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope (BuilderT
   GPU
   (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
   ([VName], [VName])
 -> ReaderT
      (Scope GPUMem)
      (StateT VNameSource (Either String))
      (([VName], [VName]), Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (([VName], [VName]), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
    Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
      [Ident] -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPat ([Ident] -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Ident]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Pat (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [Ident]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent String
"max_per_thread" (TypeBase (ShapeBase SubExp) NoUniqueness
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      Ident)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"size_slice_w"
        (Exp GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      SubExp)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Exp
        (Rep
           (BuilderT
              GPU
              (ReaderT (Scope GPUMem) (StateT VNameSource (Either String))))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"thread_space_iota" (Exp
   (Rep
      (BuilderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
    let red_op :: SegBinOp GPU
red_op =
          Commutativity
-> Lambda GPU -> [SubExp] -> ShapeBase SubExp -> SegBinOp GPU
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
            Commutativity
Commutative
            Lambda GPU
max_lam
            (Int -> SubExp -> [SubExp]
forall a. Int -> a -> [a]
replicate Int
num_sizes (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            ShapeBase SubExp
forall a. Monoid a => a
mempty
    SegLevel
lvl <- String
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     SegLevel
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> m SegLevel
segThread String
"segred"

    Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
      (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      ())
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stm GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Stm GPU))
-> Stms GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stm GPU)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      (Stms GPU
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      (Stms GPU))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel GPU
SegLevel
lvl Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- [VName]
-> (VName
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         VName)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Pat (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) ((VName
  -> BuilderT
       GPU
       (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
       VName)
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      [VName])
-> (VName
    -> BuilderT
         GPU
         (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
         VName)
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     [VName]
forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      String
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"size_sum" (Exp
   (Rep
      (BuilderT
         GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
 -> BuilderT
      GPU
      (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
      VName)
-> Exp
     (Rep
        (BuilderT
           GPU (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))))
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     VName
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp GPU) -> BasicOp -> Exp GPU
forall a b. (a -> b) -> a -> b
$
          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads

    ([VName], [VName])
-> BuilderT
     GPU
     (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
     ([VName], [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)

  (Stms GPU, [VName], [VName])
-> ExpandM (Stms GPU, [VName], [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)