-- |
-- This module implements an optimization pass that merges 'GPUBody' kernels to
-- eliminate memory transactions and reduce the number of kernel launches.
-- This is useful because the "Futhark.Optimise.ReduceDeviceSyncs" pass introduces
-- 'GPUBody' kernels that only execute single statements.
--
-- To merge as many 'GPUBody' kernels as possible, this pass reorders statements
-- with the goal of bringing as many 'GPUBody' statements next to each other in
-- a sequence. Such sequence can then trivially be merged.
module Futhark.Optimise.MergeGPUBodies (mergeGPUBodies) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first)
import Data.Foldable
import Data.IntMap qualified as IM
import Data.IntSet ((\\))
import Data.IntSet qualified as IS
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Sequence ((|>))
import Data.Sequence qualified as SQ
import Futhark.Analysis.Alias
import Futhark.Construct (sliceDim)
import Futhark.Error
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.MonadFreshNames hiding (newName)
import Futhark.Pass

-- | An optimization pass that reorders and merges 'GPUBody' statements to
-- eliminate memory transactions and reduce the number of kernel launches.
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies =
  String -> String -> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"merge GPU bodies"
    String
"Reorder and merge GPUBody constructs to reduce kernels executions."
    ((Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU)
-> (Prog GPU -> PassM (Prog GPU)) -> Pass GPU GPU
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases GPU) -> PassM (Stms GPU))
-> (Stms GPU -> FunDef (Aliases GPU) -> PassM (FunDef GPU))
-> Prog (Aliases GPU)
-> PassM (Prog GPU)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases GPU) -> PassM (Stms GPU)
onStms Stms GPU -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
forall {p}. p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef (Prog (Aliases GPU) -> PassM (Prog GPU))
-> (Prog GPU -> Prog (Aliases GPU)) -> Prog GPU -> PassM (Prog GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog GPU -> Prog (Aliases GPU)
forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
aliasAnalysis
  where
    onFunDef :: p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef p
_ (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [(RetType (Aliases GPU), RetAls)]
types [FParam (Aliases GPU)]
params Body (Aliases GPU)
body) =
      Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType GPU, RetAls)]
-> [FParam GPU]
-> Body GPU
-> FunDef GPU
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [(RetType (Aliases GPU), RetAls)]
[(RetType GPU, RetAls)]
types [FParam (Aliases GPU)]
[FParam GPU]
params (Body GPU -> FunDef GPU)
-> ((Body GPU, Dependencies) -> Body GPU)
-> (Body GPU, Dependencies)
-> FunDef GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Body GPU, Dependencies) -> Body GPU
forall a b. (a, b) -> a
fst ((Body GPU, Dependencies) -> FunDef GPU)
-> PassM (Body GPU, Dependencies) -> PassM (FunDef GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
forall a. Monoid a => a
mempty Body (Aliases GPU)
body
    onStms :: Stms (Aliases GPU) -> PassM (Stms GPU)
onStms Stms (Aliases GPU)
stms =
      (Stms GPU, Dependencies) -> Stms GPU
forall a b. (a, b) -> a
fst ((Stms GPU, Dependencies) -> Stms GPU)
-> PassM (Stms GPU, Dependencies) -> PassM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms AliasTable
forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms

--------------------------------------------------------------------------------
--                               COMMON - TYPES                               --
--------------------------------------------------------------------------------

-- | A set of 'VName' tags that denote all variables that some group of
-- statements depend upon. Those must be computed before the group statements.
type Dependencies = IS.IntSet

-- | A set of 'VName' tags that denote all variables that some group of
-- statements binds.
type Bindings = IS.IntSet

-- | A set of 'VName' tags that denote the root aliases of all arrays that some
-- statement consumes.
type Consumption = IS.IntSet

--------------------------------------------------------------------------------
--                              COMMON - HELPERS                              --
--------------------------------------------------------------------------------

-- | All free variables of a construct as 'Dependencies'.
depsOf :: (FreeIn a) => a -> Dependencies
depsOf :: forall a. FreeIn a => a -> Dependencies
depsOf = Names -> Dependencies
namesToSet (Names -> Dependencies) -> (a -> Names) -> a -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Names
forall a. FreeIn a => a -> Names
freeIn

-- | Convert 'Names' to an integer set of name tags.
namesToSet :: Names -> IS.IntSet
namesToSet :: Names -> Dependencies
namesToSet = [Int] -> Dependencies
IS.fromList ([Int] -> Dependencies)
-> (Names -> [Int]) -> Names -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Int) -> [VName] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Int
baseTag ([VName] -> [Int]) -> (Names -> [VName]) -> Names -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

--------------------------------------------------------------------------------
--                            AD HOC OPTIMIZATION                             --
--------------------------------------------------------------------------------

-- | Optimize a lambda and determine its dependencies.
transformLambda ::
  AliasTable ->
  Lambda (Aliases GPU) ->
  PassM (Lambda GPU, Dependencies)
