{-# LANGUAGE TypeFamilies #-}

-- | "Sinking" is conceptually the opposite of hoisting.  The idea is
-- to take code that looks like this:
--
-- @
-- x = xs[i]
-- y = ys[i]
-- if x != 0 then {
--   y
-- } else {
--   0
-- }
-- @
--
-- and turn it into
--
-- @
-- x = xs[i]
-- if x != 0 then {
--   y = ys[i]
--   y
-- } else {
--   0
-- }
-- @
--
-- The idea is to delay loads from memory until (if) they are actually
-- needed.  Code patterns like the above is particularly common in
-- code that makes use of pattern matching on sum types.
--
-- We are currently quite conservative about when we do this.  In
-- particular, if any consumption is going on in a body, we don't do
-- anything.  This is far too conservative.  Also, we are careful
-- never to duplicate work.
--
-- This pass redundantly computes free-variable information a lot.  If
-- you ever see this pass as being a compilation speed bottleneck,
-- start by caching that a bit.
--
-- This pass is defined on post-SOACS representations.  This is not
-- because we do anything GPU-specific here, but simply because more
-- explicit indexing is going on after SOACs are gone.
module Futhark.Optimise.Sink (sinkGPU, sinkMC) where

import Control.Monad.State
import Data.Bifunctor
import Data.List (foldl')
import Data.Map qualified as M
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Builder.Class
import Futhark.IR.Aliases
import Futhark.IR.GPU
import Futhark.IR.MC
import Futhark.Pass

type SymbolTable rep = ST.SymbolTable rep

type Sinking rep = M.Map VName (Stm rep)

type Sunk = Names

type Sinker rep a = SymbolTable rep -> Sinking rep -> a -> (a, Sunk)

type Constraints rep =
  ( ASTRep rep,
    Aliased rep,
    Buildable rep,
    ST.IndexOp (Op rep)
  )

-- | Given a statement, compute how often each of its free variables
-- are used.  Not accurate: what we care about are only 1, and greater
-- than 1.
multiplicity :: (Constraints rep) => Stm rep -> M.Map VName Int
multiplicity :: forall rep. Constraints rep => Stm rep -> Map VName Int
multiplicity Stm rep
stm =
  case Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm of
    Match [SubExp]
cond [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_ ->
      (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName Int -> Map VName Int -> Map VName Int
comb Map VName Int
forall a. Monoid a => a
mempty ([Map VName Int] -> Map VName Int)
-> [Map VName Int] -> Map VName Int
forall a b. (a -> b) -> a -> b
$
        Int -> [SubExp] -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 [SubExp]
cond
          Map VName Int -> [Map VName Int] -> [Map VName Int]
forall a. a -> [a] -> [a]
: Int -> Body rep -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 Body rep
defbody
          Map VName Int -> [Map VName Int] -> [Map VName Int]
forall a. a -> [a] -> [a]
: (Case (Body rep) -> Map VName Int)
-> [Case (Body rep)] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Body rep -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 (Body rep -> Map VName Int)
-> (Case (Body rep) -> Body rep)
-> Case (Body rep)
-> Map VName Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body rep) -> Body rep
forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
    Op {} -> Int -> Stm rep -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
2 Stm rep
stm
    Loop {} -> Int -> Stm rep -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
2 Stm rep
stm
    Exp rep
_ -> Int -> Stm rep -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Int
1 Stm rep
stm
  where
    free :: a -> a -> Map VName a
free a
k a
x = [(VName, a)] -> Map VName a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, a)] -> Map VName a) -> [(VName, a)] -> Map VName a
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, a)) -> [VName] -> [(VName, a)]
forall a b. (a -> b) -> [a] -> [b]
map (,a
k) ([VName] -> [(VName, a)]) -> [VName] -> [(VName, a)]
forall a b. (a -> b) -> a -> b
$ Sunk -> [VName]
namesToList (Sunk -> [VName]) -> Sunk -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn a
x
    comb :: Map VName Int -> Map VName Int -> Map VName Int
comb = (Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)

optimiseBranch ::
  (Constraints rep) =>
  Sinker rep (Op rep) ->
  Sinker rep (Body rep)
