{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExpandAllocations (expandAllocations) where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Data.Either (rights)
import Data.List (find, foldl')
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Rephrase
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Error
import Futhark.IR
import Futhark.IR.GPU.Simplify qualified as GPU
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify.Rep (addScopeWisdom)
import Futhark.Pass
import Futhark.Pass.ExplicitAllocations.GPU (explicitAllocationsInStms)
import Futhark.Pass.ExtractKernels.BlockedKernel (nonSegRed)
import Futhark.Pass.ExtractKernels.ToGPU (segThread)
import Futhark.Tools
import Futhark.Transform.CopyPropagate (copyPropagateInFun)
import Futhark.Transform.Rename (renameStm)
import Futhark.Util (mapAccumLM)
import Futhark.Util.IntegralExp
import Prelude hiding (quot)
expandAllocations :: Pass GPUMem GPUMem
expandAllocations :: Pass GPUMem GPUMem
expandAllocations =
forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"expand allocations" String
"Expand allocations" forall a b. (a -> b) -> a -> b
$
\Prog GPUMem
prog -> do
Stms GPUMem
consts' <-
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$
forall a. Either String a -> a
limitationOnLeft
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms (forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPUMem
prog)) forall a. Monoid a => a
mempty)
[FunDef GPUMem]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPUMem
consts') (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPUMem
prog)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Prog GPUMem
prog {progConsts :: Stms GPUMem
progConsts = Stms GPUMem
consts', progFuns :: [FunDef GPUMem]
progFuns = [FunDef GPUMem]
funs'}
type ExpandM = ReaderT (Scope GPUMem) (StateT VNameSource (Either String))
limitationOnLeft :: Either String a -> a
limitationOnLeft :: forall a. Either String a -> a
limitationOnLeft = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. String -> a
compilerLimitationS forall a. a -> a
id
transformFunDef ::
Scope GPUMem ->
FunDef GPUMem ->
PassM (FunDef GPUMem)
transformFunDef :: Scope GPUMem -> FunDef GPUMem -> PassM (FunDef GPUMem)
transformFunDef Scope GPUMem
scope FunDef GPUMem
fundec = do
Body GPUMem
body' <- forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ forall a. Either String a -> a
limitationOnLeft forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m forall a. Monoid a => a
mempty)
forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, SimplifiableRep rep) =>
SimpleOps rep
-> SymbolTable (Wise rep) -> FunDef rep -> m (FunDef rep)
copyPropagateInFun
SimpleOps GPUMem
simpleGPUMem
(forall {k} (rep :: k). ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope (forall {k} (rep :: k). Scope rep -> Scope (Wise rep)
addScopeWisdom Scope GPUMem
scope))
FunDef GPUMem
fundec {funDefBody :: Body GPUMem
funDefBody = Body GPUMem
body'}
where
m :: ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
m =
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf FunDef GPUMem
fundec forall a b. (a -> b) -> a -> b
$
Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPUMem
fundec
transformBody :: Body GPUMem -> ExpandM (Body GPUMem)
transformBody :: Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody (Body () Stms GPUMem
stms Result
res) = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda :: Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda [LParam GPUMem]
params
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [LParam GPUMem]
params) (Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret
transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms :: Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms Stms GPUMem
stms =
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
stms forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm (forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm :: Stm GPUMem -> ExpandM (Stms GPUMem)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux (Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbody (MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv))) = do
let onCase :: Case (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
onCase (Case [Maybe PrimValue]
vs Body GPUMem
body) =
(forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
body) forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left)
[Case (Body GPUMem)]
cases' <- forall a b. [Either a b] -> [b]
rights forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body GPUMem)
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Either String (Case (Body GPUMem)))
onCase [Case (Body GPUMem)]
cases
Either String (Body GPUMem)
defbody' <- (forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody Body GPUMem
defbody) forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`catchError` (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left)
case ([Case (Body GPUMem)]
cases', Either String (Body GPUMem)
defbody') of
([], Left String
e) ->
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
e
(Case (Body GPUMem)
_ : [Case (Body GPUMem)]
_, Left String
_) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond (forall a. [a] -> [a]
init [Case (Body GPUMem)]
cases') (forall body. Case body -> body
caseBody forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
last [Case (Body GPUMem)]
cases') (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv)
([Case (Body GPUMem)]
_, Right Body GPUMem
defbody'') ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases' Body GPUMem
defbody'' (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [BranchType GPUMem]
ts MatchSort
MatchEquiv)
transformStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e) = do
(Stms GPUMem
stms, Exp GPUMem
e') <- Exp GPUMem
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stms GPUMem, Exp GPUMem)
transformExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform Exp GPUMem
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPUMem
stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
aux Exp GPUMem
e')
where
transform :: Mapper
GPUMem
GPUMem
(ReaderT (Scope GPUMem) (StateT VNameSource (Either String)))
transform =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope GPUMem
-> Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
mapOnBody = \Scope GPUMem
scope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem
-> ReaderT
(Scope GPUMem) (StateT VNameSource (Either String)) (Body GPUMem)
transformBody
}
transformExp :: Exp GPUMem -> ExpandM (Stms GPUMem, Exp GPUMem)
transformExp :: Exp GPUMem
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stms GPUMem, Exp GPUMem)
transformExp (Op (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
_, KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [] KernelBody GPUMem
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Stms GPUMem
alloc_stms,
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
transformExp (Op (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
reds) KernelBody GPUMem
kbody
let reds' :: [SegBinOp GPUMem]
reds' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
reds [Lambda GPUMem]
lams
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Stms GPUMem
alloc_stms,
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
transformExp (Op (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams, KernelBody GPUMem
kbody')) <-
SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda [SegBinOp GPUMem]
scans) KernelBody GPUMem
kbody
let scans' :: [SegBinOp GPUMem]
scans' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\SegBinOp GPUMem
red Lambda GPUMem
lam -> SegBinOp GPUMem
red {segBinOpLambda :: Lambda GPUMem
segBinOpLambda = Lambda GPUMem
lam}) [SegBinOp GPUMem]
scans [Lambda GPUMem]
lams
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Stms GPUMem
alloc_stms,
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
transformExp (Op (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody)))) = do
(Stms GPUMem
alloc_stms, ([Lambda GPUMem]
lams', KernelBody GPUMem
kbody')) <- SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
lams KernelBody GPUMem
kbody
let ops' :: [HistOp GPUMem]
ops' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} {k} {rep :: k} {rep :: k}.
HistOp rep -> Lambda rep -> HistOp rep
onOp [HistOp GPUMem]
ops [Lambda GPUMem]
lams'
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Stms GPUMem
alloc_stms,
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. inner -> MemOp inner
Inner forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops' [TypeBase (ShapeBase SubExp) NoUniqueness]
ts KernelBody GPUMem
kbody'
)
where
lams :: [Lambda GPUMem]
lams = forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp [HistOp GPUMem]
ops
onOp :: HistOp rep -> Lambda rep -> HistOp rep
onOp HistOp rep
op Lambda rep
lam = HistOp rep
op {histOp :: Lambda rep
histOp = Lambda rep
lam}
transformExp (WithAcc [WithAccInput GPUMem]
inputs Lambda GPUMem
lam) = do
Lambda GPUMem
lam' <- Lambda GPUMem -> ExpandM (Lambda GPUMem)
transformLambda Lambda GPUMem
lam
([Stms GPUMem]
input_alloc_stms, [WithAccInput GPUMem]
inputs') <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {b} {b}.
(ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput [WithAccInput GPUMem]
inputs
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall a. Monoid a => [a] -> a
mconcat [Stms GPUMem]
input_alloc_stms,
forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPUMem]
inputs' Lambda GPUMem
lam'
)
where
onInput :: (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stms GPUMem, (ShapeBase SubExp, b, Maybe (Lambda GPUMem, b)))
onInput (ShapeBase SubExp
shape, b
arrs, Maybe (Lambda GPUMem, b)
Nothing) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, (ShapeBase SubExp
shape, b
arrs, forall a. Maybe a
Nothing))
onInput (ShapeBase SubExp
shape, b
arrs, Just (Lambda GPUMem
op_lam, b
nes)) = do
Names
bound_outside <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys
let
lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (forall {k} (u :: k) e. e -> Count u e
Count forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SegVirt
SegNoVirt
(Lambda GPUMem
op_lam', Extraction
lam_allocs) =
(SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel
lvl, [TPrimExp Int64 VName
0]) Names
bound_outside forall a. Monoid a => a
mempty Lambda GPUMem
op_lam
variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
(Extraction
variant_allocs, Extraction
invariant_allocs) = forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc Extraction
lam_allocs
case forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
((SegLevel, [TPrimExp Int64 VName])
_, SubExp
v, Space
_) : [((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)]
_ ->
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$
String
"Cannot handle un-sliceable allocation size: "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString SubExp
v
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside accumulator update operator."
[] ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
let num_is :: Int
num_is = forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shape
is :: [VName]
is = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
num_is forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op_lam
(Stms GPUMem
alloc_stms, RebaseMap
alloc_offsets) <-
((SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (forall a b. a -> b -> a
const (ShapeBase SubExp
shape, forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
le64 [VName]
is)) Extraction
invariant_allocs
Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
let scope' :: Scope GPUMem
scope' = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op_lam forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets forall a b. (a -> b) -> a -> b
$ do
Lambda GPUMem
op_lam'' <- Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op_lam'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, (ShapeBase SubExp
shape, b
arrs, forall a. a -> Maybe a
Just (Lambda GPUMem
op_lam'', b
nes)))
transformExp Exp GPUMem
e =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, Exp GPUMem
e)
transformScanRed ::
SegLevel ->
SegSpace ->
[Lambda GPUMem] ->
KernelBody GPUMem ->
ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed :: SegLevel
-> SegSpace
-> [Lambda GPUMem]
-> KernelBody GPUMem
-> ExpandM (Stms GPUMem, ([Lambda GPUMem], KernelBody GPUMem))
transformScanRed SegLevel
lvl SegSpace
space [Lambda GPUMem]
ops KernelBody GPUMem
kbody = do
Names
bound_outside <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys
let user :: (SegLevel, [TPrimExp Int64 VName])
user = (SegLevel
lvl, [forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space])
(KernelBody GPUMem
kbody', Extraction
kbody_allocs) =
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_in_kernel KernelBody GPUMem
kbody
([Lambda GPUMem]
ops', [Extraction]
ops_allocs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ((SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Lambda GPUMem -> (Lambda GPUMem, Extraction)
extractLambdaAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside forall a. Monoid a => a
mempty) [Lambda GPUMem]
ops
variantAlloc :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_outside
variantAlloc ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
(Extraction
variant_allocs, Extraction
invariant_allocs) =
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
variantAlloc forall a b. (a -> b) -> a -> b
$ Extraction
kbody_allocs forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Extraction]
ops_allocs
badVariant :: ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant ((SegLevel, [TPrimExp Int64 VName])
_, Var VName
v, Space
_) = VName
v VName -> Names -> Bool
`notNameIn` Names
bound_in_kernel
badVariant ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
_ = Bool
False
case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space) -> Bool
badVariant forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Extraction
variant_allocs of
Just ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v ->
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$
String
"Cannot handle un-sliceable allocation size: "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
v
forall a. [a] -> [a] -> [a]
++ String
"\nLikely cause: irregular nested operations inside parallel constructs."
Maybe ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space)
Nothing ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
case SegLevel
lvl of
SegGroup {}
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs ->
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError String
"Cannot handle invariant allocations in SegGroup."
SegLevel
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' forall a b. (a -> b) -> a -> b
$ \Stms GPUMem
alloc_stms KernelBody GPUMem
kbody'' -> do
[Lambda GPUMem]
ops'' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Lambda GPUMem]
ops' forall a b. (a -> b) -> a -> b
$ \Lambda GPUMem
op' ->
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda GPUMem
op') forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
op'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, ([Lambda GPUMem]
ops'', KernelBody GPUMem
kbody''))
where
bound_in_kernel :: Names
bound_in_kernel =
[VName] -> Names
namesFromList (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space)
forall a. Semigroup a => a -> a -> a
<> KernelBody GPUMem -> Names
boundInKernelBody KernelBody GPUMem
kbody
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody :: KernelBody GPUMem -> Names
boundInKernelBody = [VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms
allocsForBody ::
Extraction ->
Extraction ->
SegLevel ->
SegSpace ->
KernelBody GPUMem ->
(Stms GPUMem -> KernelBody GPUMem -> OffsetM b) ->
ExpandM b
allocsForBody :: forall b.
Extraction
-> Extraction
-> SegLevel
-> SegSpace
-> KernelBody GPUMem
-> (Stms GPUMem -> KernelBody GPUMem -> OffsetM b)
-> ExpandM b
allocsForBody Extraction
variant_allocs Extraction
invariant_allocs SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody' Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m = do
(RebaseMap
alloc_offsets, Stms GPUMem
alloc_stms) <-
SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements
SegLevel
lvl
SegSpace
space
(forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody')
Extraction
variant_allocs
Extraction
invariant_allocs
Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
let scope' :: Scope GPUMem
scope' = forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope' RebaseMap
alloc_offsets forall a b. (a -> b) -> a -> b
$ do
KernelBody GPUMem
kbody'' <- KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody'
Stms GPUMem -> KernelBody GPUMem -> OffsetM b
m Stms GPUMem
alloc_stms KernelBody GPUMem
kbody''
memoryRequirements ::
SegLevel ->
SegSpace ->
Stms GPUMem ->
Extraction ->
Extraction ->
ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements :: SegLevel
-> SegSpace
-> Stms GPUMem
-> Extraction
-> Extraction
-> ExpandM (RebaseMap, Stms GPUMem)
memoryRequirements SegLevel
lvl SegSpace
space Stms GPUMem
kstms Extraction
variant_allocs Extraction
invariant_allocs = do
(SubExp
num_threads, Stms GPUMem
num_threads_stms) <-
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"num_threads" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
(forall {k} (u :: k) e. Count u e -> e
unCount forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
(Stms GPUMem
invariant_alloc_stms, RebaseMap
invariant_alloc_offsets) <-
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms forall a b. (a -> b) -> a -> b
$
SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations
SubExp
num_threads
(SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
(SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
Extraction
invariant_allocs
(Stms GPUMem
variant_alloc_stms, RebaseMap
variant_alloc_offsets) <-
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPUMem
num_threads_stms forall a b. (a -> b) -> a -> b
$
SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations
SubExp
num_threads
SegSpace
space
Stms GPUMem
kstms
Extraction
variant_allocs
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( RebaseMap
invariant_alloc_offsets forall a. Semigroup a => a -> a -> a
<> RebaseMap
variant_alloc_offsets,
Stms GPUMem
num_threads_stms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
invariant_alloc_stms forall a. Semigroup a => a -> a -> a
<> Stms GPUMem
variant_alloc_stms
)
type User = (SegLevel, [TPrimExp Int64 VName])
type = M.Map VName (User, SubExp, Space)
extractKernelBodyAllocations ::
User ->
Names ->
Names ->
KernelBody GPUMem ->
( KernelBody GPUMem,
Extraction
)
extractKernelBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel =
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
lvl Names
bound_outside Names
bound_kernel forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms forall a b. (a -> b) -> a -> b
$
\Stms GPUMem
stms KernelBody GPUMem
kbody -> KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms}
extractBodyAllocations ::
User ->
Names ->
Names ->
Body GPUMem ->
(Body GPUMem, Extraction)
extractBodyAllocations :: (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel =
forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$
\Stms GPUMem
stms Body GPUMem
body -> Body GPUMem
body {bodyStms :: Stms GPUMem
bodyStms = Stms GPUMem
stms}
extractLambdaAllocations ::
User ->
Names ->
Names ->
Lambda GPUMem ->
(Lambda GPUMem, Extraction)
(SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Lambda GPUMem
lam = (Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body'}, Extraction
allocs)
where
(Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
extractGenericBodyAllocations ::
User ->
Names ->
Names ->
(body -> Stms GPUMem) ->
(Stms GPUMem -> body -> body) ->
body ->
( body,
Extraction
)
extractGenericBodyAllocations :: forall body.
(SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> (body -> Stms GPUMem)
-> (Stms GPUMem -> body -> body)
-> body
-> (body, Extraction)
extractGenericBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel body -> Stms GPUMem
get_stms Stms GPUMem -> body -> body
set_stms body
body =
let bound_kernel' :: Names
bound_kernel' = Names
bound_kernel forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stms rep -> Names
boundByStms (body -> Stms GPUMem
get_stms body
body)
([Stm GPUMem]
stms, Extraction
allocs) =
forall w a. Writer w a -> (a, w)
runWriter forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> Stm GPUMem
-> Writer Extraction (Maybe (Stm GPUMem))
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel') forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$
body -> Stms GPUMem
get_stms body
body
in (Stms GPUMem -> body -> body
set_stms (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
stms) body
body, Extraction
allocs)
expandable, notScalar :: Space -> Bool
expandable :: Space -> Bool
expandable (Space String
"local") = Bool
False
expandable ScalarSpace {} = Bool
False
expandable Space
_ = Bool
True
notScalar :: Space -> Bool
notScalar ScalarSpace {} = Bool
False
notScalar Space
_ = Bool
True
extractStmAllocations ::
User ->
Names ->
Names ->
Stm GPUMem ->
Writer Extraction (Maybe (Stm GPUMem))
(SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel (Let (Pat [PatElem (LetDec GPUMem)
patElem]) StmAux (ExpDec GPUMem)
_ (Op (Alloc SubExp
size Space
space)))
| Space -> Bool
expandable Space
space Bool -> Bool -> Bool
&& SubExp -> Bool
expandableSize SubExp
size
Bool -> Bool -> Bool
|| (SubExp -> Bool
boundInKernel SubExp
size Bool -> Bool -> Bool
&& Space -> Bool
notScalar Space
space) = do
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
patElem) ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
size, Space
space)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
where
expandableSize :: SubExp -> Bool
expandableSize (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_outside Bool -> Bool -> Bool
|| VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
expandableSize Constant {} = Bool
True
boundInKernel :: SubExp -> Bool
boundInKernel (Var VName
v) = VName
v VName -> Names -> Bool
`nameIn` Names
bound_kernel
boundInKernel Constant {} = Bool
False
extractStmAllocations (SegLevel, [TPrimExp Int64 VName])
user Names
bound_outside Names
bound_kernel Stm GPUMem
stm = do
Exp GPUMem
e <- forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM ((SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPUMem
stm
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Stm GPUMem
stm {stmExp :: Exp GPUMem
stmExp = Exp GPUMem
e}
where
expMapper :: (SegLevel, [TPrimExp Int64 VName])
-> Mapper GPUMem GPUMem (WriterT Extraction Identity)
expMapper (SegLevel, [TPrimExp Int64 VName])
user' =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope GPUMem
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user',
mapOnOp :: Op GPUMem -> WriterT Extraction Identity (Op GPUMem)
mapOnOp = (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel, [TPrimExp Int64 VName])
user'
}
onBody :: (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' Body GPUMem
body = do
let (Body GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names -> Names -> Body GPUMem -> (Body GPUMem, Extraction)
extractBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel Body GPUMem
body
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPUMem
body'
onOp :: (SegLevel, [TPrimExp Int64 VName])
-> MemOp (HostOp GPUMem ())
-> WriterT Extraction Identity (MemOp (HostOp GPUMem ()))
onOp (SegLevel
_, [TPrimExp Int64 VName]
user_ids) (Inner (SegOp SegOp SegLevel GPUMem
op)) =
forall inner. inner -> MemOp inner
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM ((SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user'') SegOp SegLevel GPUMem
op
where
user'' :: (SegLevel, [TPrimExp Int64 VName])
user'' =
(forall {k} lvl (rep :: k). SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op, [TPrimExp Int64 VName]
user_ids forall a. [a] -> [a] -> [a]
++ [forall a. a -> TPrimExp Int64 a
le64 (SegSpace -> VName
segFlat (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op))])
onOp (SegLevel, [TPrimExp Int64 VName])
_ MemOp (HostOp GPUMem ())
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp GPUMem ())
op
opMapper :: (SegLevel, [TPrimExp Int64 VName])
-> SegOpMapper SegLevel GPUMem GPUMem (WriterT Extraction Identity)
opMapper (SegLevel, [TPrimExp Int64 VName])
user' =
forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
mapOnSegOpLambda = (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user',
mapOnSegOpBody :: KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
mapOnSegOpBody = (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user'
}
onKernelBody :: (SegLevel, [TPrimExp Int64 VName])
-> KernelBody GPUMem
-> WriterT Extraction Identity (KernelBody GPUMem)
onKernelBody (SegLevel, [TPrimExp Int64 VName])
user' KernelBody GPUMem
body = do
let (KernelBody GPUMem
body', Extraction
allocs) = (SegLevel, [TPrimExp Int64 VName])
-> Names
-> Names
-> KernelBody GPUMem
-> (KernelBody GPUMem, Extraction)
extractKernelBodyAllocations (SegLevel, [TPrimExp Int64 VName])
user' Names
bound_outside Names
bound_kernel KernelBody GPUMem
body
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Extraction
allocs
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
body'
onLambda :: (SegLevel, [TPrimExp Int64 VName])
-> Lambda GPUMem -> WriterT Extraction Identity (Lambda GPUMem)
onLambda (SegLevel, [TPrimExp Int64 VName])
user' Lambda GPUMem
lam = do
Body GPUMem
body <- (SegLevel, [TPrimExp Int64 VName])
-> Body GPUMem -> WriterT Extraction Identity (Body GPUMem)
onBody (SegLevel, [TPrimExp Int64 VName])
user' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}
genericExpandedInvariantAllocations ::
(User -> (Shape, [TPrimExp Int64 VName])) -> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations :: ((SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers Extraction
invariant_allocs = do
([RebaseMap]
rebases, Stms GPUMem
alloc_stms) <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Extraction
invariant_allocs
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
alloc_stms, forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
where
expand :: (VName, ((SegLevel, [TPrimExp Int64 VName]), SubExp, Space))
-> BuilderT GPUMem (State VNameSource) RebaseMap
expand (VName
mem, ((SegLevel, [TPrimExp Int64 VName])
user, SubExp
per_thread_size, Space
space)) = do
let num_users :: ShapeBase SubExp
num_users = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
VName
total_size <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"total_size" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$
SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_size forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
num_users)
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc (VName -> SubExp
Var VName
total_size) Space
space
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
mem forall a b. (a -> b) -> a -> b
$ (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase (SegLevel, [TPrimExp Int64 VName])
user
untouched :: d -> DimIndex d
untouched d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
newBase :: (SegLevel, [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegThread {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
num_dims :: Int
num_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape
perm :: [Int]
perm = [Int
num_dims .. Int
num_dims forall a. Num a => a -> a -> a
+ forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
users_shape forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_dims forall a. Num a => a -> a -> a
- Int
1]
root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName]
old_shape forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape))
permuted_ixfun :: IxFun (TPrimExp Int64 VName)
permuted_ixfun = forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TPrimExp Int64 VName)
root_ixfun [Int]
perm
offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
permuted_ixfun forall a b. (a -> b) -> a -> b
$
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
in IxFun (TPrimExp Int64 VName)
offset_ixfun
newBase user :: (SegLevel, [TPrimExp Int64 VName])
user@(SegGroup {}, [TPrimExp Int64 VName]
_) ([TPrimExp Int64 VName]
old_shape, PrimType
_) =
let (ShapeBase SubExp
users_shape, [TPrimExp Int64 VName]
user_ids) = (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user
root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
users_shape) forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
old_shape
offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
user_ids forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall {d}. Num d => d -> DimIndex d
untouched [TPrimExp Int64 VName]
old_shape
in IxFun (TPrimExp Int64 VName)
offset_ixfun
expandedInvariantAllocations ::
SubExp ->
Count NumGroups SubExp ->
Count GroupSize SubExp ->
Extraction ->
ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations :: SubExp
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedInvariantAllocations SubExp
num_threads (Count SubExp
num_groups) (Count SubExp
group_size) =
((SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName]))
-> Extraction -> ExpandM (Stms GPUMem, RebaseMap)
genericExpandedInvariantAllocations (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers
where
getNumUsers :: (SegLevel, [TPrimExp Int64 VName])
-> (ShapeBase SubExp, [TPrimExp Int64 VName])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gtid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads], [TPrimExp Int64 VName
gtid])
getNumUsers (SegThread {}, [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups, SubExp
group_size], [TPrimExp Int64 VName
gid, TPrimExp Int64 VName
ltid])
getNumUsers (SegGroup {}, [TPrimExp Int64 VName
gid]) = (forall d. [d] -> ShapeBase d
Shape [SubExp
num_groups], [TPrimExp Int64 VName
gid])
getNumUsers (SegLevel, [TPrimExp Int64 VName])
user = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getNumUsers: unhandled " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (SegLevel, [TPrimExp Int64 VName])
user
expandedVariantAllocations ::
SubExp ->
SegSpace ->
Stms GPUMem ->
Extraction ->
ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations :: SubExp
-> SegSpace
-> Stms GPUMem
-> Extraction
-> ExpandM (Stms GPUMem, RebaseMap)
expandedVariantAllocations SubExp
_ SegSpace
_ Stms GPUMem
_ Extraction
variant_allocs
| forall (t :: * -> *) a. Foldable t => t a -> Bool
null Extraction
variant_allocs = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)
expandedVariantAllocations SubExp
num_threads SegSpace
kspace Stms GPUMem
kstms Extraction
variant_allocs = do
let sizes_to_blocks :: [(SubExp, [(VName, Space)])]
sizes_to_blocks = Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes Extraction
variant_allocs
variant_sizes :: [SubExp]
variant_sizes = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, [(VName, Space)])]
sizes_to_blocks
(Stms GPU
slice_stms, [VName]
offsets, [VName]
size_sums) <-
SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
variant_sizes SegSpace
kspace Stms GPUMem
kstms
Stms GPUMem
slice_stms_tmp <- forall (m :: * -> *).
(HasScope GPUMem m, MonadFreshNames m) =>
Stms GPUMem -> m (Stms GPUMem)
simplifyStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadFreshNames m, HasScope GPUMem m) =>
Stms GPU -> m (Stms GPUMem)
explicitAllocationsInStms Stms GPU
slice_stms
Stms GPUMem
slice_stms' <- Stms GPUMem -> ExpandM (Stms GPUMem)
transformStms Stms GPUMem
slice_stms_tmp
let variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' :: [(VName, (SubExp, SubExp, Space))]
variant_allocs' =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
forall {a} {c}.
[(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo
(forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, [(VName, Space)])]
sizes_to_blocks)
(forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
offsets [VName]
size_sums)
memInfo :: [(a, c)] -> (VName, VName) -> [(a, (SubExp, SubExp, c))]
memInfo [(a, c)]
blocks (VName
offset, VName
total_size) =
[(a
mem, (VName -> SubExp
Var VName
offset, VName -> SubExp
Var VName
total_size, c
space)) | (a
mem, c
space) <- [(a, c)]
blocks]
([Stm GPUMem]
alloc_stms, [RebaseMap]
rebases) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName, (SubExp, SubExp, Space))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stm GPUMem, RebaseMap)
expand [(VName, (SubExp, SubExp, Space))]
variant_allocs'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPUMem
slice_stms' forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [Stm GPUMem]
alloc_stms, forall a. Monoid a => [a] -> a
mconcat [RebaseMap]
rebases)
where
expand :: (VName, (SubExp, SubExp, Space))
-> ReaderT
(Scope GPUMem)
(StateT VNameSource (Either String))
(Stm GPUMem, RebaseMap)
expand (VName
mem, (SubExp
offset, SubExp
total_size, Space
space)) = do
let allocpat :: Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
mem forall a b. (a -> b) -> a -> b
$ forall d u ret. Space -> MemInfo d u ret
MemMem Space
space]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (MemInfo SubExp NoUniqueness MemBind)
allocpat (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall inner. SubExp -> Space -> MemOp inner
Alloc SubExp
total_size Space
space,
forall k a. k -> a -> Map k a
M.singleton VName
mem forall a b. (a -> b) -> a -> b
$ SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
offset
)
num_threads' :: TPrimExp Int64 VName
num_threads' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
gtid :: TPrimExp Int64 VName
gtid = forall a. a -> TPrimExp Int64 a
le64 forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
kspace
newBase :: SubExp
-> ([TPrimExp Int64 VName], PrimType)
-> IxFun (TPrimExp Int64 VName)
newBase SubExp
size_per_thread ([TPrimExp Int64 VName]
old_shape, PrimType
pt) =
let elems_per_thread :: TPrimExp Int64 VName
elems_per_thread =
SubExp -> TPrimExp Int64 VName
pe64 SubExp
size_per_thread forall e. IntegralExp e => e -> e -> e
`quot` forall a. Num a => PrimType -> a
primByteSize PrimType
pt
root_ixfun :: IxFun (TPrimExp Int64 VName)
root_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName
elems_per_thread, TPrimExp Int64 VName
num_threads']
offset_ixfun :: IxFun (TPrimExp Int64 VName)
offset_ixfun =
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
root_ixfun forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
[forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
0 TPrimExp Int64 VName
num_threads' TPrimExp Int64 VName
1, forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]
in if forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
old_shape forall a. Eq a => a -> a -> Bool
== Int
1
then forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.coerce IxFun (TPrimExp Int64 VName)
offset_ixfun [TPrimExp Int64 VName]
old_shape
else forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape IxFun (TPrimExp Int64 VName)
offset_ixfun [TPrimExp Int64 VName]
old_shape
type RebaseMap = M.Map VName (([TPrimExp Int64 VName], PrimType) -> IxFun)
newtype OffsetM a
= OffsetM
( ReaderT
(Scope GPUMem)
(ReaderT RebaseMap (Either String))
a
)
deriving
( Functor OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
$c<* :: forall a b. OffsetM a -> OffsetM b -> OffsetM a
*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c*> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
liftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
$cliftA2 :: forall a b c. (a -> b -> c) -> OffsetM a -> OffsetM b -> OffsetM c
<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
$c<*> :: forall a b. OffsetM (a -> b) -> OffsetM a -> OffsetM b
pure :: forall a. a -> OffsetM a
$cpure :: forall a. a -> OffsetM a
Applicative,
forall a b. a -> OffsetM b -> OffsetM a
forall a b. (a -> b) -> OffsetM a -> OffsetM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> OffsetM b -> OffsetM a
$c<$ :: forall a b. a -> OffsetM b -> OffsetM a
fmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
$cfmap :: forall a b. (a -> b) -> OffsetM a -> OffsetM b
Functor,
Applicative OffsetM
forall a. a -> OffsetM a
forall a b. OffsetM a -> OffsetM b -> OffsetM b
forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> OffsetM a
$creturn :: forall a. a -> OffsetM a
>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
$c>> :: forall a b. OffsetM a -> OffsetM b -> OffsetM b
>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
$c>>= :: forall a b. OffsetM a -> (a -> OffsetM b) -> OffsetM b
Monad,
HasScope GPUMem,
LocalScope GPUMem,
MonadError String
)
runOffsetM :: Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM :: forall a. Scope GPUMem -> RebaseMap -> OffsetM a -> Either String a
runOffsetM Scope GPUMem
scope RebaseMap
offsets (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) =
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope) RebaseMap
offsets
askRebaseMap :: OffsetM RebaseMap
askRebaseMap :: OffsetM RebaseMap
askRebaseMap = forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall r (m :: * -> *). MonadReader r m => m r
ask
localRebaseMap :: (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap :: forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
f (OffsetM ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m) = forall a.
ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
-> OffsetM a
OffsetM forall a b. (a -> b) -> a -> b
$ do
Scope GPUMem
scope <- forall r (m :: * -> *). MonadReader r m => m r
ask
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local RebaseMap -> RebaseMap
f forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope GPUMem) (ReaderT RebaseMap (Either String)) a
m Scope GPUMem
scope
lookupNewBase :: VName -> ([TPrimExp Int64 VName], PrimType) -> OffsetM (Maybe IxFun)
lookupNewBase :: VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
name ([TPrimExp Int64 VName], PrimType)
x = do
RebaseMap
offsets <- OffsetM RebaseMap
askRebaseMap
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ (forall a b. (a -> b) -> a -> b
$ ([TPrimExp Int64 VName], PrimType)
x) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name RebaseMap
offsets
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody KernelBody GPUMem
kbody = do
Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
Stms GPUMem
stms' <-
forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(\Scope GPUMem
scope' -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
Scope GPUMem
scope
(forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody)
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPUMem
kbody {kernelBodyStms :: Stms GPUMem
kernelBodyStms = Stms GPUMem
stms'}
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody :: Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) = do
Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
Stms GPUMem
stms' <-
forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM
(\Scope GPUMem
scope' -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
scope' forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm)
Scope GPUMem
scope
(forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms GPUMem
stms)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPUMem
dec Stms GPUMem
stms' Result
res
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm :: Stm GPUMem -> OffsetM (Scope GPUMem, Stm GPUMem)
offsetMemoryInStm (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) = do
Exp GPUMem
e' <- Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp Exp GPUMem
e
Pat (MemInfo SubExp NoUniqueness MemBind)
pat' <- Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
offsetMemoryInPat Pat (LetDec GPUMem)
pat forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp GPUMem
e'
Scope GPUMem
scope <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
[ExpReturns]
rts <- forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall {k} (rep :: k) (m :: * -> *) inner.
(LocalScope rep m, Mem rep inner) =>
Exp rep -> m [ExpReturns]
expReturns Exp GPUMem
e') Scope GPUMem
scope
let pat'' :: Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' = forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
pick (forall dec. Pat dec -> [PatElem dec]
patElems Pat (MemInfo SubExp NoUniqueness MemBind)
pat') [ExpReturns]
rts
stm :: Stm GPUMem
stm = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (MemInfo SubExp NoUniqueness MemBind)
pat'' StmAux (ExpDec GPUMem)
dec Exp GPUMem
e'
let scope' :: Scope GPUMem
scope' = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stm GPUMem
stm forall a. Semigroup a => a -> a -> a
<> Scope GPUMem
scope
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scope GPUMem
scope', Stm GPUMem
stm)
where
pick ::
PatElem (MemInfo SubExp NoUniqueness MemBind) ->
ExpReturns ->
PatElem (MemInfo SubExp NoUniqueness MemBind)
pick :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns -> PatElem (MemInfo SubExp NoUniqueness MemBind)
pick
(PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u MemBind
_ret))
(MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsInBlock VName
m ExtIxFun
extixfun)))
| Just IxFun (TPrimExp Int64 VName)
ixfun <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun ExtIxFun
extixfun =
forall dec. VName -> dec -> PatElem dec
PatElem VName
name (forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
s NoUniqueness
u (VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
m IxFun (TPrimExp Int64 VName)
ixfun))
pick PatElem (MemInfo SubExp NoUniqueness MemBind)
p ExpReturns
_ = PatElem (MemInfo SubExp NoUniqueness MemBind)
p
instantiateIxFun :: ExtIxFun -> Maybe IxFun
instantiateIxFun :: ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
instantiateIxFun = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {a}. Ext a -> Maybe a
inst)
where
inst :: Ext a -> Maybe a
inst Ext {} = forall a. Maybe a
Nothing
inst (Free a
x) = forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
offsetMemoryInPat :: Pat LetDecMem -> [ExpReturns] -> OffsetM (Pat LetDecMem)
offsetMemoryInPat :: Pat (MemInfo SubExp NoUniqueness MemBind)
-> [ExpReturns]
-> OffsetM (Pat (MemInfo SubExp NoUniqueness MemBind))
offsetMemoryInPat (Pat [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes) [ExpReturns]
rets = do
forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
onPE [PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes [ExpReturns]
rets
where
onPE :: PatElem (MemInfo SubExp NoUniqueness MemBind)
-> ExpReturns
-> OffsetM (PatElem (MemInfo SubExp NoUniqueness MemBind))
onPE
(PatElem VName
name (MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
_)))
(MemArray PrimType
_ ShapeBase ExtSize
_ NoUniqueness
_ (Just (ReturnsNewBlock Space
_ Int
_ ExtIxFun
ixfun))) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. VName -> dec -> PatElem dec
PatElem VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext VName -> VName
unExt) ExtIxFun
ixfun
onPE PatElem (MemInfo SubExp NoUniqueness MemBind)
pe ExpReturns
_ = do
MemInfo SubExp NoUniqueness MemBind
new_dec <- forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> dec
patElemDec PatElem (MemInfo SubExp NoUniqueness MemBind)
pe
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem (MemInfo SubExp NoUniqueness MemBind)
pe {patElemDec :: MemInfo SubExp NoUniqueness MemBind
patElemDec = MemInfo SubExp NoUniqueness MemBind
new_dec}
unExt :: Ext VName -> VName
unExt (Ext Int
i) = forall dec. PatElem dec -> VName
patElemName ([PatElem (MemInfo SubExp NoUniqueness MemBind)]
pes forall a. [a] -> Int -> a
!! Int
i)
unExt (Free VName
v) = VName
v
offsetMemoryInParam :: Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam :: forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam Param (MemBound u)
fparam = do
MemBound u
fparam' <- forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> dec
paramDec Param (MemBound u)
fparam
forall (f :: * -> *) a. Applicative f => a -> f a
pure Param (MemBound u)
fparam {paramDec :: MemBound u
paramDec = MemBound u
fparam'}
offsetMemoryInMemBound :: MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound :: forall u. MemBound u -> OffsetM (MemBound u)
offsetMemoryInMemBound summary :: MemBound u
summary@(MemArray PrimType
pt ShapeBase SubExp
shape u
u (ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun)) = do
Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun, PrimType
pt)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a -> a
fromMaybe MemBound u
summary forall a b. (a -> b) -> a -> b
$ do
IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase SubExp
shape u
u forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase IxFun (TPrimExp Int64 VName)
new_base' IxFun (TPrimExp Int64 VName)
ixfun
offsetMemoryInMemBound MemBound u
summary = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemBound u
summary
offsetMemoryInBodyReturns :: BodyReturns -> OffsetM BodyReturns
offsetMemoryInBodyReturns :: BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns br :: BranchTypeMem
br@(MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u (ReturnsInBlock VName
mem ExtIxFun
ixfun))
| Just IxFun (TPrimExp Int64 VName)
ixfun' <- ExtIxFun -> Maybe (IxFun (TPrimExp Int64 VName))
isStaticIxFun ExtIxFun
ixfun = do
Maybe (IxFun (TPrimExp Int64 VName))
new_base <- VName
-> ([TPrimExp Int64 VName], PrimType)
-> OffsetM (Maybe (IxFun (TPrimExp Int64 VName)))
lookupNewBase VName
mem (forall num. IxFun num -> Shape num
IxFun.base IxFun (TPrimExp Int64 VName)
ixfun', PrimType
pt)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a -> a
fromMaybe BranchTypeMem
br forall a b. (a -> b) -> a -> b
$ do
IxFun (TPrimExp Int64 VName)
new_base' <- Maybe (IxFun (TPrimExp Int64 VName))
new_base
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt ShapeBase ExtSize
shape NoUniqueness
u forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ExtIxFun -> MemReturn
ReturnsInBlock VName
mem forall a b. (a -> b) -> a -> b
$
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
IxFun.rebase (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Ext a
Free) IxFun (TPrimExp Int64 VName)
new_base') ExtIxFun
ixfun
offsetMemoryInBodyReturns BranchTypeMem
br = forall (f :: * -> *) a. Applicative f => a -> f a
pure BranchTypeMem
br
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda Lambda GPUMem
lam = forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Lambda GPUMem
lam forall a b. (a -> b) -> a -> b
$ do
Body GPUMem
body <- Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lambda GPUMem
lam {lambdaBody :: Body GPUMem
lambdaBody = Body GPUMem
body}
offsetMemoryInLoopParams ::
[(FParam GPUMem, SubExp)] ->
([(FParam GPUMem, SubExp)] -> OffsetM a) ->
OffsetM a
offsetMemoryInLoopParams :: forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge [(FParam GPUMem, SubExp)] -> OffsetM a
f = do
let ([Param (MemInfo SubExp Uniqueness MemBind)]
params, [SubExp]
args) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam GPUMem, SubExp)]
merge
forall a. (RebaseMap -> RebaseMap) -> OffsetM a -> OffsetM a
localRebaseMap RebaseMap -> RebaseMap
extend forall a b. (a -> b) -> a -> b
$ do
[Param (MemInfo SubExp Uniqueness MemBind)]
params' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall u. Param (MemBound u) -> OffsetM (Param (MemBound u))
offsetMemoryInParam [Param (MemInfo SubExp Uniqueness MemBind)]
params
[(FParam GPUMem, SubExp)] -> OffsetM a
f forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp Uniqueness MemBind)]
params' [SubExp]
args
where
extend :: RebaseMap -> RebaseMap
extend RebaseMap
rm = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {a} {dec}. Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg RebaseMap
rm [(FParam GPUMem, SubExp)]
merge
onParamArg :: Map VName a -> (Param dec, SubExp) -> Map VName a
onParamArg Map VName a
rm (Param dec
param, Var VName
arg)
| Just a
x <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arg Map VName a
rm =
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. Param dec -> VName
paramName Param dec
param) a
x Map VName a
rm
onParamArg Map VName a
rm (Param dec, SubExp)
_ = Map VName a
rm
offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp :: Exp GPUMem -> OffsetM (Exp GPUMem)
offsetMemoryInExp (DoLoop [(FParam GPUMem, SubExp)]
merge LoopForm GPUMem
form Body GPUMem
body) = do
forall a.
[(FParam GPUMem, SubExp)]
-> ([(FParam GPUMem, SubExp)] -> OffsetM a) -> OffsetM a
offsetMemoryInLoopParams [(FParam GPUMem, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \[(FParam GPUMem, SubExp)]
merge' -> do
Body GPUMem
body' <-
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
(forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam GPUMem, SubExp)]
merge') forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm GPUMem
form)
(Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody Body GPUMem
body)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam GPUMem, SubExp)]
merge' LoopForm GPUMem
form Body GPUMem
body'
offsetMemoryInExp Exp GPUMem
e = forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPUMem OffsetM
recurse Exp GPUMem
e
where
recurse :: Mapper GPUMem GPUMem OffsetM
recurse =
forall {k} (m :: * -> *) (rep :: k). Monad m => Mapper rep rep m
identityMapper
{ mapOnBody :: Scope GPUMem -> Body GPUMem -> OffsetM (Body GPUMem)
mapOnBody = \Scope GPUMem
bscope -> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPUMem
bscope forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> OffsetM (Body GPUMem)
offsetMemoryInBody,
mapOnBranchType :: BranchType GPUMem -> OffsetM (BranchType GPUMem)
mapOnBranchType = BranchTypeMem -> OffsetM BranchTypeMem
offsetMemoryInBodyReturns,
mapOnOp :: Op GPUMem -> OffsetM (Op GPUMem)
mapOnOp = forall {op}.
MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp
}
onOp :: MemOp (HostOp GPUMem op) -> OffsetM (MemOp (HostOp GPUMem op))
onOp (Inner (SegOp SegOp SegLevel GPUMem
op)) =
forall inner. inner -> MemOp inner
Inner forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace (forall {k} lvl (rep :: k). SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op)) (forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM forall {lvl}. SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper SegOp SegLevel GPUMem
op)
where
segOpMapper :: SegOpMapper lvl GPUMem GPUMem OffsetM
segOpMapper =
forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpBody :: KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
mapOnSegOpBody = KernelBody GPUMem -> OffsetM (KernelBody GPUMem)
offsetMemoryInKernelBody,
mapOnSegOpLambda :: Lambda GPUMem -> OffsetM (Lambda GPUMem)
mapOnSegOpLambda = Lambda GPUMem -> OffsetM (Lambda GPUMem)
offsetMemoryInLambda
}
onOp MemOp (HostOp GPUMem op)
op = forall (f :: * -> *) a. Applicative f => a -> f a
pure MemOp (HostOp GPUMem op)
op
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU.GPU)
unAllocGPUStms :: Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms = Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
False
where
unAllocBody :: Body GPUMem -> Either String (Body GPU)
unAllocBody (Body BodyDec GPUMem
dec Stms GPUMem
stms Result
res) =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec GPUMem
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
unAllocKernelBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody (KernelBody BodyDec GPUMem
dec Stms GPUMem
stms [KernelResult]
res) =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec GPUMem
dec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
True Stms GPUMem
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res
unAllocStms :: Bool -> Stms GPUMem -> Either String (Stms GPU)
unAllocStms Bool
nested =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList
unAllocStm :: Bool -> Stm GPUMem -> Either String (Maybe (Stm GPU))
unAllocStm Bool
nested stm :: Stm GPUMem
stm@(Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op Alloc {}))
| Bool
nested = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String
"Cannot handle nested allocation: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Stm GPUMem
stm
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
unAllocStm Bool
_ (Let Pat (LetDec GPUMem)
pat StmAux (ExpDec GPUMem)
dec Exp GPUMem
e) =
forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {d} {u} {ret} {a}.
Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat Pat (LetDec GPUMem)
pat forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure StmAux (ExpDec GPUMem)
dec forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPUMem GPU (Either String)
unAlloc' Exp GPUMem
e)
unAllocLambda :: Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda (Lambda [LParam GPUMem]
params Body GPUMem
body [TypeBase (ShapeBase SubExp) NoUniqueness]
ret) =
forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda (forall a b. (a -> b) -> [a] -> [b]
map forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam [LParam GPUMem]
params) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPUMem -> Either String (Body GPU)
unAllocBody Body GPUMem
body forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase (ShapeBase SubExp) NoUniqueness]
ret
unAllocPat :: Pat (MemInfo d u ret) -> Either a (Pat (TypeBase (ShapeBase d) u))
unAllocPat (Pat [PatElem (MemInfo d u ret)]
pes) =
forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) from to.
Monad m =>
(from -> m to) -> PatElem from -> m (PatElem to)
rephrasePatElem (forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem)) [PatElem (MemInfo d u ret)]
pes
unAllocOp :: MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp Alloc {} = forall a b. a -> Either a b
Left String
"unAllocOp: unhandled Alloc"
unAllocOp (Inner OtherOp {}) = forall a b. a -> Either a b
Left String
"unAllocOp: unhandled OtherOp"
unAllocOp (Inner GPUBody {}) = forall a b. a -> Either a b
Left String
"unAllocOp: unhandled GPUBody"
unAllocOp (Inner (SizeOp SizeOp
op)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op
unAllocOp (Inner (SegOp SegOp SegLevel GPUMem
op)) = forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k1} {k2} (m :: * -> *) lvl (frep :: k1) (trep :: k2).
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM SegOpMapper SegLevel GPUMem GPU (Either String)
mapper SegOp SegLevel GPUMem
op
where
mapper :: SegOpMapper SegLevel GPUMem GPU (Either String)
mapper =
forall {k} (m :: * -> *) lvl (rep :: k).
Monad m =>
SegOpMapper lvl rep rep m
identitySegOpMapper
{ mapOnSegOpLambda :: Lambda GPUMem -> Either String (Lambda GPU)
mapOnSegOpLambda = Lambda GPUMem -> Either String (Lambda GPU)
unAllocLambda,
mapOnSegOpBody :: KernelBody GPUMem -> Either String (KernelBody GPU)
mapOnSegOpBody = KernelBody GPUMem -> Either String (KernelBody GPU)
unAllocKernelBody
}
unParam :: Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem
unT :: MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem
unAlloc' :: Mapper GPUMem GPU (Either String)
unAlloc' =
Mapper
{ mapOnBody :: Scope GPU -> Body GPUMem -> Either String (Body GPU)
mapOnBody = forall a b. a -> b -> a
const Body GPUMem -> Either String (Body GPU)
unAllocBody,
mapOnRetType :: RetType GPUMem -> Either String (RetType GPU)
mapOnRetType = forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
mapOnBranchType :: BranchType GPUMem -> Either String (BranchType GPU)
mapOnBranchType = forall {d} {u} {ret} {a}.
MemInfo d u ret -> Either a (TypeBase (ShapeBase d) u)
unT,
mapOnFParam :: FParam GPUMem -> Either String (FParam GPU)
mapOnFParam = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
mapOnLParam :: LParam GPUMem -> Either String (LParam GPU)
mapOnLParam = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Param (TypeBase (ShapeBase d) u)
unParam,
mapOnOp :: Op GPUMem -> Either String (Op GPU)
mapOnOp = MemOp (HostOp GPUMem ()) -> Either String (HostOp GPU (SOAC GPU))
unAllocOp,
mapOnSubExp :: SubExp -> Either String SubExp
mapOnSubExp = forall a b. b -> Either a b
Right,
mapOnVName :: VName -> Either String VName
mapOnVName = forall a b. b -> Either a b
Right
}
unMem :: MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem :: forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem (MemPrim PrimType
pt) = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
unMem (MemArray PrimType
pt ShapeBase d
shape u
u ret
_) = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase d
shape u
u
unMem (MemAcc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u) = forall shape u.
VName
-> ShapeBase SubExp
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> u
-> TypeBase shape u
Acc VName
acc ShapeBase SubExp
ispace [TypeBase (ShapeBase SubExp) NoUniqueness]
ts u
u
unMem MemMem {} = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
unAllocScope :: Scope GPUMem -> Scope GPU.GPU
unAllocScope :: Scope GPUMem -> Scope GPU
unAllocScope = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall {k} {k} {rep :: k} {d} {u} {rep :: k} {ret} {d} {u} {ret}
{d} {u} {ret}.
(LetDec rep ~ TypeBase (ShapeBase d) u,
LetDec rep ~ MemInfo d u ret,
FParamInfo rep ~ TypeBase (ShapeBase d) u,
FParamInfo rep ~ MemInfo d u ret,
LParamInfo rep ~ TypeBase (ShapeBase d) u,
LParamInfo rep ~ MemInfo d u ret) =>
NameInfo rep -> NameInfo rep
unInfo
where
unInfo :: NameInfo rep -> NameInfo rep
unInfo (LetName LetDec rep
dec) = forall {k} (rep :: k). LetDec rep -> NameInfo rep
LetName forall a b. (a -> b) -> a -> b
$ forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LetDec rep
dec
unInfo (FParamName FParamInfo rep
dec) = forall {k} (rep :: k). FParamInfo rep -> NameInfo rep
FParamName forall a b. (a -> b) -> a -> b
$ forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem FParamInfo rep
dec
unInfo (LParamName LParamInfo rep
dec) = forall {k} (rep :: k). LParamInfo rep -> NameInfo rep
LParamName forall a b. (a -> b) -> a -> b
$ forall d u ret. MemInfo d u ret -> TypeBase (ShapeBase d) u
unMem LParamInfo rep
dec
unInfo (IndexName IntType
it) = forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
it
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes :: Extraction -> [(SubExp, [(VName, Space)])]
removeCommonSizes = forall k a. Map k a -> [(k, a)]
M.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {k} {a} {b} {a}.
Ord k =>
Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList
where
comb :: Map k [(a, b)] -> (a, (a, k, b)) -> Map k [(a, b)]
comb Map k [(a, b)]
m (a
mem, (a
_, k
size, b
space)) = forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith forall a. [a] -> [a] -> [a]
(++) k
size [(a
mem, b
space)] Map k [(a, b)]
m
sliceKernelSizes ::
SubExp ->
[SubExp] ->
SegSpace ->
Stms GPUMem ->
ExpandM (Stms GPU.GPU, [VName], [VName])
sliceKernelSizes :: SubExp
-> [SubExp]
-> SegSpace
-> Stms GPUMem
-> ExpandM (Stms GPU, [VName], [VName])
sliceKernelSizes SubExp
num_threads [SubExp]
sizes SegSpace
space Stms GPUMem
kstms = do
Stms GPU
kstms' <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPUMem -> Either String (Stms GPU)
unAllocGPUStms Stms GPUMem
kstms
let num_sizes :: Int
num_sizes = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
sizes
i64s :: [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s = forall a. Int -> a -> [a]
replicate Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Scope GPU
kernels_scope <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Scope GPUMem -> Scope GPU
unAllocScope
(Lambda GPU
max_lam, Stms GPU
_) <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope forall a b. (a -> b) -> a -> b
$ do
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"y" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
(Result
zs, Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall a b. (a -> b) -> a -> b
$ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x, Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y) ->
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"z" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
Int64) (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
x) (VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
y)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
xs forall a. [a] -> [a] -> [a]
++ [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ys) (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s
Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flat_gtid" (forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
Int64))
(Lambda GPU
size_lam', Stms GPU
_) <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope forall a b. (a -> b) -> a -> b
$ do
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
(Result
zs, Stms GPU
stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope
(forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
params forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam])
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms
forall a b. (a -> b) -> a -> b
$ do
let ([VName]
kspace_gtids, [SubExp]
kspace_dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
new_inds :: [TPrimExp Int64 VName]
new_inds =
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex
(forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
kspace_dims)
(SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam)
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames (forall a b. (a -> b) -> [a] -> [b]
map forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
kspace_gtids) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp [TPrimExp Int64 VName]
new_inds
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stms GPU
kstms'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
sizes
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
GPU.simplifyLambda (forall {k} (rep :: k).
[LParam rep]
-> Body rep
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda rep
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)
flat_gtid_lparam] (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms Result
zs) [TypeBase (ShapeBase SubExp) NoUniqueness]
i64s)
(([VName]
maxes_per_thread, [VName]
size_sums), Stms GPU
slice_stms) <- forall a b c. (a -> b -> c) -> b -> a -> c
flip forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT Scope GPU
kernels_scope forall a b. (a -> b) -> a -> b
$ do
Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat <-
[Ident] -> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
basicPat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
num_sizes (forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent String
"max_per_thread" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
SubExp
w <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"size_slice_w"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
VName
thread_space_iota <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"thread_space_iota" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
let red_op :: SegBinOp GPU
red_op =
forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp
Commutativity
Commutative
Lambda GPU
max_lam
(forall a. Int -> a -> [a]
replicate Int
num_sizes forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
forall a. Monoid a => a
mempty
SegLevel
lvl <- forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> m SegLevel
segThread String
"segred"
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegLevel
lvl Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat SubExp
w [SegBinOp GPU
red_op] Lambda GPU
size_lam' [VName
thread_space_iota]
[VName]
size_sums <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat) forall a b. (a -> b) -> a -> b
$ \VName
threads_max ->
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"size_sum" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
threads_max) SubExp
num_threads
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Pat dec -> [VName]
patNames Pat (TypeBase (ShapeBase SubExp) NoUniqueness)
pat, [VName]
size_sums)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
slice_stms, [VName]
maxes_per_thread, [VName]
size_sums)