transformLambda :: AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases (Lambda [LParam (Aliases GPU)]
params [Type]
types Body (Aliases GPU)
body) = do
  (Body GPU
body', Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
  (Lambda GPU, Dependencies) -> PassM (Lambda GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([LParam GPU] -> [Type] -> Body GPU -> Lambda GPU
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [LParam (Aliases GPU)]
[LParam GPU]
params [Type]
types Body GPU
body', Dependencies
deps)

-- | Optimize a body and determine its dependencies.
transformBody ::
  AliasTable ->
  Body (Aliases GPU) ->
  PassM (Body GPU, Dependencies)
transformBody :: AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (Body BodyDec (Aliases GPU)
_ Stms (Aliases GPU)
stms Result
res) = do
  Group
grp <- StateT State PassM Group -> State -> PassM Group
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((AliasTable -> Stm (Aliases GPU) -> StateT State PassM AliasTable)
-> AliasTable -> Stms (Aliases GPU) -> StateT State PassM ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ AliasTable -> Stm (Aliases GPU) -> StateT State PassM AliasTable
reorderStm AliasTable
aliases Stms (Aliases GPU)
stms StateT State PassM ()
-> StateT State PassM Group -> StateT State PassM Group
forall a b.
StateT State PassM a
-> StateT State PassM b -> StateT State PassM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StateT State PassM Group
collapse) State
initialState

  let stms' :: Stms GPU
stms' = Group -> Stms GPU
groupStms Group
grp
  let deps :: Dependencies
deps = (Group -> Dependencies
groupDependencies Group
grp Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Result -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Result
res) Dependencies -> Dependencies -> Dependencies
\\ Group -> Dependencies
groupBindings Group
grp

  (Body GPU, Dependencies) -> PassM (Body GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res, Dependencies
deps)

-- | Optimize a sequence of statements and determine their dependencies.
transformStms ::
  AliasTable ->
  Stms (Aliases GPU) ->
  PassM (Stms GPU, Dependencies)
transformStms :: AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms AliasTable
aliases Stms (Aliases GPU)
stms = do
  (Body BodyDec GPU
_ Stms GPU
stms' Result
_, Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (BodyDec (Aliases GPU)
-> Stms (Aliases GPU) -> Result -> Body (Aliases GPU)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body (BodyAliasing, ())
BodyDec (Aliases GPU)
forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms [])
  (Stms GPU, Dependencies) -> PassM (Stms GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Dependencies
deps)

-- | Optimizes and reorders a single statement within a sequence while tracking
-- the declaration, observation, and consumption of its dependencies.
-- This creates sequences of GPUBody statements that can be merged into single
-- kernels.
reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm :: AliasTable -> Stm (Aliases GPU) -> StateT State PassM AliasTable
reorderStm AliasTable
aliases (Let Pat (LetDec (Aliases GPU))
pat (StmAux Certs
cs Attrs
attrs ExpDec (Aliases GPU)
_) Exp (Aliases GPU)
e) = do
  (Exp GPU
e', Dependencies
deps) <- PassM (Exp GPU, Dependencies)
-> StateT State PassM (Exp GPU, Dependencies)
forall (m :: * -> *) a. Monad m => m a -> StateT State m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e)
  let pat' :: Pat Type
pat' = Pat (VarAliases, Type) -> Pat Type
forall a. Pat (VarAliases, a) -> Pat a
removePatAliases Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat
  let stm' :: Stm GPU
stm' = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ()) Exp GPU
e'
  let pes' :: [PatElem Type]
pes' = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'

  -- Array aliases can be seen as a directed graph where vertices are arrays
  -- (or the names that bind them) and an edge x -> y denotes that x aliases y.
  -- The root aliases of some array A is then the set of arrays that can be
  -- reached from A in graph and which have no edges themselves.
  --
  -- All arrays that share a root alias are considered aliases of each other
  -- and will be consumed if either of them is consumed.
  -- When reordering statements we must ensure that no statement that consumes
  -- an array is moved before any statement that observes one of its aliases.
  --
  -- That is to move statement X before statement Y the set of root aliases of
  -- arrays consumed by X must not overlap with the root aliases of arrays
  -- observed by Y.
  --
  -- We consider the root aliases of Y's observed arrays as part of Y's
  -- dependencies and simply say that the root aliases of arrays consumed by X
  -- must not overlap those.
  --
  -- To move X before Y then the dependencies of X must also not overlap with
  -- the variables bound by Y.

  let observed :: Dependencies
observed = Names -> Dependencies
namesToSet (Names -> Dependencies) -> Names -> Dependencies
forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf ([Names] -> Names
forall m. Monoid m => [m] -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ Pat (VarAliases, Type) -> [Names]
forall dec. AliasesOf dec => Pat dec -> [Names]
patAliases Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat) AliasTable
aliases
  let consumed :: Dependencies
consumed = Names -> Dependencies
namesToSet (Names -> Dependencies) -> Names -> Dependencies
forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf (Exp (Aliases GPU) -> Names
forall rep. Aliased rep => Exp rep -> Names
consumedInExp Exp (Aliases GPU)
e) AliasTable
aliases
  let usage :: Usage
usage =
        Usage
          { usageBindings :: Dependencies
usageBindings = [Int] -> Dependencies
IS.fromList ([Int] -> Dependencies) -> [Int] -> Dependencies
forall a b. (a -> b) -> a -> b
$ (PatElem Type -> Int) -> [PatElem Type] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Int
baseTag (VName -> Int) -> (PatElem Type -> VName) -> PatElem Type -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes',
            usageDependencies :: Dependencies
usageDependencies = Dependencies
observed Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Pat Type -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Pat Type
pat' Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Certs -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Certs
cs
          }

  case Exp GPU
e' of
    Op GPUBody {} ->
      Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm' Usage
usage Dependencies
consumed
    Exp GPU
_ ->
      Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm' Usage
usage Dependencies
consumed

  AliasTable -> StateT State PassM AliasTable
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AliasTable -> StateT State PassM AliasTable)
-> AliasTable -> StateT State PassM AliasTable
forall a b. (a -> b) -> a -> b
$ (AliasTable -> PatElem (VarAliases, Type) -> AliasTable)
-> AliasTable -> [PatElem (VarAliases, Type)] -> AliasTable
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl AliasTable -> PatElem (VarAliases, Type) -> AliasTable
forall {dec}.
AliasesOf dec =>
AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
aliases (Pat (VarAliases, Type) -> [PatElem (VarAliases, Type)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (VarAliases, Type)
Pat (LetDec (Aliases GPU))
pat)
  where
    rootAliasesOf :: Names -> AliasTable -> Names
rootAliasesOf Names
names AliasTable
atable =
      let look :: VName -> Names
look VName
n = Names -> VName -> AliasTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
n) VName
n AliasTable
atable
       in (VName -> Names) -> [VName] -> Names
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Names
look (Names -> [VName]
namesToList Names
names)

    recordAliases :: AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
atable PatElem dec
pe
      | PatElem dec -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
== Names
forall a. Monoid a => a
mempty =
          AliasTable
atable
      | Bool
otherwise =
          let root_aliases :: Names
root_aliases = Names -> AliasTable -> Names
rootAliasesOf (PatElem dec -> Names
forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe) AliasTable
atable
           in VName -> Names -> AliasTable -> AliasTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) Names
root_aliases AliasTable
atable

-- | Optimize a single expression and determine its dependencies.
transformExp ::
  AliasTable ->
  Exp (Aliases GPU) ->
  PassM (Exp GPU, Dependencies)
transformExp :: AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e =
  case Exp (Aliases GPU)
e of
    BasicOp {} -> (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Aliases GPU) -> Exp GPU
