-- |
-- This module implements an optimization pass that merges 'GPUBody' kernels to
-- eliminate memory transactions and reduce the number of kernel launches.
-- This is useful because the "Futhark.Optimise.ReduceDeviceSyncs" pass introduces
-- 'GPUBody' kernels that only execute single statements.
--
-- To merge as many 'GPUBody' kernels as possible, this pass reorders statements
-- with the goal of bringing as many 'GPUBody' statements next to each other in
-- a sequence. Such sequence can then trivially be merged.
module Futhark.Optimise.MergeGPUBodies (mergeGPUBodies) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Foldable
import qualified Data.IntMap as IM
import Data.IntSet ((\\))
import qualified Data.IntSet as IS
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Data.Sequence ((|>))
import qualified Data.Sequence as SQ
import Futhark.Analysis.Alias
import Futhark.Construct (sliceDim)
import Futhark.Error
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.MonadFreshNames hiding (newName)
import Futhark.Pass

-- | An optimization pass that reorders and merges 'GPUBody' statements to
-- eliminate memory transactions and reduce the number of kernel launches.
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"merge GPU bodies"
    String
"Reorder and merge GPUBody constructs to reduce kernels executions."
    ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases GPU) -> PassM (Stms GPU))
-> (Stms GPU -> FunDef (Aliases GPU) -> PassM (FunDef GPU))
-> Prog (Aliases GPU)
-> PassM (Prog GPU)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases GPU) -> PassM (Stms GPU)
onStms Stms GPU -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
forall p. p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef (Prog (Aliases GPU) -> PassM (Prog GPU))
-> (Prog GPU -> Prog (Aliases GPU)) -> Prog GPU -> PassM (Prog GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog GPU -> Prog (Aliases GPU)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
aliasAnalysis
  where
    onFunDef :: p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef p
_ (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType (Aliases GPU)]
types [FParam (Aliases GPU)]
params Body (Aliases GPU)
body) =
      Maybe EntryPoint
-> Attrs
-> Name
-> [RetType GPU]
-> [FParam GPU]
-> Body GPU
-> FunDef GPU
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType (Aliases GPU)]
[RetType GPU]
types [FParam (Aliases GPU)]
[FParam GPU]
params (Body GPU -> FunDef GPU)
-> ((Body GPU, Dependencies) -> Body GPU)
-> (Body GPU, Dependencies)
-> FunDef GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Body GPU, Dependencies) -> Body GPU
forall a b. (a, b) -> a
fst ((Body GPU, Dependencies) -> FunDef GPU)
-> PassM (Body GPU, Dependencies) -> PassM (FunDef GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
forall a. Monoid a => a
mempty Body (Aliases GPU)
body
    onStms :: Stms (Aliases GPU) -> PassM (Stms GPU)
onStms Stms (Aliases GPU)
stms =
      (Stms GPU, Dependencies) -> Stms GPU
forall a b. (a, b) -> a
fst ((Stms GPU, Dependencies) -> Stms GPU)
-> PassM (Stms GPU, Dependencies) -> PassM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms AliasTable
forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms

--------------------------------------------------------------------------------
--                               COMMON - TYPES                               --
--------------------------------------------------------------------------------

-- | A set of 'VName' tags that denote all variables that some group of
-- statements depend upon. Those must be computed before the group statements.
type Dependencies = IS.IntSet

-- | A set of 'VName' tags that denote all variables that some group of
-- statements binds.
type Bindings = IS.IntSet

-- | A set of 'VName' tags that denote the root aliases of all arrays that some
-- statement consumes.
type Consumption = IS.IntSet

--------------------------------------------------------------------------------
--                              COMMON - HELPERS                              --
--------------------------------------------------------------------------------

-- | All free variables of a construct as 'Dependencies'.
depsOf :: FreeIn a => a -> Dependencies
depsOf :: a -> Dependencies
depsOf = Names -> Dependencies
namesToSet (Names -> Dependencies) -> (a -> Names) -> a -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Names
forall a. FreeIn a => a -> Names
freeIn

-- | Convert 'Names' to an integer set of name tags.
namesToSet :: Names -> IS.IntSet
namesToSet :: Names -> Dependencies
namesToSet = [Key] -> Dependencies
IS.fromList ([Key] -> Dependencies)
-> (Names -> [Key]) -> Names -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Key) -> [VName] -> [Key]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Key
baseTag ([VName] -> [Key]) -> (Names -> [VName]) -> Names -> [Key]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

--------------------------------------------------------------------------------
--                            AD HOC OPTIMIZATION                             --
--------------------------------------------------------------------------------

-- | Optimize a lambda and determine its dependencies.
transformLambda ::
  AliasTable ->
  Lambda (Aliases GPU) ->
  PassM (Lambda GPU, Dependencies)
transformLambda :: AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases (Lambda [LParam (Aliases GPU)]
params Body (Aliases GPU)
body [Type]
types) = do
  (Body GPU
body', Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
  (Lambda GPU, Dependencies) -> PassM (Lambda GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam (Aliases GPU)]
[LParam GPU]
params Body GPU
body' [Type]
types, Dependencies
deps)

-- | Optimize a body and determine its dependencies.
transformBody ::
  AliasTable ->
  Body (Aliases GPU) ->
  PassM (Body GPU, Dependencies)
transformBody :: AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (Body BodyDec (Aliases GPU)
_ Stms (Aliases GPU)
stms Result
res) = do
  Group
grp <- StateT State PassM Group -> State -> PassM Group
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((AliasTable -> Stm (Aliases GPU) -> StateT State PassM AliasTable)
-> AliasTable -> Stms (Aliases GPU) -> StateT State PassM ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ AliasTable -> Stm (Aliases GPU) -> StateT State PassM AliasTable
reorderStm AliasTable
aliases Stms (Aliases GPU)
stms StateT State PassM ()
-> StateT State PassM Group -> StateT State PassM Group
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StateT State PassM Group
collapse) State
initialState

  let stms' :: Stms GPU
stms' = Group -> Stms GPU
groupStms Group
grp
  let deps :: Dependencies
deps = (Group -> Dependencies
groupDependencies Group
grp Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Result -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Result
res) Dependencies -> Dependencies -> Dependencies
\\ Group -> Dependencies
groupBindings Group
grp

  (Body GPU, Dependencies) -> PassM (Body GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res, Dependencies
deps)

-- | Optimize a sequence of statements and determine their dependencies.
transformStms ::
  AliasTable ->
  Stms (Aliases GPU) ->
  PassM (Stms GPU, Dependencies)