optimiseBranch :: forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (Body BodyDec rep
dec Stms rep
stms Result
res) =
  let (Stms rep
stms', Sunk
stms_sunk) = Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking' (Stms rep
sunk_stms Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
stms) (Sunk -> (Stms rep, Sunk)) -> Sunk -> (Stms rep, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn Result
res
   in ( BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
dec Stms rep
stms' Result
res,
        Sunk
sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
      )
  where
    free_in_stms :: Sunk
free_in_stms = Stms rep -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn Stms rep
stms Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Result -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn Result
res
    (Sinking rep
sinking_here, Sinking rep
sinking') = (VName -> Stm rep -> Bool)
-> Sinking rep -> (Sinking rep, Sinking rep)
forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm rep -> Bool
sunkHere Sinking rep
sinking
    sunk_stms :: Stms rep
sunk_stms = [Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList ([Stm rep] -> Stms rep) -> [Stm rep] -> Stms rep
forall a b. (a -> b) -> a -> b
$ Sinking rep -> [Stm rep]
forall k a. Map k a -> [a]
M.elems Sinking rep
sinking_here
    sunkHere :: VName -> Stm rep -> Bool
sunkHere VName
v Stm rep
stm =
      VName
v
        VName -> Sunk -> Bool
`nameIn` Sunk
free_in_stms
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable rep
vtable) (Sunk -> [VName]
namesToList (Stm rep -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn Stm rep
stm))
    sunk :: Sunk
sunk = [VName] -> Sunk
namesFromList ([VName] -> Sunk) -> [VName] -> Sunk
forall a b. (a -> b) -> a -> b
$ (Stm rep -> [VName]) -> Stms rep -> [VName]
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames (Pat (LetDec rep) -> [VName])
-> (Stm rep -> Pat (LetDec rep)) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat) Stms rep
sunk_stms

optimiseLoop ::
  (Constraints rep) =>
  Sinker rep (Op rep) ->
  Sinker rep ([(FParam rep, SubExp)], LoopForm, Body rep)
optimiseLoop :: forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> Sinker rep ([(FParam rep, SubExp)], LoopForm, Body rep)
optimiseLoop Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking ([(FParam rep, SubExp)]
merge, LoopForm
form, Body rep
body0) =
  let (Body rep
body1, Sunk
sunk) = Sinker rep (Op rep) -> Sinker rep (Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable' Sinking rep
sinking Body rep
body0
   in (([(FParam rep, SubExp)]
merge, LoopForm
form, Body rep
body1), Sunk
sunk)
  where
    ([Param DeclType]
params, [SubExp]
_) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam rep, SubExp)]
merge
    scope :: Scope rep
scope = case LoopForm
form of
      WhileLoop {} -> [Param DeclType] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
      ForLoop VName
i IntType
it SubExp
_ -> VName -> NameInfo rep -> Scope rep -> Scope rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
i (IntType -> NameInfo rep
forall rep. IntType -> NameInfo rep
IndexName IntType
it) (Scope rep -> Scope rep) -> Scope rep -> Scope rep
forall a b. (a -> b) -> a -> b
$ [Param DeclType] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params
    vtable' :: SymbolTable rep
vtable' = Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope SymbolTable rep -> SymbolTable rep -> SymbolTable rep
forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable

optimiseStms ::
  (Constraints rep) =>
  Sinker rep (Op rep) ->
  SymbolTable rep ->
  Sinking rep ->
  Stms rep ->
  Names ->
  (Stms rep, Sunk)
optimiseStms :: forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (OpC rep rep)
onOp SymbolTable rep
init_vtable Sinking rep
init_sinking Stms rep
all_stms Sunk
free_in_res =
  let ([Stm rep]
all_stms', Sunk
sunk) =
        SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
init_vtable Sinking rep
init_sinking ([Stm rep] -> ([Stm rep], Sunk)) -> [Stm rep] -> ([Stm rep], Sunk)
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
   in ([Stm rep] -> Stms rep
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
all_stms', Sunk
sunk)
  where
    multiplicities :: Map VName Int