forall rep. RephraseOp (OpC rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, Exp (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
    Apply {} -> (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Aliases GPU) -> Exp GPU
forall rep. RephraseOp (OpC rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, Exp (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
    Match [SubExp]
ses [Case (Body (Aliases GPU))]
cases Body (Aliases GPU)
defbody MatchDec (BranchType (Aliases GPU))
dec -> do
      let transformCase :: Case (Body (Aliases GPU)) -> PassM (Case (Body GPU), Dependencies)
transformCase (Case [Maybe PrimValue]
vs Body (Aliases GPU)
body) =
            (Body GPU -> Case (Body GPU))
-> (Body GPU, Dependencies) -> (Case (Body GPU), Dependencies)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ([Maybe PrimValue] -> Body GPU -> Case (Body GPU)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) ((Body GPU, Dependencies) -> (Case (Body GPU), Dependencies))
-> PassM (Body GPU, Dependencies)
-> PassM (Case (Body GPU), Dependencies)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
      ([Case (Body GPU)]
cases', [Dependencies]
cases_deps) <- (Case (Body (Aliases GPU))
 -> PassM (Case (Body GPU), Dependencies))
-> [Case (Body (Aliases GPU))]
-> PassM ([Case (Body GPU)], [Dependencies])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Case (Body (Aliases GPU)) -> PassM (Case (Body GPU), Dependencies)
transformCase [Case (Body (Aliases GPU))]
cases
      (Body GPU
defbody', Dependencies
defbody_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
defbody
      let deps :: Dependencies
deps = [SubExp] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
ses Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [Dependencies] -> Dependencies
forall m. Monoid m => [m] -> m
mconcat [Dependencies]
cases_deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
defbody_deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> MatchDec ExtType -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf MatchDec ExtType
MatchDec (BranchType (Aliases GPU))
dec
      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
-> [Case (Body GPU)]
-> Body GPU
-> MatchDec (BranchType GPU)
-> Exp GPU
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType (Aliases GPU))
MatchDec (BranchType GPU)
dec, Dependencies
deps)
    Loop [(FParam (Aliases GPU), SubExp)]
merge LoopForm
lform Body (Aliases GPU)
body -> do
      -- What merge and lform aliases outside the loop is irrelevant as those
      -- cannot be consumed within the loop.
      (Body GPU
body', Dependencies
body_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
      let ([Param DeclType]
params, [SubExp]
args) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam (Aliases GPU), SubExp)]
merge
      let deps :: Dependencies
deps = Dependencies
body_deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [Param DeclType]
params Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
args Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> LoopForm -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf LoopForm
lform

      let scope :: Scope (Aliases GPU)
scope =
            LoopForm -> Scope (Aliases GPU)
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
lform Scope (Aliases GPU) -> Scope (Aliases GPU) -> Scope (Aliases GPU)
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope (Aliases GPU)
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params ::
              Scope (Aliases GPU)
      let bound :: Dependencies
bound = [Int] -> Dependencies
IS.fromList ([Int] -> Dependencies) -> [Int] -> Dependencies
forall a b. (a -> b) -> a -> b
$ (VName -> Int) -> [VName] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Int
baseTag (Scope (Aliases GPU) -> [VName]
forall k a. Map k a -> [k]
M.keys Scope (Aliases GPU)
scope)
      let deps' :: Dependencies
deps' = Dependencies
deps Dependencies -> Dependencies -> Dependencies
\\ Dependencies
bound

      let dummy :: Exp (Aliases GPU)
dummy =
            [(FParam (Aliases GPU), SubExp)]
-> LoopForm -> Body (Aliases GPU) -> Exp (Aliases GPU)
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam (Aliases GPU), SubExp)]
merge LoopForm
lform (BodyDec (Aliases GPU)
-> Stms (Aliases GPU) -> Result -> Body (Aliases GPU)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body (Body (Aliases GPU) -> BodyDec (Aliases GPU)
forall rep. Body rep -> BodyDec rep
bodyDec Body (Aliases GPU)
body) Stms (Aliases GPU)
forall a. Seq a
SQ.empty []) ::
              Exp (Aliases GPU)
      let Loop [(FParam GPU, SubExp)]
merge' LoopForm
lform' Body GPU
_ = Exp (Aliases GPU) -> Exp GPU
forall rep. RephraseOp (OpC rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
dummy

      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
merge' LoopForm
lform' Body GPU
body', Dependencies
deps')
    WithAcc [WithAccInput (Aliases GPU)]
inputs Lambda (Aliases GPU)
lambda -> do
      [(WithAccInput GPU, Dependencies)]
accs <- (WithAccInput (Aliases GPU)
 -> PassM (WithAccInput GPU, Dependencies))
-> [WithAccInput (Aliases GPU)]
-> PassM [(WithAccInput GPU, Dependencies)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases) [WithAccInput (Aliases GPU)]
inputs
      let ([WithAccInput GPU]
inputs', [Dependencies]
input_deps) = [(WithAccInput GPU, Dependencies)]
-> ([WithAccInput GPU], [Dependencies])
forall a b. [(a, b)] -> ([a], [b])
unzip [(WithAccInput GPU, Dependencies)]
accs
      -- The lambda parameters are all unique and thus have no aliases.
      (Lambda GPU
lambda', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
lambda
      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lambda', Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [Dependencies] -> Dependencies
forall m. Monoid m => [m] -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold [Dependencies]
input_deps)
    Op {} ->
      -- A GPUBody cannot be nested within other HostOp constructs.
      (Exp GPU, Dependencies) -> PassM (Exp GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Aliases GPU) -> Exp GPU
forall rep. RephraseOp (OpC rep) => Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, Exp (Aliases GPU) -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)

-- | Optimize a single WithAcc input and determine its dependencies.
transformWithAccInput ::
  AliasTable ->
  WithAccInput (Aliases GPU) ->
  PassM (WithAccInput GPU, Dependencies)
transformWithAccInput :: AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda (Aliases GPU), [SubExp])
op) = do
  (Maybe (Lambda GPU, [SubExp])
op', Dependencies
deps) <- case Maybe (Lambda (Aliases GPU), [SubExp])
op of
    Maybe (Lambda (Aliases GPU), [SubExp])
Nothing -> (Maybe (Lambda GPU, [SubExp]), Dependencies)
-> PassM (Maybe (Lambda GPU, [SubExp]), Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing, Dependencies
forall a. Monoid a => a
mempty)
    Just (Lambda (Aliases GPU)
f, [SubExp]
nes) -> do
      -- The lambda parameters have no aliases.
      (Lambda GPU
f', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
f
      (Maybe (Lambda GPU, [SubExp]), Dependencies)
-> PassM (Maybe (Lambda GPU, [SubExp]), Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
f', [SubExp]
nes), Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
nes)
  let deps' :: Dependencies
deps' = Dependencies
deps Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf ShapeBase SubExp
shape Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> [VName] -> Dependencies
forall a. FreeIn a => a -> Dependencies
depsOf [VName]
arrs
  (WithAccInput GPU, Dependencies)
-> PassM (WithAccInput GPU, Dependencies)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
op'), Dependencies
deps')

--------------------------------------------------------------------------------
--                             REORDERING - TYPES                             --
--------------------------------------------------------------------------------

-- | The monad used to reorder statements within a sequence such that its
-- GPUBody statements can be merged into as few possible kernels.
type ReorderM = StateT State PassM