transformStms :: AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms AliasTable
aliases Stms (Aliases GPU)
stms = do
  (Body BodyDec GPU
_ Stms GPU
stms' Result
_, Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (BodyDec (Aliases GPU)
-> Stms (Aliases GPU) -> Result -> Body (Aliases GPU)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec (Aliases GPU)
forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms [])
  (Stms GPU, Dependencies) -> PassM (Stms GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Dependencies
deps)

-- | Optimizes and reorders a single statement within a sequence while tracking
-- the declaration, observation, and consumption of its dependencies.
-- This creates sequences of GPUBody statements that can be merged into single
-- kernels.
reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm :: AliasTable -> Stm (Aliases GPU) -> StateT State PassM AliasTable
reorderStm AliasTable
aliases (Let Pat (LetDec (Aliases GPU))
pat (StmAux Certs
cs Attrs
attrs ExpDec (Aliases GPU)
_) Exp (Aliases GPU)
e) = do
  (Exp GPU
e', Dependencies
deps) <- PassM (Exp GPU, Dependencies)
-> StateT State PassM (Exp GPU, Dependencies)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e)
  let pat' :: Pat Type
pat' = Pat (AliasDec, Type) -> Pat Type
forall a. Pat (AliasDec, a) -> Pat a
removePatAliases Pat (AliasDec, Type)
Pat (LetDec (Aliases GPU))
pat
  let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ()) Exp GPU
e'
  let pes' :: [PatElem Type]
pes' = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'

  -- Array aliases can be seen as a directed graph where vertices are arrays
  -- (or the names that bind them) and an edge x -> y denotes that x aliases y.
  -- The root aliases of some array A is then the set of arrays that can be
  -- reached from A in graph and which have no edges themselves.
  --
  -- All arrays that share a root alias are considered aliases of each other
  -- and will be consumed if either of them is consumed.
  -- When reordering statements we must ensure that no statement that consumes
  -- an array is moved before any statement that observes one of its aliases.
  --
  -- That is to move statement X before statement Y the set of root aliases of
  -- arrays consumed by X must not overlap with the root aliases of arrays
  -- observed by Y.
  --
  -- We consider the root aliases of Y's observed arrays as part of Y's
  -- dependencies and simply say that the root aliases of arrays consumed by X
  -- must not overlap those.
  --
  -- To move X before Y then the dependencies of X must also not overlap with
  -- the variables bound by Y.

  let observed :: Dependencies
observed = Names -> Dependencies
namesToSet (Names -> Dependencies) -> Names -> Dependencies
forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf ([Names] -> Names
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ Exp (Aliases GPU) -> [Names]
forall rep. Aliased rep => Exp rep -> [Names]
expAliases Exp (Aliases GPU)
e) AliasTable
aliases
  let consumed :: Dependencies
consumed = Names -> Dependencies
namesToSet (Names -> Dependencies) -> Names -> Dependencies
forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf (Exp (Aliases GPU) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp Exp (Aliases GPU)
e) AliasTable
aliases
  let usage :: Usage
usage =
        Usage :: Dependencies -> Dependencies -> Usage
Usage
          { usageBindings :: Dependencies
usageBindings = [Key] -> Dependencies
IS.fromList ([Key] -> Dependencies) -> [Key] -> Dependencies
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> Key) -> [PatElem Type] -> [Key]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Key
baseTag (VName -> Key) -> (PatElem Type -> VName) -> PatElem Type -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes',
            usageDependencies :: Dependencies
usageDependencies = Dependencies
observed Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Pat Type -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Pat Type
pat' Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Certs -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Certs
cs
          }

  case Exp GPU
e' of
    Op GPUBody {} ->
      Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm' Usage
usage Dependencies
consumed
    Exp GPU
_ ->
      Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm' Usage
usage Dependencies
consumed

  AliasTable -> StateT State PassM AliasTable
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AliasTable -> StateT State PassM AliasTable)
-> AliasTable -> StateT State PassM AliasTable
forall a b. (a -> b) -> a -> b
$ (AliasTable -> PatElem (AliasDec, Type) -> AliasTable)
-> AliasTable -> [PatElem (AliasDec, Type)] -> AliasTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl AliasTable -> PatElem (AliasDec, Type) -> AliasTable
forall dec.
AliasesOf dec =>
AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
aliases (Pat (AliasDec, Type) -> [PatElem (AliasDec, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (AliasDec, Type)
Pat (LetDec (Aliases GPU))
pat)
  where
    rootAliasesOf :: Names -> AliasTable -> Names
rootAliasesOf Names
names AliasTable
atable =
      let look :: VName -> Names
look VName
n = Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
n) VName
n AliasTable
atable
       in (VName -> Names) -> [VName] -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Names
look (Names -> [VName]
namesToList Names
names)

    recordAliases :: AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
atable PatElem dec
pe
      | PatElem dec -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty =
          AliasTable
atable
      | Bool
otherwise =
          let root_aliases :: Names
root_aliases = Names -> AliasTable -> Names
rootAliasesOf (PatElem dec -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe) AliasTable
atable
           in VName -> Names -> AliasTable -> AliasTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) Names
root_aliases AliasTable
atable

-- | Optimize a single expression and determine its dependencies.
transformExp ::
  AliasTable ->
  Exp (Aliases GPU) ->
  PassM (Exp GPU, Dependencies)
transformExp :: AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e =
  case Exp (Aliases GPU)
e of
    BasicOp {} -> (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Aliases GPU) -> Exp GPU
