-- |
-- This module implements an optimization pass that merges 'GPUBody' kernels to
-- eliminate memory transactions and reduce the number of kernel launches.
-- This is useful because the "Futhark.Optimise.ReduceDeviceSyncs" pass introduces
-- 'GPUBody' kernels that only execute single statements.
--
-- To merge as many 'GPUBody' kernels as possible, this pass reorders statements
-- with the goal of bringing as many 'GPUBody' statements next to each other in
-- a sequence. Such sequence can then trivially be merged.
module Futhark.Optimise.MergeGPUBodies (mergeGPUBodies) where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first)
import Data.Foldable
import Data.IntMap qualified as IM
import Data.IntSet ((\\))
import Data.IntSet qualified as IS
import Data.Map qualified as M
import Data.Maybe (fromMaybe)
import Data.Sequence ((|>))
import Data.Sequence qualified as SQ
import Futhark.Analysis.Alias
import Futhark.Construct (sliceDim)
import Futhark.Error
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.MonadFreshNames hiding (newName)
import Futhark.Pass

-- | An optimization pass that reorders and merges 'GPUBody' statements to
-- eliminate memory transactions and reduce the number of kernel launches.
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies :: Pass GPU GPU
mergeGPUBodies =
  forall {k} {k1} (fromrep :: k) (torep :: k1).
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    String
"merge GPU bodies"
    String
"Reorder and merge GPUBody constructs to reduce kernels executions."
    forall a b. (a -> b) -> a -> b
$ forall {k1} {k2} (fromrep :: k1) (torep :: k2).
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases GPU) -> PassM (Stms GPU)
onStms forall {p}. p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
Prog rep -> Prog (Aliases rep)
aliasAnalysis
  where
    onFunDef :: p -> FunDef (Aliases GPU) -> PassM (FunDef GPU)
onFunDef p
_ (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType (Aliases GPU)]
types [FParam (Aliases GPU)]
params Body (Aliases GPU)
body) =
      forall {k} (rep :: k).
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType (Aliases GPU)]
types [FParam (Aliases GPU)]
params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody forall a. Monoid a => a
mempty Body (Aliases GPU)
body
    onStms :: Stms (Aliases GPU) -> PassM (Stms GPU)
onStms Stms (Aliases GPU)
stms =
      forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms

--------------------------------------------------------------------------------
--                               COMMON - TYPES                               --
--------------------------------------------------------------------------------

-- | A set of 'VName' tags that denote all variables that some group of
-- statements depend upon. Those must be computed before the group statements.
type Dependencies = IS.IntSet

-- | A set of 'VName' tags that denote all variables that some group of
-- statements binds.
type Bindings = IS.IntSet

-- | A set of 'VName' tags that denote the root aliases of all arrays that some
-- statement consumes.
type Consumption = IS.IntSet

--------------------------------------------------------------------------------
--                              COMMON - HELPERS                              --
--------------------------------------------------------------------------------

-- | All free variables of a construct as 'Dependencies'.
depsOf :: FreeIn a => a -> Dependencies
depsOf :: forall a. FreeIn a => a -> Dependencies
depsOf = Names -> Dependencies
namesToSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn

-- | Convert 'Names' to an integer set of name tags.
namesToSet :: Names -> IS.IntSet
namesToSet :: Names -> Dependencies
namesToSet = [Int] -> Dependencies
IS.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map VName -> Int
baseTag forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

--------------------------------------------------------------------------------
--                            AD HOC OPTIMIZATION                             --
--------------------------------------------------------------------------------

-- | Optimize a lambda and determine its dependencies.
transformLambda ::
  AliasTable ->
  Lambda (Aliases GPU) ->
  PassM (Lambda GPU, Dependencies)
transformLambda :: AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases (Lambda [LParam (Aliases GPU)]
params Body (Aliases GPU)
body [Type]
types) = do
  (Body GPU
body', Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam (Aliases GPU)]
params Body GPU
body' [Type]
types, Dependencies
deps)

-- | Optimize a body and determine its dependencies.
transformBody ::
  AliasTable ->
  Body (Aliases GPU) ->
  PassM (Body GPU, Dependencies)
transformBody :: AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (Body BodyDec (Aliases GPU)
_ Stms (Aliases GPU)
stms Result
res) = do
  Group
grp <- forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm AliasTable
aliases Stms (Aliases GPU)
stms forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StateT State PassM Group
collapse) State
initialState

  let stms' :: Stms GPU
stms' = Group -> Stms GPU
groupStms Group
grp
  let deps :: Dependencies
deps = (Group -> Dependencies
groupDependencies Group
grp forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf Result
res) Dependencies -> Dependencies -> Dependencies
\\ Group -> Dependencies
groupBindings Group
grp

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms GPU
stms' Result
res, Dependencies
deps)

-- | Optimize a sequence of statements and determine their dependencies.
transformStms ::
  AliasTable ->
  Stms (Aliases GPU) ->
  PassM (Stms GPU, Dependencies)
transformStms :: AliasTable -> Stms (Aliases GPU) -> PassM (Stms GPU, Dependencies)
transformStms AliasTable
aliases Stms (Aliases GPU)
stms = do
  (Body BodyDec GPU
_ Stms GPU
stms' Result
_, Dependencies
deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body forall a. Monoid a => a
mempty Stms (Aliases GPU)
stms [])
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU
stms', Dependencies
deps)