multiplicities =
      (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
        ((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
        ([(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ((VName -> (VName, Int)) -> [VName] -> [(VName, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (,Int
1) (Sunk -> [VName]
namesToList Sunk
free_in_res)))
        ((Stm rep -> Map VName Int) -> [Stm rep] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map Stm rep -> Map VName Int
forall rep. Constraints rep => Stm rep -> Map VName Int
multiplicity ([Stm rep] -> [Map VName Int]) -> [Stm rep] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms)

    optimiseStms' :: SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
_ Sinking rep
_ [] = ([], Sunk
forall a. Monoid a => a
mempty)
    optimiseStms' SymbolTable rep
vtable Sinking rep
sinking (Stm rep
stm : [Stm rep]
stms)
      | BasicOp Index {} <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        [PatElem (LetDec rep)
pe] <- Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm),
        Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe,
        Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Map VName Int
multiplicities =
          let ([Stm rep]
stms', Sunk
sunk) =
                SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' (VName -> Stm rep -> Sinking rep -> Sinking rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Stm rep
stm Sinking rep
sinking) [Stm rep]
stms
           in if PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Sunk -> Bool
`nameIn` Sunk
sunk
                then ([Stm rep]
stms', Sunk
sunk)
                else (Stm rep
stm Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms', Sunk
sunk)
      | Match [SubExp]
cond [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
ret <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          let onCase :: Case (Body rep) -> (Case (Body rep), Sunk)
onCase (Case [Maybe PrimValue]
vs Body rep
body) =
                let (Body rep
body', Sunk
body_sunk) = Sinker rep (OpC rep rep) -> Sinker rep (Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (OpC rep rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Body rep
body
                 in ([Maybe PrimValue] -> Body rep -> Case (Body rep)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs Body rep
body', Sunk
body_sunk)
              ([Case (Body rep)]
cases', [Sunk]
cases_sunk) = [(Case (Body rep), Sunk)] -> ([Case (Body rep)], [Sunk])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Case (Body rep), Sunk)] -> ([Case (Body rep)], [Sunk]))
-> [(Case (Body rep), Sunk)] -> ([Case (Body rep)], [Sunk])
forall a b. (a -> b) -> a -> b
$ (Case (Body rep) -> (Case (Body rep), Sunk))
-> [Case (Body rep)] -> [(Case (Body rep), Sunk)]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body rep) -> (Case (Body rep), Sunk)
onCase [Case (Body rep)]
cases
              (Body rep
defbody', Sunk
defbody_sunk) = Sinker rep (OpC rep rep) -> Sinker rep (Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBranch Sinker rep (OpC rep rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Body rep
defbody
              ([Stm rep]
stms', Sunk
sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
           in ( Stm rep
stm {stmExp = Match cond cases' defbody' ret} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
                [Sunk] -> Sunk
forall a. Monoid a => [a] -> a
mconcat [Sunk]
cases_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
defbody_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk
              )
      | Loop [(FParam rep, SubExp)]
merge LoopForm
lform Body rep
body <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          let comps :: ([(Param DeclType, SubExp)], LoopForm, Body rep)
comps = ([(Param DeclType, SubExp)]
[(FParam rep, SubExp)]
merge, LoopForm
lform, Body rep
body)
              (([(FParam rep, SubExp)], LoopForm, Body rep)
comps', Sunk
loop_sunk) = Sinker rep (OpC rep rep)
-> Sinker rep ([(FParam rep, SubExp)], LoopForm, Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> Sinker rep ([(FParam rep, SubExp)], LoopForm, Body rep)
optimiseLoop Sinker rep (OpC rep rep)
onOp SymbolTable rep
vtable Sinking rep
sinking ([(Param DeclType, SubExp)], LoopForm, Body rep)
([(FParam rep, SubExp)], LoopForm, Body rep)
comps
              ([(FParam rep, SubExp)]
merge', LoopForm
_, Body rep
body') = ([(FParam rep, SubExp)], LoopForm, Body rep)
comps'

              ([Stm rep]
stms', Sunk
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
           in ( Stm rep
stm {stmExp = Loop merge' lform body'} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
                Sunk
stms_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
loop_sunk
              )
      | Op OpC rep rep
op <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm =
          let (OpC rep rep
op', Sunk
op_sunk) = Sinker rep (OpC rep rep)
onOp SymbolTable rep
vtable Sinking rep
sinking OpC rep rep
op
              ([Stm rep]
stms', Sunk
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
           in ( Stm rep
stm {stmExp = Op op'} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
                Sunk
stms_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
op_sunk
              )
      | Bool
otherwise =
          let ([Stm rep]
stms', Sunk
stms_sunk) = SymbolTable rep -> Sinking rep -> [Stm rep] -> ([Stm rep], Sunk)
optimiseStms' SymbolTable rep
vtable' Sinking rep
sinking [Stm rep]
stms
              (Exp rep
e', Sunk
stm_sunk) = State Sunk (Exp rep) -> Sunk -> (Exp rep, Sunk)
forall s a. State s a -> s -> (a, s)
runState (Mapper rep rep (StateT Sunk Identity)
-> Exp rep -> State Sunk (Exp rep)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper rep rep (StateT Sunk Identity)
mapper (Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm)) Sunk
forall a. Monoid a => a
mempty
           in ( Stm rep
stm {stmExp = e'} Stm rep -> [Stm rep] -> [Stm rep]
forall a. a -> [a] -> [a]
: [Stm rep]
stms',
                Sunk
stm_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
              )
      where
        vtable' :: SymbolTable rep
vtable' = Stm rep -> SymbolTable rep -> SymbolTable rep
forall rep.
(IndexOp (Op rep), Aliased rep) =>
Stm rep -> SymbolTable rep -> SymbolTable rep
ST.insertStm Stm rep
stm SymbolTable rep
vtable
        mapper :: Mapper rep rep (StateT Sunk Identity)
mapper =
          Mapper rep rep (StateT Sunk Identity)
forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper
            { mapOnBody = \Scope rep
scope Body rep
body -> do
                let (Body rep
body', Sunk
sunk) =
                      Sinker rep (OpC rep rep) -> Sinker rep (Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody
                        Sinker rep (OpC rep rep)
onOp
                        (Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope SymbolTable rep -> SymbolTable rep -> SymbolTable rep
forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable)
                        Sinking rep
sinking
                        Body rep
body
                (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
                Body rep -> StateT Sunk Identity (Body rep)
forall a. a -> StateT Sunk Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body rep
body'
            }

optimiseBody ::
  (Constraints rep) =>
  Sinker rep (Op rep) ->
  Sinker rep (Body rep)
optimiseBody :: forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (Body BodyDec rep
attr Stms rep
stms Result
res) =
  let (Stms rep
stms', Sunk
sunk) = Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Stms rep
stms (Sunk -> (Stms rep, Sunk)) -> Sunk -> (Stms rep, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn Result
res
   in (BodyDec rep -> Stms rep -> Result -> Body rep
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body BodyDec rep
attr Stms rep
stms' Result
res, Sunk
sunk)

optimiseKernelBody ::
  (Constraints rep) =>
  Sinker rep (Op rep) ->
  Sinker rep (KernelBody rep)
optimiseKernelBody :: forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
optimiseKernelBody Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking (KernelBody BodyDec rep
attr Stms rep
stms [KernelResult]
res) =
  let (Stms rep
stms', Sunk
sunk) = Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking Stms rep
stms (Sunk -> (Stms rep, Sunk)) -> Sunk -> (Stms rep, Sunk)
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Sunk
forall a. FreeIn a => a -> Sunk
freeIn [KernelResult]
res
   in (BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody BodyDec rep
attr Stms rep
stms' [KernelResult]
res, Sunk
sunk)

optimiseSegOp ::
  (Constraints rep) =>
  Sinker rep (Op rep) ->
  Sinker rep (SegOp lvl rep)
optimiseSegOp :: forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker rep (Op rep)
onOp SymbolTable rep
vtable Sinking rep
sinking SegOp lvl rep
op =
  let scope :: Scope rep
scope = SegSpace -> Scope rep
forall rep. SegSpace -> Scope rep
scopeOfSegSpace (SegSpace -> Scope rep) -> SegSpace -> Scope rep
forall a b. (a -> b) -> a -> b
$ SegOp lvl rep -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp lvl rep
op
   in State Sunk (SegOp lvl rep) -> Sunk -> (SegOp lvl rep, Sunk)
forall s a. State s a -> s -> (a, s)
runState (SegOpMapper lvl rep rep (StateT Sunk Identity)
-> SegOp lvl rep -> State Sunk (SegOp lvl rep)
forall (m :: * -> *) lvl frep trep.
Monad m =>
SegOpMapper lvl frep trep m -> SegOp lvl frep -> m (SegOp lvl trep)
mapSegOpM (Scope rep -> SegOpMapper lvl rep rep (StateT Sunk Identity)
opMapper Scope rep
scope) SegOp lvl rep
op) Sunk
forall a. Monoid a => a
mempty
  where
    opMapper :: Scope rep -> SegOpMapper lvl rep rep (StateT Sunk Identity)
opMapper Scope rep
scope =
      SegOpMapper lvl Any Any (StateT Sunk Identity)
forall (m :: * -> *) lvl rep. Monad m => SegOpMapper lvl rep rep m
identitySegOpMapper
        { mapOnSegOpLambda = \Lambda rep
lam -> do
            let (Body rep
body, Sunk
sunk) =
                  Sinker rep (Op rep) -> Sinker rep (Body rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker rep (Op rep)
onOp SymbolTable rep
op_vtable Sinking rep
sinking (Body rep -> (Body rep, Sunk)) -> Body rep -> (Body rep, Sunk)
forall a b. (a -> b) -> a -> b
$
                    Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
            (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
            Lambda rep -> StateT Sunk Identity (Lambda rep)
forall a. a -> StateT Sunk Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
lam {lambdaBody = body},
          mapOnSegOpBody = \KernelBody rep
body -> do
            let (KernelBody rep
body', Sunk
sunk) =
                  Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (KernelBody rep)
optimiseKernelBody Sinker rep (Op rep)
onOp SymbolTable rep
op_vtable Sinking rep
sinking KernelBody rep
body
            (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
            KernelBody rep -> StateT Sunk Identity (KernelBody rep)
forall a. a -> StateT Sunk Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure KernelBody rep
body'
        }
      where
        op_vtable :: SymbolTable rep
op_vtable = Scope rep -> SymbolTable rep
forall rep. ASTRep rep => Scope rep -> SymbolTable rep
ST.fromScope Scope rep
scope SymbolTable rep -> SymbolTable rep -> SymbolTable rep
forall a. Semigroup a => a -> a -> a
<> SymbolTable rep
vtable

type SinkRep rep = Aliases rep

sink ::
  ( Buildable rep,
    AliasableRep rep,
    ST.IndexOp (Op (Aliases rep))
  ) =>
  Sinker (SinkRep rep) (Op (SinkRep rep)) ->
  Pass rep rep
sink :: forall rep.
(Buildable rep, AliasableRep rep, IndexOp (Op (Aliases rep))) =>
Sinker (Aliases rep) (Op (Aliases rep)) -> Pass rep rep
sink Sinker (Aliases rep) (Op (Aliases rep))
onOp =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass String
"sink" String
"move memory loads closer to their uses" ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$
    (Prog (Aliases rep) -> Prog rep)
-> PassM (Prog (Aliases rep)) -> PassM (Prog rep)
forall a b. (a -> b) -> PassM a -> PassM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (Aliases rep) -> Prog rep
forall rep. RephraseOp (OpC rep) => Prog (Aliases rep) -> Prog rep
removeProgAliases
      (PassM (Prog (Aliases rep)) -> PassM (Prog rep))
-> (Prog rep -> PassM (Prog (Aliases rep)))
-> Prog rep
-> PassM (Prog rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (Aliases rep) -> PassM (Stms (Aliases rep)))
-> (Stms (Aliases rep)
    -> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep)))
-> Prog (Aliases rep)
-> PassM (Prog (Aliases rep))
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts Stms (Aliases rep) -> PassM (Stms (Aliases rep))
onConsts Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
onFun
      (Prog (Aliases rep) -> PassM (Prog (Aliases rep)))
-> (Prog rep -> Prog (Aliases rep))
-> Prog rep
-> PassM (Prog (Aliases rep))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog rep -> Prog (Aliases rep)
forall rep. AliasableRep rep => Prog rep -> Prog (Aliases rep)
Alias.aliasAnalysis
  where
    onFun :: Stms (Aliases rep)
-> FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
onFun Stms (Aliases rep)
_ FunDef (Aliases rep)
fd = do
      let vtable :: SymbolTable (Aliases rep)
vtable = [FParam (Aliases rep)]
-> SymbolTable (Aliases rep) -> SymbolTable (Aliases rep)
forall rep.
ASTRep rep =>
[FParam rep] -> SymbolTable rep -> SymbolTable rep
ST.insertFParams (FunDef (Aliases rep) -> [FParam (Aliases rep)]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef (Aliases rep)
fd) SymbolTable (Aliases rep)
forall a. Monoid a => a
mempty
          (Body (Aliases rep)
body, Sunk
_) = Sinker (Aliases rep) (Op (Aliases rep))
-> Sinker (Aliases rep) (Body (Aliases rep))
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker (Aliases rep) (Op (Aliases rep))
onOp SymbolTable (Aliases rep)
vtable Sinking (Aliases rep)
forall a. Monoid a => a
mempty (Body (Aliases rep) -> (Body (Aliases rep), Sunk))
-> Body (Aliases rep) -> (Body (Aliases rep), Sunk)
forall a b. (a -> b) -> a -> b
$ FunDef (Aliases rep) -> Body (Aliases rep)
forall rep. FunDef rep -> Body rep
funDefBody FunDef (Aliases rep)
fd
      FunDef (Aliases rep) -> PassM (FunDef (Aliases rep))
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FunDef (Aliases rep)
fd {funDefBody = body}

    onConsts :: Stms (Aliases rep) -> PassM (Stms (Aliases rep))
onConsts Stms (Aliases rep)
consts =
      Stms (Aliases rep) -> PassM (Stms (Aliases rep))
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms (Aliases rep) -> PassM (Stms (Aliases rep)))
-> Stms (Aliases rep) -> PassM (Stms (Aliases rep))
forall a b. (a -> b) -> a -> b
$
        (Stms (Aliases rep), Sunk) -> Stms (Aliases rep)
forall a b. (a, b) -> a
fst ((Stms (Aliases rep), Sunk) -> Stms (Aliases rep))
-> (Stms (Aliases rep), Sunk) -> Stms (Aliases rep)
forall a b. (a -> b) -> a -> b
$
          Sinker (Aliases rep) (Op (Aliases rep))
-> SymbolTable (Aliases rep)
-> Sinking (Aliases rep)
-> Stms (Aliases rep)
-> Sunk
-> (Stms (Aliases rep), Sunk)
forall rep.
Constraints rep =>
Sinker rep (Op rep)
-> SymbolTable rep
-> Sinking rep
-> Stms rep
-> Sunk
-> (Stms rep, Sunk)
optimiseStms Sinker (Aliases rep) (Op (Aliases rep))
onOp SymbolTable (Aliases rep)
forall a. Monoid a => a
mempty Sinking (Aliases rep)
forall a. Monoid a => a
mempty Stms (Aliases rep)
consts (Sunk -> (Stms (Aliases rep), Sunk))
-> Sunk -> (Stms (Aliases rep), Sunk)
forall a b. (a -> b) -> a -> b
$
            [VName] -> Sunk
namesFromList ([VName] -> Sunk) -> [VName] -> Sunk
forall a b. (a -> b) -> a -> b
$
              Map VName (NameInfo (Aliases rep)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (Aliases rep)) -> [VName])
-> Map VName (NameInfo (Aliases rep)) -> [VName]
forall a b. (a -> b) -> a -> b
$
                Stms (Aliases rep) -> Map VName (NameInfo (Aliases rep))
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Aliases rep)
consts

-- | Sinking in GPU kernels.
sinkGPU :: Pass GPU GPU
sinkGPU :: Pass GPU GPU
sinkGPU = Sinker (SinkRep GPU) (Op (SinkRep GPU)) -> Pass GPU GPU
forall rep.
(Buildable rep, AliasableRep rep, IndexOp (Op (Aliases rep))) =>
Sinker (Aliases rep) (Op (Aliases rep)) -> Pass rep rep
sink Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp
  where
    onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU))
    onHostOp :: Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking (SegOp SegOp SegLevel (SinkRep GPU)
op) =
      (SegOp SegLevel (SinkRep GPU) -> Op (SinkRep GPU))
-> (SegOp SegLevel (SinkRep GPU), Sunk) -> (Op (SinkRep GPU), Sunk)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SegOp SegLevel (SinkRep GPU) -> Op (SinkRep GPU)
SegOp SegLevel (SinkRep GPU) -> HostOp SOAC (SinkRep GPU)
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp ((SegOp SegLevel (SinkRep GPU), Sunk) -> (Op (SinkRep GPU), Sunk))
-> (SegOp SegLevel (SinkRep GPU), Sunk) -> (Op (SinkRep GPU), Sunk)
forall a b. (a -> b) -> a -> b
$ Sinker (SinkRep GPU) (Op (SinkRep GPU))
-> Sinker (SinkRep GPU) (SegOp SegLevel (SinkRep GPU))
forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking SegOp SegLevel (SinkRep GPU)
op
    onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking (GPUBody [Type]
types Body (SinkRep GPU)
body) =
      (Body (SinkRep GPU) -> Op (SinkRep GPU))
-> (Body (SinkRep GPU), Sunk) -> (Op (SinkRep GPU), Sunk)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ([Type] -> Body (SinkRep GPU) -> HostOp SOAC (SinkRep GPU)
forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types) ((Body (SinkRep GPU), Sunk) -> (Op (SinkRep GPU), Sunk))
-> (Body (SinkRep GPU), Sunk) -> (Op (SinkRep GPU), Sunk)
forall a b. (a -> b) -> a -> b
$ Sinker (SinkRep GPU) (Op (SinkRep GPU))
-> Sinker (SinkRep GPU) (Body (SinkRep GPU))
forall rep.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (Body rep)
optimiseBody Sinker (SinkRep GPU) (Op (SinkRep GPU))
onHostOp SymbolTable (SinkRep GPU)
vtable Sinking (SinkRep GPU)
sinking Body (SinkRep GPU)
body
    onHostOp SymbolTable (SinkRep GPU)
_ Sinking (SinkRep GPU)
_ Op (SinkRep GPU)
op = (Op (SinkRep GPU)
op, Sunk
forall a. Monoid a => a
mempty)

-- | Sinking for multicore.
sinkMC :: Pass MC MC
sinkMC :: Pass MC MC
sinkMC = Sinker (SinkRep MC) (Op (SinkRep MC)) -> Pass MC MC
forall rep.
(Buildable rep, AliasableRep rep, IndexOp (Op (Aliases rep))) =>
Sinker (Aliases rep) (Op (Aliases rep)) -> Pass rep rep
sink Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp
  where
    onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC))
    onHostOp :: Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking (ParOp Maybe (SegOp () (SinkRep MC))
par_op SegOp () (SinkRep MC)
op) =
      let (Maybe (SegOp () (SinkRep MC))
par_op', Sunk
par_sunk) =
            (Maybe (SegOp () (SinkRep MC)), Sunk)
-> (SegOp () (SinkRep MC) -> (Maybe (SegOp () (SinkRep MC)), Sunk))
-> Maybe (SegOp () (SinkRep MC))
-> (Maybe (SegOp () (SinkRep MC)), Sunk)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
              (Maybe (SegOp () (SinkRep MC))
forall a. Maybe a
Nothing, Sunk
forall a. Monoid a => a
mempty)
              ((SegOp () (SinkRep MC) -> Maybe (SegOp () (SinkRep MC)))
-> (SegOp () (SinkRep MC), Sunk)
-> (Maybe (SegOp () (SinkRep MC)), Sunk)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SegOp () (SinkRep MC) -> Maybe (SegOp () (SinkRep MC))
forall a. a -> Maybe a
Just ((SegOp () (SinkRep MC), Sunk)
 -> (Maybe (SegOp () (SinkRep MC)), Sunk))
-> (SegOp () (SinkRep MC) -> (SegOp () (SinkRep MC), Sunk))
-> SegOp () (SinkRep MC)
-> (Maybe (SegOp () (SinkRep MC)), Sunk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sinker (SinkRep MC) (Op (SinkRep MC))
-> Sinker (SinkRep MC) (SegOp () (SinkRep MC))
forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking)
              Maybe (SegOp () (SinkRep MC))
par_op
          (SegOp () (SinkRep MC)
op', Sunk
sunk) = Sinker (SinkRep MC) (Op (SinkRep MC))
-> Sinker (SinkRep MC) (SegOp () (SinkRep MC))
forall rep lvl.
Constraints rep =>
Sinker rep (Op rep) -> Sinker rep (SegOp lvl rep)
optimiseSegOp Sinker (SinkRep MC) (Op (SinkRep MC))
onHostOp SymbolTable (SinkRep MC)
vtable Sinking (SinkRep MC)
sinking SegOp () (SinkRep MC)
op
       in (Maybe (SegOp () (SinkRep MC))
-> SegOp () (SinkRep MC) -> MCOp SOAC (SinkRep MC)
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () (SinkRep MC))
par_op' SegOp () (SinkRep MC)
op', Sunk
par_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
    onHostOp SymbolTable (SinkRep MC)
_ Sinking (SinkRep MC)
_ Op (SinkRep MC)
op = (Op (SinkRep MC)
op, Sunk
forall a. Monoid a => a
mempty)