forall rep. CanBeAliased (Op rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, Exp (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
    Apply {} -> (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Aliases GPU) -> Exp GPU
forall rep. CanBeAliased (Op rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, Exp (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
    If SubExp
c Body (Aliases GPU)
tbody Body (Aliases GPU)
fbody IfDec (BranchType (Aliases GPU))
dec -> do
      (Body GPU
tbody', Dependencies
t_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
tbody
      (Body GPU
fbody', Dependencies
f_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
fbody
      let deps :: Dependencies
deps = SubExp -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf SubExp
c Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
t_deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
f_deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> IfDec ExtType -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf IfDec ExtType
IfDec (BranchType (Aliases GPU))
dec
      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Body GPU -> Body GPU -> IfDec (BranchType GPU) -> Exp GPU
forall rep.
SubExp -> Body rep -> Body rep -> IfDec (BranchType rep) -> Exp rep
If SubExp
c Body GPU
tbody' Body GPU
fbody' IfDec (BranchType (Aliases GPU))
IfDec (BranchType GPU)
dec, Dependencies
deps)
    DoLoop [(FParam (Aliases GPU), SubExp)]
merge LoopForm (Aliases GPU)
lform Body (Aliases GPU)
body -> do
      -- What merge and lform aliases outside the loop is irrelevant as those
      -- cannot be consumed within the loop.
      (Body GPU
body', Dependencies
body_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
      let ([Param DeclType]
params, [SubExp]
args) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam (Aliases GPU), SubExp)]
merge
      let deps :: Dependencies
deps = Dependencies
body_deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [Param DeclType]
params Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
args Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> LoopForm (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf LoopForm (Aliases GPU)
lform

      let scope :: Scope (Aliases GPU)
scope = LoopForm (Aliases GPU) -> Scope (Aliases GPU)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Aliases GPU)
lform Scope (Aliases GPU) -> Scope (Aliases GPU) -> Scope (Aliases GPU)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Aliases GPU)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
      let bound :: Dependencies
bound = [Key] -> Dependencies
IS.fromList ([Key] -> Dependencies) -> [Key] -> Dependencies
forall a b. (a -> b) -> a -> b
$ (VName -> Key) -> [VName] -> [Key]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Key
baseTag (Scope (Aliases GPU) -> [VName]
forall k a. Map k a -> [k]
M.keys Scope (Aliases GPU)
scope)
      let deps' :: Dependencies
deps' = Dependencies
deps Dependencies -> Dependencies -> Dependencies
\\ Dependencies
bound

      let dummy :: Exp (Aliases GPU)
dummy = [(FParam (Aliases GPU), SubExp)]
-> LoopForm (Aliases GPU)
-> Body (Aliases GPU)
-> Exp (Aliases GPU)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam (Aliases GPU), SubExp)]
merge LoopForm (Aliases GPU)
lform (BodyDec (Aliases GPU)
-> Stms (Aliases GPU) -> Result -> Body (Aliases GPU)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body (Body (Aliases GPU) -> BodyDec (Aliases GPU)
forall rep. Body rep -> BodyDec rep
bodyDec Body (Aliases GPU)
body) Stms (Aliases GPU)
forall a. Seq a
SQ.empty [])
      let DoLoop [(FParam GPU, SubExp)]
merge' LoopForm GPU
lform' Body GPU
_ = Exp (Aliases GPU) -> Exp GPU
forall rep. CanBeAliased (Op rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
dummy

      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
merge' LoopForm GPU
lform' Body GPU
body', Dependencies
deps')
    WithAcc [WithAccInput (Aliases GPU)]
inputs Lambda (Aliases GPU)
lambda -> do
      [(WithAccInput GPU, Dependencies)]
accs <- (WithAccInput (Aliases GPU)
 -> PassM (WithAccInput GPU, Dependencies))
-> [WithAccInput (Aliases GPU)]
-> PassM [(WithAccInput GPU, Dependencies)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases) [WithAccInput (Aliases GPU)]
inputs
      let ([WithAccInput GPU]
inputs', [Dependencies]
input_deps) = [(WithAccInput GPU, Dependencies)]
-> ([WithAccInput GPU], [Dependencies])
forall a b. [(a, b)] -> ([a], [b])
unzip [(WithAccInput GPU, Dependencies)]
accs
      -- The lambda parameters are all unique and thus have no aliases.
      (Lambda GPU
lambda', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
lambda
      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lambda', Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [Dependencies] -> Dependencies
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold [Dependencies]
input_deps)
    Op {} ->
      -- A GPUBody cannot be nested within other HostOp constructs.
      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Aliases GPU) -> Exp GPU
forall rep. CanBeAliased (Op rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, Exp (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)

-- | Optimize a single WithAcc input and determine its dependencies.
transformWithAccInput ::
  AliasTable ->
  WithAccInput (Aliases GPU) ->
  PassM (WithAccInput GPU, Dependencies)
transformWithAccInput :: AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases (Shape
shape, [VName]
arrs, Maybe (Lambda (Aliases GPU), [SubExp])
op) = do
  (Maybe (Lambda GPU, [SubExp])
op', Dependencies
deps) <- case Maybe (Lambda (Aliases GPU), [SubExp])
op of
    Maybe (Lambda (Aliases GPU), [SubExp])
Nothing -> (Maybe (Lambda GPU, [SubExp]), Dependencies)
-> PassM (Maybe (Lambda GPU, [SubExp]), Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing, Dependencies
forall a. Monoid a => a
mempty)
    Just (Lambda (Aliases GPU)
f, [SubExp]
nes) -> do
      -- The lambda parameters have no aliases.
      (Lambda GPU
f', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
f
      (Maybe (Lambda GPU, [SubExp]), Dependencies)
-> PassM (Maybe (Lambda GPU, [SubExp]), Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
f', [SubExp]
nes), Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
nes)
  let deps' :: Dependencies
deps' = Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Shape -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Shape
shape Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [VName] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [VName]
arrs
  (WithAccInput GPU, Dependencies)
-> PassM (WithAccInput GPU, Dependencies)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Shape
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
op'), Dependencies
deps')

--------------------------------------------------------------------------------
--                             REORDERING - TYPES                             --
--------------------------------------------------------------------------------

-- | The monad used to reorder statements within a sequence such that its
-- GPUBody statements can be merged into as few possible kernels.
type ReorderM = StateT State PassM

-- | The state used by a 'ReorderM' monad.
data State = State
  { -- | All statements that already have been processed from the sequence,
    -- divided into alternating groups of non-GPUBody and GPUBody statements.
    -- Groups at even indices only contain non-GPUBody statements. Groups at
    -- odd indices only contain GPUBody statements.
    State -> Groups
stateGroups :: Groups,
    State -> EquivalenceTable
stateEquivalents :: EquivalenceTable
  }

-- | A map from variable tags to t'SubExp's returned from within GPUBodies.
type EquivalenceTable = IM.IntMap Entry

-- | An entry in an 'EquivalenceTable'.
data Entry = Entry
  { -- | A value returned from within a GPUBody kernel.
    -- In @let res = gpu { x }@ this is @x@.
    Entry -> SubExp
entryValue :: SubExp,
    -- | The type of the 'entryValue'.
    Entry -> Type
entryType :: Type,
    -- | The name of the variable that binds the return value for 'entryValue'.
    -- In @let res = gpu { x }@ this is @res@.
    Entry -> VName
entryResult :: VName,
    -- | The index of the group that `entryResult` is bound in.
    Entry -> Key
entryGroupIdx :: Int,
    -- | If 'False' then the entry key is a variable that binds the same value
    -- as the 'entryValue'. Otherwise it binds an array with an outer dimension
    -- of one whose row equals that value.
    Entry -> Bool
entryStored :: Bool
  }

type Groups = SQ.Seq Group

-- | A group is a subsequence of statements, usually either only GPUBody
-- statements or only non-GPUBody statements. The 'Usage' statistics of those
-- statements are also stored.
data Group = Group
  { -- | The statements of the group.
    Group -> Stms GPU
groupStms :: Stms GPU,
    -- | The usage statistics of the statements within the group.
    Group -> Usage
groupUsage :: Usage
  }

-- | Usage statistics for some set of statements.
data Usage = Usage
  { -- | The variables that the statements bind.
    Usage -> Dependencies
usageBindings :: Bindings,
    -- | The variables that the statements depend upon, i.e. the free variables
    -- of each statement and the root aliases of every array that they observe.
    Usage -> Dependencies
usageDependencies :: Dependencies
  }

instance Semigroup Group where
  (Group Stms GPU
s1 Usage
u1) <> :: Group -> Group -> Group
<> (Group Stms GPU
s2 Usage
u2) = Stms GPU -> Usage -> Group
Group (Stms GPU
s1 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s2) (Usage
u1 Usage -> Usage -> Usage
forall a. Semigroup a => a -> a -> a
<> Usage
u2)

instance Monoid Group where
  mempty :: Group
mempty = Group :: Stms GPU -> Usage -> Group
Group {groupStms :: Stms GPU
groupStms = Stms GPU
forall a. Monoid a => a
mempty, groupUsage :: Usage
groupUsage = Usage
forall a. Monoid a => a
mempty}

instance Semigroup Usage where
  (Usage Dependencies
b1 Dependencies
d1) <> :: Usage -> Usage -> Usage
<> (Usage Dependencies
b2 Dependencies
d2) = Dependencies -> Dependencies -> Usage
Usage (Dependencies
b1 Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
b2) (Dependencies
d1 Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
d2)

instance Monoid Usage where
  mempty :: Usage
mempty = Usage :: Dependencies -> Dependencies -> Usage
Usage {usageBindings :: Dependencies
usageBindings = Dependencies
forall a. Monoid a => a
mempty, usageDependencies :: Dependencies
usageDependencies = Dependencies
forall a. Monoid a => a
mempty}

--------------------------------------------------------------------------------
--                           REORDERING - FUNCTIONS                           --
--------------------------------------------------------------------------------

-- | Return the usage bindings of the group.
groupBindings :: Group -> Bindings
groupBindings :: Group -> Dependencies
groupBindings = Usage -> Dependencies
usageBindings (Usage -> Dependencies)
-> (Group -> Usage) -> Group -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage

-- | Return the usage dependencies of the group.
groupDependencies :: Group -> Dependencies
groupDependencies :: Group -> Dependencies
groupDependencies = Usage -> Dependencies
usageDependencies (Usage -> Dependencies)
-> (Group -> Usage) -> Group -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage

-- | An initial state to use when running a 'ReorderM' monad.
initialState :: State
initialState :: State
initialState =
  State :: Groups -> EquivalenceTable -> State
State
    { stateGroups :: Groups
stateGroups = Group -> Groups
forall a. a -> Seq a
SQ.singleton Group
forall a. Monoid a => a
mempty,
      stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
forall a. Monoid a => a
mempty
    }

-- | Modify the groups that the sequence has been split into so far.
modifyGroups :: (Groups -> Groups) -> ReorderM ()
modifyGroups :: (Groups -> Groups) -> StateT State PassM ()
modifyGroups Groups -> Groups
f =
  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = Groups -> Groups
f (State -> Groups
stateGroups State
st)}

-- | Remove these keys from the equivalence table.
removeEquivalents :: IS.IntSet -> ReorderM ()
removeEquivalents :: Dependencies -> StateT State PassM ()
removeEquivalents Dependencies
keys =
  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let eqs' :: EquivalenceTable
eqs' = State -> EquivalenceTable
stateEquivalents State
st EquivalenceTable -> Dependencies -> EquivalenceTable
forall a. IntMap a -> Dependencies -> IntMap a
`IM.withoutKeys` Dependencies
keys
     in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}