-- | Optimizes and reorders a single statement within a sequence while tracking
-- the declaration, observation, and consumption of its dependencies.
-- This creates sequences of GPUBody statements that can be merged into single
-- kernels.
reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm :: AliasTable -> Stm (Aliases GPU) -> ReorderM AliasTable
reorderStm AliasTable
aliases (Let Pat (LetDec (Aliases GPU))
pat (StmAux Certs
cs Attrs
attrs ExpDec (Aliases GPU)
_) Exp (Aliases GPU)
e) = do
  (Exp GPU
e', Dependencies
deps) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e)
  let pat' :: Pat Type
pat' = forall a. Pat (VarAliases, a) -> Pat a
removePatAliases Pat (LetDec (Aliases GPU))
pat
  let stm' :: Stm GPU
stm' = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs Attrs
attrs ()) Exp GPU
e'
  let pes' :: [PatElem Type]
pes' = forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'

  -- Array aliases can be seen as a directed graph where vertices are arrays
  -- (or the names that bind them) and an edge x -> y denotes that x aliases y.
  -- The root aliases of some array A is then the set of arrays that can be
  -- reached from A in graph and which have no edges themselves.
  --
  -- All arrays that share a root alias are considered aliases of each other
  -- and will be consumed if either of them is consumed.
  -- When reordering statements we must ensure that no statement that consumes
  -- an array is moved before any statement that observes one of its aliases.
  --
  -- That is to move statement X before statement Y the set of root aliases of
  -- arrays consumed by X must not overlap with the root aliases of arrays
  -- observed by Y.
  --
  -- We consider the root aliases of Y's observed arrays as part of Y's
  -- dependencies and simply say that the root aliases of arrays consumed by X
  -- must not overlap those.
  --
  -- To move X before Y then the dependencies of X must also not overlap with
  -- the variables bound by Y.

  let observed :: Dependencies
observed = Names -> Dependencies
namesToSet forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf (forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Exp rep -> [Names]
expAliases Exp (Aliases GPU)
e) AliasTable
aliases
  let consumed :: Dependencies
consumed = Names -> Dependencies
namesToSet forall a b. (a -> b) -> a -> b
$ Names -> AliasTable -> Names
rootAliasesOf (forall {k} (rep :: k). Aliased rep => Exp rep -> Names
consumedInExp Exp (Aliases GPU)
e) AliasTable
aliases
  let usage :: Usage
usage =
        Usage
          { usageBindings :: Dependencies
usageBindings = [Int] -> Dependencies
IS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> Int
baseTag forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes',
            usageDependencies :: Dependencies
usageDependencies = Dependencies
observed forall a. Semigroup a => a -> a -> a
<> Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf Pat Type
pat' forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf Certs
cs
          }

  case Exp GPU
e' of
    Op GPUBody {} ->
      Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm' Usage
usage Dependencies
consumed
    Exp GPU
_ ->
      Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm' Usage
usage Dependencies
consumed

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {dec}.
AliasesOf dec =>
AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
aliases (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Aliases GPU))
pat)
  where
    rootAliasesOf :: Names -> AliasTable -> Names
rootAliasesOf Names
names AliasTable
atable =
      let look :: VName -> Names
look VName
n = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
n) VName
n AliasTable
atable
       in forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Names
look (Names -> [VName]
namesToList Names
names)

    recordAliases :: AliasTable -> PatElem dec -> AliasTable
recordAliases AliasTable
atable PatElem dec
pe
      | forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty =
          AliasTable
atable
      | Bool
otherwise =
          let root_aliases :: Names
root_aliases = Names -> AliasTable -> Names
rootAliasesOf (forall a. AliasesOf a => a -> Names
aliasesOf PatElem dec
pe) AliasTable
atable
           in forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) Names
root_aliases AliasTable
atable

-- | Optimize a single expression and determine its dependencies.
transformExp ::
  AliasTable ->
  Exp (Aliases GPU) ->
  PassM (Exp GPU, Dependencies)
transformExp :: AliasTable -> Exp (Aliases GPU) -> PassM (Exp GPU, Dependencies)
transformExp AliasTable
aliases Exp (Aliases GPU)
e =
  case Exp (Aliases GPU)
e of
    BasicOp {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
    Apply {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)
    Match [SubExp]
ses [Case (Body (Aliases GPU))]
cases Body (Aliases GPU)
defbody MatchDec (BranchType (Aliases GPU))
dec -> do
      let transformCase :: Case (Body (Aliases GPU)) -> PassM (Case (Body GPU), Dependencies)
transformCase (Case [Maybe PrimValue]
vs Body (Aliases GPU)
body) =
            forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
      ([Case (Body GPU)]
cases', [Dependencies]
cases_deps) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body (Aliases GPU)) -> PassM (Case (Body GPU), Dependencies)
transformCase [Case (Body (Aliases GPU))]
cases
      (Body GPU
defbody', Dependencies
defbody_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
defbody
      let deps :: Dependencies
deps = forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
ses forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Dependencies]
cases_deps forall a. Semigroup a => a -> a -> a
<> Dependencies
defbody_deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf MatchDec (BranchType (Aliases GPU))
dec
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType (Aliases GPU))
dec, Dependencies
deps)
    DoLoop [(FParam (Aliases GPU), SubExp)]
merge LoopForm (Aliases GPU)
lform Body (Aliases GPU)
body -> do
      -- What merge and lform aliases outside the loop is irrelevant as those
      -- cannot be consumed within the loop.
      (Body GPU
body', Dependencies
body_deps) <- AliasTable -> Body (Aliases GPU) -> PassM (Body GPU, Dependencies)
transformBody AliasTable
aliases Body (Aliases GPU)
body
      let ([Param DeclType]
params, [SubExp]
args) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam (Aliases GPU), SubExp)]
merge
      let deps :: Dependencies
