module Futhark.Optimise.ReduceDeviceSyncs (reduceDeviceSyncs) where
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State hiding (State)
import Data.Bifunctor (second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.List (transpose, zip4)
import Data.Map.Strict qualified as M
import Data.Sequence ((<|), (><), (|>))
import Data.Text qualified as T
import Futhark.Construct (fullSlice, mkBody, sliceDim)
import Futhark.Error
import Futhark.IR.GPU
import Futhark.MonadFreshNames
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
import Futhark.Pass
import Futhark.Transform.Substitute
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs :: Pass GPU GPU
reduceDeviceSyncs =
forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
String
"reduce device synchronizations"
String
"Move host statements to device to reduce blocking memory operations."
forall a b. (a -> b) -> a -> b
$ \Prog GPU
prog -> do
let hof :: HostOnlyFuns
hof = [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPU
prog
consts_mt :: MigrationTable
consts_mt = HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPU
prog) (forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPU
prog)
Stms GPU
consts <- forall {m :: * -> *}.
MonadFreshNames m =>
MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog GPU
prog
[FunDef GPU]
funs <- forall a b. (a -> PassM b) -> [a] -> PassM [b]
parPass (forall {m :: * -> *}.
MonadFreshNames m =>
HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt) (forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog GPU
prog)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Prog GPU
prog {progConsts :: Stms GPU
progConsts = Stms GPU
consts, progFuns :: [FunDef GPU]
progFuns = [FunDef GPU]
funs}
where
onConsts :: MigrationTable -> Stms GPU -> m (Stms GPU)
onConsts MigrationTable
consts_mt Stms GPU
stms =
forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
consts_mt (Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms)
onFun :: HostOnlyFuns -> MigrationTable -> FunDef GPU -> m (FunDef GPU)
onFun HostOnlyFuns
hof MigrationTable
consts_mt FunDef GPU
fd = do
let mt :: MigrationTable
mt = MigrationTable
consts_mt forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd
forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef :: FunDef GPU -> ReduceM (FunDef GPU)
optimizeFunDef FunDef GPU
fd = do
let body :: Body GPU
body = forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPU
fd
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ FunDef GPU
fd {funDefBody :: Body GPU
funDefBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody :: Body GPU -> ReduceM (Body GPU)
optimizeBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms Stms GPU
stms
Result
res' <- Result -> ReduceM Result
resolveResult Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms :: Stms GPU -> ReduceM (Stms GPU)
optimizeStms = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm forall a. Monoid a => a
mempty
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
optimizeStm Stms GPU
out Stm GPU
stm = do
Bool
move <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Stm GPU -> MigrationTable -> Bool
shouldMoveStm Stm GPU
stm)
if Bool
move
then Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out Stm GPU
stm
else case forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm of
BasicOp (Update Safety
safety VName
arr Slice SubExp
slice (Var VName
v))
| Just [SubExp]
_ <- forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice -> do
Maybe VName
dev <- SubExp -> ReduceM (Maybe VName)
storedScalar (VName -> SubExp
Var VName
v)
case Maybe VName
dev of
Maybe VName
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Just VName
dst -> do
let dims :: [DimIndex SubExp]
dims = forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice
let ([DimIndex SubExp]
outer, [DimFix SubExp
i]) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
dims forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
dims
let one :: SubExp
one = IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
let slice' :: Slice SubExp
slice' = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
outer forall a. [a] -> [a] -> [a]
++ [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i SubExp
one SubExp
one]
let e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr Slice SubExp
slice' (VName -> SubExp
Var VName
dst))
let stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = forall {k} {rep :: k}. Exp rep
e}
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
BasicOp (Replicate (Shape [SubExp]
dims) (Var VName
v))
| Pat [PatElem VName
n LetDec GPU
arr_t] <- forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm -> do
VName
v' <- VName -> ReduceM VName
resolveName VName
v
let v_kept_on_device :: Bool
v_kept_on_device = VName
v forall a. Eq a => a -> a -> Bool
/= VName
v'
Bool
gpubody_ok <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
case Bool
v_kept_on_device of
Bool
False -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Bool
True
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims,
Just Type
t' <- forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 LetDec GPU
arr_t,
Bool
gpubody_ok -> do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [SubExp]
dims) (VName -> SubExp
Var VName
v)
let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) forall {k} {rep :: k}. Exp rep
e'
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm})
Bool
True
| forall a. [a] -> a
last [SubExp]
dims forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1 ->
let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
init [SubExp]
dims) (VName -> SubExp
Var VName
v')
stm' :: Stm GPU
stm' = Stm GPU
stm {stmExp :: Exp GPU
stmExp = forall {k} {rep :: k}. Exp rep
e'}
in forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm')
Bool
True -> do
VName
n' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
let dims' :: [SubExp]
dims' = [SubExp]
dims forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let arr_t' :: Type
arr_t' = forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (forall shape u. TypeBase shape u -> PrimType
elemType LetDec GPU
arr_t) (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims') NoUniqueness
NoUniqueness
let pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
arr_t']
let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
dims) (VName -> SubExp
Var VName
v')
let repl :: Stm GPU
repl = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) forall {k} {rep :: k}. Exp rep
e'
let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
let slice :: [DimIndex SubExp]
slice = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims LetDec GPU
arr_t)
let slice' :: [DimIndex SubExp]
slice' = [DimIndex SubExp]
slice forall a. [a] -> [a] -> [a]
++ [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
let idx :: Exp rep
idx = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
n' (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice')
let index :: Stm GPU
index = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) StmAux ()
aux forall {k} {rep :: k}. Exp rep
idx
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
repl forall a. Seq a -> a -> Seq a
|> Stm GPU
index)
BasicOp {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Apply {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm)
Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody (MatchDec [BranchType GPU]
btypes MatchSort
sort) -> do
[Stms GPU]
cases_stms <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Stms GPU -> ReduceM (Stms GPU)
optimizeStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
let cases_res :: [Result]
cases_res = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Result
bodyResult forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
Stms GPU
defbody_stms <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
defbody
let defbody_res :: Result
defbody_res = forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
defbody
let bmerge :: ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
bmerge ([(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms) (PatElem Type
pe, Result
reses, ExtType
bt) = do
let onHost :: SubExp -> ReduceM Bool
onHost (Var VName
v) = (VName
v ==) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
v
onHost SubExp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
Bool
on_host <- forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> ReduceM Bool
onHost forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
reses
if Bool
on_host
then
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, Result
reses, ExtType
bt) forall a. a -> [a] -> [a]
: [(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms)
else do
([Stms GPU]
all_stms', [VName]
arrs) <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Stms GPU]
all_stms Result
reses) forall a b. (a -> b) -> a -> b
$ \(Stms GPU
stms, SubExpRes
res) ->
Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms (SubExpRes -> SubExp
resSubExp SubExpRes
res) (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe)
PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
let bt' :: ExtType
bt' = forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase ExtSize) u
staticShapes1 (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe')
reses' :: Result
reses' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Certs -> SubExp -> SubExpRes
SubExpRes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> Certs
resCerts Result
reses) (forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe', Result
reses', ExtType
bt') forall a. a -> [a] -> [a]
: [(PatElem Type, Result, ExtType)]
acc, [Stms GPU]
all_stms')
pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
([(PatElem Type, Result, ExtType)]
acc, ~(Stms GPU
defbody_stms' : [Stms GPU]
cases_stms')) <-
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem Type, Result, ExtType)], [Stms GPU])
-> (PatElem Type, Result, ExtType)
-> ReduceM ([(PatElem Type, Result, ExtType)], [Stms GPU])
bmerge ([], Stms GPU
defbody_stms forall a. a -> [a] -> [a]
: [Stms GPU]
cases_stms) forall a b. (a -> b) -> a -> b
$
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem Type]
pes (forall a. [[a]] -> [[a]]
transpose forall a b. (a -> b) -> a -> b
$ Result
defbody_res forall a. a -> [a] -> [a]
: [Result]
cases_res) [BranchType GPU]
btypes
let ([PatElem Type]
pes', [Result]
reses, [ExtType]
btypes') = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (forall a. [a] -> [a]
reverse [(PatElem Type, Result, ExtType)]
acc)
let cases' :: [Case (Body GPU)]
cases' =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall body. [Maybe PrimValue] -> body -> Case body
Case (forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> [Maybe PrimValue]
casePat [Case (Body GPU)]
cases) forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody [Stms GPU]
cases_stms' forall a b. (a -> b) -> a -> b
$
forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$
forall a. [[a]] -> [[a]]
transpose [Result]
reses
defbody' :: Body GPU
defbody' = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
defbody_stms' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> a
head [Result]
reses
e' :: Exp GPU
e' = forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
btypes' MatchSort
sort)
stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
DoLoop [(FParam GPU, SubExp)]
ps LoopForm GPU
lf Body GPU
b -> do
([(Param DeclType, SubExp)]
params, LoopForm GPU
lform, Body GPU
body) <- ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn ([(FParam GPU, SubExp)]
ps, LoopForm GPU
lf, Body GPU
b)
let lmerge :: ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param DeclType, SubExp)
param, MigrationStatus
StayOnHost) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe, (Param DeclType, SubExp)
param) forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds)
lmerge ([(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms, Stms GPU
rebinds) (PatElem Type
pe, (Param Attrs
_ VName
pn DeclType
pt, SubExp
pval), MigrationStatus
_) = do
PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
(Stms GPU
stms', VName
arr) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
pval (forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
pt)
VName
pn' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
pn
let pt' :: DeclType
pt' = forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe') Uniqueness
Nonunique
let pval' :: SubExp
pval' = VName -> SubExp
Var VName
arr
let param' :: (Param DeclType, SubExp)
param' = (forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
pn' DeclType
pt', SubExp
pval')
Stms GPU
rebinds' <- (PatElem Type
pe {patElemName :: VName
patElemName = VName
pn}) PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
pn', Stms GPU
rebinds)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PatElem Type
pe', (Param DeclType, SubExp)
param') forall a. a -> [a] -> [a]
: [(PatElem Type, (Param DeclType, SubExp))]
res, Stms GPU
stms', Stms GPU
rebinds')
MigrationTable
mt <- forall r (m :: * -> *). MonadReader r m => m r
ask
let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let mss :: [MigrationStatus]
mss = forall a b. (a -> b) -> [a] -> [b]
map (\(Param Attrs
_ VName
n DeclType
_, SubExp
_) -> VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt) [(Param DeclType, SubExp)]
params
([(PatElem Type, (Param DeclType, SubExp))]
zipped', Stms GPU
out', Stms GPU
rebinds) <-
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
-> (PatElem Type, (Param DeclType, SubExp), MigrationStatus)
-> ReduceM
([(PatElem Type, (Param DeclType, SubExp))], Stms GPU, Stms GPU)
lmerge ([], Stms GPU
out, forall a. Monoid a => a
mempty) (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [PatElem Type]
pes [(Param DeclType, SubExp)]
params [MigrationStatus]
mss)
let ([PatElem Type]
pes', [(Param DeclType, SubExp)]
params') = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a. [a] -> [a]
reverse [(PatElem Type, (Param DeclType, SubExp))]
zipped')
let body1 :: Body GPU
body1 = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
rebinds forall a. Seq a -> Seq a -> Seq a
>< forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body}
Body GPU
body2 <- Body GPU -> ReduceM (Body GPU)
optimizeBody Body GPU
body1
let zipped :: [(MigrationStatus, SubExpRes, SubExp, Type)]
zipped =
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4
[MigrationStatus]
mss
(forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body2)
(forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
(forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes)
let rstore :: (Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result)
rstore (Stms GPU
bstms, Result
res) (MigrationStatus
StayOnHost, SubExpRes
r, SubExp
_, Type
_) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms, SubExpRes
r forall a. a -> [a] -> [a]
: Result
res)
rstore (Stms GPU
bstms, Result
res) (MigrationStatus
_, SubExpRes Certs
certs SubExp
_, SubExp
se, Type
t) = do
(Stms GPU
bstms', VName
dev) <- Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
bstms SubExp
se Type
t
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
bstms', Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs (VName -> SubExp
Var VName
dev) forall a. a -> [a] -> [a]
: Result
res)
(Stms GPU
bstms, Result
res) <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Stms GPU, Result)
-> (MigrationStatus, SubExpRes, SubExp, Type)
-> ReduceM (Stms GPU, Result)
rstore (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body2, []) [(MigrationStatus, SubExpRes, SubExp, Type)]
zipped
let body3 :: Body GPU
body3 = Body GPU
body2 {bodyStms :: Stms GPU
bodyStms = Stms GPU
bstms, bodyResult :: Result
bodyResult = forall a. [a] -> [a]
reverse Result
res}
let e' :: Exp GPU
e' = forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
params' LoopForm GPU
lform Body GPU
body3
let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out' forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
WithAcc [WithAccInput GPU]
inputs Lambda GPU
lmd -> do
let getAcc :: TypeBase shape u -> VName
getAcc (Acc VName
a ShapeBase SubExp
_ [Type]
_ u
_) = VName
a
getAcc TypeBase shape u
_ =
forall a. String -> a
compilerBugS
String
"Type error: WithAcc expression did not return accumulator."
let accs :: [(VName, WithAccInput GPU)]
accs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t WithAccInput GPU
i -> (forall {shape} {u}. TypeBase shape u -> VName
getAcc Type
t, WithAccInput GPU
i)) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd) [WithAccInput GPU]
inputs
[WithAccInput GPU]
inputs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput) [(VName, WithAccInput GPU)]
accs
let body :: Body GPU
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lmd
Stms GPU
stms' <- Stms GPU -> ReduceM (Stms GPU)
optimizeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)
let rewrite :: (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
rewrite (SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe) =
do
SubExp
se' <- SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
if SubExp
se forall a. Eq a => a -> a -> Bool
== SubExp
se'
then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se, Type
t, PatElem Type
pe)
else do
PatElem Type
pe' <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem PatElem Type
pe
let t' :: Type
t' = forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem Type
pe'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs SubExp
se', Type
t', PatElem Type
pe')
let len :: Int
len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs
let (Result
res0, Result
res1) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
let ([Type]
rts0, [Type]
rts1) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
len (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
lmd)
let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let ([PatElem Type]
pes0, [PatElem Type]
pes1) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [PatElem Type]
pes forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
res1) [PatElem Type]
pes
(Result
res1', [Type]
rts1', [PatElem Type]
pes1') <- forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExpRes, Type, PatElem Type)
-> ReduceM (SubExpRes, Type, PatElem Type)
rewrite (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Result
res1 [Type]
rts1 [PatElem Type]
pes1)
let res' :: Result
res' = Result
res0 forall a. [a] -> [a] -> [a]
++ Result
res1'
let rts' :: [Type]
rts' = [Type]
rts0 forall a. [a] -> [a] -> [a]
++ [Type]
rts1'
let pes' :: [PatElem Type]
pes' = [PatElem Type]
pes0 forall a. [a] -> [a] -> [a]
++ [PatElem Type]
pes1'
let body' :: Body GPU
body' = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res'
let lmd' :: Lambda GPU
lmd' = Lambda GPU
lmd {lambdaBody :: Body GPU
lambdaBody = Body GPU
body', lambdaReturnType :: [Type]
lambdaReturnType = [Type]
rts'}
let e' :: Exp GPU
e' = forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lmd'
let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes') (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm GPU
stm) Exp GPU
e'
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {dec}.
Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [PatElem Type]
pes')
Op Op GPU
op -> do
HostOp GPU (SOAC GPU)
op' <- forall op. HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp Op GPU
op
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm {stmExp :: Exp GPU
stmExp = forall {k} (rep :: k). Op rep -> Exp rep
Op HostOp GPU (SOAC GPU)
op'})
where
addRead :: Stms GPU -> (PatElem Type, PatElem dec) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
n Type
_), PatElem VName
dev dec
_)
| VName
n forall a. Eq a => a -> a -> Bool
== VName
dev = forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
| Bool
otherwise = PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
rewriteForIn ::
([(FParam GPU, SubExp)], LoopForm GPU, Body GPU) ->
ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
-> ReduceM ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
rewriteForIn loop :: ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop@([(FParam GPU, SubExp)]
_, WhileLoop {}, Body GPU
_) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(FParam GPU, SubExp)], LoopForm GPU, Body GPU)
loop
rewriteForIn ([(FParam GPU, SubExp)]
params, ForLoop VName
i IntType
t SubExp
n [(LParam GPU, VName)]
elems, Body GPU
body) = do
MigrationTable
mt <- forall r (m :: * -> *). MonadReader r m => m r
ask
let ([(Param Type, VName)]
elems', Stms GPU
stms') = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall {dec}.
Typed dec =>
MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt) ([], forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body) [(LParam GPU, VName)]
elems
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(FParam GPU, SubExp)]
params, forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
t SubExp
n [(Param Type, VName)]
elems', Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'})
where
inline :: MigrationTable
-> (Param dec, VName)
-> ([(Param dec, VName)], Stms GPU)
-> ([(Param dec, VName)], Stms GPU)
inline MigrationTable
mt (Param dec
x, VName
arr) ([(Param dec, VName)]
arrs, Stms GPU
stms)
| VName
pn <- forall dec. Param dec -> VName
paramName Param dec
x,
Bool -> Bool
not (VName -> MigrationTable -> Bool
usedOnHost VName
pn MigrationTable
mt) =
let pt :: Type
pt = forall t. Typed t => t -> Type
typeOf Param dec
x
stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (forall dec. VName -> dec -> PatElem dec
PatElem VName
pn Type
pt) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ forall {u}. VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr Type
pt)
in ([(Param dec, VName)]
arrs, Stm GPU
stm forall a. a -> Seq a -> Seq a
<| Stms GPU
stms)
| Bool
otherwise =
((Param dec
x, VName
arr) forall a. a -> [a] -> [a]
: [(Param dec, VName)]
arrs, Stms GPU
stms)
index :: VName -> TypeBase (ShapeBase SubExp) u -> BasicOp
index VName
arr TypeBase (ShapeBase SubExp) u
of_type =
VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) u
of_type)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput :: VName -> WithAccInput GPU -> ReduceM (WithAccInput GPU)
optimizeWithAccInput VName
_ (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
Nothing) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, forall a. Maybe a
Nothing)
optimizeWithAccInput VName
acc (ShapeBase SubExp
shape, [VName]
arrs, Just (Lambda GPU
op, [SubExp]
nes)) = do
Bool
device_only <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
shouldMove VName
acc)
if Bool
device_only
then do
Lambda GPU
op' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
op
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
else do
let body :: Body GPU
body = forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
op
Stms GPU
stms' <- forall a. ReduceM a -> ReduceM a
noGPUBody forall a b. (a -> b) -> a -> b
$ Stms GPU -> ReduceM (Stms GPU)
optimizeStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)
let op' :: Lambda GPU
op' = Lambda GPU
op {lambdaBody :: Body GPU
lambdaBody = Body GPU
body {bodyStms :: Stms GPU
bodyStms = Stms GPU
stms'}}
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ShapeBase SubExp
shape, [VName]
arrs, forall a. a -> Maybe a
Just (Lambda GPU
op', [SubExp]
nes))
optimizeHostOp :: HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp :: forall op. HostOp GPU op -> ReduceM (HostOp GPU op)
optimizeHostOp (SegOp (SegMap SegLevel
lvl SegSpace
space [Type]
types KernelBody GPU
kbody)) =
forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [Type]
types KernelBody GPU
kbody)) = do
[SegBinOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops [Type]
types KernelBody GPU
kbody)) = do
[SegBinOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp [SegBinOp GPU]
ops
forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegLevel
lvl SegSpace
space [SegBinOp GPU]
ops' [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops [Type]
types KernelBody GPU
kbody)) = do
[HistOp GPU]
ops' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp [HistOp GPU]
ops
forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegLevel
lvl SegSpace
space [HistOp GPU]
ops' [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody
optimizeHostOp (SizeOp SizeOp
op) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp SizeOp
op)
optimizeHostOp OtherOp {} =
forall a. String -> a
compilerBugS String
"optimizeHostOp: unhandled OtherOp"
optimizeHostOp (GPUBody [Type]
types Body GPU
body) =
forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body
withSuffix :: Name -> String -> Name
withSuffix :: Name -> String -> Name
withSuffix Name
name String
sfx = Text -> Name
nameFromText forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text
T.append (Name -> Text
nameToText Name
name) (String -> Text
T.pack String
sfx)
newtype ReduceM a = ReduceM (StateT State (Reader MigrationTable) a)
deriving
( forall a b. a -> ReduceM b -> ReduceM a
forall a b. (a -> b) -> ReduceM a -> ReduceM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ReduceM b -> ReduceM a
$c<$ :: forall a b. a -> ReduceM b -> ReduceM a
fmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
$cfmap :: forall a b. (a -> b) -> ReduceM a -> ReduceM b
Functor,
Functor ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
$c<* :: forall a b. ReduceM a -> ReduceM b -> ReduceM a
*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$c*> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
liftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ReduceM a -> ReduceM b -> ReduceM c
<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
$c<*> :: forall a b. ReduceM (a -> b) -> ReduceM a -> ReduceM b
pure :: forall a. a -> ReduceM a
$cpure :: forall a. a -> ReduceM a
Applicative,
Applicative ReduceM
forall a. a -> ReduceM a
forall a b. ReduceM a -> ReduceM b -> ReduceM b
forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ReduceM a
$creturn :: forall a. a -> ReduceM a
>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
$c>> :: forall a b. ReduceM a -> ReduceM b -> ReduceM b
>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
$c>>= :: forall a b. ReduceM a -> (a -> ReduceM b) -> ReduceM b
Monad,
MonadState State,
MonadReader MigrationTable
)
runReduceM :: MonadFreshNames m => MigrationTable -> ReduceM a -> m a
runReduceM :: forall (m :: * -> *) a.
MonadFreshNames m =>
MigrationTable -> ReduceM a -> m a
runReduceM MigrationTable
mt (ReduceM StateT State (Reader MigrationTable) a
m) = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second State -> VNameSource
stateNameSource (forall r a. Reader r a -> r -> a
runReader (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT State (Reader MigrationTable) a
m (VNameSource -> State
initialState VNameSource
src)) MigrationTable
mt)
instance MonadFreshNames ReduceM where
getNameSource :: ReduceM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
putNameSource :: VNameSource -> ReduceM ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
data State = State
{
State -> VNameSource
stateNameSource :: VNameSource,
State -> IntMap (Name, Type, VName, Bool)
stateMigrated :: IM.IntMap (Name, Type, VName, Bool),
State -> Bool
stateGPUBodyOk :: Bool
}
initialState :: VNameSource -> State
initialState :: VNameSource -> State
initialState VNameSource
ns =
State
{ stateNameSource :: VNameSource
stateNameSource = VNameSource
ns,
stateMigrated :: IntMap (Name, Type, VName, Bool)
stateMigrated = forall a. Monoid a => a
mempty,
stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
True
}
noGPUBody :: ReduceM a -> ReduceM a
noGPUBody :: forall a. ReduceM a -> ReduceM a
noGPUBody ReduceM a
m = do
Bool
prev <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
False}
a
res <- ReduceM a
m
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGPUBodyOk :: Bool
stateGPUBodyOk = Bool
prev}
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
res
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem :: PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (PatElem VName
n Type
t) = do
let name :: Name
name = VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_dev"
VName
dev <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let dev_t :: Type
dev_t = Type
t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. VName -> dec -> PatElem dec
PatElem VName
dev Type
dev_t)
movedTo :: Ident -> VName -> ReduceM ()
movedTo :: Ident -> VName -> ReduceM ()
movedTo = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
False
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy :: Ident -> VName -> ReduceM ()
aliasedBy = Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
True
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration :: Bool -> Ident -> VName -> ReduceM ()
recordMigration Bool
host (Ident VName
x Type
t) VName
arr =
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let migrated :: IntMap (Name, Type, VName, Bool)
migrated = State -> IntMap (Name, Type, VName, Bool)
stateMigrated State
st
entry :: (Name, Type, VName, Bool)
entry = (VName -> Name
baseName VName
x, Type
t, VName
arr, Bool
host)
migrated' :: IntMap (Name, Type, VName, Bool)
migrated' = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
x) (Name, Type, VName, Bool)
entry IntMap (Name, Type, VName, Bool)
migrated
in State
st {stateMigrated :: IntMap (Name, Type, VName, Bool)
stateMigrated = IntMap (Name, Type, VName, Bool)
migrated'}
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo :: PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
migratedTo PatElem Type
pe (VName
dev, Stms GPU
stms) = do
Bool
used <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VName -> MigrationTable -> Bool
usedOnHost forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
if Bool
used
then forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem Type
pe Ident -> VName -> ReduceM ()
`aliasedBy` VName
dev forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (VName -> Exp GPU
eIndex VName
dev))
else forall dec. Typed dec => PatElem dec -> Ident
patElemIdent PatElem Type
pe Ident -> VName -> ReduceM ()
`movedTo` VName
dev forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar :: Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n = do
Maybe (Name, Type, VName, Bool)
entry <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
case Maybe (Name, Type, VName, Bool)
entry of
Maybe (Name, Type, VName, Bool)
Nothing ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
_, Type
_, VName
_, Bool
True) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
n)
Just (Name
name, Type
t, VName
arr, Bool
_) ->
do
VName
n' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (Name -> Int -> VName
VName Name
name Int
0)
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (VName -> Exp GPU
eIndex VName
arr)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, VName
n')
eIndex :: VName -> Exp GPU
eIndex :: VName -> Exp GPU
eIndex VName
arr = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind :: PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ())
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar :: SubExp -> ReduceM (Maybe VName)
storedScalar (Constant PrimValue
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
storedScalar (Var VName
n) = do
Maybe (Name, Type, VName, Bool)
entry <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Name
_, Type
_, VName
arr, Bool
_) -> VName
arr) Maybe (Name, Type, VName, Bool)
entry
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar :: Stms GPU -> SubExp -> Type -> ReduceM (Stms GPU, VName)
storeScalar Stms GPU
stms SubExp
se Type
t = do
Maybe (Name, Type, VName, Bool)
entry <- case SubExp
se of
Var VName
n -> forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
SubExp
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
case Maybe (Name, Type, VName, Bool)
entry of
Just (Name
_, Type
_, VName
arr, Bool
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms, VName
arr)
Maybe (Name, Type, VName, Bool)
Nothing -> do
Bool
gpubody_ok <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Bool
stateGPUBodyOk
case SubExp
se of
Var VName
n | Bool
gpubody_ok -> do
VName
n' <- forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind (forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (forall (f :: * -> *) a. Applicative f => a -> f a
pure Stm GPU
stm)
let dev :: VName
dev = forall dec. PatElem dec -> VName
patElemName forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody, VName
dev)
Var VName
n -> do
PatElem Type
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
let shape :: ShapeBase SubExp
shape = forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1]
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
shape SubExp
se)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
SubExp
_ -> do
let n :: VName
n = Name -> Int -> VName
VName (String -> Name
nameFromString String
"const") Int
0
PatElem Type
pe <- PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem (forall dec. VName -> dec -> PatElem dec
PatElem VName
n Type
t)
let stm :: Stm GPU
stm = PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms forall a. Seq a -> a -> Seq a
|> Stm GPU
stm, forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
resolveName :: VName -> ReduceM VName
resolveName :: VName -> ReduceM VName
resolveName VName
n = do
Maybe (Name, Type, VName, Bool)
entry <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap (Name, Type, VName, Bool)
stateMigrated
case Maybe (Name, Type, VName, Bool)
entry of
Maybe (Name, Type, VName, Bool)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, Type
_, VName
_, Bool
True) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n
Just (Name
_, Type
_, VName
arr, Bool
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp :: SubExp -> ReduceM SubExp
resolveSubExp (Var VName
n) = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ReduceM VName
resolveName VName
n
resolveSubExp SubExp
cnst = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
cnst
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes :: SubExpRes -> ReduceM SubExpRes
resolveSubExpRes (SubExpRes Certs
certs SubExp
se) =
Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ReduceM SubExp
resolveSubExp SubExp
se
resolveResult :: Result -> ReduceM Result
resolveResult :: Result -> ReduceM Result
resolveResult = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> ReduceM SubExpRes
resolveSubExpRes
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm :: Stms GPU -> Stm GPU -> ReduceM (Stms GPU)
moveStm Stms GPU
out (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux (BasicOp (ArrayLit [SubExp
se] Type
t')))
| Pat [PatElem VName
n LetDec GPU
_] <- Pat (LetDec GPU)
pat =
do
let n' :: VName
n' = Name -> Int -> VName
VName (VName -> Name
baseName VName
n Name -> String -> Name
`withSuffix` String
"_inner") Int
0
let pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t']
let e' :: Exp rep
e' = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp SubExp
se)
let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux (ExpDec GPU)
aux forall {k} {rep :: k}. Exp rep
e'
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm')
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody {stmPat :: Pat (LetDec GPU)
stmPat = Pat (LetDec GPU)
pat})
moveStm Stms GPU
out Stm GPU
stm = do
Stm GPU
gpubody <- RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody (Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm)
let arrs :: [(PatElem Type, PatElem Type)]
arrs = forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm) (forall dec. Pat dec -> [PatElem dec]
patElems forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
gpubody)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead (Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
gpubody) [(PatElem Type, PatElem Type)]
arrs
where
addRead :: Stms GPU -> (PatElem Type, PatElem Type) -> ReduceM (Stms GPU)
addRead Stms GPU
stms (pe :: PatElem Type
pe@(PatElem VName
_ Type
t), PatElem VName
dev Type
dev_t) =
let add' :: Exp GPU -> f (Stms GPU)
add' Exp GPU
e = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Stms GPU
stms forall a. Seq a -> a -> Seq a
|> PatElem Type -> Exp GPU -> Stm GPU
bind PatElem Type
pe Exp GPU
e
add :: BasicOp -> ReduceM (Stms GPU)
add = forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp
in case forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
dev_t of
Int
0 -> BasicOp -> ReduceM (Stms GPU)
add forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (VName -> SubExp
Var VName
dev)
Int
1 | Type
t forall a. Eq a => a -> a -> Bool
== forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit -> forall {f :: * -> *}. Applicative f => Exp GPU -> f (Stms GPU)
add' (VName -> Exp GPU
eIndex VName
dev)
Int
1 -> PatElem Type
pe PatElem Type -> (VName, Stms GPU) -> ReduceM (Stms GPU)
`migratedTo` (VName
dev, Stms GPU
stms)
Int
_ -> BasicOp -> ReduceM (Stms GPU)
add forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
dev (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
dev_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody :: RewriteM (Stm GPU) -> ReduceM (Stm GPU)
inGPUBody RewriteM (Stm GPU)
m = do
(Stm GPU
stm, RState
st) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT RewriteM (Stm GPU)
m RState
initialRState
let prologue :: Stms GPU
prologue = RState -> Stms GPU
rewritePrologue RState
st
let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
Pat Type
pat <- forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> ReduceM (PatElem Type)
arrayizePatElem [PatElem Type]
pes
let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
let types :: [Type]
types = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Type
patElemType [PatElem Type]
pes
let res :: Result
res = forall a b. (a -> b) -> [a] -> [b]
map (Certs -> SubExp -> SubExpRes
SubExpRes forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes
let body :: Body GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue forall a. Seq a -> a -> Seq a
|> Stm GPU
stm) Result
res
let e :: Exp GPU
e = forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types Body GPU
body)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux Exp GPU
e)
type RewriteM = StateT RState ReduceM
data RState = RState
{
RState -> IntMap VName
rewriteRenames :: IM.IntMap VName,
RState -> Stms GPU
rewritePrologue :: Stms GPU
}
initialRState :: RState
initialRState :: RState
initialRState =
RState
{ rewriteRenames :: IntMap VName
rewriteRenames = forall a. Monoid a => a
mempty,
rewritePrologue :: Stms GPU
rewritePrologue = forall a. Monoid a => a
mempty
}
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp :: SegBinOp GPU -> ReduceM (SegBinOp GPU)
addReadsToSegBinOp SegBinOp GPU
op = do
Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (forall {k} (rep :: k). SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPU
op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegBinOp GPU
op {segBinOpLambda :: Lambda GPU
segBinOpLambda = Lambda GPU
f'})
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp :: HistOp GPU -> ReduceM (HistOp GPU)
addReadsToHistOp HistOp GPU
op = do
Lambda GPU
f' <- Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda (forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp GPU
op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp GPU
op {histOp :: Lambda GPU
histOp = Lambda GPU
f'})
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda :: Lambda GPU -> ReduceM (Lambda GPU)
addReadsToLambda Lambda GPU
f = do
Body GPU
body' <- Body GPU -> ReduceM (Body GPU)
addReadsToBody (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU
f {lambdaBody :: Body GPU
lambdaBody = Body GPU
body'})
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody :: Body GPU -> ReduceM (Body GPU)
addReadsToBody Body GPU
body = do
(Body GPU
body', Stms GPU
prologue) <- forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper Body GPU
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body GPU
body' {bodyStms :: Stms GPU
bodyStms = Stms GPU
prologue forall a. Seq a -> Seq a -> Seq a
>< forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body'}
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody :: KernelBody GPU -> ReduceM (KernelBody GPU)
addReadsToKernelBody KernelBody GPU
kbody = do
(KernelBody GPU
kbody', Stms GPU
prologue) <- forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper KernelBody GPU
kbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody GPU
kbody' {kernelBodyStms :: Stms GPU
kernelBodyStms = Stms GPU
prologue forall a. Seq a -> Seq a -> Seq a
>< forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPU
kbody'}
addReadsHelper :: (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper :: forall a. (FreeIn a, Substitute a) => a -> ReduceM (a, Stms GPU)
addReadsHelper a
x = do
let from :: [VName]
from = Names -> [VName]
namesToList (forall a. FreeIn a => a -> Names
freeIn a
x)
([VName]
to, RState
st) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RewriteM VName
rename [VName]
from) RState
initialRState
let rename_map :: Map VName VName
rename_map = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
from [VName]
to)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
rename_map a
x, RState -> Stms GPU
rewritePrologue RState
st)
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> RewriteM VName
rewriteName VName
n = do
VName
n' <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName VName
n)
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
st -> RState
st {rewriteRenames :: IntMap VName
rewriteRenames = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) VName
n' (RState -> IntMap VName
rewriteRenames RState
st)}
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) = do
Stms GPU
stms' <- Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms
Result
res' <- Result -> RewriteM Result
renameResult Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res')
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo forall a. Monoid a => a
mempty
where
rewriteTo :: Stms GPU -> Stm GPU -> RewriteM (Stms GPU)
rewriteTo Stms GPU
out Stm GPU
stm = do
Stm GPU
stm' <- Stm GPU -> RewriteM (Stm GPU)
rewriteStm Stm GPU
stm
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm' of
Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
stms Result
res)) ->
let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm')
in forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd (Stms GPU
out forall a. Seq a -> Seq a -> Seq a
>< Stms GPU
stms) (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes Result
res)
Exp GPU
_ -> Stms GPU
out forall a. Seq a -> a -> Seq a
|> Stm GPU
stm'
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd :: Stms GPU -> (PatElem Type, SubExpRes) -> Stms GPU
bnd Stms GPU
out (PatElem Type
pe, SubExpRes Certs
cs SubExp
se)
| Just Type
t' <- forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 (forall t. Typed t => t -> Type
typeOf PatElem Type
pe) =
Stms GPU
out forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs forall a. Monoid a => a
mempty ()) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
t')
| Bool
otherwise =
Stms GPU
out forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs forall a. Monoid a => a
mempty ()) (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
aux Exp GPU
e) = do
Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
Pat Type
pat' <- Pat Type -> RewriteM (Pat Type)
rewritePat Pat (LetDec GPU)
pat
StmAux ()
aux' <- StmAux () -> RewriteM (StmAux ())
rewriteStmAux StmAux (ExpDec GPU)
aux
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux ()
aux' Exp GPU
e')
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat :: Pat Type -> RewriteM (Pat Type)
rewritePat Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> RewriteM (PatElem Type)
rewritePatElem (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) = do
VName
n' <- VName -> RewriteM VName
rewriteName VName
n
Type
t' <- forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType Type
t
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. VName -> dec -> PatElem dec
PatElem VName
n' Type
t')
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux :: StmAux () -> RewriteM (StmAux ())
rewriteStmAux (StmAux Certs
certs Attrs
attrs ()
_) = do
Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
certs' Attrs
attrs ())
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp =
forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM forall a b. (a -> b) -> a -> b
$
Mapper
{ mapOnSubExp :: SubExp -> StateT RState ReduceM SubExp
mapOnSubExp = SubExp -> StateT RState ReduceM SubExp
renameSubExp,
mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
mapOnVName :: VName -> RewriteM VName
mapOnVName = VName -> RewriteM VName
rename,
mapOnRetType :: RetType GPU -> StateT RState ReduceM (RetType GPU)
mapOnRetType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
mapOnBranchType :: BranchType GPU -> StateT RState ReduceM (BranchType GPU)
mapOnBranchType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType,
mapOnFParam :: FParam GPU -> StateT RState ReduceM (FParam GPU)
mapOnFParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnLParam :: LParam GPU -> StateT RState ReduceM (LParam GPU)
mapOnLParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnOp :: Op GPU -> StateT RState ReduceM (Op GPU)
mapOnOp = forall a b. a -> b -> a
const forall {a}. a
opError
}
where
opError :: a
opError = forall a. String -> a
compilerBugS String
"Cannot migrate a host-only operation to device."
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) = do
VName
n' <- VName -> RewriteM VName
rewriteName VName
n
TypeBase (ShapeBase SubExp) u
t' <- forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType TypeBase (ShapeBase SubExp) u
t
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n' TypeBase (ShapeBase SubExp) u
t')
rename :: VName -> RewriteM VName
rename :: VName -> RewriteM VName
rename VName
n = do
RState
st <- forall s (m :: * -> *). MonadState s m => m s
get
let renames :: IntMap VName
renames = RState -> IntMap VName
rewriteRenames RState
st
let idx :: Int
idx = VName -> Int
baseTag VName
n
case forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
idx IntMap VName
renames of
Just VName
n' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
Maybe VName
_ ->
do
let stms :: Stms GPU
stms = RState -> Stms GPU
rewritePrologue RState
st
(Stms GPU
stms', VName
n') <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Stms GPU -> VName -> ReduceM (Stms GPU, VName)
useScalar Stms GPU
stms VName
n
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RState
st' ->
RState
st'
{ rewriteRenames :: IntMap VName
rewriteRenames = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
idx VName
n' IntMap VName
renames,
rewritePrologue :: Stms GPU
rewritePrologue = Stms GPU
stms'
}
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
renameResult :: Result -> RewriteM Result
renameResult :: Result -> RewriteM Result
renameResult = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> RewriteM SubExpRes
renameSubExpRes
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes :: SubExpRes -> RewriteM SubExpRes
renameSubExpRes (SubExpRes Certs
certs SubExp
se) = do
Certs
certs' <- Certs -> RewriteM Certs
renameCerts Certs
certs
SubExp
se' <- SubExp -> StateT RState ReduceM SubExp
renameSubExp SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs -> SubExp -> SubExpRes
SubExpRes Certs
certs' SubExp
se')
renameCerts :: Certs -> RewriteM Certs
renameCerts :: Certs -> RewriteM Certs
renameCerts Certs
cs = [VName] -> Certs
Certs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RewriteM VName
rename (Certs -> [VName]
unCerts Certs
cs)
renameSubExp :: SubExp -> RewriteM SubExp
renameSubExp :: SubExp -> StateT RState ReduceM SubExp
renameSubExp (Var VName
n) = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RewriteM VName
rename VName
n
renameSubExp SubExp
se = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
renameType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
renameType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
renameType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT RState ReduceM SubExp
renameSubExp
renameExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
renameExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
renameExtType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT RState ReduceM SubExp
renameSubExp