-- | Add an entry to the equivalence table.
recordEquivalent :: VName -> Entry -> ReorderM ()
recordEquivalent :: VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n Entry
entry =
  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let eqs :: EquivalenceTable
eqs = State -> EquivalenceTable
stateEquivalents State
st
        eqs' :: EquivalenceTable
eqs' = Key -> Entry -> EquivalenceTable -> EquivalenceTable
forall a. Key -> a -> IntMap a -> IntMap a
IM.insert (VName -> Key
baseTag VName
n) Entry
entry EquivalenceTable
eqs
     in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}

-- | Moves a GPUBody statement to the furthest possible group of the statement
-- sequence, possibly a new group at the end of sequence.
--
-- To simplify consumption handling a GPUBody is not allowed to merge with a
-- kernel whose result it consumes. Such GPUBody may therefore not be moved
-- into the same group as such kernel.
moveGPUBody :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveGPUBody :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm Usage
usage Dependencies
consumed = do
  -- Replace dependencies with their GPUBody result equivalents.
  EquivalenceTable
eqs <- (State -> EquivalenceTable) -> StateT State PassM EquivalenceTable
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
  let g :: Key -> Key
g Key
i = Key -> (Entry -> Key) -> Maybe Entry -> Key
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Key
i (VName -> Key
baseTag (VName -> Key) -> (Entry -> VName) -> Entry -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry -> VName
entryResult) (Key -> EquivalenceTable -> Maybe Entry
forall a. Key -> IntMap a -> Maybe a
IM.lookup Key
i EquivalenceTable
eqs)
  let deps' :: Dependencies
deps' = (Key -> Key) -> Dependencies -> Dependencies
IS.map Key -> Key
g (Usage -> Dependencies
usageDependencies Usage
usage)
  let usage' :: Usage