deps = Dependencies
body_deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [Param DeclType]
params forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
args forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf LoopForm (Aliases GPU)
lform

      let scope :: Scope (Aliases GPU)
scope = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm (Aliases GPU)
lform forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
      let bound :: Dependencies
bound = [Int] -> Dependencies
IS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> Int
baseTag (forall k a. Map k a -> [k]
M.keys Scope (Aliases GPU)
scope)
      let deps' :: Dependencies
deps' = Dependencies
deps Dependencies -> Dependencies -> Dependencies
\\ Dependencies
bound

      let dummy :: Exp (Aliases GPU)
dummy = forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam (Aliases GPU), SubExp)]
merge LoopForm (Aliases GPU)
lform (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body (forall {k} (rep :: k). Body rep -> BodyDec rep
bodyDec Body (Aliases GPU)
body) forall a. Seq a
SQ.empty [])
      let DoLoop [(FParam GPU, SubExp)]
merge' LoopForm GPU
lform' Body GPU
_ = forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
dummy

      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
merge' LoopForm GPU
lform' Body GPU
body', Dependencies
deps')
    WithAcc [WithAccInput (Aliases GPU)]
inputs Lambda (Aliases GPU)
lambda -> do
      [(WithAccInput GPU, Dependencies)]
accs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases) [WithAccInput (Aliases GPU)]
inputs
      let ([WithAccInput GPU]
inputs', [Dependencies]
input_deps) = forall a b. [(a, b)] -> ([a], [b])
unzip [(WithAccInput GPU, Dependencies)]
accs
      -- The lambda parameters are all unique and thus have no aliases.
      (Lambda GPU
lambda', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
lambda
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lambda', Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold [Dependencies]
input_deps)
    Op {} ->
      -- A GPUBody cannot be nested within other HostOp constructs.
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Exp (Aliases rep) -> Exp rep
removeExpAliases Exp (Aliases GPU)
e, forall a. FreeIn a => a -> Dependencies
depsOf Exp (Aliases GPU)
e)

-- | Optimize a single WithAcc input and determine its dependencies.
transformWithAccInput ::
  AliasTable ->
  WithAccInput (Aliases GPU) ->
  PassM (WithAccInput GPU, Dependencies)
transformWithAccInput :: AliasTable
-> WithAccInput (Aliases GPU)
-> PassM (WithAccInput GPU, Dependencies)
transformWithAccInput AliasTable
aliases (ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda (Aliases GPU), [SubExp])
op) = do
  (Maybe (Lambda GPU, [SubExp])
op', Dependencies
deps) <- case Maybe (Lambda (Aliases GPU), [SubExp])
op of
    Maybe (Lambda (Aliases GPU), [SubExp])
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Maybe a
Nothing, forall a. Monoid a => a
mempty)
    Just (Lambda (Aliases GPU)
f, [SubExp]
nes) -> do
      -- The lambda parameters have no aliases.
      (Lambda GPU
f', Dependencies
deps) <- AliasTable
-> Lambda (Aliases GPU) -> PassM (Lambda GPU, Dependencies)
transformLambda AliasTable
aliases Lambda (Aliases GPU)
f
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just (Lambda GPU
f', [SubExp]
nes), Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [SubExp]
nes)
  let deps' :: Dependencies
deps' = Dependencies
deps forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf ShapeBase SubExp
shape forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> Dependencies
depsOf [VName]
arrs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ShapeBase SubExp
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
op'), Dependencies
deps')

--------------------------------------------------------------------------------
--                             REORDERING - TYPES                             --
--------------------------------------------------------------------------------

-- | The monad used to reorder statements within a sequence such that its
-- GPUBody statements can be merged into as few possible kernels.
type ReorderM = StateT State PassM

-- | The state used by a 'ReorderM' monad.
data State = State
  { -- | All statements that already have been processed from the sequence,
    -- divided into alternating groups of non-GPUBody and GPUBody statements.
    -- Groups at even indices only contain non-GPUBody statements. Groups at
    -- odd indices only contain GPUBody statements.
    State -> Groups
stateGroups :: Groups,
    State -> EquivalenceTable
stateEquivalents :: EquivalenceTable
  }

-- | A map from variable tags to t'SubExp's returned from within GPUBodies.
type EquivalenceTable = IM.IntMap Entry

-- | An entry in an 'EquivalenceTable'.
data Entry = Entry
  { -- | A value returned from within a GPUBody kernel.
    -- In @let res = gpu { x }@ this is @x@.
    Entry -> SubExp
entryValue :: SubExp,
    -- | The type of the 'entryValue'.
    Entry -> Type
entryType :: Type,
    -- | The name of the variable that binds the return value for 'entryValue'.
    -- In @let res = gpu { x }@ this is @res@.
    Entry -> VName
entryResult :: VName,
    -- | The index of the group that `entryResult` is bound in.
    Entry -> Int
entryGroupIdx :: Int,
    -- | If 'False' then the entry key is a variable that binds the same value
    -- as the 'entryValue'. Otherwise it binds an array with an outer dimension
    -- of one whose row equals that value.
    Entry -> Bool
entryStored :: Bool
  }

type Groups = SQ.Seq Group

