module Futhark.Optimise.MergeGPUBodies (mergeGPUBodies) where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first)
import Data.Foldable
import Data.IntMap qualified as IM
import Data.IntSet ((\\))
import Data.IntSet qualified as IS
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Sequence ((|>))
import Data.Sequence qualified as SQ
import Futhark.Analysis.Alias
import Futhark.Construct (sliceDim)
import Futhark.Error
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.MonadFreshNames hiding (newName)
import Futhark.Pass
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies =
forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
String
"merge GPU bodies"
String
"Reorder and merge GPUBody constructs to reduce kernels executions."
forall a b. (a -> b) -> a -> b
$ forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases GPU) -> PassM (Stms GPU)
onStms forall {p}. p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
aliasAnalysis
where
onFunDef :: p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef p
_ (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType (Aliases GPU)]
types [FParam (Aliases GPU)]
params Body (Aliases GPU)
body) =
forall {k} (rep :: k).
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType (Aliases GPU)]
types [FParam (Aliases GPU)]
params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody forall a. Monoid a => a
mempty Body (Aliases GPU)
body
onStms :: Stms (Aliases GPU) -> PassM (Stms GPU)
onStms Stms (Aliases GPU)
stms =
forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms
type Dependencies = IS.IntSet
type Bindings = IS.IntSet
type Consumption = IS.IntSet
depsOf :: FreeIn a => a -> Dependencies
depsOf :: forall a. FreeIn a => a -> Dependencies
depsOf = Names -> Dependencies
namesToSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn
namesToSet :: Names -> IS.IntSet
namesToSet :: Names -> Dependencies
namesToSet = [Int] -> Dependencies
IS.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> Int
baseTag forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
transformLambda ::
AliasTable ->
Lambda (Aliases GPU) ->
PassM (Lambda GPU, Dependencies)
transformLambda :: AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases (Lambda [LParam (Aliases GPU)]
params Body (Aliases GPU)
body [Type]
types) = do
(Body GPU
body', Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam (Aliases GPU)]
params Body GPU
body' [Type]
types, Dependencies
deps)
transformBody ::
AliasTable ->
Body (Aliases GPU) ->
PassM (Body GPU, Dependencies)
transformBody :: AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (Body BodyDec (Aliases GPU)
_ Stms (Aliases GPU)
stms Result
res) = do
Group
grp <- forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm AliasTable
aliases Stms (Aliases GPU)
stms forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StateT State PassM Group
collapse) State
initialState
let stms' :: Stms GPU
stms' = Group -> Stms GPU
groupStms Group
grp
let deps :: Dependencies
deps = (Group -> Dependencies
groupDependencies Group
grp forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf Result
res) Dependencies -> Dependencies -> Dependencies
\\ Group -> Dependencies
groupBindings Group
grp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res, Dependencies
deps)
transformStms ::
AliasTable ->
Stms (Aliases GPU) ->
PassM (Stms GPU, Dependencies)
transformStms :: AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms AliasTable
aliases Stms (Aliases GPU)
stms = do
(Body BodyDec GPU
_ Stms GPU
stms' Result
_, Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms [])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Dependencies
deps)
reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm AliasTable
aliases (Let Pat (LetDec (Aliases GPU))
pat (StmAux Certs
cs Attrs
attrs ExpDec (Aliases GPU)
_) Exp (Aliases GPU)
e) = do
(Exp GPU
e', Dependencies
deps) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e)
let pat' :: Pat Type
pat' = forall a. Pat (VarAliases, a) -> Pat a
removePatAliases Pat (LetDec (Aliases GPU))
pat
let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ()) Exp GPU
e'
let pes' :: [PatElem Type]
pes' = forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'
let observed :: Dependencies
observed = Names -> Dependencies
namesToSet forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf (forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Exp rep -> [Names]
expAliases Exp (Aliases GPU)
e) AliasTable
aliases
let consumed :: Dependencies
consumed = Names -> Dependencies
namesToSet forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf (forall {k} (rep :: k). Aliased rep => Exp rep -> Names
consumedInExp Exp (Aliases GPU)
e) AliasTable
aliases
let usage :: Usage
usage =
Usage
{ usageBindings :: Dependencies
usageBindings = [Int] -> Dependencies
IS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> Int
baseTag forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes',
usageDependencies :: Dependencies
usageDependencies = Dependencies
observed forall a. Semigroup a => a -> a -> a
<> Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf Pat Type
pat' forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf Certs
cs
}
case Exp GPU
e' of
Op GPUBody {} ->
Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm' Usage
usage Dependencies
consumed
Exp GPU
_ ->
Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm' Usage
usage Dependencies
consumed
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {dec}.
AliasesOf dec =>
AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
aliases (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases GPU))
pat)
where
rootAliasesOf :: Names -> AliasTable -> Names
rootAliasesOf Names
names AliasTable
atable =
let look :: VName -> Names
look VName
n = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
n) VName
n AliasTable
atable
in forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Names
look (Names -> [VName]
namesToList Names
names)
recordAliases :: AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
atable PatElem dec
pe
| forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty =
AliasTable
atable
| Bool
otherwise =
let root_aliases :: Names
root_aliases = Names -> AliasTable -> Names
rootAliasesOf (forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe) AliasTable
atable
in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) Names
root_aliases AliasTable
atable
transformExp ::
AliasTable ->
Exp (Aliases GPU) ->
PassM (Exp GPU, Dependencies)
transformExp :: AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e =
case Exp (Aliases GPU)
e of
BasicOp {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
Apply {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
Match [SubExp]
ses [Case (Body (Aliases GPU))]
cases Body (Aliases GPU)
defbody MatchDec (BranchType (Aliases GPU))
dec -> do
let transformCase :: Case (Body (Aliases GPU)) -> PassM (Case (Body GPU), Dependencies)
transformCase (Case [Maybe PrimValue]
vs Body (Aliases GPU)
body) =
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
([Case (Body GPU)]
cases', [Dependencies]
cases_deps) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body (Aliases GPU)) -> PassM (Case (Body GPU), Dependencies)
transformCase [Case (Body (Aliases GPU))]
cases
(Body GPU
defbody', Dependencies
defbody_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
defbody
let deps :: Dependencies
deps = forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
ses forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Dependencies]
cases_deps forall a. Semigroup a => a -> a -> a
<> Dependencies
defbody_deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf MatchDec (BranchType (Aliases GPU))
dec
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType (Aliases GPU))
dec, Dependencies
deps)
DoLoop [(FParam (Aliases GPU), SubExp)]
merge LoopForm (Aliases GPU)
lform Body (Aliases GPU)
body -> do
(Body GPU
body', Dependencies
body_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
let ([Param DeclType]
params, [SubExp]
args) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam (Aliases GPU), SubExp)]
merge
let deps :: Dependencies
deps = Dependencies
body_deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [Param DeclType]
params forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
args forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf LoopForm (Aliases GPU)
lform
let scope :: Scope (Aliases GPU)
scope = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Aliases GPU)
lform forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
let bound :: Dependencies
bound = [Int] -> Dependencies
IS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> Int
baseTag (forall k a. Map k a -> [k]
M.keys Scope (Aliases GPU)
scope)
let deps' :: Dependencies
deps' = Dependencies
deps Dependencies -> Dependencies -> Dependencies
\\ Dependencies
bound
let dummy :: Exp (Aliases GPU)
dummy = forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam (Aliases GPU), SubExp)]
merge LoopForm (Aliases GPU)
lform (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body (forall {k} (rep :: k). Body rep -> BodyDec rep
bodyDec Body (Aliases GPU)
body) forall a. Seq a
SQ.empty [])
let DoLoop [(FParam GPU, SubExp)]
merge' LoopForm GPU
lform' Body GPU
_ = forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
dummy
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge' LoopForm GPU
lform' Body GPU
body', Dependencies
deps')
WithAcc [WithAccInput (Aliases GPU)]
inputs Lambda (Aliases GPU)
lambda -> do
[(WithAccInput GPU, Dependencies)]
accs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases) [WithAccInput (Aliases GPU)]
inputs
let ([WithAccInput GPU]
inputs', [Dependencies]
input_deps) = forall a b. [(a, b)] -> ([a], [b])
unzip [(WithAccInput GPU, Dependencies)]
accs
(Lambda GPU
lambda', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
lambda
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lambda', Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold [Dependencies]
input_deps)
Op {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
transformWithAccInput ::
AliasTable ->
WithAccInput (Aliases GPU) ->
PassM (WithAccInput GPU, Dependencies)
transformWithAccInput :: AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda (Aliases GPU), [SubExp])
op) = do
(Maybe (Lambda GPU, [SubExp])
op', Dependencies
deps) <- case Maybe (Lambda (Aliases GPU), [SubExp])
op of
Maybe (Lambda (Aliases GPU), [SubExp])
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Maybe a
Nothing, forall a. Monoid a => a
mempty)
Just (Lambda (Aliases GPU)
f, [SubExp]
nes) -> do
(Lambda GPU
f', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
f
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just (Lambda GPU
f', [SubExp]
nes), Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
nes)
let deps' :: Dependencies
deps' = Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [VName]
arrs
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
op'), Dependencies
deps')
type ReorderM = StateT State PassM
data State = State
{
State -> Groups
stateGroups :: Groups,
State -> EquivalenceTable
stateEquivalents :: EquivalenceTable
}
type EquivalenceTable = IM.IntMap Entry
data Entry = Entry
{
Entry -> SubExp
entryValue :: SubExp,
Entry -> Type
entryType :: Type,
Entry -> VName
entryResult :: VName,
Entry -> Int
entryGroupIdx :: Int,
Entry -> Bool
entryStored :: Bool
}
type Groups = SQ.Seq Group
data Group = Group
{
Group -> Stms GPU
groupStms :: Stms GPU,
Group -> Usage
groupUsage :: Usage
}
data Usage = Usage
{
Usage -> Dependencies
usageBindings :: Bindings,
Usage -> Dependencies
usageDependencies :: Dependencies
}
instance Semigroup Group where
(Group Stms GPU
s1 Usage
u1) <> :: Group -> Group -> Group
<> (Group Stms GPU
s2 Usage
u2) = Stms GPU -> Usage -> Group
Group (Stms GPU
s1 forall a. Semigroup a => a -> a -> a
<> Stms GPU
s2) (Usage
u1 forall a. Semigroup a => a -> a -> a
<> Usage
u2)
instance Monoid Group where
mempty :: Group
mempty = Group {groupStms :: Stms GPU
groupStms = forall a. Monoid a => a
mempty, groupUsage :: Usage
groupUsage = forall a. Monoid a => a
mempty}
instance Semigroup Usage where
(Usage Dependencies
b1 Dependencies
d1) <> :: Usage -> Usage -> Usage
<> (Usage Dependencies
b2 Dependencies
d2) = Dependencies -> Dependencies -> Usage
Usage (Dependencies
b1 forall a. Semigroup a => a -> a -> a
<> Dependencies
b2) (Dependencies
d1 forall a. Semigroup a => a -> a -> a
<> Dependencies
d2)
instance Monoid Usage where
mempty :: Usage
mempty = Usage {usageBindings :: Dependencies
usageBindings = forall a. Monoid a => a
mempty, usageDependencies :: Dependencies
usageDependencies = forall a. Monoid a => a
mempty}
groupBindings :: Group -> Bindings
groupBindings :: Group -> Dependencies
groupBindings = Usage -> Dependencies
usageBindings forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage
groupDependencies :: Group -> Dependencies
groupDependencies :: Group -> Dependencies
groupDependencies = Usage -> Dependencies
usageDependencies forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage
initialState :: State
initialState :: State
initialState =
State
{ stateGroups :: Groups
stateGroups = forall a. a -> Seq a
SQ.singleton forall a. Monoid a => a
mempty,
stateEquivalents :: EquivalenceTable
stateEquivalents = forall a. Monoid a => a
mempty
}
modifyGroups :: (Groups -> Groups) -> ReorderM ()
modifyGroups :: (Groups -> Groups) -> StateT State PassM ()
modifyGroups Groups -> Groups
f =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = Groups -> Groups
f (State -> Groups
stateGroups State
st)}
removeEquivalents :: IS.IntSet -> ReorderM ()
removeEquivalents :: Dependencies -> StateT State PassM ()
removeEquivalents Dependencies
keys =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let eqs' :: EquivalenceTable
eqs' = State -> EquivalenceTable
stateEquivalents State
st forall a. IntMap a -> Dependencies -> IntMap a
`IM.withoutKeys` Dependencies
keys
in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}
recordEquivalent :: VName -> Entry -> ReorderM ()
recordEquivalent :: VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n Entry
entry =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let eqs :: EquivalenceTable
eqs = State -> EquivalenceTable
stateEquivalents State
st
eqs' :: EquivalenceTable
eqs' = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) Entry
entry EquivalenceTable
eqs
in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}
moveGPUBody :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveGPUBody :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm Usage
usage Dependencies
consumed = do
EquivalenceTable
eqs <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
let g :: Int -> Int
g Int
i = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
i (VName -> Int
baseTag forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry -> VName
entryResult) (forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i EquivalenceTable
eqs)
let deps' :: Dependencies
deps' = (Int -> Int) -> Dependencies -> Dependencies
IS.map Int -> Int
g (Usage -> Dependencies
usageDependencies Usage
usage)
let usage' :: Usage
usage' = Usage
usage {usageDependencies :: Dependencies
usageDependencies = Dependencies
deps'}
Groups
grps <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage' Dependencies
consumed
let idx :: Int
idx = forall a. a -> Maybe a -> a
fromMaybe Int
1 (forall a. (a -> Bool) -> Seq a -> Maybe Int
SQ.findIndexR Group -> Bool
f Groups
grps)
let idx' :: Int
idx' = case Int
idx forall a. Integral a => a -> a -> a
`mod` Int
2 of
Int
0 -> Int
idx forall a. Num a => a -> a -> a
+ Int
1
Int
_ | Int -> Groups -> Bool
consumes Int
idx Groups
grps -> Int
idx forall a. Num a => a -> a -> a
+ Int
2
Int
_ -> Int
idx
(Groups -> Groups) -> StateT State PassM ()
modifyGroups forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Int
idx'
let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
let Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
_ Result
res)) = forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Int
idx') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res))
where
consumes :: Int -> Groups -> Bool
consumes Int
idx Groups
grps
| Just Group
grp <- forall a. Int -> Seq a -> Maybe a
SQ.lookup Int
idx Groups
grps =
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Dependencies -> Dependencies -> Bool
IS.disjoint (Group -> Dependencies
groupBindings Group
grp) Dependencies
consumed
| Bool
otherwise =
Bool
False
stores :: Int -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Int
idx (PatElem VName
n Type
t, SubExp
se)
| Just Type
row_t <- forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 Type
t =
VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Int -> Bool -> Entry
Entry SubExp
se Type
row_t VName
n Int
idx Bool
True
| Bool
otherwise =
VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Int -> Bool -> Entry
Entry SubExp
se Type
t VName
n Int
idx Bool
False
moveOther :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveOther :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm Usage
usage Dependencies
consumed = do
Groups
grps <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed
let idx :: Int
idx = forall a. a -> Maybe a -> a
fromMaybe Int
0 (forall a. (a -> Bool) -> Seq a -> Maybe Int
SQ.findIndexR Group -> Bool
f Groups
grps)
let idx' :: Int
idx' = ((Int
idx forall a. Num a => a -> a -> a
+ Int
1) forall a. Integral a => a -> a -> a
`div` Int
2) forall a. Num a => a -> a -> a
* Int
2
(Groups -> Groups) -> StateT State PassM ()
modifyGroups forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Int
idx'
Stm GPU -> Int -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Int
idx'
recordEquivalentsOf :: Stm GPU -> Int -> ReorderM ()
recordEquivalentsOf :: Stm GPU -> Int -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Int
idx = do
EquivalenceTable
eqs <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
case Stm GPU
stm of
Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (SubExp (Var VName
n)))
| Just Entry
entry <- forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) EquivalenceTable
eqs,
Entry -> Int
entryGroupIdx Entry
entry forall a. Eq a => a -> a -> Bool
== Int
idx forall a. Num a => a -> a -> a
- Int
1 ->
VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x Entry
entry
Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
arr Slice SubExp
slice))
| Just Entry
entry <- forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
arr) EquivalenceTable
eqs,
Entry -> Int
entryGroupIdx Entry
entry forall a. Eq a => a -> a -> Bool
== Int
idx forall a. Num a => a -> a -> a
- Int
1,
Slice (DimFix SubExp
i : [DimIndex SubExp]
dims) <- Slice SubExp
slice,
SubExp
i forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0,
[DimIndex SubExp]
dims forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ Entry -> Type
entryType Entry
entry) ->
VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x (Entry
entry {entryStored :: Bool
entryStored = Bool
False})
Stm GPU
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
groupBlocks :: Usage -> Consumption -> Group -> Bool
groupBlocks :: Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed Group
grp =
let bound :: Dependencies
bound = Group -> Dependencies
groupBindings Group
grp
deps :: Dependencies
deps = Group -> Dependencies
groupDependencies Group
grp
used :: Dependencies
used = Usage -> Dependencies
usageDependencies Usage
usage
in Bool -> Bool
not (Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
bound Dependencies
used Bool -> Bool -> Bool
&& Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
deps Dependencies
consumed)
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Int
idx Groups
grps
| Int
idx forall a. Ord a => a -> a -> Bool
>= forall a. Seq a -> Int
SQ.length Groups
grps =
(Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Int
idx (Groups
grps forall a. Seq a -> a -> Seq a
|> forall a. Monoid a => a
mempty)
| Bool
otherwise =
forall a. (a -> a) -> Int -> Seq a -> Seq a
SQ.adjust' ((Stm GPU, Usage)
stm `moveTo`) Int
idx Groups
grps
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo (Stm GPU
stm, Usage
usage) Group
grp =
Group
grp
{ groupStms :: Stms GPU
groupStms = Group -> Stms GPU
groupStms Group
grp forall a. Seq a -> a -> Seq a
|> Stm GPU
stm,
groupUsage :: Usage
groupUsage = Group -> Usage
groupUsage Group
grp forall a. Semigroup a => a -> a -> a
<> Usage
usage
}
type RewriteM = StateT (Stms GPU) ReorderM
collapse :: ReorderM Group
collapse :: StateT State PassM Group
collapse = do
[(Bool, Group)]
grps <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [a]
cycle [Bool
False, Bool
True]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
Group
grp <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Group -> (Bool, Group) -> StateT State PassM Group
clps forall a. Monoid a => a
mempty [(Bool, Group)]
grps
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = forall a. a -> Seq a
SQ.singleton Group
grp}
forall (f :: * -> *) a. Applicative f => a -> f a
pure Group
grp
where
clps :: Group -> (Bool, Group) -> StateT State PassM Group
clps Group
grp0 (Bool
gpu_bodies, Group Stms GPU
stms Usage
usage) = do
Group
grp1 <-
if Bool
gpu_bodies
then Stms GPU -> Usage -> Group
Group forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> ReorderM (Stms GPU)
mergeKernels Stms GPU
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Usage
usage
else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> Usage -> Group
Group Stms GPU
stms Usage
usage)
Dependencies -> StateT State PassM ()
removeEquivalents (Group -> Dependencies
groupBindings Group
grp1)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Group
grp0 forall a. Semigroup a => a -> a -> a
<> Group
grp1)
mergeKernels :: Stms GPU -> ReorderM (Stms GPU)
mergeKernels :: Stms GPU -> ReorderM (Stms GPU)
mergeKernels Stms GPU
stms
| forall a. Seq a -> Int
SQ.length Stms GPU
stms forall a. Ord a => a -> a -> Bool
< Int
2 =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
| Bool
otherwise =
forall a. a -> Seq a
SQ.singleton forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> b -> m b) -> b -> t a -> m b
foldrM Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
merge Stm GPU
empty Stms GPU
stms
where
empty :: Stm GPU
empty = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let forall a. Monoid a => a
mempty (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()) Exp GPU
noop
noop :: Exp GPU
noop = forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [] (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall a. Seq a
SQ.empty []))
merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
merge Stm GPU
stm0 Stm GPU
stm1
| Let Pat (LetDec GPU)
pat0 (StmAux Certs
cs0 Attrs
attrs0 ExpDec GPU
_) (Op (GPUBody [Type]
types0 Body GPU
body)) <- Stm GPU
stm0,
Let Pat (LetDec GPU)
pat1 (StmAux Certs
cs1 Attrs
attrs1 ExpDec GPU
_) (Op (GPUBody [Type]
types1 Body GPU
body1)) <- Stm GPU
stm1 =
do
Body BodyDec GPU
_ Stms GPU
stms0 Result
res0 <- RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite (Body GPU -> RewriteM (Body GPU)
rewriteBody Body GPU
body)
let Body BodyDec GPU
_ Stms GPU
stms1 Result
res1 = Body GPU
body1
pat' :: Pat Type
pat' = Pat (LetDec GPU)
pat0 forall a. Semigroup a => a -> a -> a
<> Pat (LetDec GPU)
pat1
aux' :: StmAux ()
aux' = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs0 forall a. Semigroup a => a -> a -> a
<> Certs
cs1) (Attrs
attrs0 forall a. Semigroup a => a -> a -> a
<> Attrs
attrs1) ()
types' :: [Type]
types' = [Type]
types0 forall a. [a] -> [a] -> [a]
++ [Type]
types1
body' :: Body GPU
body' = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
stms0 forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms1) (Result
res0 forall a. Semigroup a => a -> a -> a
<> Result
res1)
in forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux ()
aux' (forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types' Body GPU
body')))
merge Stm GPU
_ Stm GPU
_ =
forall a. String -> a
compilerBugS String
"mergeGPUBodies: cannot merge non-GPUBody statements"
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite RewriteM (Body GPU)
m = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT RewriteM (Body GPU)
m' forall a. Seq a
SQ.empty
where
m' :: RewriteM (Body GPU)
m' = do
Body BodyDec GPU
_ Stms GPU
stms Result
res <- RewriteM (Body GPU)
m
Stms GPU
prologue <- forall (m :: * -> *) s. Monad m => StateT s m s
get
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms) Result
res)
equivalents :: RewriteM EquivalenceTable
equivalents :: RewriteM EquivalenceTable
equivalents = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) =
forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> RewriteM Result
rewriteResult Result
res
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> RewriteM (Stm GPU)
rewriteStm
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let (Pat [PatElem (LetDec GPU)]
pes) (StmAux Certs
cs Attrs
attrs ExpDec GPU
_) Exp GPU
e) = do
Pat Type
pat' <- forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> RewriteM (PatElem Type)
rewritePatElem [PatElem (LetDec GPU)]
pes
Certs
cs' <- Certs -> RewriteM Certs
rewriteCerts Certs
cs
Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs' Attrs
attrs ()) Exp GPU
e'
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) =
forall dec. VName -> dec -> PatElem dec
PatElem VName
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType Type
t
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e = do
EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
case Exp GPU
e of
BasicOp (Index VName
arr Slice SubExp
slice)
| Just Entry
entry <- forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
arr) EquivalenceTable
eqs,
DimFix SubExp
idx : [DimIndex SubExp]
dims <- forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
SubExp
idx forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 ->
let se :: SubExp
se = Entry -> SubExp
entryValue Entry
entry
in forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ case ([DimIndex SubExp]
dims, SubExp
se) of
([], SubExp
_) -> SubExp -> BasicOp
SubExp SubExp
se
([DimIndex SubExp]
_, Var VName
src) -> VName -> Slice SubExp -> BasicOp
Index VName
src (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
dims)
([DimIndex SubExp], SubExp)
_ -> forall a. String -> a
compilerBugS String
"rewriteExp: bad equivalence entry"
Exp GPU
_ -> forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter Exp GPU
e
where
rewriter :: Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter =
Mapper
{ mapOnSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
mapOnSubExp = SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp,
mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
mapOnVName :: VName -> StateT (Stms GPU) ReorderM VName
mapOnVName = VName -> StateT (Stms GPU) ReorderM VName
rewriteName,
mapOnRetType :: RetType GPU -> StateT (Stms GPU) ReorderM (RetType GPU)
mapOnRetType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType,
mapOnBranchType :: BranchType GPU -> StateT (Stms GPU) ReorderM (BranchType GPU)
mapOnBranchType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType,
mapOnFParam :: FParam GPU -> StateT (Stms GPU) ReorderM (FParam GPU)
mapOnFParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnLParam :: LParam GPU -> StateT (Stms GPU) ReorderM (LParam GPU)
mapOnLParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
mapOnOp :: Op GPU -> StateT (Stms GPU) ReorderM (Op GPU)
mapOnOp = forall a b. a -> b -> a
const forall {a}. a
opError
}
opError :: a
opError = forall a. String -> a
compilerBugS String
"rewriteExp: unhandled HostOp in GPUBody"
rewriteResult :: Result -> RewriteM Result
rewriteResult :: Result -> RewriteM Result
rewriteResult = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes
rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes (SubExpRes Certs
cs SubExp
se) =
Certs -> SubExp -> SubExpRes
SubExpRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> RewriteM Certs
rewriteCerts Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp SubExp
se
rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts (Certs [VName]
cs) =
[VName] -> Certs
Certs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT (Stms GPU) ReorderM VName
rewriteName [VName]
cs
rewriteType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
rewriteType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp
rewriteExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
rewriteExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) =
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType TypeBase (ShapeBase SubExp) u
t
rewriteSubExp :: SubExp -> RewriteM SubExp
rewriteSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (Constant PrimValue
c) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimValue -> SubExp
Constant PrimValue
c)
rewriteSubExp (Var VName
n) = do
EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
case forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) EquivalenceTable
eqs of
Maybe Entry
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
n)
Just (Entry SubExp
se Type
_ VName
_ Int
_ Bool
False) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
Just (Entry SubExp
se Type
t VName
_ Int
_ Bool
True) -> VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
t
rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT (Stms GPU) ReorderM VName
rewriteName VName
n = do
SubExp
se <- SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (VName -> SubExp
Var VName
n)
case SubExp
se of
Var VName
n' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
Constant PrimValue
c -> PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c
asArray :: SubExp -> Type -> RewriteM VName
asArray :: SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
row_t = do
VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"arr"
let t :: Type
t = Type
row_t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
let pat :: Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
name Type
t]
let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
let e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp ([SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
row_t)
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall {k} {rep :: k}. Exp rep
e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name
referConst :: PrimValue -> RewriteM VName
referConst :: PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c = do
VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"cnst"
let t :: TypeBase shape u
t = forall shape u. PrimType -> TypeBase shape u
Prim (PrimValue -> PrimType
primValueType PrimValue
c)
let pat :: Pat (TypeBase shape u)
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
name forall {shape} {u}. TypeBase shape u
t]
let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
let e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
c)
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let forall {shape} {u}. Pat (TypeBase shape u)
pat StmAux ()
aux forall {k} {rep :: k}. Exp rep
e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name
newName :: String -> RewriteM VName
newName :: String -> StateT (Stms GPU) ReorderM VName
newName String
s = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
s)