-- | The state used by a 'ReorderM' monad.
data State = State
  { -- | All statements that already have been processed from the sequence,
    -- divided into alternating groups of non-GPUBody and GPUBody statements.
    -- Groups at even indices only contain non-GPUBody statements. Groups at
    -- odd indices only contain GPUBody statements.
    State -> Groups
stateGroups :: Groups,
    State -> EquivalenceTable
stateEquivalents :: EquivalenceTable
  }

-- | A map from variable tags to t'SubExp's returned from within GPUBodies.
type EquivalenceTable = IM.IntMap Entry

-- | An entry in an 'EquivalenceTable'.
data Entry = Entry
  { -- | A value returned from within a GPUBody kernel.
    -- In @let res = gpu { x }@ this is @x@.
    Entry -> SubExp
entryValue :: SubExp,
    -- | The type of the 'entryValue'.
    Entry -> Type
entryType :: Type,
    -- | The name of the variable that binds the return value for 'entryValue'.
    -- In @let res = gpu { x }@ this is @res@.
    Entry -> VName
entryResult :: VName,
    -- | The index of the group that `entryResult` is bound in.
    Entry -> Int
entryGroupIdx :: Int,
    -- | If 'False' then the entry key is a variable that binds the same value
    -- as the 'entryValue'. Otherwise it binds an array with an outer dimension
    -- of one whose row equals that value.
    Entry -> Bool
entryStored :: Bool
  }

type Groups = SQ.Seq Group

-- | A group is a subsequence of statements, usually either only GPUBody
-- statements or only non-GPUBody statements. The 'Usage' statistics of those
-- statements are also stored.
data Group = Group
  { -- | The statements of the group.
    Group -> Stms GPU
groupStms :: Stms GPU,
    -- | The usage statistics of the statements within the group.
    Group -> Usage
groupUsage :: Usage
  }

-- | Usage statistics for some set of statements.
data Usage = Usage
  { -- | The variables that the statements bind.
    Usage -> Dependencies
usageBindings :: Bindings,
    -- | The variables that the statements depend upon, i.e. the free variables
    -- of each statement and the root aliases of every array that they observe.
    Usage -> Dependencies
usageDependencies :: Dependencies
  }

instance Semigroup Group where
  (Group Stms GPU
s1 Usage
u1) <> :: Group -> Group -> Group
<> (Group Stms GPU
s2 Usage
u2) = Stms GPU -> Usage -> Group
Group (Stms GPU
s1 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
s2) (Usage
u1 Usage -> Usage -> Usage
forall a. Semigroup a => a -> a -> a
<> Usage
u2)

instance Monoid Group where
  mempty :: Group
mempty = Group {groupStms :: Stms GPU
groupStms = Stms GPU
forall a. Monoid a => a
mempty, groupUsage :: Usage
groupUsage = Usage
forall a. Monoid a => a
mempty}

instance Semigroup Usage where
  (Usage Dependencies
b1 Dependencies
d1) <> :: Usage -> Usage -> Usage
<> (Usage Dependencies
b2 Dependencies
d2) = Dependencies -> Dependencies -> Usage
Usage (Dependencies
b1 Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
b2) (Dependencies
d1 Dependencies -> Dependencies -> Dependencies
forall a. Semigroup a => a -> a -> a
<> Dependencies
d2)

instance Monoid Usage where
  mempty :: Usage
mempty = Usage {usageBindings :: Dependencies
usageBindings = Dependencies
forall a. Monoid a => a
mempty, usageDependencies :: Dependencies
usageDependencies = Dependencies
forall a. Monoid a => a
mempty}

--------------------------------------------------------------------------------
--                           REORDERING - FUNCTIONS                           --
--------------------------------------------------------------------------------

-- | Return the usage bindings of the group.
groupBindings :: Group -> Bindings
groupBindings :: Group -> Dependencies
groupBindings = Usage -> Dependencies
usageBindings (Usage -> Dependencies)
-> (Group -> Usage) -> Group -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage

-- | Return the usage dependencies of the group.
groupDependencies :: Group -> Dependencies
groupDependencies :: Group -> Dependencies
groupDependencies = Usage -> Dependencies
usageDependencies (Usage -> Dependencies)
-> (Group -> Usage) -> Group -> Dependencies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage

-- | An initial state to use when running a 'ReorderM' monad.
initialState :: State
initialState :: State
initialState =
  State
    { stateGroups :: Groups
stateGroups = Group -> Groups
forall a. a -> Seq a
SQ.singleton Group
forall a. Monoid a => a
mempty,
      stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
forall a. Monoid a => a
mempty
    }

-- | Modify the groups that the sequence has been split into so far.
modifyGroups :: (Groups -> Groups) -> ReorderM ()
modifyGroups :: (Groups -> Groups) -> StateT State PassM ()
modifyGroups Groups -> Groups
f =
  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = Groups -> Groups
f (State -> Groups
stateGroups State
st)}

-- | Remove these keys from the equivalence table.
removeEquivalents :: IS.IntSet -> ReorderM ()
removeEquivalents :: Dependencies -> StateT State PassM ()
removeEquivalents Dependencies
keys =
  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let eqs' :: EquivalenceTable
eqs' = State -> EquivalenceTable
stateEquivalents State
st EquivalenceTable -> Dependencies -> EquivalenceTable
forall a. IntMap a -> Dependencies -> IntMap a
`IM.withoutKeys` Dependencies
keys
     in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}

-- | Add an entry to the equivalence table.
recordEquivalent :: VName -> Entry -> ReorderM ()
recordEquivalent :: VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n Entry
entry =
  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let eqs :: EquivalenceTable
eqs = State -> EquivalenceTable
stateEquivalents State
st
        eqs' :: EquivalenceTable
eqs' = Int -> Entry -> EquivalenceTable -> EquivalenceTable
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) Entry
entry EquivalenceTable
eqs
     in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}

-- | Moves a GPUBody statement to the furthest possible group of the statement
-- sequence, possibly a new group at the end of sequence.
--
-- To simplify consumption handling a GPUBody is not allowed to merge with a
-- kernel whose result it consumes. Such GPUBody may therefore not be moved
-- into the same group as such kernel.
moveGPUBody :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveGPUBody :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm Usage
usage Dependencies
consumed = do
  -- Replace dependencies with their GPUBody result equivalents.
  EquivalenceTable
eqs <- (State -> EquivalenceTable) -> StateT State PassM EquivalenceTable
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
  let g :: Int -> Int
g Int
i = Int -> (Entry -> Int) -> Maybe Entry -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
i (VName -> Int
baseTag (VName -> Int) -> (Entry -> VName) -> Entry -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry -> VName
entryResult) (Int -> EquivalenceTable -> Maybe Entry
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i EquivalenceTable
eqs)
  let deps' :: Dependencies