-- | A group is a subsequence of statements, usually either only GPUBody
-- statements or only non-GPUBody statements. The 'Usage' statistics of those
-- statements are also stored.
data Group = Group
  { -- | The statements of the group.
    Group -> Stms GPU
groupStms :: Stms GPU,
    -- | The usage statistics of the statements within the group.
    Group -> Usage
groupUsage :: Usage
  }

-- | Usage statistics for some set of statements.
data Usage = Usage
  { -- | The variables that the statements bind.
    Usage -> Dependencies
usageBindings :: Bindings,
    -- | The variables that the statements depend upon, i.e. the free variables
    -- of each statement and the root aliases of every array that they observe.
    Usage -> Dependencies
usageDependencies :: Dependencies
  }

instance Semigroup Group where
  (Group Stms GPU
s1 Usage
u1) <> :: Group -> Group -> Group
<> (Group Stms GPU
s2 Usage
u2) = Stms GPU -> Usage -> Group
Group (Stms GPU
s1 forall a. Semigroup a => a -> a -> a
<> Stms GPU
s2) (Usage
u1 forall a. Semigroup a => a -> a -> a
<> Usage
u2)

instance Monoid Group where
  mempty :: Group
mempty = Group {groupStms :: Stms GPU
groupStms = forall a. Monoid a => a
mempty, groupUsage :: Usage
groupUsage = forall a. Monoid a => a
mempty}

instance Semigroup Usage where
  (Usage Dependencies
b1 Dependencies
d1) <> :: Usage -> Usage -> Usage
<> (Usage Dependencies
b2 Dependencies
d2) = Dependencies -> Dependencies -> Usage
Usage (Dependencies
b1 forall a. Semigroup a => a -> a -> a
<> Dependencies
b2) (Dependencies
d1 forall a. Semigroup a => a -> a -> a
<> Dependencies
d2)

instance Monoid Usage where
  mempty :: Usage
mempty = Usage {usageBindings :: Dependencies
usageBindings = forall a. Monoid a => a
mempty, usageDependencies :: Dependencies
usageDependencies = forall a. Monoid a => a
mempty}

--------------------------------------------------------------------------------
--                           REORDERING - FUNCTIONS                           --
--------------------------------------------------------------------------------

-- | Return the usage bindings of the group.
groupBindings :: Group -> Bindings
groupBindings :: Group -> Dependencies
groupBindings = Usage -> Dependencies
usageBindings forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage

-- | Return the usage dependencies of the group.
groupDependencies :: Group -> Dependencies
groupDependencies :: Group -> Dependencies
groupDependencies = Usage -> Dependencies
usageDependencies forall b c a. (b -> c) -> (a -> b) -> a -> c
. Group -> Usage
groupUsage

-- | An initial state to use when running a 'ReorderM' monad.
initialState :: State
initialState :: State
initialState =
  State
    { stateGroups :: Groups
stateGroups = forall a. a -> Seq a
SQ.singleton forall a. Monoid a => a
mempty,
      stateEquivalents :: EquivalenceTable
stateEquivalents = forall a. Monoid a => a
mempty
    }

-- | Modify the groups that the sequence has been split into so far.
modifyGroups :: (Groups -> Groups) -> ReorderM ()
modifyGroups :: (Groups -> Groups) -> StateT State PassM ()
modifyGroups Groups -> Groups
f =
  forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = Groups -> Groups
f (State -> Groups
stateGroups State
st)}

-- | Remove these keys from the equivalence table.
removeEquivalents :: IS.IntSet -> ReorderM ()
removeEquivalents :: Dependencies -> StateT State PassM ()
removeEquivalents Dependencies
keys =
  forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
    let eqs' :: EquivalenceTable
eqs' = State -> EquivalenceTable
stateEquivalents State
st forall a. IntMap a -> Dependencies -> IntMap a
`IM.withoutKeys` Dependencies
keys
     in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}

-- | Add an entry to the equivalence table.
recordEquivalent :: VName -> Entry -> ReorderM ()
recordEquivalent :: VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n Entry
entry =
  forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
    let eqs :: EquivalenceTable
eqs = State -> EquivalenceTable
stateEquivalents State
st
        eqs' :: EquivalenceTable
eqs' = forall a. Int -> a -> IntMap a -> IntMap a
IM.insert (VName -> Int
baseTag VName
n) Entry
entry EquivalenceTable
eqs
     in State