usage' = Usage
usage {usageDependencies :: Dependencies
usageDependencies = Dependencies
deps'}

  -- Move the GPUBody.
  Groups
grps <- (State -> Groups) -> StateT State PassM Groups
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage' Dependencies
consumed
  let idx :: Key
idx = Key -> Maybe Key -> Key
forall a. a -> Maybe a -> a
fromMaybe Key
1 ((Group -> Bool) -> Groups -> Maybe Key
forall a. (a -> Bool) -> Seq a -> Maybe Key
SQ.findIndexR Group -> Bool
f Groups
grps)
  let idx' :: Key
idx' = case Key
idx Key -> Key -> Key
forall a. Integral a => a -> a -> a
`mod` Key
2 of
        Key
0 -> Key
idx Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
1
        Key
_ | Key -> Groups -> Bool
consumes Key
idx Groups
grps -> Key
idx Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
2
        Key
_ -> Key
idx
  (Groups -> Groups) -> StateT State PassM ()
modifyGroups ((Groups -> Groups) -> StateT State PassM ())
-> (Groups -> Groups) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Key -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Key
idx'

  -- Record the kernel equivalents of the bound results.
  let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
  let Op (GPUBody _ (Body _ _ res)) = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
  ((PatElem Type, SubExp) -> StateT State PassM ())
-> [(PatElem Type, SubExp)] -> StateT State PassM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Key -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Key
idx') ([PatElem Type] -> [SubExp] -> [(PatElem Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res))
  where
    consumes :: Key -> Groups -> Bool
consumes Key
idx Groups
grps
      | Just Group
grp <- Key -> Groups -> Maybe Group
forall a. Key -> Seq a -> Maybe a
SQ.lookup Key
idx Groups
grps =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Dependencies -> Dependencies -> Bool
IS.disjoint (Group -> Dependencies
groupBindings Group
grp) Dependencies
consumed
      | Bool
otherwise =
          Bool
False

    stores :: Key -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Key
idx (PatElem VName
n Type
t, SubExp
se)
      | Just Type
row_t <- Key -> Type -> Maybe Type
forall u. Key -> TypeBase Shape u -> Maybe (TypeBase Shape u)
peelArray Key
1 Type
t =
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n (Entry -> StateT State PassM ()) -> Entry -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Key -> Bool -> Entry
Entry SubExp
se Type
row_t VName
n Key
idx Bool
True
      | Bool
otherwise =
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n (Entry -> StateT State PassM ()) -> Entry -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Key -> Bool -> Entry
Entry SubExp
se Type
t VName
n Key
idx Bool
False

-- | Moves a non-GPUBody statement to the furthest possible groups of the
-- statement sequence, possibly a new group at the end of sequence.
moveOther :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveOther :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm Usage
usage Dependencies
consumed = do
  Groups
grps <- (State -> Groups) -> StateT State PassM Groups
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed
  let idx :: Key
idx = Key -> Maybe Key -> Key
forall a. a -> Maybe a -> a
fromMaybe Key
0 ((Group -> Bool) -> Groups -> Maybe Key
forall a. (a -> Bool) -> Seq a -> Maybe Key
SQ.findIndexR Group -> Bool
f Groups
grps)
  let idx' :: Key
idx' = ((Key
idx Key -> Key -> Key
forall a. Num a => a -> a -> a
+ Key
1) Key -> Key -> Key
forall a. Integral a => a -> a -> a
`div` Key
2) Key -> Key -> Key
forall a. Num a => a -> a -> a
* Key
2
  (Groups -> Groups) -> StateT State PassM ()
modifyGroups ((Groups -> Groups) -> StateT State PassM ())
-> (Groups -> Groups) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Key -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Key
idx'
  Stm GPU -> Key -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Key
idx'

-- | @recordEquivalentsOf stm idx@ records the GPUBody result and/or return
-- value that @stm@ is equivalent to. @idx@ is the index of the group that @stm@
-- belongs to.
--
-- A GPUBody can have a dependency substituted with a result equivalent if it
-- merges with the source GPUBody, allowing it to be moved beyond the binding
-- site of that dependency.
--
-- To guarantee that a GPUBody which moves beyond a dependency also merges with
-- its source GPUBody, equivalents are only allowed to be recorded for results
-- bound within the group at index @idx-1@.
recordEquivalentsOf :: Stm GPU -> Int -> ReorderM ()
recordEquivalentsOf :: Stm GPU -> Key -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Key
idx = do
  EquivalenceTable
eqs <- (State -> EquivalenceTable) -> StateT State PassM EquivalenceTable
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
  case Stm GPU
stm of
    Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (SubExp (Var VName
n)))
      | Just Entry
entry <- Key -> EquivalenceTable -> Maybe Entry
forall a. Key -> IntMap a -> Maybe a
IM.lookup (VName -> Key
baseTag VName
n) EquivalenceTable
eqs,
        Entry -> Key
entryGroupIdx Entry
entry Key -> Key -> Bool
forall a. Eq a => a -> a -> Bool
== Key
idx Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
1 ->
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x Entry
entry
    Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
arr Slice SubExp
slice))
      | Just Entry
entry <- Key -> EquivalenceTable -> Maybe Entry
forall a. Key -> IntMap a -> Maybe a
IM.lookup (VName -> Key
baseTag VName
arr) EquivalenceTable
eqs,
        Entry -> Key
entryGroupIdx Entry
entry Key -> Key -> Bool
forall a. Eq a => a -> a -> Bool
== Key
idx Key -> Key -> Key
forall a. Num a => a -> a -> a
- Key
1,
        Slice (DimFix SubExp
i : [DimIndex SubExp]
dims) <- Slice SubExp
slice,
        SubExp
i SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0,
        [DimIndex SubExp]
dims [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Entry -> Type
entryType Entry
entry) ->
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x (Entry
entry {entryStored :: Bool
entryStored = Bool
False})
    Stm GPU
_ -> () -> StateT State PassM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Does this group block a statement with this usage/consumption statistics
-- from being moved past it?
groupBlocks :: Usage -> Consumption -> Group -> Bool
groupBlocks :: Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed Group
grp =
  let bound :: Dependencies
bound = Group -> Dependencies
groupBindings Group
grp
      deps :: Dependencies
deps = Group -> Dependencies
groupDependencies Group
grp

      used :: Dependencies
used = Usage -> Dependencies
usageDependencies Usage
usage
   in Bool -> Bool
not (Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
bound Dependencies
used Bool -> Bool -> Bool
&& Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
deps Dependencies
consumed)

-- | @moveToGrp stm idx grps@ moves @stm@ into the group at index @idx@ of
-- @grps@.
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp :: (Stm GPU, Usage) -> Key -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Key
idx Groups
grps
  | Key
idx Key -> Key -> Bool
forall a. Ord a => a -> a -> Bool
>= Groups -> Key
forall a. Seq a -> Key
SQ.length Groups
grps =
      (Stm GPU, Usage) -> Key -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Key
idx (Groups
grps Groups -> Group -> Groups
forall a. Seq a -> a -> Seq a
|> Group
forall a. Monoid a => a
mempty)
  | Bool
otherwise =
      (Group -> Group) -> Key -> Groups -> Groups