deps' = (Int -> Int) -> Dependencies -> Dependencies
IS.map Int -> Int
g (Usage -> Dependencies
usageDependencies Usage
usage)
  let usage' :: Usage
usage' = Usage
usage {usageDependencies :: Dependencies
usageDependencies = Dependencies
deps'}

  -- Move the GPUBody.
  Groups
grps <- (State -> Groups) -> StateT State PassM Groups
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage' Dependencies
consumed
  let idx :: Int
idx = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
1 ((Group -> Bool) -> Groups -> Maybe Int
forall a. (a -> Bool) -> Seq a -> Maybe Int
SQ.findIndexR Group -> Bool
f Groups
grps)
  let idx' :: Int
idx' = case Int
idx Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
2 of
        Int
0 -> Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        Int
_ | Int -> Groups -> Bool
consumes Int
idx Groups
grps -> Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2
        Int
_ -> Int
idx
  (Groups -> Groups) -> StateT State PassM ()
modifyGroups ((Groups -> Groups) -> StateT State PassM ())
-> (Groups -> Groups) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Int
idx'

  -- Record the kernel equivalents of the bound results.
  let pes :: [PatElem Type]
pes = Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
  let Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
_ Result
res)) = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
  ((PatElem Type, SubExp) -> StateT State PassM ())
-> [(PatElem Type, SubExp)] -> StateT State PassM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Int
idx') ([PatElem Type] -> [SubExp] -> [(PatElem Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res))
  where
    consumes :: Int -> Groups -> Bool
consumes Int
idx Groups
grps
      | Just Group
grp <- Int -> Groups -> Maybe Group
forall a. Int -> Seq a -> Maybe a
SQ.lookup Int
idx Groups
grps =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Dependencies -> Dependencies -> Bool
IS.disjoint (Group -> Dependencies
groupBindings Group
grp) Dependencies
consumed
      | Bool
otherwise =
          Bool
False

    stores :: Int -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Int
idx (PatElem VName
n Type
t, SubExp
se)
      | Just Type
row_t <- Int -> Type -> Maybe Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 Type
t =
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n (Entry -> StateT State PassM ()) -> Entry -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Int -> Bool -> Entry
Entry SubExp
se Type
row_t VName
n Int
idx Bool
True
      | Bool
otherwise =
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n (Entry -> StateT State PassM ()) -> Entry -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Int -> Bool -> Entry
Entry SubExp
se Type
t VName
n Int
idx Bool
False

-- | Moves a non-GPUBody statement to the furthest possible groups of the
-- statement sequence, possibly a new group at the end of sequence.
moveOther :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveOther :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm Usage
usage Dependencies
consumed = do
  Groups
grps <- (State -> Groups) -> StateT State PassM Groups
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed
  let idx :: Int
idx = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 ((Group -> Bool) -> Groups -> Maybe Int
forall a. (a -> Bool) -> Seq a -> Maybe Int
SQ.findIndexR Group -> Bool
f Groups
grps)
  let idx' :: Int
idx' = ((Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2
  (Groups -> Groups) -> StateT State PassM ()
modifyGroups ((Groups -> Groups) -> StateT State PassM ())
-> (Groups -> Groups) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Int
idx'
  Stm GPU -> Int -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Int
idx'

-- | @recordEquivalentsOf stm idx@ records the GPUBody result and/or return
-- value that @stm@ is equivalent to. @idx@ is the index of the group that @stm@
-- belongs to.
--
-- A GPUBody can have a dependency substituted with a result equivalent if it
-- merges with the source GPUBody, allowing it to be moved beyond the binding
-- site of that dependency.
--
-- To guarantee that a GPUBody which moves beyond a dependency also merges with
-- its source GPUBody, equivalents are only allowed to be recorded for results
-- bound within the group at index @idx-1@.
recordEquivalentsOf :: Stm GPU -> Int -> ReorderM ()
recordEquivalentsOf :: Stm GPU -> Int -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Int
idx = do
  EquivalenceTable
eqs <- (State -> EquivalenceTable) -> StateT State PassM EquivalenceTable
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
  case Stm GPU
stm of
    Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (SubExp (Var VName
n)))
      | Just Entry
entry <- Int -> EquivalenceTable -> Maybe Entry
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) EquivalenceTable
eqs,
        Entry -> Int
entryGroupIdx Entry
entry Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 ->
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x Entry
entry
    Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
arr Slice SubExp
slice))
      | Just Entry
entry <- Int -> EquivalenceTable -> Maybe Entry
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
arr) EquivalenceTable
eqs,
        Entry -> Int
entryGroupIdx Entry
entry Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
idx Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1,
        Slice (DimFix SubExp
i : [DimIndex SubExp]
dims) <- Slice SubExp
slice,
        SubExp
i SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0,
        [DimIndex SubExp]
dims [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Entry -> Type
entryType Entry
entry) ->
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x (Entry
entry {entryStored :: Bool
entryStored = Bool
False})
    Stm GPU
_ -> () -> StateT State PassM ()
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Does this group block a statement with this usage/consumption statistics
-- from being moved past it?
groupBlocks :: Usage -> Consumption -> Group -> Bool
groupBlocks :: Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed Group
grp =
  let bound :: Dependencies
bound = Group -> Dependencies
groupBindings Group
grp
      deps :: Dependencies
deps = Group -> Dependencies
groupDependencies Group
grp

      used :: Dependencies
used = Usage -> Dependencies
usageDependencies Usage
usage
   in Bool -> Bool
not (Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
bound Dependencies
used Bool -> Bool -> Bool
&& Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
deps Dependencies
consumed)

-- | @moveToGrp stm idx grps@ moves @stm@ into the group at index @idx@ of
-- @grps@.
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Int
idx Groups
grps
  | Int
idx Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Groups -> Int
forall a. Seq a -> Int
SQ.length Groups
grps =
      (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Int
idx (Groups
grps Groups -> Group -> Groups
forall a. Seq a -> a -> Seq a
|> Group
forall a. Monoid a => a
mempty)
  | Bool
otherwise =
      (Group -> Group) -> Int -> Groups -> Groups
forall a. (a -> a) -> Int -> Seq a -> Seq a
SQ.adjust' ((Stm GPU, Usage)
stm `moveTo`) Int
idx Groups
grps

-- | Adds the statement and its usage statistics to the group.
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo (Stm GPU
stm, Usage
usage) Group
grp =
  Group
grp
    { groupStms :: Stms GPU
groupStms = Group -> Stms GPU
groupStms Group
grp Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Stm GPU
stm,
      groupUsage :: Usage
groupUsage = Group -> Usage
groupUsage Group
grp Usage -> Usage -> Usage
forall a. Semigroup a => a -> a -> a
<> Usage
usage
    }

--------------------------------------------------------------------------------
--                         MERGING GPU BODIES - TYPES                         --
--------------------------------------------------------------------------------