st {stateEquivalents :: EquivalenceTable
stateEquivalents = EquivalenceTable
eqs'}

-- | Moves a GPUBody statement to the furthest possible group of the statement
-- sequence, possibly a new group at the end of sequence.
--
-- To simplify consumption handling a GPUBody is not allowed to merge with a
-- kernel whose result it consumes. Such GPUBody may therefore not be moved
-- into the same group as such kernel.
moveGPUBody :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveGPUBody :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveGPUBody Stm GPU
stm Usage
usage Dependencies
consumed = do
  -- Replace dependencies with their GPUBody result equivalents.
  EquivalenceTable
eqs <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
  let g :: Int -> Int
g Int
i = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int
i (VName -> Int
baseTag forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry -> VName
entryResult) (forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i EquivalenceTable
eqs)
  let deps' :: Dependencies
deps' = (Int -> Int) -> Dependencies -> Dependencies
IS.map Int -> Int
g (Usage -> Dependencies
usageDependencies Usage
usage)
  let usage' :: Usage
usage' = Usage
usage {usageDependencies :: Dependencies
usageDependencies = Dependencies
deps'}

  -- Move the GPUBody.
  Groups
grps <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage' Dependencies
consumed
  let idx :: Int
idx = forall a. a -> Maybe a -> a
fromMaybe Int
1 (forall a. (a -> Bool) -> Seq a -> Maybe Int
SQ.findIndexR Group -> Bool
f Groups
grps)
  let idx' :: Int
idx' = case Int
idx forall a. Integral a => a -> a -> a
`mod` Int
2 of
        Int
0 -> Int
idx forall a. Num a => a -> a -> a
+ Int
1
        Int
_ | Int -> Groups -> Bool
consumes Int
idx Groups
grps -> Int
idx forall a. Num a => a -> a -> a
+ Int
2
        Int
_ -> Int
idx
  (Groups -> Groups) -> StateT State PassM ()
modifyGroups forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Int
idx'

  -- Record the kernel equivalents of the bound results.
  let pes :: [PatElem Type]
pes = forall dec. Pat dec -> [PatElem dec]
patElems (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm GPU
stm)
  let Op (GPUBody [Type]
_ (Body BodyDec GPU
_ Stms GPU
_ Result
res)) = forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Int
idx') (forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res))
  where
    consumes :: Int -> Groups -> Bool
consumes Int
idx Groups
grps
      | Just Group
grp <- forall a. Int -> Seq a -> Maybe a
SQ.lookup Int
idx Groups
grps =
          Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Dependencies -> Dependencies -> Bool
IS.disjoint (Group -> Dependencies
groupBindings Group
grp) Dependencies
consumed
      | Bool
otherwise =
          Bool
False

    stores :: Int -> (PatElem Type, SubExp) -> StateT State PassM ()
stores Int
idx (PatElem VName
n Type
t, SubExp
se)
      | Just Type
row_t <- forall u.
Int
-> TypeBase (ShapeBase SubExp) u
-> Maybe (TypeBase (ShapeBase SubExp) u)
peelArray Int
1 Type
t =
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Int -> Bool -> Entry
Entry SubExp
se Type
row_t VName
n Int
idx Bool
True
      | Bool
otherwise =
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
n forall a b. (a -> b) -> a -> b
$ SubExp -> Type -> VName -> Int -> Bool -> Entry
Entry SubExp
se Type
t VName
n Int
idx Bool
False

-- | Moves a non-GPUBody statement to the furthest possible groups of the
-- statement sequence, possibly a new group at the end of sequence.
moveOther :: Stm GPU -> Usage -> Consumption -> ReorderM ()
moveOther :: Stm GPU -> Usage -> Dependencies -> StateT State PassM ()
moveOther Stm GPU
stm Usage
usage Dependencies
consumed = do
  Groups
grps <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  let f :: Group -> Bool
f = Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed
  let idx :: Int
idx = forall a. a -> Maybe a -> a
fromMaybe Int
0 (forall a. (a -> Bool) -> Seq a -> Maybe Int
SQ.findIndexR Group -> Bool
f Groups
grps)
  let idx' :: Int
idx' = ((Int
idx forall a. Num a => a -> a -> a
+ Int
1) forall a. Integral a => a -> a -> a
`div` Int
2) forall a. Num a => a -> a -> a
* Int
2
  (Groups -> Groups) -> StateT State PassM ()
modifyGroups forall a b. (a -> b) -> a -> b
$ (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU
stm, Usage
usage) Int
idx'
  Stm GPU -> Int -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Int
idx'

-- | @recordEquivalentsOf stm idx@ records the GPUBody result and/or return
-- value that @stm@ is equivalent to. @idx@ is the index of the group that @stm@
-- belongs to.
--
-- A GPUBody can have a dependency substituted with a result equivalent if it
-- merges with the source GPUBody, allowing it to be moved beyond the binding
-- site of that dependency.
--
-- To guarantee that a GPUBody which moves beyond a dependency also merges with
-- its source GPUBody, equivalents are only allowed to be recorded for results
-- bound within the group at index @idx-1@.
recordEquivalentsOf :: Stm GPU -> Int -> ReorderM ()
recordEquivalentsOf :: Stm GPU -> Int -> StateT State PassM ()
recordEquivalentsOf Stm GPU
stm Int
idx = do
  EquivalenceTable
eqs <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents
  case Stm GPU
stm of
    Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (SubExp (Var VName
n)))
      | Just Entry
entry <- forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) EquivalenceTable
eqs,
        Entry -> Int
entryGroupIdx Entry
entry forall a. Eq a => a -> a -> Bool
== Int
idx forall a. Num a => a -> a -> a
- Int
1 ->
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x Entry
entry
    Let (Pat [PatElem VName
x LetDec GPU
_]) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
arr Slice SubExp
slice))
      | Just Entry
entry <- forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
arr) EquivalenceTable
eqs,
        Entry -> Int
entryGroupIdx Entry
entry forall a. Eq a => a -> a -> Bool
== Int
idx forall a. Num a => a -> a -> a
- Int
1,
        Slice (DimFix SubExp
i : [DimIndex SubExp]
dims) <- Slice SubExp
slice,
        SubExp
i forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0,
        [DimIndex SubExp]
dims forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ Entry -> Type
entryType Entry
entry) ->
          VName -> Entry -> StateT State PassM ()
recordEquivalent VName
x (Entry
entry {entryStored :: Bool
entryStored = Bool
False})
    Stm GPU
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Does this group block a statement with this usage/consumption statistics
-- from being moved past it?
groupBlocks :: Usage -> Consumption -> Group -> Bool
groupBlocks :: Usage -> Dependencies -> Group -> Bool
groupBlocks Usage
usage Dependencies
consumed Group
grp =
  let bound :: Dependencies