forall a. (a -> a) -> Key -> Seq a -> Seq a
SQ.adjust' ((Stm GPU, Usage)
stm (Stm GPU, Usage) -> Group -> Group
`moveTo`) Key
idx Groups
grps

-- | Adds the statement and its usage statistics to the group.
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo (Stm GPU
stm, Usage
usage) Group
grp =
  Group
grp
    { groupStms :: Stms GPU
groupStms = Group -> Stms GPU
groupStms Group
grp Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm,
      groupUsage :: Usage
groupUsage = Group -> Usage
groupUsage Group
grp Usage -> Usage -> Usage
forall a. Semigroup a => a -> a -> a
<> Usage
usage
    }

--------------------------------------------------------------------------------
--                         MERGING GPU BODIES - TYPES                         --
--------------------------------------------------------------------------------

-- | The monad used for rewriting a GPUBody to use the t'SubExp's that are
-- returned from kernels it is merged with rather than the results that they
-- bind.
--
-- The state is a prologue of statements to be added at the beginning of the
-- rewritten kernel body.
type RewriteM = StateT (Stms GPU) ReorderM

--------------------------------------------------------------------------------
--                       MERGING GPU BODIES - FUNCTIONS                       --
--------------------------------------------------------------------------------

-- | Collapses the processed sequence of groups into a single group and returns
-- it, merging GPUBody groups into single kernels in the process.
collapse :: ReorderM Group
collapse :: StateT State PassM Group
collapse = do
  [(Bool, Group)]
grps <- [Bool] -> [Group] -> [(Bool, Group)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool] -> [Bool]
forall a. [a] -> [a]
cycle [Bool
False, Bool
True]) ([Group] -> [(Bool, Group)])
-> (Groups -> [Group]) -> Groups -> [(Bool, Group)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Groups -> [Group]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Groups -> [(Bool, Group)])
-> StateT State PassM Groups -> StateT State PassM [(Bool, Group)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Groups) -> StateT State PassM Groups
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  Group
grp <- (Group -> (Bool, Group) -> StateT State PassM Group)
-> Group -> [(Bool, Group)] -> StateT State PassM Group
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Group -> (Bool, Group) -> StateT State PassM Group
clps Group
forall a. Monoid a => a
mempty [(Bool, Group)]
grps

  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = Group -> Groups
forall a. a -> Seq a
SQ.singleton Group
grp}
  Group -> StateT State PassM Group
forall (f :: * -> *) a. Applicative f => a -> f a
pure Group
grp
  where
    clps :: Group -> (Bool, Group) -> StateT State PassM Group
clps Group
grp0 (Bool
gpu_bodies, Group Stms GPU
stms Usage
usage) = do
      Group
grp1 <-
        if Bool
gpu_bodies
          then Stms GPU -> Usage -> Group
Group (Stms GPU -> Usage -> Group)
-> StateT State PassM (Stms GPU)
-> StateT State PassM (Usage -> Group)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> StateT State PassM (Stms GPU)
mergeKernels Stms GPU
stms StateT State PassM (Usage -> Group)
-> StateT State PassM Usage -> StateT State PassM Group
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Usage -> StateT State PassM Usage
forall (f :: * -> *) a. Applicative f => a -> f a
pure Usage
usage
          else Group -> StateT State PassM Group
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> Usage -> Group
Group Stms GPU
stms Usage
usage)
      -- Remove equivalents that no longer are relevant for rewriting GPUBody
      -- kernels. This ensures that they are not substituted in later kernels
      -- where the replacement variables might not be in scope.
      Dependencies -> StateT State PassM ()
removeEquivalents (Group -> Dependencies
groupBindings Group
grp1)
      Group -> StateT State PassM Group
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Group
grp0 Group -> Group -> Group
forall a. Semigroup a => a -> a -> a
<> Group
grp1)

-- | Merges a sequence of GPUBody statements into a single kernel.
mergeKernels :: Stms GPU -> ReorderM (Stms GPU)
mergeKernels :: Stms GPU -> StateT State PassM (Stms GPU)
mergeKernels Stms GPU
stms
  | Stms GPU -> Key
forall a. Seq a -> Key
SQ.length Stms GPU
stms Key -> Key -> Bool
forall a. Ord a => a -> a -> Bool
< Key
2 =
      Stms GPU -> StateT State PassM (Stms GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
  | Bool
otherwise =
      Stm GPU -> Stms GPU
forall a. a -> Seq a
SQ.singleton (Stm GPU -> Stms GPU)
-> StateT State PassM (Stm GPU) -> StateT State PassM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Stm GPU -> StateT State PassM (Stm GPU))
-> Stm GPU -> Stms GPU -> StateT State PassM (Stm GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> b -> m b) -> b -> t a -> m b
foldrM Stm GPU -> Stm GPU -> StateT State PassM (Stm GPU)
merge Stm GPU
empty Stms GPU
stms
  where
    empty :: Stm GPU
empty = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
forall a. Monoid a => a
mempty (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()) Exp GPU
noop
    noop :: Exp GPU
noop = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([Type] -> Body GPU -> HostOp GPU (SOAC GPU)
forall rep op. [Type] -> Body rep -> HostOp rep op
GPUBody [] (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty []))

    merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
    merge :: Stm GPU -> Stm GPU -> StateT State PassM (Stm GPU)
merge Stm GPU
stm0 Stm GPU
stm1
      | Let Pat (LetDec GPU)
pat0 (StmAux Certs
cs0 Attrs
attrs0 ExpDec GPU
_) (Op (GPUBody types0 body)) <- Stm GPU
stm0,
        Let Pat (LetDec GPU)
pat1 (StmAux Certs
cs1 Attrs
attrs1 ExpDec GPU
_) (Op (GPUBody types1 body1)) <- Stm GPU
stm1 =
          do
            Body BodyDec GPU
_ Stms GPU
stms0 Result
res0 <- RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite (Body GPU -> RewriteM (Body GPU)
rewriteBody Body GPU
body)
            let Body BodyDec GPU
_ Stms GPU
stms1 Result
res1 = Body GPU
body1

                pat' :: Pat Type
pat' = Pat Type
Pat (LetDec GPU)
pat0 Pat Type -> Pat Type -> Pat Type
forall a. Semigroup a => a -> a -> a
<> Pat Type
Pat (LetDec GPU)
pat1
                aux' :: StmAux ()
aux' = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs0 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs1) (Attrs
attrs0 Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Attrs
attrs1) ()
                types' :: [Type]
types' = [Type]
types0 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
types1
                body' :: Body GPU
body' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
stms0 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms1) (Result
res0 Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res1)
             in Stm GPU -> StateT State PassM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' StmAux ()
StmAux (ExpDec GPU)
aux' (Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([Type] -> Body GPU -> HostOp GPU (SOAC GPU)
forall rep op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types' Body GPU
body')))
    merge Stm GPU
_ Stm GPU
_ =
      String -> StateT State PassM (Stm GPU)
forall a. String -> a
compilerBugS String
"mergeGPUBodies: cannot merge non-GPUBody statements"

-- | Perform a rewrite and finish it by adding the rewrite prologue to the start
-- of the body.
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite RewriteM (Body GPU)
m = RewriteM (Body GPU) -> Stms GPU -> ReorderM (Body GPU)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT RewriteM (Body GPU)
m' Stms GPU
forall a. Seq a
SQ.empty
  where
    m' :: RewriteM (Body GPU)
m' = do
      Body BodyDec GPU
_ Stms GPU
stms Result
res <- RewriteM (Body GPU)
m
      Stms GPU
prologue <- StateT (Stms GPU) ReorderM (Stms GPU)
forall (m :: * -> *) s. Monad m => StateT s m s
get
      Body GPU -> RewriteM (Body GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms) Result
res)

-- | Return the equivalence table.
equivalents :: RewriteM EquivalenceTable
equivalents :: RewriteM EquivalenceTable
equivalents = StateT State PassM EquivalenceTable -> RewriteM EquivalenceTable
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((State -> EquivalenceTable) -> StateT State PassM EquivalenceTable
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents)

rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) =
  BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU -> Result -> Body GPU)
-> StateT (Stms GPU) ReorderM (Stms GPU)
-> StateT (Stms GPU) ReorderM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> StateT (Stms GPU) ReorderM (Stms GPU)
rewriteStms Stms GPU
stms StateT (Stms GPU) ReorderM (Result -> Body GPU)
-> StateT (Stms GPU) ReorderM Result -> RewriteM (Body GPU)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> StateT (Stms GPU) ReorderM Result
rewriteResult Result
res

rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> StateT (Stms GPU) ReorderM (Stms GPU)
rewriteStms = (Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU))
-> Stms GPU -> StateT (Stms GPU) ReorderM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
rewriteStm

rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
rewriteStm (Let (Pat [PatElem (LetDec GPU)]
pes) (StmAux Certs
cs Attrs
attrs ExpDec GPU
_) Exp GPU
e) = do
  Pat Type
pat' <- [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> StateT (Stms GPU) ReorderM [PatElem Type]
-> StateT (Stms GPU) ReorderM (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem Type -> StateT (Stms GPU) ReorderM (PatElem Type))
-> [PatElem Type] -> StateT (Stms GPU) ReorderM [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> StateT (Stms GPU) ReorderM (PatElem Type)
rewritePatElem [PatElem Type]
[PatElem (LetDec GPU)]
pes
  Certs
cs' <- Certs -> RewriteM Certs
rewriteCerts Certs
cs
  Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
  Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU))
-> Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs' Attrs
attrs ()) Exp GPU
e'

rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> StateT (Stms GPU) ReorderM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) =
  VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n (Type -> PatElem Type)
-> StateT (Stms GPU) ReorderM Type
-> StateT (Stms GPU) ReorderM (PatElem Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> StateT (Stms GPU) ReorderM Type
forall u. TypeBase Shape u -> RewriteM (TypeBase Shape u)
rewriteType Type
t

rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e = do
  EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
  case Exp GPU
e of
    BasicOp (Index VName
arr Slice SubExp
slice)
      | Just Entry
entry <- Key -> EquivalenceTable -> Maybe Entry
forall a. Key -> IntMap a -> Maybe a
IM.lookup (VName -> Key
baseTag VName
arr) EquivalenceTable
eqs,
        DimFix SubExp
idx : [DimIndex SubExp]
dims <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        SubExp
idx SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 ->
          let se :: SubExp
se = Entry -> SubExp
entryValue Entry
entry
           in Exp GPU -> RewriteM (Exp GPU)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp GPU -> RewriteM (Exp GPU))
-> (BasicOp -> Exp GPU) -> BasicOp -> RewriteM (Exp GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RewriteM (Exp GPU)) -> BasicOp -> RewriteM (Exp GPU)
forall a b. (a -> b) -> a -> b
$ case ([DimIndex SubExp]
dims, SubExp
se) of
                ([], SubExp
_) -> SubExp -> BasicOp
SubExp SubExp
se
                ([DimIndex SubExp]
_, Var VName
src) -> VName -> Slice SubExp -> BasicOp
Index VName
src ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
dims)
                ([DimIndex SubExp], SubExp)
_ -> String -> BasicOp
forall a. String -> a
compilerBugS String
"rewriteExp: bad equivalence entry"
    Exp GPU
_ -> Mapper GPU GPU (StateT (Stms GPU) ReorderM)
-> Exp GPU -> RewriteM (Exp GPU)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter Exp GPU
e
  where
    rewriter :: Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter =
      Mapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope trep -> Body frep -> m (Body trep))
-> (VName -> m VName)
-> (RetType frep -> m (RetType trep))
-> (BranchType frep -> m (BranchType trep))
-> (FParam frep -> m (FParam trep))
-> (LParam frep -> m (LParam trep))
-> (Op frep -> m (Op trep))
-> Mapper frep trep m
Mapper
        { mapOnSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
mapOnSubExp = SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp,
          mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = (Body GPU -> RewriteM (Body GPU))
-> Scope GPU -> Body GPU -> RewriteM (Body GPU)
forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
          mapOnVName :: VName -> StateT (Stms GPU) ReorderM VName
mapOnVName = VName -> StateT (Stms GPU) ReorderM VName
rewriteName,
          mapOnRetType :: RetType GPU -> StateT (Stms GPU) ReorderM (RetType GPU)
mapOnRetType = RetType GPU -> StateT (Stms GPU) ReorderM (RetType GPU)
forall u. TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
rewriteExtType,
          mapOnBranchType :: BranchType GPU -> StateT (Stms GPU) ReorderM (BranchType GPU)
mapOnBranchType = BranchType GPU -> StateT (Stms GPU) ReorderM (BranchType GPU)
forall u. TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
rewriteExtType,
          mapOnFParam :: FParam GPU -> StateT (Stms GPU) ReorderM (FParam GPU)
mapOnFParam = FParam GPU -> StateT (Stms GPU) ReorderM (FParam GPU)
forall u.
Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam,
          mapOnLParam :: LParam GPU -> StateT (Stms GPU) ReorderM (LParam GPU)
mapOnLParam = LParam GPU -> StateT (Stms GPU) ReorderM (LParam GPU)
forall u.
Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam,
          mapOnOp :: Op GPU -> StateT (Stms GPU) ReorderM (Op GPU)
mapOnOp = StateT (Stms GPU) ReorderM (HostOp GPU (SOAC GPU))
-> HostOp GPU (SOAC GPU)
-> StateT (Stms GPU) ReorderM (HostOp GPU (SOAC GPU))
forall a b. a -> b -> a
const StateT (Stms GPU) ReorderM (HostOp GPU (SOAC GPU))
forall a. a
opError
        }

    opError :: a
