{-# LANGUAGE TypeFamilies #-}

-- | Expand allocations inside of maps when possible.
module Futhark.Pass.ExpandAllocations (expandAllocations) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Either (rights)
import Data.List (find, foldl')
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Error
import Futhark.IR
import Futhark.IR.GPU.Simplify qualified as GPU
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)

-- | The memory expansion pass definition.
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"expand allocations" String
"Expand allocations" forall a b. (a -> b) -> a -> b
$
    \Prog GPUMem
prog -> do
      Stms GPUMem
consts' <-
        forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
          forall a. Either String a -> a
limitationOnLeft
            forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms (forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPUMem
prog)) forall a. Monoid a => a
mempty)
      [FunDef GPUMem]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Prog GPUMem
prog {progConsts :: Stms GPUMem
progConsts = Stms GPUMem
consts', progFuns :: [FunDef GPUMem]
progFuns = [FunDef GPUMem]
funs'}

-- Cannot use intraproceduralTransformation because it might create
-- duplicate size keys (which are not fixed by renamer, and size
-- keys must currently be globally unique).

type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))

limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either String a -> a
limitationOnLeft = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. String -> a
compilerLimitationS forall a. a -> a
id

transformFunDef ::
  Scope GPUMem ->
  FunDef GPUMem ->
  PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
  Body GPUMem
body' <- forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall a. Either String a -> a
limitationOnLeft forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m forall a. Monoid a => a
mempty)
  forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
    SimpleOps GPUMem
simpleGPUMem
    (forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (forall {k} (rep :: k). Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
    FunDef GPUMem
fundec {funDefBody :: Body GPUMem
funDefBody = Body GPUMem
body'}
  where
    m :: ReaderT
  (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m =
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec forall a b. (a -> b) -> a -> b
$
          Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPUMem
fundec

transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
  forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda [LParam GPUMem]
params
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
params) (Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body)
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms Stms GPUMem
stms =
  forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)

transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
-- It is possible that we are unable to expand allocations in some
-- code versions.  If so, we can remove the offending branch.  Only if
-- all versions fail do we propagate the error.
-- FIXME: this can remove safety checks if the default branch fails!
transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv))) = do
  let onCase :: Case (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
onCase (Case [Maybe PrimValue]
vs Body GPUMem
body) =
        (forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body) forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left)
  [Case (Body GPUMem)]
cases' <- forall a b. [Either a b] -> [b]
rights forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body GPUMem)
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Either String (Case (Body GPUMem)))
onCase [Case (Body GPUMem)]
cases
  Either String (Body GPUMem)