-- | The monad used for rewriting a GPUBody to use the t'SubExp's that are
-- returned from kernels it is merged with rather than the results that they
-- bind.
--
-- The state is a prologue of statements to be added at the beginning of the
-- rewritten kernel body.
type RewriteM = StateT (Stms GPU) ReorderM

--------------------------------------------------------------------------------
--                       MERGING GPU BODIES - FUNCTIONS                       --
--------------------------------------------------------------------------------

-- | Collapses the processed sequence of groups into a single group and returns
-- it, merging GPUBody groups into single kernels in the process.
collapse :: ReorderM Group
collapse :: StateT State PassM Group
collapse = do
  [(Bool, Group)]
grps <- [Bool] -> [Group] -> [(Bool, Group)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Bool] -> [Bool]
forall a. HasCallStack => [a] -> [a]
cycle [Bool
False, Bool
True]) ([Group] -> [(Bool, Group)])
-> (Groups -> [Group]) -> Groups -> [(Bool, Group)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Groups -> [Group]
forall a. Seq a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Groups -> [(Bool, Group)])
-> StateT State PassM Groups -> StateT State PassM [(Bool, Group)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Groups) -> StateT State PassM Groups
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  Group
grp <- (Group -> (Bool, Group) -> StateT State PassM Group)
-> Group -> [(Bool, Group)] -> StateT State PassM Group
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Group -> (Bool, Group) -> StateT State PassM Group
clps Group
forall a. Monoid a => a
mempty [(Bool, Group)]
grps

  (State -> State) -> StateT State PassM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> StateT State PassM ())
-> (State -> State) -> StateT State PassM ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = Group -> Groups
forall a. a -> Seq a
SQ.singleton Group
grp}
  Group -> StateT State PassM Group
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Group
grp
  where
    clps :: Group -> (Bool, Group) -> StateT State PassM Group
clps Group
grp0 (Bool
gpu_bodies, Group Stms GPU
stms Usage
usage) = do
      Group
grp1 <-
        if Bool
gpu_bodies
          then Stms GPU -> Usage -> Group
Group (Stms GPU -> Usage -> Group)
-> StateT State PassM (Stms GPU)
-> StateT State PassM (Usage -> Group)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> StateT State PassM (Stms GPU)
mergeKernels Stms GPU
stms StateT State PassM (Usage -> Group)
-> StateT State PassM Usage -> StateT State PassM Group
forall a b.
StateT State PassM (a -> b)
-> StateT State PassM a -> StateT State PassM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Usage -> StateT State PassM Usage
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Usage
usage
          else Group -> StateT State PassM Group
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> Usage -> Group
Group Stms GPU
stms Usage
usage)
      -- Remove equivalents that no longer are relevant for rewriting GPUBody
      -- kernels. This ensures that they are not substituted in later kernels
      -- where the replacement variables might not be in scope.
      Dependencies -> StateT State PassM ()
removeEquivalents (Group -> Dependencies
groupBindings Group
grp1)
      Group -> StateT State PassM Group
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Group
grp0 Group -> Group -> Group
forall a. Semigroup a => a -> a -> a
<> Group
grp1)

-- | Merges a sequence of GPUBody statements into a single kernel.
mergeKernels :: Stms GPU -> ReorderM (Stms GPU)
mergeKernels :: Stms GPU -> StateT State PassM (Stms GPU)
mergeKernels Stms GPU
stms
  | Stms GPU -> Int
forall a. Seq a -> Int
SQ.length Stms GPU
stms Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 =
      Stms GPU -> StateT State PassM (Stms GPU)
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
  | Bool
otherwise =
      Stm GPU -> Stms GPU
forall a. a -> Seq a
SQ.singleton (Stm GPU -> Stms GPU)
-> StateT State PassM (Stm GPU) -> StateT State PassM (Stms GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Stm GPU -> StateT State PassM (Stm GPU))
-> Stm GPU -> Stms GPU -> StateT State PassM (Stm GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> b -> m b) -> b -> t a -> m b
foldrM Stm GPU -> Stm GPU -> StateT State PassM (Stm GPU)
merge Stm GPU
empty Stms GPU
stms
  where
    empty :: Stm GPU
empty = Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec GPU)
forall a. Monoid a => a
mempty (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()) Exp GPU
noop
    noop :: Exp GPU
noop = Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([Type] -> Body GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [] (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty []))

    merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
    merge :: Stm GPU -> Stm GPU -> StateT State PassM (Stm GPU)
merge Stm GPU
stm0 Stm GPU
stm1
      | Let Pat (LetDec GPU)
pat0 (StmAux Certs
cs0 Attrs
attrs0 ExpDec GPU
_) (Op (GPUBody [Type]
types0 Body GPU
body)) <- Stm GPU
stm0,
        Let Pat (LetDec GPU)
pat1 (StmAux Certs
cs1 Attrs
attrs1 ExpDec GPU
_) (Op (GPUBody [Type]
types1 Body GPU
body1)) <- Stm GPU
stm1 =
          do
            Body BodyDec GPU
_ Stms GPU
stms0 Result
res0 <- RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite (Body GPU -> RewriteM (Body GPU)
rewriteBody Body GPU
body)
            let Body BodyDec GPU
_ Stms GPU
stms1 Result
res1 = Body GPU
body1

                pat' :: Pat Type
pat' = Pat Type
Pat (LetDec GPU)
pat0 Pat Type -> Pat Type -> Pat Type
forall a. Semigroup a => a -> a -> a
<> Pat Type
Pat (LetDec GPU)
pat1
                aux' :: StmAux ()
aux' = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs0 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs1) (Attrs
attrs0 Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Attrs
attrs1) ()
                types' :: [Type]
types' = [Type]
types0 [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
types1
                body' :: Body GPU
body' = BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
stms0 Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms1) (Result
res0 Result -> Result -> Result
forall a. Semigroup a => a -> a -> a
<> Result
res1)
             in Stm GPU -> StateT State PassM (Stm GPU)
forall a. a -> StateT State PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' StmAux ()
StmAux (ExpDec GPU)
aux' (Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op ([Type] -> Body GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types' Body GPU
body')))
    merge Stm GPU
_ Stm GPU
_ =
      String -> StateT State PassM (Stm GPU)
forall a. String -> a
compilerBugS String
"mergeGPUBodies: cannot merge non-GPUBody statements"

-- | Perform a rewrite and finish it by adding the rewrite prologue to the start
-- of the body.
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite RewriteM (Body GPU)
m = RewriteM (Body GPU) -> Stms GPU -> ReorderM (Body GPU)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT RewriteM (Body GPU)
m' Stms GPU
forall a. Seq a
SQ.empty
  where
    m' :: RewriteM (Body GPU)