opError = String -> a
forall a. String -> a
compilerBugS String
"rewriteExp: unhandled HostOp in GPUBody"

rewriteResult :: Result -> RewriteM Result
rewriteResult :: Result -> StateT (Stms GPU) ReorderM Result
rewriteResult = (SubExpRes -> StateT (Stms GPU) ReorderM SubExpRes)
-> Result -> StateT (Stms GPU) ReorderM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> StateT (Stms GPU) ReorderM SubExpRes
rewriteSubExpRes

rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes :: SubExpRes -> StateT (Stms GPU) ReorderM SubExpRes
rewriteSubExpRes (SubExpRes Certs
cs SubExp
se) =
  Certs -> SubExp -> SubExpRes
SubExpRes (Certs -> SubExp -> SubExpRes)
-> RewriteM Certs
-> StateT (Stms GPU) ReorderM (SubExp -> SubExpRes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> RewriteM Certs
rewriteCerts Certs
cs StateT (Stms GPU) ReorderM (SubExp -> SubExpRes)
-> StateT (Stms GPU) ReorderM SubExp
-> StateT (Stms GPU) ReorderM SubExpRes
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp SubExp
se

rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts (Certs [VName]
cs) =
  [VName] -> Certs
Certs ([VName] -> Certs)
-> StateT (Stms GPU) ReorderM [VName] -> RewriteM Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> StateT (Stms GPU) ReorderM VName)
-> [VName] -> StateT (Stms GPU) ReorderM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT (Stms GPU) ReorderM VName
rewriteName [VName]
cs

rewriteType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
-- Note: mapOnType also maps the VName token of accumulators
rewriteType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
rewriteType = (SubExp -> StateT (Stms GPU) ReorderM SubExp)
-> TypeBase Shape u -> RewriteM (TypeBase Shape u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp

rewriteExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
-- Note: mapOnExtType also maps the VName token of accumulators
rewriteExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
rewriteExtType = (SubExp -> StateT (Stms GPU) ReorderM SubExp)
-> TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase ExtShape u -> m (TypeBase ExtShape u)
mapOnExtType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp

rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam (Param Attrs
attrs VName
n TypeBase Shape u
t) =
  Attrs -> VName -> TypeBase Shape u -> Param (TypeBase Shape u)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n (TypeBase Shape u -> Param (TypeBase Shape u))
-> StateT (Stms GPU) ReorderM (TypeBase Shape u)
-> RewriteM (Param (TypeBase Shape u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeBase Shape u -> StateT (Stms GPU) ReorderM (TypeBase Shape u)
forall u. TypeBase Shape u -> RewriteM (TypeBase Shape u)
rewriteType TypeBase Shape u
t

rewriteSubExp :: SubExp -> RewriteM SubExp
rewriteSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (Constant PrimValue
c) = SubExp -> StateT (Stms GPU) ReorderM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimValue -> SubExp
Constant PrimValue
c)
rewriteSubExp (Var VName
n) = do
  EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
  case Key -> EquivalenceTable -> Maybe Entry
forall a. Key -> IntMap a -> Maybe a
IM.lookup (VName -> Key
baseTag VName
n) EquivalenceTable
eqs of
    Maybe Entry
Nothing -> SubExp -> StateT (Stms GPU) ReorderM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
n)
    Just (Entry SubExp
se Type
_ VName
_ Key
_ Bool
False) -> SubExp -> StateT (Stms GPU) ReorderM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
    Just (Entry SubExp
se Type
t VName
_ Key
_ Bool
True) -> VName -> SubExp
Var (VName -> SubExp)
-> StateT (Stms GPU) ReorderM VName
-> StateT (Stms GPU) ReorderM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
t

rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT (Stms GPU) ReorderM VName
rewriteName VName
n = do
  SubExp
se <- SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (VName -> SubExp
Var VName
n)
  case SubExp
se of
    Var VName
n' -> VName -> StateT (Stms GPU) ReorderM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
    Constant PrimValue
c -> PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c

-- | @asArray se t@ adds @let x = [se]@ to the rewrite prologue and returns the
-- name of @x@. @t@ is the type of @se@.
asArray :: SubExp -> Type -> RewriteM VName
asArray :: SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
row_t = do
  VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"arr"
  let t :: Type
t = Type
row_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

  let pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
name Type
t]
  let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
  let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp ([SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
row_t)

  (Stms GPU -> Stms GPU) -> StateT (Stms GPU) ReorderM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall rep. Exp rep
e)
  VName -> StateT (Stms GPU) ReorderM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name

-- | @referConst c@ adds @let x = c@ to the rewrite prologue and returns the
-- name of @x@.
referConst :: PrimValue -> RewriteM VName
referConst :: PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c = do
  VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"cnst"
  let t :: TypeBase shape u
t = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimValue -> PrimType
primValueType PrimValue
c)

  let pat :: Pat (TypeBase shape u)
pat = [PatElem (TypeBase shape u)] -> Pat (TypeBase shape u)
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> TypeBase shape u -> PatElem (TypeBase shape u)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name TypeBase shape u
forall shape u. TypeBase shape u
t]
  let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
  let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
c)

  (Stms GPU -> Stms GPU) -> StateT (Stms GPU) ReorderM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
forall shape u. Pat (TypeBase shape u)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall rep. Exp rep
e)
  VName -> StateT (Stms GPU) ReorderM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name

-- | Produce a fresh name, using the given string as a template.
newName :: String -> RewriteM VName
newName :: String -> StateT (Stms GPU) ReorderM VName
newName String
s = StateT State PassM VName -> StateT (Stms GPU) ReorderM VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT State PassM VName -> StateT (Stms GPU) ReorderM VName)
-> StateT State PassM VName -> StateT (Stms GPU) ReorderM VName
forall a b. (a -> b) -> a -> b
$ PassM VName -> StateT State PassM VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (String -> PassM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
s)