defbody' <- (forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
defbody) forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left)
  case ([Case (Body GPUMem)]
cases', Either String (Body GPUMem)
defbody') of
    ([], Left String
e) ->
      forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
    (Case (Body GPUMem)
_ : [Case (Body GPUMem)]
_, Left String
_) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond (forall a. [a] -> [a]
init [Case (Body GPUMem)]
cases') (forall body. Case body -> body
caseBody forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [Case (Body GPUMem)]
cases') (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv)
    ([Case (Body GPUMem)]
_, Right Body GPUMem
defbody'') ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases' Body GPUMem
defbody'' (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e) = do
  (Stms GPUMem
stms, Exp GPUMem
e') <- Exp GPUMem
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, Exp GPUMem)
transformExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform Exp GPUMem
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPUMem
stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e')
  where
    transform :: Mapper
  GPUMem
  GPUMem
  (ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
mapOnBody = \Scope GPUMem
scope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
     (Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody
        }

transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: Exp GPUMem
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, Exp GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl' SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
  let reds' :: [SegBinOp GPUMem]
reds' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl' SegSpace
space [SegBinOp GPUMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
    SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
  let scans' :: [SegBinOp GPUMem]
scans' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl' SegSpace
space [SegBinOp GPUMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
  (Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
  let ops' :: [HistOp GPUMem]
ops' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} {k} {rep :: k} {rep :: k}.
HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Stms GPUMem
alloc_stms,
      forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl' SegSpace
space [HistOp GPUMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
    )
  where
    lams :: [Lambda GPUMem]
lams = forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
    onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp :: Lambda rep
histOp = Lambda rep
lam}
transformExp (WithAcc [WithAccInput GPUMem]
inputs Lambda GPUMem
lam) = do
  Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
  ([Stms GPUMem]
input_alloc_stms, [WithAccInput GPUMem]
inputs') <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {b} {b}.
(ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput [WithAccInput GPUMem]
inputs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
      forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPUMem]
inputs' Lambda GPUMem
lam'
    )
  where
    onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, forall a. Maybe a
Nothing))
    onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
      Names
bound_outside <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys
      let -- XXX: fake a SegLevel, which we don't have here.  We will not
          -- use it for anything, as we will not allow irregular
          -- allocations inside the update function.
          lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt forall a. Maybe a
Nothing
          (Lambda GPUMem
op_lam', Extraction
lam_allocs) =
            (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
          variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
          variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
          (Extraction
variant_allocs, Extraction
invariant_allocs) = forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs

      case forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
        ((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
          forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$
            String
"Cannot handle un-sliceable allocation size: "
              forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString SubExp
v
              forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside accumulator update operator."
        [] ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      let num_is :: Int
num_is = forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
          is :: [VName]
is = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
num_is forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
      (Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
        ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (forall a b. a -> b -> a
const (ShapeBase SubExp
shape, forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs

      Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
      let scope' :: Scope GPUMem
scope' = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
      forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets forall a b. (a -> b) -> a -> b
$ do
          Lambda GPUMem
op_lam'' <- Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op_lam'
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, (ShapeBase SubExp
shape, b
arrs, forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp Exp GPUMem
e =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, Exp GPUMem
e)

ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown :: SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown SegLevel
lvl =
  case SegLevel
lvl of
    SegThread SegVirt
_ (Just KernelGrid
grid) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, SegLevel
lvl, KernelGrid
grid)
    SegGroup SegVirt
_ (Just KernelGrid
grid) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, SegLevel
lvl, KernelGrid
grid)
    SegThread SegVirt
virt Maybe KernelGrid
Nothing -> forall {k1} {k} {somerep :: k1} {rep} {rep :: k} {op} {m :: * -> *}
       {b}.
(FParamInfo somerep ~ FParamInfo rep,
 Op rep ~ MemOp (HostOp rep op), LetDec somerep ~ LetDec rep,
 LParamInfo somerep ~ LParamInfo rep, HasScope somerep m,
 MonadFreshNames m, BuilderOps rep) =>
(Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid (SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
virt)
    SegGroup SegVirt
virt Maybe KernelGrid
Nothing -> forall {k1} {k} {somerep :: k1} {rep} {rep :: k} {op} {m :: * -> *}
       {b}.
(FParamInfo somerep ~ FParamInfo rep,
 Op rep ~ MemOp (HostOp rep op), LetDec somerep ~ LetDec rep,
 LParamInfo somerep ~ LParamInfo rep, HasScope somerep m,
 MonadFreshNames m, BuilderOps rep) =>
(Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid (SegVirt -> Maybe KernelGrid -> SegLevel
SegGroup SegVirt
virt)
    SegThreadInGroup {} -> forall a. HasCallStack => String -> a
error String
"ensureGridKnown: SegThreadInGroup"
  where
    mkGrid :: (Maybe KernelGrid -> b) -> m (Stms rep, b, KernelGrid)
mkGrid Maybe KernelGrid -> b
f = do
      (KernelGrid
grid, Stms rep
stms) <-
        forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
          Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (u :: k) e. e -> Count u e
Count forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {m :: * -> *} {rep :: k} {op}.
(Op (Rep m) ~ MemOp (HostOp rep op), MonadBuilder m) =>
String -> SizeClass -> m SubExp
getSize String
"num_groups" SizeClass
SizeNumGroups)
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall {k} (u :: k) e. e -> Count u e
Count forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {m :: * -> *} {rep :: k} {op}.
(Op (Rep m) ~ MemOp (HostOp rep op), MonadBuilder m) =>
String -> SizeClass -> m SubExp
getSize String
"group_size" SizeClass
SizeGroup)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms rep
stms, Maybe KernelGrid -> b
f forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just KernelGrid
grid, KernelGrid
grid)

    getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
      Name
size_key <- String -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> String
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

transformScanRed ::
  SegLevel ->
  SegSpace ->
  [Lambda GPUMem] ->
  KernelBody GPUMem ->
  ExpandM (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM
     (Stms GPUMem, (SegLevel, [Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
  Names
bound_outside <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys
  let user :: (SegLevel, [TPrimExp Int64 VName])
user = (SegLevel
lvl, [forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
      (KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
        (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
      ([Lambda GPUMem]
ops', [Extraction]
ops_allocs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ((SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
      variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
      variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
      (Extraction
variant_allocs, Extraction
invariant_allocs) =
        forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
      badVariant :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_in_kernel
      badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False

  case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
    Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
      forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$
        String
"Cannot handle un-sliceable allocation size: "
          forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
          forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside parallel constructs."
    Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  case SegLevel
lvl of
    SegGroup {}
      | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
          forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegGroup."
    SegLevel
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  if forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
invariant_allocs
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, (SegLevel
lvl, [Lambda GPUMem]
ops, KernelBody GPUMem
kbody))
    else do
      (Stms GPUMem
lvl_stms, SegLevel
lvl', KernelGrid
grid) <- SegLevel -> ExpandM (Stms GPUMem, SegLevel, KernelGrid)
ensureGridKnown SegLevel
lvl
      forall b.
Extraction
-> Extraction
-> KernelGrid
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs KernelGrid
grid SegSpace
space KernelBody GPUMem
kbody' forall a b. (a -> b) -> a -> b
$ \Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
        [Lambda GPUMem]
ops'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
lvl_stms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
alloc_stms, (SegLevel
lvl', [Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
  where
    bound_in_kernel :: Names
bound_in_kernel =
      [VName] -> Names
namesFromList (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
        forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody

boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms

allocsForBody ::
  Extraction ->
  Extraction ->
  KernelGrid ->
  SegSpace ->
  KernelBody GPUMem ->
  (Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
  ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> KernelGrid
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs KernelGrid
grid SegSpace
space KernelBody GPUMem
kbody' Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
  (RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
    KernelGrid
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
      KernelGrid
grid
      SegSpace
space
      (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody')
      Extraction
variant_allocs
      Extraction
invariant_allocs

  Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  let scope' :: Scope GPUMem
scope' = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets forall a b. (a -> b) -> a -> b
$ do
      KernelBody GPUMem
kbody'' <- KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody'
      Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m Stms GPUMem
alloc_stms KernelBody GPUMem
kbody''

memoryRequirements ::
  KernelGrid ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  Extraction ->
  ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: KernelGrid
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements KernelGrid
grid SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
  (SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
    forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp
        (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
        (forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count NumGroups SubExp
gridNumGroups KernelGrid
grid)
        (forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ KernelGrid -> Count GroupSize SubExp
gridGroupSize KernelGrid
grid)

  (Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
    forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms forall a b. (a -> b) -> a -> b
$
      SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
        SubExp
num_threads
        (KernelGrid -> Count NumGroups SubExp
gridNumGroups KernelGrid
grid)
        (KernelGrid -> Count GroupSize SubExp
gridGroupSize KernelGrid
grid)
        Extraction
invariant_allocs

  (Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
    forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms forall a b. (a -> b) -> a -> b
$
      SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
        SubExp
num_threads
        SegSpace
space
        Stms GPUMem
kstms
        Extraction
variant_allocs

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( RebaseMap
invariant_alloc_offsets forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
      Stms GPUMem
num_threads_stms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
    )

-- | Identifying the spot where an allocation occurs in terms of its
-- level and unique thread ID.
type User = (SegLevel, [TPrimExp Int64 VName])

-- | A description of allocations that have been extracted, and how
-- much memory (and which space) is needed.
type Extraction = M.Map VName (User, SubExp, Space)

extractKernelBodyAllocations ::
  User ->
  Names ->
  Names ->
  KernelBody GPUMem ->
  ( KernelBody GPUMem,
    Extraction
  )
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
  forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}

extractBodyAllocations ::
  User ->
  Names ->
  Names ->
  Body GPUMem ->
  (Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
  forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$
    \Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
stms}

extractLambdaAllocations ::
  User ->
  Names ->
  Names ->
  Lambda GPUMem ->
  (Lambda GPUMem, Extraction)
extractLambdaAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam = (Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body'}, Extraction
allocs)
  where
    (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam

extractGenericBodyAllocations ::
  User ->
  Names ->
  Names ->
  (body -> Stms GPUMem) ->
  (Stms GPUMem -> body -> body) ->
  body ->
  ( body,
    Extraction
  )
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
  let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
      ([Stm GPUMem]
stms, Extraction
allocs) =
        forall w a. Writer w a -> (a, w)
runWriter forall a b. (a -> b) -> a -> b
$
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
                body -> Stms GPUMem
get_stms body
body
   in (Stms GPUMem -> body -> body
set_stms (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)

expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space String
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True

extractStmAllocations ::
  User ->
  Names ->
  Names ->
  Stm GPUMem ->
  Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pat [PatElem (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
  | Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
      -- FIXME: the '&& notScalar space' part is a hack because we
      -- don't otherwise hoist the sizes out far enough, and we
      -- promise to be super-duper-careful about not having variant
      -- scalar allocations.
      Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  where
    expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    expandableSize Constant {} = Bool
True
    boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
    boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
  Exp GPUMem
e <- forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPUMem
stm
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = Exp GPUMem
e}
  where
    expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user',
          mapOnOp :: Op GPUMem -> WriterT Extraction Identity (Op GPUMem)
mapOnOp = (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel, [TPrimExp Int64 VName])
user'
        }

    onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
      let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPUMem
body'

    onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      forall inner. inner -> MemOp inner
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
      where
        user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
          (forall {k} lvl (rep :: k). SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids forall a. [a] -> [a] -> [a]
++ [forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
    onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp GPUMem ())
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp GPUMem ())
op

    opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
      forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
mapOnSegOpLambda = (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user',
          mapOnSegOpBody :: KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
mapOnSegOpBody = (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user'
        }

    onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
      let (KernelBody GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
      forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
      forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
body'

    onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
      Body GPUMem
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

genericExpandedInvariantAllocations ::
  (User -> (Shape, [TPrimExp Int64 VName])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the number of kernel threads.
  ([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
      let num_users :: ShapeBase SubExp
num_users = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      VName
total_size <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"total_size" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$
          SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
mem forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase (SegLevel, [TPrimExp Int64 VName])
user

    untouched :: d -> DimIndex d
untouched d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1

    newBaseThread :: (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          num_dims :: Int
num_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
          perm :: [Int]
perm = [Int
num_dims .. Int
num_dims forall a. Num a => a -> a -> a
+ forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
users_shape forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_dims forall a. Num a => a -> a -> a
- Int
1]
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName]
old_shape forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape))
          permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun forall a b. (a -> b) -> a -> b
$
              forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun

    newBase :: (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThreadInGroup {}, [TPrimExp Int64 VName]
_) = (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user
    newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) = (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBaseThread (SegLevel, [TPrimExp Int64 VName])
user
    newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegGroup {}, [TPrimExp Int64 VName]
_) = \([TPrimExp Int64 VName]
old_shape, PrimType
_) ->
      let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape) forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
old_shape
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
              forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
       in IxFun (TPrimExp Int64 VName)
offset_ixfun

expandedInvariantAllocations ::
  SubExp ->
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_groups) (Count SubExp
group_size) =
  ((SegLevel, [TPrimExp Int64 VName])
 -> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
  where
    getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegThreadInGroup {}, [TPrimExp Int64 VName
gtid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
    getNumUsers (SegThreadInGroup {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
    getNumUsers (SegGroup {}, [TPrimExp Int64 VName
gid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups], [TPrimExp Int64 VName
gid])
    getNumUsers (SegLevel, [TPrimExp Int64 VName])
user = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getNumUsers: unhandled " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (SegLevel, [TPrimExp Int64 VName])
user

expandedVariantAllocations ::
  SubExp ->
  SegSpace ->
  Stms GPUMem ->
  Extraction ->
  ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
  let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
      variant_sizes :: [SubExp]
variant_sizes = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks

  (Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
    SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
variant_sizes SegSpace
kspace Stms GPUMem
kstms
  -- Note the recursive call to expand allocations inside the newly
  -- produced kernels.
  Stms GPUMem
slice_stms_tmp <- forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (Stms GPUMem)
simplifyStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
  Stms GPUMem
slice_stms' <- Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp

  let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
      variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
        forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
          forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
            (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
            (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
      memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
        [(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]

  -- We expand the invariant allocations by adding an inner dimension
  -- equal to the sum of the sizes required by different threads.
  ([Stm GPUMem]
alloc_stms, [RebaseMap]
rebases) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
slice_stms' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_stms, forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
  where
    expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
     (Scope GPUMem)
     (StateT VNameSource (Either String))
     (Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
      let allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
          forall k a. k -> a -> Map k a
M.singleton VName
mem forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
        )

    num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
    gtid :: TPrimExp Int64 VName
gtid = forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace

    -- For the variant allocations, we add an inner dimension,
    -- which is then offset by a thread-specific amount.
    newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
      let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
            SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread forall e. IntegralExp e => e -> e -> e
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
pt
          root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
          offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
            forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
              [forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1, forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]
       in if forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape forall a. Eq a => a -> a -> Bool
== Int
1
            then forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.coerce IxFun (TPrimExp Int64 VName)
offset_ixfun [TPrimExp Int64 VName]
old_shape
            else forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [TPrimExp Int64 VName]
old_shape

-- | A map from memory block names to new index function bases.
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)

newtype OffsetM a
  = OffsetM
      ( ReaderT
          (Scope GPUMem)
          (ReaderT RebaseMap (Either String))
          a
      )
  deriving
    ( Functor OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: forall a. a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
Applicative,
      forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
      Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
Monad,
      HasScope GPUMem,
      LocalScope GPUMem,
      MonadError String
    )

runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) =
  forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope) RebaseMap
offsets

askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall r (m :: * -> *). MonadReader r m => m r
ask

localRebaseMap :: (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap :: forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
f (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) = forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM forall a b. (a -> b) -> a -> b
$ do
  Scope GPUMem
scope <- forall r (m :: * -> *). MonadReader r m => m r
ask
  forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local RebaseMap -> RebaseMap
f forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope

lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
  RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ (forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets

offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody = do
  Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms'}

offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) = do
  Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  Stms GPUMem
stms' <-
    forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
        (\Scope GPUMem
scope' -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
        Scope GPUMem
scope
        (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPUMem
dec Stms GPUMem
stms' Result
res

offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) = do
  Exp GPUMem
e' <- Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp Exp GPUMem
e
  Pat (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
offsetMemoryInPat Pat (LetDec GPUMem)
pat forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp GPUMem
e'
  Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
  -- Try to recompute the index function.  Fall back to creating rebase
  -- operations with the RebaseMap.
  [ExpReturns]
rts <- forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp GPUMem
e') Scope GPUMem
scope
  let pat'' :: Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
pick (forall dec. Pat dec -> [PatElem dec]
patElems Pat (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts
      stm :: Stm GPUMem
stm = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec GPUMem)
dec Exp GPUMem
e'
  let scope' :: Scope GPUMem
scope' = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stm GPUMem
stm forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scope GPUMem
scope', Stm GPUMem
stm)
  where
    pick ::
      PatElem (MemInfo SubExp NoUniqueness MemBind) ->
      ExpReturns ->
      PatElem (MemInfo SubExp NoUniqueness MemBind)
    pick :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
pick
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
        | Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
            forall dec. VName -> dec -> PatElem dec
PatElem VName
name (forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
    pick PatElem (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElem (MemInfo SubExp NoUniqueness MemBind)
p

    instantiateIxFun :: ExtIxFun -> Maybe IxFun
    instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {a}. Ext a -> Maybe a
inst)
      where
        inst :: Ext a -> Maybe a
inst Ext {} = forall a. Maybe a
Nothing
        inst (Free a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

offsetMemoryInPat :: Pat LetDecMem -> [ExpReturns] -> OffsetM (Pat LetDecMem)
offsetMemoryInPat :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
offsetMemoryInPat (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [ExpReturns]
rets = do
  forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
onPE [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [ExpReturns]
rets
  where
    onPE :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
onPE
      (PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
      (MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
ixfun))) =
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
unExt) ExtIxFun
ixfun
    onPE PatElem (MemInfo SubExp NoUniqueness MemBind)
pe ExpReturns
_ = do
      MemInfo SubExp NoUniqueness MemBind
new_dec <- forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> dec
patElemDec PatElem (MemInfo SubExp NoUniqueness MemBind)
pe
      forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem (MemInfo SubExp NoUniqueness MemBind)
pe {patElemDec :: MemInfo SubExp NoUniqueness MemBind
patElemDec = MemInfo SubExp NoUniqueness MemBind
new_dec}
    unExt :: Ext VName -> VName
unExt (Ext Int
i) = forall dec. PatElem dec -> VName
patElemName ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes forall a. [a] -> Int -> a
!! Int
i)
    unExt (Free VName
v) = VName
v

offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
  MemBound u
fparam' <- forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}

offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
  Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary forall a b. (a -> b) -> a -> b
$ do
    IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemBound u
summary

offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
  | Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
      Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br forall a b. (a -> b) -> a -> b
$ do
        IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall a b. (a -> b) -> a -> b
$
          forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
br

offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
lam = forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam forall a b. (a -> b) -> a -> b
$ do
  Body GPUMem
body <- Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}

-- A loop may have memory parameters, and those memory blocks may
-- be expanded.  We assume (but do not check - FIXME) that if the
-- initial value of a loop parameter is an expanded memory block,
-- then so will the result be.
offsetMemoryInLoopParams ::
  [(FParam GPUMem, SubExp)] ->
  ([(FParam GPUMem, SubExp)] -> OffsetM a) ->
  OffsetM a
offsetMemoryInLoopParams :: forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge [(FParam GPUMem, SubExp)] -> OffsetM a
f = do
  let ([Param (MemInfo SubExp Uniqueness MemBind)]
params, [SubExp]
args) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
merge
  forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
extend forall a b. (a -> b) -> a -> b
$ do
    [Param (MemInfo SubExp Uniqueness MemBind)]
params' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
params
    [(FParam GPUMem, SubExp)] -> OffsetM a
f forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
params' [SubExp]
args
  where
    extend :: RebaseMap -> RebaseMap
extend RebaseMap
rm = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {a} {dec}. Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg RebaseMap
rm [(FParam GPUMem, SubExp)]
merge
    onParamArg :: Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg Map VName a
rm (Param dec
param, Var VName
arg)
      | Just a
x <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arg Map VName a
rm =
          forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. Param dec -> VName
paramName Param dec
param) a
x Map VName a
rm
    onParamArg Map VName a
rm (Param dec, SubExp)
_ = Map VName a
rm

offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp (DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form Body GPUMem
body) = do
  forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \[(FParam GPUMem, SubExp)]
merge' -> do
    Body GPUMem
body' <-
      forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
        (forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam GPUMem, SubExp)]
merge') forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPUMem
form)
        (Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody Body GPUMem
body)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam GPUMem, SubExp)]
merge' LoopForm GPUMem
form Body GPUMem
body'
offsetMemoryInExp Exp GPUMem
e = forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse Exp GPUMem
e
  where
    recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
      forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
        { mapOnBody :: Scope GPUMem -> Body GPUMem -> OffsetM (Body GPUMem)
mapOnBody = \Scope GPUMem
bscope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody,
          mapOnBranchType :: BranchType GPUMem -> OffsetM (BranchType GPUMem)
mapOnBranchType = BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
          mapOnOp :: Op GPUMem -> OffsetM (Op GPUMem)
mapOnOp = forall {op}.
MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp
        }
    onOp :: MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
      forall inner. inner -> MemOp inner
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
      where
        segOpMapper :: SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper =
          forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody,
              mapOnSegOpLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda
            }
    onOp MemOp (HostOp GPUMem op)
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp GPUMem op)
op

---- Slicing allocation sizes out of a kernel.

unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
False
  where
    unAllocBody :: Body GPUMem -> Either String (Body GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
      forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPUMem
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

    unAllocKernelBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
      forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPUMem
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res

    unAllocStms :: Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
nested =
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList

    unAllocStm :: Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op Alloc {}))
      | Bool
nested = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Stm GPUMem
stm
      | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    unAllocStm Bool
_ (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) =
      forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
pat forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux (ExpDec GPUMem)
dec forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either String)
unAlloc' Exp GPUMem
e)

    unAllocLambda :: Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
      forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda (forall a b. (a -> b) -> [a] -> [b]
map forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam [LParam GPUMem]
params) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either String (Body GPU)
unAllocBody Body GPUMem
body forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret

    unAllocPat :: Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat (Pat [PatElem (MemInfo d u ret)]
pes) =
      forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem (forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem)) [PatElem (MemInfo d u ret)]
pes

    unAllocOp :: MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp Alloc {} = forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
    unAllocOp (Inner OtherOp {}) = forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
    unAllocOp (Inner GPUBody {}) = forall a b. a -> Either a b
Left String
"unAllocOp: unhandled GPUBody"
    unAllocOp (Inner (SizeOp SizeOp
op)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
    unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either String)
mapper SegOp SegLevel GPUMem
op
      where
        mapper :: SegOpMapper SegLevel GPUMem GPU (Either String)
mapper =
          forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
            { mapOnSegOpLambda :: Lambda GPUMem -> Either String (Lambda GPU)
mapOnSegOpLambda = Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda,
              mapOnSegOpBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
mapOnSegOpBody = KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody
            }

    unParam :: Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem

    unT :: MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem

    unAlloc' :: Mapper GPUMem GPU (Either String)
unAlloc' =
      Mapper
        { mapOnBody :: Scope GPU -> Body GPUMem -> Either String (Body GPU)
mapOnBody = forall a b. a -> b -> a
const Body GPUMem -> Either String (Body GPU)
unAllocBody,
          mapOnRetType :: RetType GPUMem -> Either String (RetType GPU)
mapOnRetType = forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
          mapOnBranchType :: BranchType GPUMem -> Either String (BranchType GPU)
mapOnBranchType = forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
          mapOnFParam :: FParam GPUMem -> Either String (FParam GPU)
mapOnFParam = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
          mapOnLParam :: LParam GPUMem -> Either String (LParam GPU)
mapOnLParam = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
          mapOnOp :: Op GPUMem -> Either String (Op GPU)
mapOnOp = MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp,
          mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = forall a b. b -> Either a b
Right,
          mapOnVName :: VName -> Either String VName
mapOnVName = forall a b. b -> Either a b
Right
        }

unMem :: MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem :: forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem (MemPrim PrimType
pt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit

unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall {k} {k} {rep :: k} {d} {u} {rep :: k} {ret} {d} {u} {ret}
       {d} {u} {ret}.
(LetDec rep ~ TypeBase (ShapeBase d) u,
 LetDec rep ~ MemInfo d u ret,
 FParamInfo rep ~ TypeBase (ShapeBase d) u,
 FParamInfo rep ~ MemInfo d u ret,
 LParamInfo rep ~ TypeBase (ShapeBase d) u,
 LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> NameInfo rep
unInfo
  where
    unInfo :: NameInfo rep -> NameInfo rep
unInfo (LetName LetDec rep
dec) = forall {k} (rep :: k). LetDec rep -> NameInfo rep
LetName forall a b. (a -> b) -> a -> b
$ forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LetDec rep
dec
    unInfo (FParamName FParamInfo rep
dec) = forall {k} (rep :: k). FParamInfo rep -> NameInfo rep
FParamName forall a b. (a -> b) -> a -> b
$ forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem FParamInfo rep
dec
    unInfo (LParamName LParamInfo rep
dec) = forall {k} (rep :: k). LParamInfo rep -> NameInfo rep
LParamName forall a b. (a -> b) -> a -> b
$ forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LParamInfo rep
dec
    unInfo (IndexName IntType
it) = forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
it

removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = forall k a. Map k a -> [(k, a)]
M.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList
  where
    comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m

sliceKernelSizes ::
  SubExp ->
  [SubExp] ->
  SegSpace ->
  Stms GPUMem ->
  ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
sizes SegSpace
space Stms GPUMem
kstms = do
  Stms GPU
kstms' <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
  let num_sizes :: Int
num_sizes = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
sizes
      i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = forall a. Int -> a -> [a]
replicate Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Scope GPU
kernels_scope <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope

  (Lambda GPU
max_lam, Stms GPU
_) <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"y" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"z" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s

  Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flat_gtid" (forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))

  (Lambda GPU
size_lam', Stms GPU
_) <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope forall a b. (a -> b) -> a -> b
$ do
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
    (Result
zs, Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
      (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam])
      forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
      forall a b. (a -> b) -> a -> b
$ do
        -- Even though this SegRed is one-dimensional, we need to
        -- provide indexes corresponding to the original potentially
        -- multi-dimensional construct.
        let ([VName]
kspace_gtids, [SubExp]
kspace_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
            new_inds :: [TPrimExp Int64 VName]
new_inds =
              forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
                (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
kspace_dims)
                (SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall a b. (a -> b) -> [a] -> [b]
map forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stms GPU
kstms'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
sizes

    forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda (forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam] (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s)

  (([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope forall a b. (a -> b) -> a -> b
$ do
    Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
      [Ident] -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent String
"max_per_thread" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)

    SubExp
w <-
      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"size_slice_w"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

    VName
thread_space_iota <-
      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"thread_space_iota" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
    let red_op :: SegBinOp GPU
red_op =
          forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
            Commutativity
Commutative
            Lambda GPU
max_lam
            (forall a. Int -> a -> [a]
replicate Int
num_sizes forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
            forall a. Monoid a => a
mempty
    SegLevel
lvl <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> m SegLevel
segThread String
"segred"

    forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegLevel
lvl Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]

    [VName]
size_sums <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"size_sum" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads

    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)