bound = Group -> Dependencies
groupBindings Group
grp
      deps :: Dependencies
deps = Group -> Dependencies
groupDependencies Group
grp

      used :: Dependencies
used = Usage -> Dependencies
usageDependencies Usage
usage
   in Bool -> Bool
not (Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
bound Dependencies
used Bool -> Bool -> Bool
&& Dependencies -> Dependencies -> Bool
IS.disjoint Dependencies
deps Dependencies
consumed)

-- | @moveToGrp stm idx grps@ moves @stm@ into the group at index @idx@ of
-- @grps@.
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp :: (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Int
idx Groups
grps
  | Int
idx forall a. Ord a => a -> a -> Bool
>= forall a. Seq a -> Int
SQ.length Groups
grps =
      (Stm GPU, Usage) -> Int -> Groups -> Groups
moveToGrp (Stm GPU, Usage)
stm Int
idx (Groups
grps forall a. Seq a -> a -> Seq a
|> forall a. Monoid a => a
mempty)
  | Bool
otherwise =
      forall a. (a -> a) -> Int -> Seq a -> Seq a
SQ.adjust' ((Stm GPU, Usage)
stm `moveTo`) Int
idx Groups
grps

-- | Adds the statement and its usage statistics to the group.
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo :: (Stm GPU, Usage) -> Group -> Group
moveTo (Stm GPU
stm, Usage
usage) Group
grp =
  Group
grp
    { groupStms :: Stms GPU
groupStms = Group -> Stms GPU
groupStms Group
grp forall a. Seq a -> a -> Seq a
|> Stm GPU
stm,
      groupUsage :: Usage
groupUsage = Group -> Usage
groupUsage Group
grp forall a. Semigroup a => a -> a -> a
<> Usage
usage
    }

--------------------------------------------------------------------------------
--                         MERGING GPU BODIES - TYPES                         --
--------------------------------------------------------------------------------

-- | The monad used for rewriting a GPUBody to use the t'SubExp's that are
-- returned from kernels it is merged with rather than the results that they
-- bind.
--
-- The state is a prologue of statements to be added at the beginning of the
-- rewritten kernel body.
type RewriteM = StateT (Stms GPU) ReorderM

--------------------------------------------------------------------------------
--                       MERGING GPU BODIES - FUNCTIONS                       --
--------------------------------------------------------------------------------

-- | Collapses the processed sequence of groups into a single group and returns
-- it, merging GPUBody groups into single kernels in the process.
collapse :: ReorderM Group
collapse :: StateT State PassM Group
collapse = do
  [(Bool, Group)]
grps <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [a]
cycle [Bool
False, Bool
True]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Groups
stateGroups
  Group
grp <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Group -> (Bool, Group) -> StateT State PassM Group
clps forall a. Monoid a => a
mempty [(Bool, Group)]
grps

  forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGroups :: Groups
stateGroups = forall a. a -> Seq a
SQ.singleton Group
grp}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Group
grp
  where
    clps :: Group -> (Bool, Group) -> StateT State PassM Group
clps Group
grp0 (Bool
gpu_bodies, Group Stms GPU
stms Usage
usage) = do
      Group
grp1 <-
        if Bool
gpu_bodies
          then Stms GPU -> Usage -> Group
Group forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> ReorderM (Stms GPU)
mergeKernels Stms GPU
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Usage
usage
          else forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> Usage -> Group
Group Stms GPU
stms Usage
usage)
      -- Remove equivalents that no longer are relevant for rewriting GPUBody
      -- kernels. This ensures that they are not substituted in later kernels
      -- where the replacement variables might not be in scope.
      Dependencies -> StateT State PassM ()
removeEquivalents (Group -> Dependencies
groupBindings Group
grp1)
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Group
grp0 forall a. Semigroup a => a -> a -> a
<> Group
grp1)

-- | Merges a sequence of GPUBody statements into a single kernel.
mergeKernels :: Stms GPU -> ReorderM (Stms GPU)
mergeKernels :: Stms GPU -> ReorderM (Stms GPU)
mergeKernels Stms GPU
stms
  | forall a. Seq a -> Int
SQ.length Stms GPU
stms forall a. Ord a => a -> a -> Bool
< Int
2 =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms GPU
stms
  | Bool
otherwise =
      forall a. a -> Seq a
SQ.singleton forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> b -> m b) -> b -> t a -> m b
foldrM Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
merge Stm GPU
empty Stms GPU
stms
  where
    empty :: Stm GPU
empty = forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let forall a. Monoid a => a
mempty (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()) Exp GPU
noop
    noop :: Exp GPU
noop = forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [] (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall a. Seq a
SQ.empty []))

    merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
    merge :: Stm GPU -> Stm GPU -> ReorderM (Stm GPU)
merge Stm GPU
stm0 Stm GPU
stm1
      | Let Pat (LetDec GPU)
pat0 (StmAux Certs
cs0 Attrs
attrs0 ExpDec GPU
_) (Op (GPUBody [Type]
types0 Body GPU
body)) <- Stm GPU
stm0,
        Let Pat (LetDec GPU)
pat1 (StmAux Certs
cs1 Attrs
attrs1 ExpDec GPU
_) (Op (GPUBody [Type]
types1 Body GPU
body1)) <- Stm GPU
stm1 =
          do
            Body BodyDec GPU
