{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | "Sinking" is conceptually the opposite of hoisting.  The idea is
-- to take code that looks like this:
--
-- @
-- x = xs[i]
-- y = ys[i]
-- if x != 0 then {
--   y
-- } else {
--   0
-- }
-- @
--
-- and turn it into
--
-- @
-- x = xs[i]
-- if x != 0 then {
--   y = ys[i]
--   y
-- } else {
--   0
-- }
-- @
--
-- The idea is to delay loads from memory until (if) they are actually
-- needed.  Code patterns like the above is particularly common in
-- code that makes use of pattern matching on sum types.
--
-- We are currently quite conservative about when we do this.  In
-- particular, if any consumption is going on in a body, we don't do
-- anything.  This is far too conservative.  Also, we are careful
-- never to duplicate work.
--
-- This pass redundantly computes free-variable information a lot.  If
-- you ever see this pass as being a compilation speed bottleneck,
-- start by caching that a bit.
--
-- This pass is defined on the Kernels representation.  This is not
-- because we do anything kernel-specific here, but simply because
-- more explicit indexing is going on after SOACs are gone.
module Futhark.Optimise.Sink (sinkKernels, sinkMC) where

import Control.Monad.State
import Data.Bifunctor
import Data.List (foldl')
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Aliases
import Futhark.IR.Kernels
import Futhark.IR.MC
import Futhark.Pass

type SymbolTable lore = ST.SymbolTable lore

type Sinking lore = M.Map VName (Stm lore)

type Sunk = S.Set VName

type Sinker lore a = SymbolTable lore -> Sinking lore -> a -> (a, Sunk)

type Constraints lore =
  ( ASTLore lore,
    Aliased lore,
    ST.IndexOp (Op lore)
  )

-- | Given a statement, compute how often each of its free variables
-- are used.  Not accurate: what we care about are only 1, and greater
-- than 1.
multiplicity :: Constraints lore => Stm lore -> M.Map VName Int
multiplicity :: forall lore. Constraints lore => Stm lore -> Map VName Int
multiplicity Stm lore
stm =
  case Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm of
    If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_ ->
      SubExp -> Int -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free SubExp
cond Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT lore -> Int -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free BodyT lore
tbranch Int
1 Map VName Int -> Map VName Int -> Map VName Int
`comb` BodyT lore -> Int -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free BodyT lore
fbranch Int
1
    Op {} -> Stm lore -> Int -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Stm lore
stm Int
2
    DoLoop {} -> Stm lore -> Int -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Stm lore
stm Int
2
    Exp lore
_ -> Stm lore -> Int -> Map VName Int
forall {a} {a}. FreeIn a => a -> a -> Map VName a
free Stm lore
stm Int
1
  where
    free :: a -> a -> Map VName a
free a
x a
k = [(VName, a)] -> Map VName a
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, a)] -> Map VName a) -> [(VName, a)] -> Map VName a
forall a b. (a -> b) -> a -> b
$ [VName] -> [a] -> [(VName, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x) ([a] -> [(VName, a)]) -> [a] -> [(VName, a)]
forall a b. (a -> b) -> a -> b
$ a -> [a]
forall a. a -> [a]
repeat a
k
    comb :: Map VName Int -> Map VName Int -> Map VName Int
comb = (Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+)

optimiseBranch ::
  Constraints lore =>
  Sinker lore (Op lore) ->
  Sinker lore (Body lore)
optimiseBranch :: forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBranch Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking (Body BodyDec lore
dec Stms lore
stms Result
res) =
  let (Stms lore
stms', Sunk
stms_sunk) = Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
forall lore.
Constraints lore =>
Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
optimiseStms Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking' Stms lore
stms (Names -> (Stms lore, Sunk)) -> Names -> (Stms lore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
   in ( BodyDec lore -> Stms lore -> Result -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec lore
dec (Stms lore
sunk_stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
stms') Result
res,
        Sunk
sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
      )
  where
    free_in_stms :: Names
free_in_stms = Stms lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
stms Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
    (Sinking lore
sinking_here, Sinking lore
sinking') = (VName -> Stm lore -> Bool)
-> Sinking lore -> (Sinking lore, Sinking lore)
forall k a. (k -> a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partitionWithKey VName -> Stm lore -> Bool
sunkHere Sinking lore
sinking
    sunk_stms :: Stms lore
sunk_stms = [Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm lore] -> Stms lore) -> [Stm lore] -> Stms lore
forall a b. (a -> b) -> a -> b
$ Sinking lore -> [Stm lore]
forall k a. Map k a -> [a]
M.elems Sinking lore
sinking_here
    sunkHere :: VName -> Stm lore -> Bool
sunkHere VName
v Stm lore
stm =
      VName
v VName -> Names -> Bool
`nameIn` Names
free_in_stms
        Bool -> Bool -> Bool
&& (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.available` SymbolTable lore
vtable) (Names -> [VName]
namesToList (Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stm lore
stm))
    sunk :: Sunk
sunk = [VName] -> Sunk
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Sunk) -> [VName] -> Sunk
forall a b. (a -> b) -> a -> b
$ (Stm lore -> [VName]) -> Stms lore -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) Stms lore
sunk_stms

optimiseStms ::
  Constraints lore =>
  Sinker lore (Op lore) ->
  SymbolTable lore ->
  Sinking lore ->
  Stms lore ->
  Names ->
  (Stms lore, Sunk)
optimiseStms :: forall lore.
Constraints lore =>
Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
optimiseStms Sinker lore (Op lore)
onOp SymbolTable lore
init_vtable Sinking lore
init_sinking Stms lore
all_stms Names
free_in_res =
  let ([Stm lore]
all_stms', Sunk
sunk) =
        SymbolTable lore
-> Sinking lore -> [Stm lore] -> ([Stm lore], Sunk)
optimiseStms' SymbolTable lore
init_vtable Sinking lore
init_sinking ([Stm lore] -> ([Stm lore], Sunk))
-> [Stm lore] -> ([Stm lore], Sunk)
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
   in ([Stm lore] -> Stms lore
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm lore]
all_stms', Sunk
sunk)
  where
    multiplicities :: Map VName Int
multiplicities =
      (Map VName Int -> Map VName Int -> Map VName Int)
-> Map VName Int -> [Map VName Int] -> Map VName Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
        ((Int -> Int -> Int)
-> Map VName Int -> Map VName Int -> Map VName Int
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(+))
        ([(VName, Int)] -> Map VName Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Names -> [VName]
namesToList Names
free_in_res) (Int -> [Int]
forall a. a -> [a]
repeat Int
1)))
        ((Stm lore -> Map VName Int) -> [Stm lore] -> [Map VName Int]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> Map VName Int
forall lore. Constraints lore => Stm lore -> Map VName Int
multiplicity ([Stm lore] -> [Map VName Int]) -> [Stm lore] -> [Map VName Int]
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms)

    optimiseStms' :: SymbolTable lore
-> Sinking lore -> [Stm lore] -> ([Stm lore], Sunk)
optimiseStms' SymbolTable lore
_ Sinking lore
_ [] = ([], Sunk
forall a. Monoid a => a
mempty)
    optimiseStms' SymbolTable lore
vtable Sinking lore
sinking (Stm lore
stm : [Stm lore]
stms)
      | BasicOp Index {} <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
        [PatElemT (LetDec lore)
pe] <- PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm),
        TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> TypeBase Shape NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase Shape NoUniqueness
patElemType PatElemT (LetDec lore)
pe,
        Bool -> (Int -> Bool) -> Maybe Int -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1) (Maybe Int -> Bool) -> Maybe Int -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe) Map VName Int
multiplicities =
        let ([Stm lore]
stms', Sunk
sunk) =
              SymbolTable lore
-> Sinking lore -> [Stm lore] -> ([Stm lore], Sunk)
optimiseStms' SymbolTable lore
vtable' (VName -> Stm lore -> Sinking lore -> Sinking lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe) Stm lore
stm Sinking lore
sinking) [Stm lore]
stms
         in if PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe VName -> Sunk -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Sunk
sunk
              then ([Stm lore]
stms', Sunk
sunk)
              else (Stm lore
stm Stm lore -> [Stm lore] -> [Stm lore]
forall a. a -> [a] -> [a]
: [Stm lore]
stms', Sunk
sunk)
      | If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
ret <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
        let (BodyT lore
tbranch', Sunk
tsunk) = Sinker lore (Op lore) -> Sinker lore (BodyT lore)
forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBranch Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking BodyT lore
tbranch
            (BodyT lore
fbranch', Sunk
fsunk) = Sinker lore (Op lore) -> Sinker lore (BodyT lore)
forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBranch Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking BodyT lore
fbranch
            ([Stm lore]
stms', Sunk
sunk) = SymbolTable lore
-> Sinking lore -> [Stm lore] -> ([Stm lore], Sunk)
optimiseStms' SymbolTable lore
vtable' Sinking lore
sinking [Stm lore]
stms
         in ( Stm lore
stm {stmExp :: ExpT lore
stmExp = SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond BodyT lore
tbranch' BodyT lore
fbranch' IfDec (BranchType lore)
ret} Stm lore -> [Stm lore] -> [Stm lore]
forall a. a -> [a] -> [a]
: [Stm lore]
stms',
              Sunk
tsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
fsunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk
            )
      | Op Op lore
op <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
        let (Op lore
op', Sunk
op_sunk) = Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking Op lore
op
            ([Stm lore]
stms', Sunk
stms_sunk) = SymbolTable lore
-> Sinking lore -> [Stm lore] -> ([Stm lore], Sunk)
optimiseStms' SymbolTable lore
vtable' Sinking lore
sinking [Stm lore]
stms
         in ( Stm lore
stm {stmExp :: ExpT lore
stmExp = Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op Op lore
op'} Stm lore -> [Stm lore] -> [Stm lore]
forall a. a -> [a] -> [a]
: [Stm lore]
stms',
              Sunk
stms_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
op_sunk
            )
      | Bool
otherwise =
        let ([Stm lore]
stms', Sunk
stms_sunk) = SymbolTable lore
-> Sinking lore -> [Stm lore] -> ([Stm lore], Sunk)
optimiseStms' SymbolTable lore
vtable' Sinking lore
sinking [Stm lore]
stms
            (ExpT lore
e', Sunk
stm_sunk) = State Sunk (ExpT lore) -> Sunk -> (ExpT lore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (Mapper lore lore (StateT Sunk Identity)
-> ExpT lore -> State Sunk (ExpT lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper lore lore (StateT Sunk Identity)
mapper (Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm)) Sunk
forall a. Monoid a => a
mempty
         in ( Stm lore
stm {stmExp :: ExpT lore
stmExp = ExpT lore
e'} Stm lore -> [Stm lore] -> [Stm lore]
forall a. a -> [a] -> [a]
: [Stm lore]
stms',
              Sunk
stm_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
stms_sunk
            )
      where
        vtable' :: SymbolTable lore
vtable' = Stm lore -> SymbolTable lore -> SymbolTable lore
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
ST.insertStm Stm lore
stm SymbolTable lore
vtable
        mapper :: Mapper lore lore (StateT Sunk Identity)
mapper =
          Mapper lore lore (StateT Sunk Identity)
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
            { mapOnBody :: Scope lore -> BodyT lore -> StateT Sunk Identity (BodyT lore)
mapOnBody = \Scope lore
scope BodyT lore
body -> do
                let (BodyT lore
body', Sunk
sunk) =
                      Sinker lore (Op lore) -> Sinker lore (BodyT lore)
forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBody
                        Sinker lore (Op lore)
onOp
                        (Scope lore -> SymbolTable lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope lore
scope SymbolTable lore -> SymbolTable lore -> SymbolTable lore
forall a. Semigroup a => a -> a -> a
<> SymbolTable lore
vtable)
                        Sinking lore
sinking
                        BodyT lore
body
                (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
                BodyT lore -> StateT Sunk Identity (BodyT lore)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT lore
body'
            }

optimiseBody ::
  Constraints lore =>
  Sinker lore (Op lore) ->
  Sinker lore (Body lore)
optimiseBody :: forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBody Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking (Body BodyDec lore
attr Stms lore
stms Result
res) =
  let (Stms lore
stms', Sunk
sunk) = Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
forall lore.
Constraints lore =>
Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
optimiseStms Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking Stms lore
stms (Names -> (Stms lore, Sunk)) -> Names -> (Stms lore, Sunk)
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res
   in (BodyDec lore -> Stms lore -> Result -> BodyT lore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec lore
attr Stms lore
stms' Result
res, Sunk
sunk)

optimiseKernelBody ::
  Constraints lore =>
  Sinker lore (Op lore) ->
  Sinker lore (KernelBody lore)
optimiseKernelBody :: forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (KernelBody lore)
optimiseKernelBody Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking (KernelBody BodyDec lore
attr Stms lore
stms [KernelResult]
res) =
  let (Stms lore
stms', Sunk
sunk) = Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
forall lore.
Constraints lore =>
Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
optimiseStms Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking Stms lore
stms (Names -> (Stms lore, Sunk)) -> Names -> (Stms lore, Sunk)
forall a b. (a -> b) -> a -> b
$ [KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
res
   in (BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody BodyDec lore
attr Stms lore
stms' [KernelResult]
res, Sunk
sunk)

optimiseSegOp ::
  Constraints lore =>
  Sinker lore (Op lore) ->
  Sinker lore (SegOp lvl lore)
optimiseSegOp :: forall lore lvl.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (SegOp lvl lore)
optimiseSegOp Sinker lore (Op lore)
onOp SymbolTable lore
vtable Sinking lore
sinking SegOp lvl lore
op =
  let scope :: Scope lore
scope = SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace (SegSpace -> Scope lore) -> SegSpace -> Scope lore
forall a b. (a -> b) -> a -> b
$ SegOp lvl lore -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp lvl lore
op
   in State Sunk (SegOp lvl lore) -> Sunk -> (SegOp lvl lore, Sunk)
forall s a. State s a -> s -> (a, s)
runState (SegOpMapper lvl lore lore (StateT Sunk Identity)
-> SegOp lvl lore -> State Sunk (SegOp lvl lore)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM (Scope lore -> SegOpMapper lvl lore lore (StateT Sunk Identity)
opMapper Scope lore
scope) SegOp lvl lore
op) Sunk
forall a. Monoid a => a
mempty
  where
    opMapper :: Scope lore -> SegOpMapper lvl lore lore (StateT Sunk Identity)
opMapper Scope lore
scope =
      SegOpMapper lvl Any Any (StateT Sunk Identity)
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
        { mapOnSegOpLambda :: Lambda lore -> StateT Sunk Identity (Lambda lore)
mapOnSegOpLambda = \Lambda lore
lam -> do
            let (Body lore
body, Sunk
sunk) =
                  Sinker lore (Op lore) -> Sinker lore (Body lore)
forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBody Sinker lore (Op lore)
onOp SymbolTable lore
op_vtable Sinking lore
sinking (Body lore -> (Body lore, Sunk)) -> Body lore -> (Body lore, Sunk)
forall a b. (a -> b) -> a -> b
$
                    Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
            (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
            Lambda lore -> StateT Sunk Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
lam {lambdaBody :: Body lore
lambdaBody = Body lore
body},
          mapOnSegOpBody :: KernelBody lore -> StateT Sunk Identity (KernelBody lore)
mapOnSegOpBody = \KernelBody lore
body -> do
            let (KernelBody lore
body', Sunk
sunk) =
                  Sinker lore (Op lore) -> Sinker lore (KernelBody lore)
forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (KernelBody lore)
optimiseKernelBody Sinker lore (Op lore)
onOp SymbolTable lore
op_vtable Sinking lore
sinking KernelBody lore
body
            (Sunk -> Sunk) -> StateT Sunk Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
            KernelBody lore -> StateT Sunk Identity (KernelBody lore)
forall (m :: * -> *) a. Monad m => a -> m a
return KernelBody lore
body'
        }
      where
        op_vtable :: SymbolTable lore
op_vtable = Scope lore -> SymbolTable lore
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope Scope lore
scope SymbolTable lore -> SymbolTable lore -> SymbolTable lore
forall a. Semigroup a => a -> a -> a
<> SymbolTable lore
vtable

type SinkLore lore = Aliases lore

sink ::
  ( ASTLore lore,
    CanBeAliased (Op lore),
    ST.IndexOp (OpWithAliases (Op lore))
  ) =>
  Sinker (SinkLore lore) (Op (SinkLore lore)) ->
  Pass lore lore
sink :: forall lore.
(ASTLore lore, CanBeAliased (Op lore),
 IndexOp (OpWithAliases (Op lore))) =>
Sinker (SinkLore lore) (Op (SinkLore lore)) -> Pass lore lore
sink Sinker (SinkLore lore) (Op (SinkLore lore))
onOp =
  String
-> String -> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"sink" String
"move memory loads closer to their uses" ((Prog lore -> PassM (Prog lore)) -> Pass lore lore)
-> (Prog lore -> PassM (Prog lore)) -> Pass lore lore
forall a b. (a -> b) -> a -> b
$
    (Prog (SinkLore lore) -> Prog lore)
-> PassM (Prog (SinkLore lore)) -> PassM (Prog lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Prog (SinkLore lore) -> Prog lore
forall lore.
CanBeAliased (Op lore) =>
Prog (Aliases lore) -> Prog lore
removeProgAliases
      (PassM (Prog (SinkLore lore)) -> PassM (Prog lore))
-> (Prog lore -> PassM (Prog (SinkLore lore)))
-> Prog lore
-> PassM (Prog lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms (SinkLore lore) -> PassM (Stms (SinkLore lore)))
-> (Stms (SinkLore lore)
    -> FunDef (SinkLore lore) -> PassM (FunDef (SinkLore lore)))
-> Prog (SinkLore lore)
-> PassM (Prog (SinkLore lore))
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts Stms (SinkLore lore) -> PassM (Stms (SinkLore lore))
onConsts Stms (SinkLore lore)
-> FunDef (SinkLore lore) -> PassM (FunDef (SinkLore lore))
onFun
      (Prog (SinkLore lore) -> PassM (Prog (SinkLore lore)))
-> (Prog lore -> Prog (SinkLore lore))
-> Prog lore
-> PassM (Prog (SinkLore lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog lore -> Prog (SinkLore lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Prog lore -> Prog (Aliases lore)
Alias.aliasAnalysis
  where
    onFun :: Stms (SinkLore lore)
-> FunDef (SinkLore lore) -> PassM (FunDef (SinkLore lore))
onFun Stms (SinkLore lore)
_ FunDef (SinkLore lore)
fd = do
      let vtable :: SymbolTable (SinkLore lore)
vtable = [FParam (SinkLore lore)]
-> SymbolTable (SinkLore lore) -> SymbolTable (SinkLore lore)
forall lore.
ASTLore lore =>
[FParam lore] -> SymbolTable lore -> SymbolTable lore
ST.insertFParams (FunDef (SinkLore lore) -> [FParam (SinkLore lore)]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef (SinkLore lore)
fd) SymbolTable (SinkLore lore)
forall a. Monoid a => a
mempty
          (Body (SinkLore lore)
body, Sunk
_) = Sinker (SinkLore lore) (Op (SinkLore lore))
-> Sinker (SinkLore lore) (Body (SinkLore lore))
forall lore.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (Body lore)
optimiseBody Sinker (SinkLore lore) (Op (SinkLore lore))
onOp SymbolTable (SinkLore lore)
vtable Sinking (SinkLore lore)
forall a. Monoid a => a
mempty (Body (SinkLore lore) -> (Body (SinkLore lore), Sunk))
-> Body (SinkLore lore) -> (Body (SinkLore lore), Sunk)
forall a b. (a -> b) -> a -> b
$ FunDef (SinkLore lore) -> Body (SinkLore lore)
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef (SinkLore lore)
fd
      FunDef (SinkLore lore) -> PassM (FunDef (SinkLore lore))
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef (SinkLore lore)
fd {funDefBody :: Body (SinkLore lore)
funDefBody = Body (SinkLore lore)
body}

    onConsts :: Stms (SinkLore lore) -> PassM (Stms (SinkLore lore))
onConsts Stms (SinkLore lore)
consts =
      Stms (SinkLore lore) -> PassM (Stms (SinkLore lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms (SinkLore lore) -> PassM (Stms (SinkLore lore)))
-> Stms (SinkLore lore) -> PassM (Stms (SinkLore lore))
forall a b. (a -> b) -> a -> b
$
        (Stms (SinkLore lore), Sunk) -> Stms (SinkLore lore)
forall a b. (a, b) -> a
fst ((Stms (SinkLore lore), Sunk) -> Stms (SinkLore lore))
-> (Stms (SinkLore lore), Sunk) -> Stms (SinkLore lore)
forall a b. (a -> b) -> a -> b
$
          Sinker (SinkLore lore) (Op (SinkLore lore))
-> SymbolTable (SinkLore lore)
-> Sinking (SinkLore lore)
-> Stms (SinkLore lore)
-> Names
-> (Stms (SinkLore lore), Sunk)
forall lore.
Constraints lore =>
Sinker lore (Op lore)
-> SymbolTable lore
-> Sinking lore
-> Stms lore
-> Names
-> (Stms lore, Sunk)
optimiseStms Sinker (SinkLore lore) (Op (SinkLore lore))
onOp SymbolTable (SinkLore lore)
forall a. Monoid a => a
mempty Sinking (SinkLore lore)
forall a. Monoid a => a
mempty Stms (SinkLore lore)
consts (Names -> (Stms (SinkLore lore), Sunk))
-> Names -> (Stms (SinkLore lore), Sunk)
forall a b. (a -> b) -> a -> b
$
            [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo (SinkLore lore)) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo (SinkLore lore)) -> [VName])
-> Map VName (NameInfo (SinkLore lore)) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (SinkLore lore) -> Map VName (NameInfo (SinkLore lore))
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms (SinkLore lore)
consts

-- | Sinking in GPU kernels.
sinkKernels :: Pass Kernels Kernels
sinkKernels :: Pass Kernels Kernels
sinkKernels = Sinker (SinkLore Kernels) (Op (SinkLore Kernels))
-> Pass Kernels Kernels
forall lore.
(ASTLore lore, CanBeAliased (Op lore),
 IndexOp (OpWithAliases (Op lore))) =>
Sinker (SinkLore lore) (Op (SinkLore lore)) -> Pass lore lore
sink Sinker (SinkLore Kernels) (Op (SinkLore Kernels))
onHostOp
  where
    onHostOp :: Sinker (SinkLore Kernels) (Op (SinkLore Kernels))
    onHostOp :: Sinker (SinkLore Kernels) (Op (SinkLore Kernels))
onHostOp SymbolTable (SinkLore Kernels)
vtable Sinking (SinkLore Kernels)
sinking (SegOp SegOp SegLevel (SinkLore Kernels)
op) =
      (SegOp SegLevel (SinkLore Kernels)
 -> HostOp (SinkLore Kernels) (SOAC (SinkLore Kernels)))
-> (SegOp SegLevel (SinkLore Kernels), Sunk)
-> (HostOp (SinkLore Kernels) (SOAC (SinkLore Kernels)), Sunk)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SegOp SegLevel (SinkLore Kernels)
-> HostOp (SinkLore Kernels) (SOAC (SinkLore Kernels))
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp ((SegOp SegLevel (SinkLore Kernels), Sunk)
 -> (HostOp (SinkLore Kernels) (SOAC (SinkLore Kernels)), Sunk))
-> (SegOp SegLevel (SinkLore Kernels), Sunk)
-> (HostOp (SinkLore Kernels) (SOAC (SinkLore Kernels)), Sunk)
forall a b. (a -> b) -> a -> b
$ Sinker (SinkLore Kernels) (Op (SinkLore Kernels))
-> Sinker (SinkLore Kernels) (SegOp SegLevel (SinkLore Kernels))
forall lore lvl.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (SegOp lvl lore)
optimiseSegOp Sinker (SinkLore Kernels) (Op (SinkLore Kernels))
onHostOp SymbolTable (SinkLore Kernels)
vtable Sinking (SinkLore Kernels)
sinking SegOp SegLevel (SinkLore Kernels)
op
    onHostOp SymbolTable (SinkLore Kernels)
_ Sinking (SinkLore Kernels)
_ Op (SinkLore Kernels)
op = (Op (SinkLore Kernels)
op, Sunk
forall a. Monoid a => a
mempty)

-- | Sinking for multicore.
sinkMC :: Pass MC MC
sinkMC :: Pass MC MC
sinkMC = Sinker (SinkLore MC) (Op (SinkLore MC)) -> Pass MC MC
forall lore.
(ASTLore lore, CanBeAliased (Op lore),
 IndexOp (OpWithAliases (Op lore))) =>
Sinker (SinkLore lore) (Op (SinkLore lore)) -> Pass lore lore
sink Sinker (SinkLore MC) (Op (SinkLore MC))
onHostOp
  where
    onHostOp :: Sinker (SinkLore MC) (Op (SinkLore MC))
    onHostOp :: Sinker (SinkLore MC) (Op (SinkLore MC))
onHostOp SymbolTable (SinkLore MC)
vtable Sinking (SinkLore MC)
sinking (ParOp Maybe (SegOp () (SinkLore MC))
par_op SegOp () (SinkLore MC)
op) =
      let (Maybe (SegOp () (SinkLore MC))
par_op', Sunk
par_sunk) =
            (Maybe (SegOp () (SinkLore MC)), Sunk)
-> (SegOp () (SinkLore MC)
    -> (Maybe (SegOp () (SinkLore MC)), Sunk))
-> Maybe (SegOp () (SinkLore MC))
-> (Maybe (SegOp () (SinkLore MC)), Sunk)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
              (Maybe (SegOp () (SinkLore MC))
forall a. Maybe a
Nothing, Sunk
forall a. Monoid a => a
mempty)
              ((SegOp () (SinkLore MC) -> Maybe (SegOp () (SinkLore MC)))
-> (SegOp () (SinkLore MC), Sunk)
-> (Maybe (SegOp () (SinkLore MC)), Sunk)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first SegOp () (SinkLore MC) -> Maybe (SegOp () (SinkLore MC))
forall a. a -> Maybe a
Just ((SegOp () (SinkLore MC), Sunk)
 -> (Maybe (SegOp () (SinkLore MC)), Sunk))
-> (SegOp () (SinkLore MC) -> (SegOp () (SinkLore MC), Sunk))
-> SegOp () (SinkLore MC)
-> (Maybe (SegOp () (SinkLore MC)), Sunk)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sinker (SinkLore MC) (Op (SinkLore MC))
-> Sinker (SinkLore MC) (SegOp () (SinkLore MC))
forall lore lvl.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (SegOp lvl lore)
optimiseSegOp Sinker (SinkLore MC) (Op (SinkLore MC))
onHostOp SymbolTable (SinkLore MC)
vtable Sinking (SinkLore MC)
sinking)
              Maybe (SegOp () (SinkLore MC))
par_op
          (SegOp () (SinkLore MC)
op', Sunk
sunk) = Sinker (SinkLore MC) (Op (SinkLore MC))
-> Sinker (SinkLore MC) (SegOp () (SinkLore MC))
forall lore lvl.
Constraints lore =>
Sinker lore (Op lore) -> Sinker lore (SegOp lvl lore)
optimiseSegOp Sinker (SinkLore MC) (Op (SinkLore MC))
onHostOp SymbolTable (SinkLore MC)
vtable Sinking (SinkLore MC)
sinking SegOp () (SinkLore MC)
op
       in (Maybe (SegOp () (SinkLore MC))
-> SegOp () (SinkLore MC)
-> MCOp (SinkLore MC) (SOAC (SinkLore MC))
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () (SinkLore MC))
par_op' SegOp () (SinkLore MC)
op', Sunk
par_sunk Sunk -> Sunk -> Sunk
forall a. Semigroup a => a -> a -> a
<> Sunk
sunk)
    onHostOp SymbolTable (SinkLore MC)
_ Sinking (SinkLore MC)
_ Op (SinkLore MC)
op = (Op (SinkLore MC)
op, Sunk
forall a. Monoid a => a
mempty)