m' = do
      Body BodyDec GPU
_ Stms GPU
stms Result
res <- RewriteM (Body GPU)
m
      Stms GPU
prologue <- StateT (Stms GPU) ReorderM (Stms GPU)
forall (m :: * -> *) s. Monad m => StateT s m s
get
      Body GPU -> RewriteM (Body GPU)
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue Stms GPU -> Stms GPU -> Stms GPU
forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms) Result
res)

-- | Return the equivalence table.
equivalents :: RewriteM EquivalenceTable
equivalents :: RewriteM EquivalenceTable
equivalents = StateT State PassM EquivalenceTable -> RewriteM EquivalenceTable
forall (m :: * -> *) a. Monad m => m a -> StateT (Stms GPU) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ((State -> EquivalenceTable) -> StateT State PassM EquivalenceTable
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents)

rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) =
  BodyDec GPU -> Stms GPU -> Result -> Body GPU
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU -> Result -> Body GPU)
-> StateT (Stms GPU) ReorderM (Stms GPU)
-> StateT (Stms GPU) ReorderM (Result -> Body GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> StateT (Stms GPU) ReorderM (Stms GPU)
rewriteStms Stms GPU
stms StateT (Stms GPU) ReorderM (Result -> Body GPU)
-> StateT (Stms GPU) ReorderM Result -> RewriteM (Body GPU)
forall a b.
StateT (Stms GPU) ReorderM (a -> b)
-> StateT (Stms GPU) ReorderM a -> StateT (Stms GPU) ReorderM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> StateT (Stms GPU) ReorderM Result
rewriteResult Result
res

rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> StateT (Stms GPU) ReorderM (Stms GPU)
rewriteStms = (Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU))
-> Stms GPU -> StateT (Stms GPU) ReorderM (Stms GPU)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
rewriteStm

rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
rewriteStm (Let (Pat [PatElem (LetDec GPU)]
pes) (StmAux Certs
cs Attrs
attrs ExpDec GPU
_) Exp GPU
e) = do
  Pat Type
pat' <- [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([PatElem Type] -> Pat Type)
-> StateT (Stms GPU) ReorderM [PatElem Type]
-> StateT (Stms GPU) ReorderM (Pat Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElem Type -> StateT (Stms GPU) ReorderM (PatElem Type))
-> [PatElem Type] -> StateT (Stms GPU) ReorderM [PatElem Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem Type -> StateT (Stms GPU) ReorderM (PatElem Type)
rewritePatElem [PatElem Type]
[PatElem (LetDec GPU)]
pes
  Certs
cs' <- Certs -> RewriteM Certs
rewriteCerts Certs
cs
  Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
  Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU))
-> Stm GPU -> StateT (Stms GPU) ReorderM (Stm GPU)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat' (Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs' Attrs
attrs ()) Exp GPU
e'

rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> StateT (Stms GPU) ReorderM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) =
  VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
n (Type -> PatElem Type)
-> StateT (Stms GPU) ReorderM Type
-> StateT (Stms GPU) ReorderM (PatElem Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Type -> StateT (Stms GPU) ReorderM Type
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType Type
t

rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e = do
  EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
  case Exp GPU
e of
    BasicOp (Index VName
arr Slice SubExp
slice)
      | Just Entry
entry <- Int -> EquivalenceTable -> Maybe Entry
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
arr) EquivalenceTable
eqs,
        DimFix SubExp
idx : [DimIndex SubExp]
dims <- Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        SubExp
idx SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 ->
          let se :: SubExp
se = Entry -> SubExp
entryValue Entry
entry
           in Exp GPU -> RewriteM (Exp GPU)
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp GPU -> RewriteM (Exp GPU))
-> (BasicOp -> Exp GPU) -> BasicOp -> RewriteM (Exp GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp GPU
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RewriteM (Exp GPU)) -> BasicOp -> RewriteM (Exp GPU)
forall a b. (a -> b) -> a -> b
$ case ([DimIndex SubExp]
dims, SubExp
se) of
                ([], SubExp
_) -> SubExp -> BasicOp
SubExp SubExp
se
                ([DimIndex SubExp]
_, Var VName
src) -> VName -> Slice SubExp -> BasicOp
Index VName
src ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
dims)
                ([DimIndex SubExp], SubExp)
_ -> String -> BasicOp
forall a. String -> a
compilerBugS String
"rewriteExp: bad equivalence entry"
    Exp GPU
_ -> Mapper GPU GPU (StateT (Stms GPU) ReorderM)
-> Exp GPU -> RewriteM (Exp GPU)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter Exp GPU
e
  where
    rewriter :: Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter =
      Mapper
        { mapOnSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
mapOnSubExp = SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp,
          mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = (Body GPU -> RewriteM (Body GPU))
-> Scope GPU -> Body GPU -> RewriteM (Body GPU)
forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
          mapOnVName :: VName -> StateT (Stms GPU) ReorderM VName
mapOnVName = VName -> StateT (Stms GPU) ReorderM VName
rewriteName,
          mapOnRetType :: RetType GPU -> StateT (Stms GPU) ReorderM (RetType GPU)
mapOnRetType = DeclExtType -> RewriteM DeclExtType
RetType GPU -> StateT (Stms GPU) ReorderM (RetType GPU)
forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType,
          mapOnBranchType :: BranchType GPU -> StateT (Stms GPU) ReorderM (BranchType GPU)
mapOnBranchType = ExtType -> RewriteM ExtType
BranchType GPU -> StateT (Stms GPU) ReorderM (BranchType GPU)
forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType,
          mapOnFParam :: FParam GPU -> StateT (Stms GPU) ReorderM (FParam GPU)
mapOnFParam = Param DeclType -> RewriteM (Param DeclType)
FParam GPU -> StateT (Stms GPU) ReorderM (FParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
          mapOnLParam :: LParam GPU -> StateT (Stms GPU) ReorderM (LParam GPU)
mapOnLParam = Param Type -> RewriteM (Param Type)
LParam GPU -> StateT (Stms GPU) ReorderM (LParam GPU)
forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
          mapOnOp :: Op GPU -> StateT (Stms GPU) ReorderM (Op GPU)
mapOnOp = StateT (Stms GPU) ReorderM (HostOp SOAC GPU)
-> HostOp SOAC GPU -> StateT (Stms GPU) ReorderM (HostOp SOAC GPU)
forall a b. a -> b -> a
const StateT (Stms GPU) ReorderM (HostOp SOAC GPU)
forall {a}. a
opError
        }

    opError :: a
opError = String -> a
forall a. String -> a
compilerBugS String
"rewriteExp: unhandled HostOp in GPUBody"

rewriteResult :: Result -> RewriteM Result
rewriteResult :: Result -> StateT (Stms GPU) ReorderM Result
rewriteResult = (SubExpRes -> StateT (Stms GPU) ReorderM SubExpRes)
-> Result -> StateT (Stms GPU) ReorderM Result
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExpRes -> StateT (Stms GPU) ReorderM SubExpRes
rewriteSubExpRes

rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes :: SubExpRes -> StateT (Stms GPU) ReorderM SubExpRes
rewriteSubExpRes (SubExpRes Certs
cs SubExp
se) =
  Certs -> SubExp -> SubExpRes
SubExpRes (Certs -> SubExp -> SubExpRes)
-> RewriteM Certs
-> StateT (Stms GPU) ReorderM (SubExp -> SubExpRes)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> RewriteM Certs
rewriteCerts Certs
cs StateT (Stms GPU) ReorderM (SubExp -> SubExpRes)
-> StateT (Stms GPU) ReorderM SubExp
-> StateT (Stms GPU) ReorderM SubExpRes
forall a b.
StateT (Stms GPU) ReorderM (a -> b)
-> StateT (Stms GPU) ReorderM a -> StateT (Stms GPU) ReorderM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp SubExp
se

rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts (Certs [VName]
cs) =
  [VName] -> Certs
Certs ([VName] -> Certs)
-> StateT (Stms GPU) ReorderM [VName] -> RewriteM Certs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> StateT (Stms GPU) ReorderM VName)
-> [VName] -> StateT (Stms GPU) ReorderM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> StateT (Stms GPU) ReorderM VName
rewriteName [VName]
cs