_ Stms GPU
stms0 Result
res0 <- RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite (Body GPU -> RewriteM (Body GPU)
rewriteBody Body GPU
body)
            let Body BodyDec GPU
_ Stms GPU
stms1 Result
res1 = Body GPU
body1

                pat' :: Pat Type
pat' = Pat (LetDec GPU)
pat0 forall a. Semigroup a => a -> a -> a
<> Pat (LetDec GPU)
pat1
                aux' :: StmAux ()
aux' = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux (Certs
cs0 forall a. Semigroup a => a -> a -> a
<> Certs
cs1) (Attrs
attrs0 forall a. Semigroup a => a -> a -> a
<> Attrs
attrs1) ()
                types' :: [Type]
types' = [Type]
types0 forall a. [a] -> [a] -> [a]
++ [Type]
types1
                body' :: Body GPU
body' = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
stms0 forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms1) (Result
res0 forall a. Semigroup a => a -> a -> a
<> Result
res1)
             in forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' StmAux ()
aux' (forall {k} (rep :: k). Op rep -> Exp rep
Op (forall {k} (rep :: k) op. [Type] -> Body rep -> HostOp rep op
GPUBody [Type]
types' Body GPU
body')))
    merge Stm GPU
_ Stm GPU
_ =
      forall a. String -> a
compilerBugS String
"mergeGPUBodies: cannot merge non-GPUBody statements"

-- | Perform a rewrite and finish it by adding the rewrite prologue to the start
-- of the body.
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite :: RewriteM (Body GPU) -> ReorderM (Body GPU)
execRewrite RewriteM (Body GPU)
m = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT RewriteM (Body GPU)
m' forall a. Seq a
SQ.empty
  where
    m' :: RewriteM (Body GPU)
m' = do
      Body BodyDec GPU
_ Stms GPU
stms Result
res <- RewriteM (Body GPU)
m
      Stms GPU
prologue <- forall (m :: * -> *) s. Monad m => StateT s m s
get
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms GPU
prologue forall a. Semigroup a => a -> a -> a
<> Stms GPU
stms) Result
res)

-- | Return the equivalence table.
equivalents :: RewriteM EquivalenceTable
equivalents :: RewriteM EquivalenceTable
equivalents = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> EquivalenceTable
stateEquivalents)

rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody :: Body GPU -> RewriteM (Body GPU)
rewriteBody (Body BodyDec GPU
_ Stms GPU
stms Result
res) =
  forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms GPU -> RewriteM (Stms GPU)
rewriteStms Stms GPU
stms forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> RewriteM Result
rewriteResult Result
res

rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms :: Stms GPU -> RewriteM (Stms GPU)
rewriteStms = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> RewriteM (Stm GPU)
rewriteStm

rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm :: Stm GPU -> RewriteM (Stm GPU)
rewriteStm (Let (Pat [PatElem (LetDec GPU)]
pes) (StmAux Certs
cs Attrs
attrs ExpDec GPU
_) Exp GPU
e) = do
  Pat Type
pat' <- forall dec. [PatElem dec] -> Pat dec
Pat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElem Type -> RewriteM (PatElem Type)
rewritePatElem [PatElem (LetDec GPU)]
pes
  Certs
cs' <- Certs -> RewriteM Certs
rewriteCerts Certs
cs
  Exp GPU
e' <- Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat' (forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux Certs
cs' Attrs
attrs ()) Exp GPU
e'

rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem :: PatElem Type -> RewriteM (PatElem Type)
rewritePatElem (PatElem VName
n Type
t) =
  forall dec. VName -> dec -> PatElem dec
PatElem VName
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType Type
t

rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp :: Exp GPU -> RewriteM (Exp GPU)
rewriteExp Exp GPU
e = do
  EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
  case Exp GPU
e of
    BasicOp (Index VName
arr Slice SubExp
slice)
      | Just Entry
entry <- forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
arr) EquivalenceTable
eqs,
        DimFix SubExp
idx : [DimIndex SubExp]
dims <- forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice,
        SubExp
idx forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 ->
          let se :: SubExp
se = Entry -> SubExp
entryValue Entry
entry
           in forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ case ([DimIndex SubExp]
dims, SubExp
se) of
                ([], SubExp
_) -> SubExp -> BasicOp
SubExp SubExp
se
                ([DimIndex SubExp]
_, Var VName
src) -> VName -> Slice SubExp -> BasicOp
Index VName
src (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
dims)
                ([DimIndex SubExp], SubExp)
_ -> forall a. String -> a
compilerBugS String
"rewriteExp: bad equivalence entry"
    Exp GPU
_ -> forall {k1} {k2} (m :: * -> *) (frep :: k1) (trep :: k2).
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter Exp GPU
e
  where
    rewriter :: Mapper GPU GPU (StateT (Stms GPU) ReorderM)
rewriter =
      Mapper
        { mapOnSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
mapOnSubExp = SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp,
          mapOnBody :: Scope GPU -> Body GPU -> RewriteM (Body GPU)
mapOnBody = forall a b. a -> b -> a
const Body GPU -> RewriteM (Body GPU)
rewriteBody,
          mapOnVName :: VName -> StateT (Stms GPU) ReorderM VName
mapOnVName = VName -> StateT (Stms GPU) ReorderM VName
rewriteName,
          mapOnRetType :: RetType GPU -> StateT (Stms GPU) ReorderM (RetType GPU)
mapOnRetType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType,
          mapOnBranchType :: BranchType GPU -> StateT (Stms GPU) ReorderM (BranchType GPU)
mapOnBranchType = forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType,
          mapOnFParam :: FParam GPU -> StateT (Stms GPU) ReorderM (FParam GPU)
mapOnFParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
          mapOnLParam :: LParam GPU -> StateT (Stms GPU) ReorderM (LParam GPU)
mapOnLParam = forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam,
          mapOnOp :: Op GPU -> StateT (Stms GPU) ReorderM (Op GPU)
mapOnOp = forall a b. a -> b -> a
const forall {a}. a
opError
        }

    opError :: a
opError = forall a. String -> a
compilerBugS String
"rewriteExp: unhandled HostOp in GPUBody"

rewriteResult :: Result -> RewriteM Result
rewriteResult :: Result -> RewriteM Result
rewriteResult = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes

rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes :: SubExpRes -> RewriteM SubExpRes
rewriteSubExpRes (SubExpRes Certs
cs SubExp
se) =
  Certs -> SubExp -> SubExpRes
SubExpRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Certs -> RewriteM Certs
rewriteCerts Certs
cs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp SubExp
se

rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts :: Certs -> RewriteM Certs
rewriteCerts (Certs [VName]
cs) =
  [VName] -> Certs
Certs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> StateT (Stms GPU) ReorderM VName
rewriteName [VName]
cs

rewriteType :: TypeBase Shape u -> RewriteM (TypeBase Shape u)
-- Note: mapOnType also maps the VName token of accumulators
rewriteType :: forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase SubExp) u
-> m (TypeBase (ShapeBase SubExp) u)
mapOnType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp

rewriteExtType :: TypeBase ExtShape u -> RewriteM (TypeBase ExtShape u)
-- Note: mapOnExtType also maps the VName token of accumulators
rewriteExtType :: forall u.
TypeBase (ShapeBase ExtSize) u
-> RewriteM (TypeBase (ShapeBase ExtSize) u)
rewriteExtType = forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp

rewriteParam :: Param (TypeBase Shape u) -> RewriteM (Param (TypeBase Shape u))
rewriteParam :: forall u.
Param (TypeBase (ShapeBase SubExp) u)
-> RewriteM (Param (TypeBase (ShapeBase SubExp) u))
rewriteParam (Param Attrs
attrs VName
n TypeBase (ShapeBase SubExp) u
t) =
  forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
attrs VName
n forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall u.
TypeBase (ShapeBase SubExp) u
-> RewriteM (TypeBase (ShapeBase SubExp) u)
rewriteType TypeBase (ShapeBase SubExp) u
t

rewriteSubExp :: SubExp -> RewriteM SubExp
rewriteSubExp :: SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (Constant PrimValue
c) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimValue -> SubExp
Constant PrimValue
c)
rewriteSubExp (Var VName
n) = do
  EquivalenceTable
eqs <- RewriteM EquivalenceTable
equivalents
  case forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) EquivalenceTable
eqs of
    Maybe Entry
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> SubExp
Var VName
n)
    Just (Entry SubExp
se Type
_ VName
_ Int
_ Bool
False) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
    Just (Entry SubExp
se Type
t VName
_ Int
_ Bool
True) -> VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
t

rewriteName :: VName -> RewriteM VName
rewriteName :: VName -> StateT (Stms GPU) ReorderM VName
rewriteName VName
n = do
  SubExp
se <- SubExp -> StateT (Stms GPU) ReorderM SubExp
rewriteSubExp (VName -> SubExp
Var VName
n)
  case SubExp
se of
    Var VName
n' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
n'
    Constant PrimValue
c -> PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c

-- | @asArray se t@ adds @let x = [se]@ to the rewrite prologue and returns the
-- name of @x@. @t@ is the type of @se@.
asArray :: SubExp -> Type -> RewriteM VName
asArray :: SubExp -> Type -> StateT (Stms GPU) ReorderM VName
asArray SubExp
se Type
row_t = do
  VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"arr"
  let t :: Type
t = Type
row_t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

  let pat :: Pat Type
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
name Type
t]
  let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
  let e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp ([SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] Type
row_t)

  forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall {k} {rep :: k}. Exp rep
e)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name

-- | @referConst c@ adds @let x = c@ to the rewrite prologue and returns the
-- name of @x@.
referConst :: PrimValue -> RewriteM VName
referConst :: PrimValue -> StateT (Stms GPU) ReorderM VName
referConst PrimValue
c = do
  VName
name <- String -> StateT (Stms GPU) ReorderM VName
newName String
"cnst"
  let t :: TypeBase shape u
t = forall shape u. PrimType -> TypeBase shape u
Prim (PrimValue -> PrimType
primValueType PrimValue
c)

  let pat :: Pat (TypeBase shape u)
pat = forall dec. [PatElem dec] -> Pat dec
Pat [forall dec. VName -> dec -> PatElem dec
PatElem VName
name forall {shape} {u}. TypeBase shape u
t]
  let aux :: StmAux ()
aux = forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty ()
  let e :: Exp rep
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp (SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
c)

  forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify (forall a. Seq a -> a -> Seq a
|> forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let forall {shape} {u}. Pat (TypeBase shape u)
pat StmAux ()
aux forall {k} {rep :: k}. Exp rep
e)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name

-- | Produce a fresh name, using the given string as a template.
newName :: String -> RewriteM VName
newName :: String -> StateT (Stms GPU) ReorderM VName
newName String
s = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
s)