rewriteType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
-- Note: mapOnType also maps the VName token of accumulators
rewriteType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType = (SubExp -> StateT (Stms GPU) ReorderM SubExp)
-> TypeBase (ShapeBase SubExp) u
-> StateT (Stms GPU) ReorderM (TypeBase (ShapeBase SubExp) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp

rewriteExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
-- Note: mapOnExtType also maps the VName token of accumulators
rewriteExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType = (SubExp -> StateT (Stms GPU) ReorderM SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> StateT (Stms GPU) ReorderM (TypeBase (ShapeBase ExtSize) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp

rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) =
  Attrs
-> VName
-> TypeBase (ShapeBase SubExp) u
-> Param (TypeBase (ShapeBase SubExp) u)
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n (TypeBase (ShapeBase SubExp) u
 -> Param (TypeBase (ShapeBase SubExp) u))
-> StateT (Stms GPU) ReorderM (TypeBase (ShapeBase SubExp) u)
-> StateT
     (Stms GPU) ReorderM (Param (TypeBase (ShapeBase SubExp) u))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeBase (ShapeBase SubExp) u
-> StateT (Stms GPU) ReorderM (TypeBase (ShapeBase SubExp) u)
forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType TypeBase (ShapeBase SubExp) u
t

rewriteSubExp :: SubExp -> RewriteM SubExp
rewriteSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (Constant PrimValue
c) = SubExp -> StateT (Stms GPU) ReorderM SubExp
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimValue -> SubExp
Constant PrimValue
c)
rewriteSubExp (Var VName
n) = do
  EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
  case Int -> EquivalenceTable -> Maybe Entry
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) EquivalenceTable
eqs of
    Maybe Entry
Nothing -> SubExp -> StateT (Stms GPU) ReorderM SubExp
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
n)
    Just (Entry SubExp
se Type
_ VName
_ Int
_ Bool
False) -> SubExp -> StateT (Stms GPU) ReorderM SubExp
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
    Just (Entry SubExp
se Type
t VName
_ Int
_ Bool
True) -> VName -> SubExp
Var (VName -> SubExp)
-> StateT (Stms GPU) ReorderM VName
-> StateT (Stms GPU) ReorderM SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
t

rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT (Stms GPU) ReorderM VName
rewriteName VName
n = do
  SubExp
se <- SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (VName -> SubExp
Var VName
n)
  case SubExp
se of
    Var VName
n' -> VName -> StateT (Stms GPU) ReorderM VName
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
    Constant PrimValue
c -> PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c

-- | @asArray se t@ adds @let x = [se]@ to the rewrite prologue and returns the
-- name of @x@. @t@ is the type of @se@.
asArray :: SubExp -> Type -> RewriteM VName
asArray :: SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
row_t = do
  VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"arr"
  let t :: Type
t = Type
row_t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

  let pat :: Pat Type
pat = [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
name Type
t]
  let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
  let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp ([SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
row_t)

  (Stms GPU -> Stms GPU) -> StateT (Stms GPU) ReorderM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall {rep}. Exp rep
e)
  VName -> StateT (Stms GPU) ReorderM VName
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name

-- | @referConst c@ adds @let x = c@ to the rewrite prologue and returns the
-- name of @x@.
referConst :: PrimValue -> RewriteM VName
referConst :: PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c = do
  VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"cnst"
  let t :: TypeBase shape u
t = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimValue -> PrimType
primValueType PrimValue
c)

  let pat :: Pat (TypeBase shape u)
pat = [PatElem (TypeBase shape u)] -> Pat (TypeBase shape u)
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> TypeBase shape u -> PatElem (TypeBase shape u)
forall dec. VName -> dec -> PatElem dec
PatElem VName
name TypeBase shape u
forall {shape} {u}. TypeBase shape u
t]
  let aux :: StmAux ()
aux = Certs -> Attrs -> () -> StmAux ()
forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
forall a. Monoid a => a
mempty Attrs
forall a. Monoid a => a
mempty ()
  let e :: Exp rep
e = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
c)

  (Stms GPU -> Stms GPU) -> StateT (Stms GPU) ReorderM ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (Stms GPU -> Stm GPU -> Stms GPU
forall a. Seq a -> a -> Seq a
|> Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
forall {shape} {u}. Pat (TypeBase shape u)
pat StmAux ()
StmAux (ExpDec GPU)
aux Exp GPU
forall {rep}. Exp rep
e)
  VName -> StateT (Stms GPU) ReorderM VName
forall a. a -> StateT (Stms GPU) ReorderM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name

-- | Produce a fresh name, using the given string as a template.
newName :: String -> RewriteM VName
newName :: String -> StateT (Stms GPU) ReorderM VName
newName String
s = ReorderM VName -> StateT (Stms GPU) ReorderM VName
forall (m :: * -> *) a. Monad m => m a -> StateT (Stms GPU) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReorderM VName -> StateT (Stms GPU) ReorderM VName)
-> ReorderM VName -> StateT (Stms GPU) ReorderM VName
forall a b. (a -> b) -> a -> b
$ PassM VName -> ReorderM VName
forall (m :: * -> *) a. Monad m => m a -> StateT State m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (String -> PassM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
s)