{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
-- | This module implements a compiler pass for inlining functions,
-- then removing those that have become dead.
module Futhark.Optimise.InliningDeadFun
  ( inlineFunctions
  , removeDeadFunctions
  )
  where

import Control.Monad.Identity
import Control.Monad.State
import Data.List (partition)
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Control.Parallel.Strategies

import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify
  (simpleSOACS, simplifyFun, simplifyConsts)
import Futhark.Optimise.CSE
import Futhark.Optimise.Simplify.Lore (addScopeWisdom)
import Futhark.Transform.CopyPropagate
  (copyPropagateInProg, copyPropagateInFun)
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Transform.Rename
import Futhark.Analysis.CallGraph
import Futhark.Binder
import Futhark.Pass

parMapM :: MonadFreshNames m => (a -> State VNameSource b) -> [a] -> m [b]
parMapM :: (a -> State VNameSource b) -> [a] -> m [b]
parMapM a -> State VNameSource b
f [a]
as =
  (VNameSource -> ([b], VNameSource)) -> m [b]
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ([b], VNameSource)) -> m [b])
-> (VNameSource -> ([b], VNameSource)) -> m [b]
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  let f' :: a -> (b, VNameSource)
f' a
a = State VNameSource b -> VNameSource -> (b, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (a -> State VNameSource b
f a
a) VNameSource
src
      ([b]
bs, [VNameSource]
srcs) = [(b, VNameSource)] -> ([b], [VNameSource])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(b, VNameSource)] -> ([b], [VNameSource]))
-> [(b, VNameSource)] -> ([b], [VNameSource])
forall a b. (a -> b) -> a -> b
$ Strategy (b, VNameSource)
-> (a -> (b, VNameSource)) -> [a] -> [(b, VNameSource)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy (b, VNameSource)
forall a. Strategy a
rpar a -> (b, VNameSource)
f' [a]
as
  in ([b]
bs, [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat [VNameSource]
srcs)

aggInlineFunctions :: MonadFreshNames m => Prog SOACS -> m (Prog SOACS)
aggInlineFunctions :: Prog SOACS -> m (Prog SOACS)
aggInlineFunctions Prog SOACS
prog =
  let Prog Stms SOACS
consts [FunDef SOACS]
funs = Prog SOACS
prog
  in (Stms SOACS -> [FunDef SOACS] -> Prog SOACS)
-> (Stms SOACS, [FunDef SOACS]) -> Prog SOACS
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Stms SOACS -> [FunDef SOACS] -> Prog SOACS
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog ((Stms SOACS, [FunDef SOACS]) -> Prog SOACS)
-> ((Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS]))
-> (Stms SOACS, [FunDef SOACS])
-> Prog SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([FunDef SOACS] -> [FunDef SOACS])
-> (Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((FunDef SOACS -> Bool) -> [FunDef SOACS] -> [FunDef SOACS]
forall a. (a -> Bool) -> [a] -> [a]
filter FunDef SOACS -> Bool
forall lore. FunDef lore -> Bool
keep) ((Stms SOACS, [FunDef SOACS]) -> Prog SOACS)
-> m (Stms SOACS, [FunDef SOACS]) -> m (Prog SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
     Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
forall (m :: * -> *).
MonadFreshNames m =>
Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
recurse Int
0 (Scope (Wise SOACS) -> SymbolTable (Wise SOACS)
forall lore. ASTLore lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope SOACS -> Scope (Wise SOACS)
forall lore. Scope lore -> Scope (Wise lore)
addScopeWisdom (Stms SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SOACS
consts)), Stms SOACS
consts, [FunDef SOACS]
funs)
  where fdmap :: [FunDef lore] -> Map Name (FunDef lore)
fdmap [FunDef lore]
fds =
          [(Name, FunDef lore)] -> Map Name (FunDef lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, FunDef lore)] -> Map Name (FunDef lore))
-> [(Name, FunDef lore)] -> Map Name (FunDef lore)
forall a b. (a -> b) -> a -> b
$ [Name] -> [FunDef lore] -> [(Name, FunDef lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((FunDef lore -> Name) -> [FunDef lore] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName [FunDef lore]
fds) [FunDef lore]
fds

        cg :: CallGraph
cg = Prog SOACS -> CallGraph
buildCallGraph Prog SOACS
prog
        noninlined :: Set Name
noninlined = Prog SOACS -> Set Name
findNoninlined Prog SOACS
prog

        noCallsTo :: Set Name -> FunDef lore -> Bool
noCallsTo Set Name
which FunDef lore
fundec =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Name -> Bool) -> Set Name -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set Name
which) (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> CallGraph -> Set Name
allCalledBy (FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName FunDef lore
fundec) CallGraph
cg

        -- The inverse rate at which we perform full simplification
        -- after inlining.  For the other steps we just do copy
        -- propagation.  The rate here has been determined
        -- heuristically and is probably not optimal for any given
        -- program.
        simplifyRate :: Int
        simplifyRate :: Int
simplifyRate = Int
4

        -- We apply simplification after every round of inlining,
        -- because it is more efficient to shrink the program as soon
        -- as possible, rather than wait until it has balooned after
        -- full inlining.
        recurse :: Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
recurse Int
i (SymbolTable (Wise SOACS)
vtable, Stms SOACS
consts, [FunDef SOACS]
funs) = do
          let remaining :: Set Name
remaining = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> Name) -> [FunDef SOACS] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName [FunDef SOACS]
funs
              ([FunDef SOACS]
to_be_inlined, [FunDef SOACS]
maybe_inline_in) =
                (FunDef SOACS -> Bool)
-> [FunDef SOACS] -> ([FunDef SOACS], [FunDef SOACS])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Set Name -> FunDef SOACS -> Bool
forall lore. Set Name -> FunDef lore -> Bool
noCallsTo Set Name
remaining) [FunDef SOACS]
funs
              ([FunDef SOACS]
not_to_inline_in, [FunDef SOACS]
to_inline_in) =
                (FunDef SOACS -> Bool)
-> [FunDef SOACS] -> ([FunDef SOACS], [FunDef SOACS])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Set Name -> FunDef SOACS -> Bool
forall lore. Set Name -> FunDef lore -> Bool
noCallsTo
                           ([Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> Name) -> [FunDef SOACS] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName [FunDef SOACS]
to_be_inlined))
                [FunDef SOACS]
maybe_inline_in
              keep_although_inlined :: [FunDef SOACS]
keep_although_inlined = (FunDef SOACS -> Bool) -> [FunDef SOACS] -> [FunDef SOACS]
forall a. (a -> Bool) -> [a] -> [a]
filter FunDef SOACS -> Bool
forall lore. FunDef lore -> Bool
keep [FunDef SOACS]
to_be_inlined
          if [FunDef SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FunDef SOACS]
to_be_inlined
            then (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS
consts, [FunDef SOACS]
funs)
            else do

            (SymbolTable (Wise SOACS)
vtable', Stms SOACS
consts') <-
              if (FunDef SOACS -> Bool) -> [FunDef SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Name -> CallGraph -> Bool
`calledByConsts` CallGraph
cg) (Name -> Bool) -> (FunDef SOACS -> Name) -> FunDef SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName) [FunDef SOACS]
to_be_inlined
              then Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyConsts (Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS))
-> (Stms SOACS -> Stms SOACS)
-> Stms SOACS
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Stms SOACS -> Stms SOACS
forall lore.
(ASTLore lore, CanBeAliased (Op lore),
 CSEInOp (OpWithAliases (Op lore))) =>
Bool -> Stms lore -> Stms lore
performCSEOnStms Bool
True (Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS))
-> m (Stms SOACS) -> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                   Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
inlineInStms ([FunDef SOACS] -> Map Name (FunDef SOACS)
forall lore. [FunDef lore] -> Map Name (FunDef lore)
fdmap [FunDef SOACS]
to_be_inlined) Stms SOACS
consts
              else (SymbolTable (Wise SOACS), Stms SOACS)
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymbolTable (Wise SOACS)
vtable, Stms SOACS
consts)

            let simplifyFun' :: FunDef SOACS -> m (FunDef SOACS)
simplifyFun' FunDef SOACS
fd
                  | Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
simplifyRate Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
                      SimpleOps SOACS
-> SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> SymbolTable (Wise lore) -> FunDef lore -> m (FunDef lore)
copyPropagateInFun SimpleOps SOACS
simpleSOACS SymbolTable (Wise SOACS)
vtable' (FunDef SOACS -> m (FunDef SOACS))
-> (FunDef SOACS -> FunDef SOACS)
-> FunDef SOACS
-> m (FunDef SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                      Bool -> FunDef SOACS -> FunDef SOACS
forall lore.
(ASTLore lore, CanBeAliased (Op lore),
 CSEInOp (OpWithAliases (Op lore))) =>
Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef Bool
True (FunDef SOACS -> m (FunDef SOACS))
-> m (FunDef SOACS) -> m (FunDef SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                      SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun SymbolTable (Wise SOACS)
vtable' FunDef SOACS
fd
                  | Bool
otherwise =
                      SimpleOps SOACS
-> SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> SymbolTable (Wise lore) -> FunDef lore -> m (FunDef lore)
copyPropagateInFun SimpleOps SOACS
simpleSOACS SymbolTable (Wise SOACS)
vtable' FunDef SOACS
fd

            let onFun :: FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS)
onFun = FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
FunDef SOACS -> m (FunDef SOACS)
simplifyFun' (FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS))
-> (FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS))
-> FunDef SOACS
-> StateT VNameSource Identity (FunDef SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
                        Map Name (FunDef SOACS)
-> FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> FunDef SOACS -> m (FunDef SOACS)
inlineInFunDef ([FunDef SOACS] -> Map Name (FunDef SOACS)
forall lore. [FunDef lore] -> Map Name (FunDef lore)
fdmap [FunDef SOACS]
to_be_inlined)
            [FunDef SOACS]
to_inline_in' <- (FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS))
-> [FunDef SOACS] -> m [FunDef SOACS]
forall (m :: * -> *) a b.
MonadFreshNames m =>
(a -> State VNameSource b) -> [a] -> m [b]
parMapM FunDef SOACS -> StateT VNameSource Identity (FunDef SOACS)
onFun [FunDef SOACS]
to_inline_in
            ([FunDef SOACS] -> [FunDef SOACS])
-> (Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([FunDef SOACS]
keep_although_inlined[FunDef SOACS] -> [FunDef SOACS] -> [FunDef SOACS]
forall a. Semigroup a => a -> a -> a
<>) ((Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS]))
-> m (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
              Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
recurse (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
              (SymbolTable (Wise SOACS)
vtable', Stms SOACS
consts', [FunDef SOACS]
not_to_inline_in [FunDef SOACS] -> [FunDef SOACS] -> [FunDef SOACS]
forall a. Semigroup a => a -> a -> a
<> [FunDef SOACS]
to_inline_in')

        keep :: FunDef lore -> Bool
keep FunDef lore
fd =
          Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust (FunDef lore -> Maybe EntryPoint
forall lore. FunDef lore -> Maybe EntryPoint
funDefEntryPoint FunDef lore
fd) Bool -> Bool -> Bool
||
          FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName FunDef lore
fd Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set Name
noninlined

-- | @inlineInFunDef constf fdmap caller@ inlines in @calleer@ the
-- functions in @fdmap@ that are called as @constf@. At this point the
-- preconditions are that if @fdmap@ is not empty, and, more
-- importantly, the functions in @fdmap@ do not call any other
-- functions.
inlineInFunDef :: MonadFreshNames m =>
                  M.Map Name (FunDef SOACS) -> FunDef SOACS
               -> m (FunDef SOACS)
inlineInFunDef :: Map Name (FunDef SOACS) -> FunDef SOACS -> m (FunDef SOACS)
inlineInFunDef Map Name (FunDef SOACS)
fdmap (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rtp [FParam SOACS]
args BodyT SOACS
body) =
  Maybe EntryPoint
-> Attrs
-> Name
-> [RetType SOACS]
-> [FParam SOACS]
-> BodyT SOACS
-> FunDef SOACS
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rtp [FParam SOACS]
args (BodyT SOACS -> FunDef SOACS)
-> m (BodyT SOACS) -> m (FunDef SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
inlineInBody Map Name (FunDef SOACS)
fdmap BodyT SOACS
body

inlineFunction :: MonadFreshNames m =>
                  Pattern
               -> StmAux dec
               -> [(SubExp, Diet)]
               -> (Safety, SrcLoc, [SrcLoc])
               -> FunDef SOACS
               -> m [Stm]
inlineFunction :: Pattern
-> StmAux dec
-> [(SubExp, Diet)]
-> (Safety, SrcLoc, [SrcLoc])
-> FunDef SOACS
-> m [Stm]
inlineFunction Pattern
pat StmAux dec
aux [(SubExp, Diet)]
args (Safety
safety,SrcLoc
loc,[SrcLoc]
locs) FunDef SOACS
fun = do
  Body BodyDec SOACS
_ Stms SOACS
stms Result
res <-
    BodyT SOACS -> m (BodyT SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT SOACS -> m (BodyT SOACS)) -> BodyT SOACS -> m (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> Result -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody
    ([Stm] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm]
param_stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> [Stm] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm]
body_stms)
    (BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun))
  let res_stms :: [Stm]
res_stms =
        Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux dec -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux dec
aux) (Stm -> Stm) -> [Stm] -> [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        (Ident -> SubExp -> Stm) -> [Ident] -> Result -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([VName] -> Ident -> SubExp -> Stm
forall (t :: * -> *) lore.
(Foldable t, Bindable lore) =>
t VName -> Ident -> SubExp -> Stm lore
reshapeIfNecessary (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat))
        (PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat) Result
res
  [Stm] -> m [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm] -> m [Stm]) -> [Stm] -> m [Stm]
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms [Stm] -> [Stm] -> [Stm]
forall a. Semigroup a => a -> a -> a
<> [Stm]
res_stms
  where param_names :: [VName]
param_names =
          (Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName ([Param DeclType] -> [VName]) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> [FParam SOACS]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SOACS
fun

        param_stms :: [Stm]
param_stms =
          (Ident -> SubExp -> Stm) -> [Ident] -> Result -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([VName] -> Ident -> SubExp -> Stm
forall (t :: * -> *) lore.
(Foldable t, Bindable lore) =>
t VName -> Ident -> SubExp -> Stm lore
reshapeIfNecessary [VName]
param_names)
          ((Param DeclType -> Ident) -> [Param DeclType] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent ([Param DeclType] -> [Ident]) -> [Param DeclType] -> [Ident]
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> [FParam SOACS]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SOACS
fun) (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)

        body_stms :: [Stm]
body_stms =
          Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm]) -> Stms SOACS -> [Stm]
forall a b. (a -> b) -> a -> b
$
          Attrs -> Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations (StmAux dec -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux dec
aux) Safety
safety ((SrcLoc -> Bool) -> [SrcLoc] -> [SrcLoc]
forall a. (a -> Bool) -> [a] -> [a]
filter SrcLoc -> Bool
notmempty (SrcLoc
locSrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
:[SrcLoc]
locs)) (Stms SOACS -> Stms SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
          BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun

        reshapeIfNecessary :: t VName -> Ident -> SubExp -> Stm lore
reshapeIfNecessary t VName
dim_names Ident
ident SubExp
se
          | t :: Type
t@Array{} <- Ident -> Type
identType Ident
ident,
            (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> t VName -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t VName
dim_names) (Result -> [VName]
subExpVars (Result -> [VName]) -> Result -> [VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t),
            Var VName
v <- SubExp
se =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
ident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Result -> VName -> Exp lore
forall lore. Result -> VName -> Exp lore
shapeCoerce (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t) VName
v
          | Bool
otherwise =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
ident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp lore) -> BasicOp -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

        notmempty :: SrcLoc -> Bool
notmempty = (Loc -> Loc -> Bool
forall a. Eq a => a -> a -> Bool
/=Loc
forall a. Monoid a => a
mempty) (Loc -> Bool) -> (SrcLoc -> Loc) -> SrcLoc -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf

inlineInStms :: MonadFreshNames m =>
                M.Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
inlineInStms :: Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
inlineInStms Map Name (FunDef SOACS)
fdmap Stms SOACS
stms =
  BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> m (BodyT SOACS) -> m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
inlineInBody Map Name (FunDef SOACS)
fdmap (Stms SOACS -> Result -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
stms [])

inlineInBody :: MonadFreshNames m =>
                M.Map Name (FunDef SOACS) -> Body -> m Body
inlineInBody :: Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
inlineInBody Map Name (FunDef SOACS)
fdmap = BodyT SOACS -> m (BodyT SOACS)
onBody
  where inline :: [Stm] -> m [Stm]
inline (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Apply Name
fname [(SubExp, Diet)]
args [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
what) : [Stm]
rest)
          | Just FunDef SOACS
fd <- Name -> Map Name (FunDef SOACS) -> Maybe (FunDef SOACS)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
fname Map Name (FunDef SOACS)
fdmap,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Attr
"noinline" Attr -> Attrs -> Bool
`inAttrs` FunDef SOACS -> Attrs
forall lore. FunDef lore -> Attrs
funDefAttrs FunDef SOACS
fd,
            Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Attr
"noinline" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
              [Stm] -> [Stm] -> [Stm]
forall a. Semigroup a => a -> a -> a
(<>) ([Stm] -> [Stm] -> [Stm]) -> m [Stm] -> m ([Stm] -> [Stm])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StmAux ()
-> [(SubExp, Diet)]
-> (Safety, SrcLoc, [SrcLoc])
-> FunDef SOACS
-> m [Stm]
forall (m :: * -> *) dec.
MonadFreshNames m =>
Pattern
-> StmAux dec
-> [(SubExp, Diet)]
-> (Safety, SrcLoc, [SrcLoc])
-> FunDef SOACS
-> m [Stm]
inlineFunction Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux [(SubExp, Diet)]
args (Safety, SrcLoc, [SrcLoc])
what FunDef SOACS
fd m ([Stm] -> [Stm]) -> m [Stm] -> m [Stm]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Stm] -> m [Stm]
inline [Stm]
rest

        inline (Stm
stm : [Stm]
rest) =
          (:) (Stm -> [Stm] -> [Stm]) -> m Stm -> m ([Stm] -> [Stm])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stm -> m Stm
onStm Stm
stm m ([Stm] -> [Stm]) -> m [Stm] -> m [Stm]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Stm] -> m [Stm]
inline [Stm]
rest
        inline [] =
          [Stm] -> m [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Stm]
forall a. Monoid a => a
mempty

        onBody :: BodyT SOACS -> m (BodyT SOACS)
onBody (Body BodyDec SOACS
dec Stms SOACS
stms Result
res) =
          BodyDec SOACS -> Stms SOACS -> Result -> BodyT SOACS
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec SOACS
dec (Stms SOACS -> Result -> BodyT SOACS)
-> ([Stm] -> Stms SOACS) -> [Stm] -> Result -> BodyT SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm] -> Result -> BodyT SOACS)
-> m [Stm] -> m (Result -> BodyT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stm] -> m [Stm]
inline (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms) m (Result -> BodyT SOACS) -> m Result -> m (BodyT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

        onStm :: Stm -> m Stm
onStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux ExpT SOACS
e) =
          Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux (ExpT SOACS -> Stm) -> m (ExpT SOACS) -> m Stm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper SOACS SOACS m -> ExpT SOACS -> m (ExpT SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS SOACS m
inliner ExpT SOACS
e

        inliner :: Mapper SOACS SOACS m
inliner =
          Mapper SOACS SOACS m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope SOACS -> BodyT SOACS -> m (BodyT SOACS)
mapOnBody = (BodyT SOACS -> m (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> m (BodyT SOACS)
forall a b. a -> b -> a
const BodyT SOACS -> m (BodyT SOACS)
onBody
                         , mapOnOp :: Op SOACS -> m (Op SOACS)
mapOnOp = Op SOACS -> m (Op SOACS)
SOAC SOACS -> m (SOAC SOACS)
onSOAC
                         }

        onSOAC :: SOAC SOACS -> m (SOAC SOACS)
onSOAC =
          SOACMapper SOACS SOACS m -> SOAC SOACS -> m (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper
          { mapOnSOACLambda :: Lambda SOACS -> m (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> m (Lambda SOACS)
onLambda }

        onLambda :: Lambda SOACS -> m (Lambda SOACS)
onLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
ret) =
          [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
params (BodyT SOACS -> [Type] -> Lambda SOACS)
-> m (BodyT SOACS) -> m ([Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT SOACS -> m (BodyT SOACS)
onBody BodyT SOACS
body m ([Type] -> Lambda SOACS) -> m [Type] -> m (Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

-- Propagate source locations and attributes to the inlined
-- statements.  Attributes are propagated only when applicable (this
-- probably means that every supported attribute needs to be handled
-- specially here).
addLocations :: Attrs -> Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations :: Attrs -> Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations Attrs
attrs Safety
caller_safety [SrcLoc]
more_locs = (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm -> Stm
onStm
  where onStm :: Stm -> Stm
onStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Apply Name
fname [(SubExp, Diet)]
args [RetType SOACS]
t (Safety
safety, SrcLoc
loc,[SrcLoc]
locs))) =
          Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux' (ExpT SOACS -> Stm) -> ExpT SOACS -> Stm
forall a b. (a -> b) -> a -> b
$
          Name
-> [(SubExp, Diet)]
-> [RetType SOACS]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT SOACS
forall lore.
Name
-> [(SubExp, Diet)]
-> [RetType lore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT lore
Apply Name
fname [(SubExp, Diet)]
args [RetType SOACS]
t (Safety -> Safety -> Safety
forall a. Ord a => a -> a -> a
min Safety
caller_safety Safety
safety, SrcLoc
loc,[SrcLoc]
locs[SrcLoc] -> [SrcLoc] -> [SrcLoc]
forall a. [a] -> [a] -> [a]
++[SrcLoc]
more_locs)
          where aux' :: StmAux (ExpDec SOACS)
aux' = StmAux (ExpDec SOACS)
aux { stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec SOACS) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux }
        onStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (BasicOp (Assert SubExp
cond ErrorMsg SubExp
desc (SrcLoc
loc,[SrcLoc]
locs)))) =
          Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat (Attrs -> StmAux () -> StmAux ()
forall dec. Attrs -> StmAux dec -> StmAux dec
withAttrs (Attrs -> Attrs
attrsForAssert Attrs
attrs) StmAux ()
StmAux (ExpDec SOACS)
aux) (ExpT SOACS -> Stm) -> ExpT SOACS -> Stm
forall a b. (a -> b) -> a -> b
$
          case Safety
caller_safety of
            Safety
Safe -> BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp
Assert SubExp
cond ErrorMsg SubExp
desc (SrcLoc
loc,[SrcLoc]
locs[SrcLoc] -> [SrcLoc] -> [SrcLoc]
forall a. [a] -> [a] -> [a]
++[SrcLoc]
more_locs)
            Safety
Unsafe -> BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
Checked
        onStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac)) =
          Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat (Attrs -> StmAux () -> StmAux ()
forall dec. Attrs -> StmAux dec -> StmAux dec
withAttrs Attrs
attrs' StmAux ()
StmAux (ExpDec SOACS)
aux) (ExpT SOACS -> Stm) -> ExpT SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ Identity (SOAC SOACS) -> SOAC SOACS
forall a. Identity a -> a
runIdentity (Identity (SOAC SOACS) -> SOAC SOACS)
-> Identity (SOAC SOACS) -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ SOACMapper SOACS SOACS Identity
-> SOAC SOACS -> Identity (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM
          SOACMapper Any Any Identity
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper { mapOnSOACLambda :: Lambda SOACS -> Identity (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> Identity (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda SOACS -> Identity (Lambda SOACS))
-> (Lambda SOACS -> Lambda SOACS)
-> Lambda SOACS
-> Identity (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda SOACS
onLambda
                             } Op SOACS
SOAC SOACS
soac
          where attrs' :: Attrs
attrs' = Attrs
attrs Attrs -> Attrs -> Attrs
`withoutAttrs` Attrs
for_assert
                for_assert :: Attrs
for_assert = Attrs -> Attrs
attrsForAssert Attrs
attrs
                onLambda :: Lambda SOACS -> Lambda SOACS
onLambda Lambda SOACS
lam =
                  Lambda SOACS
lam { lambdaBody :: BodyT SOACS
lambdaBody = Attrs -> BodyT SOACS -> BodyT SOACS
onBody Attrs
for_assert (BodyT SOACS -> BodyT SOACS) -> BodyT SOACS -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam }
        onStm (Let Pattern
pat StmAux (ExpDec SOACS)
aux ExpT SOACS
e) =
          Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpDec SOACS)
aux (ExpT SOACS -> Stm) -> ExpT SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ ExpT SOACS -> ExpT SOACS
onExp ExpT SOACS
e

        onExp :: ExpT SOACS -> ExpT SOACS
onExp = Mapper SOACS SOACS Identity -> ExpT SOACS -> ExpT SOACS
forall flore tlore.
Mapper flore tlore Identity -> Exp flore -> Exp tlore
mapExp Mapper SOACS SOACS Identity
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
                { mapOnBody :: Scope SOACS -> BodyT SOACS -> Identity (BodyT SOACS)
mapOnBody = (BodyT SOACS -> Identity (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> Identity (BodyT SOACS)
forall a b. a -> b -> a
const ((BodyT SOACS -> Identity (BodyT SOACS))
 -> Scope SOACS -> BodyT SOACS -> Identity (BodyT SOACS))
-> (BodyT SOACS -> Identity (BodyT SOACS))
-> Scope SOACS
-> BodyT SOACS
-> Identity (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Identity (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT SOACS -> Identity (BodyT SOACS))
-> (BodyT SOACS -> BodyT SOACS)
-> BodyT SOACS
-> Identity (BodyT SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Attrs -> BodyT SOACS -> BodyT SOACS
onBody Attrs
attrs }

        withAttrs :: Attrs -> StmAux dec -> StmAux dec
withAttrs Attrs
attrs' StmAux dec
aux = StmAux dec
aux { stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
attrs' Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> StmAux dec -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux dec
aux }

        onBody :: Attrs -> BodyT SOACS -> BodyT SOACS
onBody Attrs
attrs' BodyT SOACS
body =
          BodyT SOACS
body { bodyStms :: Stms SOACS
bodyStms = Attrs -> Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations Attrs
attrs' Safety
caller_safety [SrcLoc]
more_locs (Stms SOACS -> Stms SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
                            BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body }

-- | Inline all functions and remove the resulting dead functions.
inlineFunctions :: Pass SOACS SOACS
inlineFunctions :: Pass SOACS SOACS
inlineFunctions =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"Inline functions"
       , passDescription :: String
passDescription = String
"Inline and remove resulting dead functions."
       , passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = SimpleOps SOACS -> Prog SOACS -> PassM (Prog SOACS)
forall lore.
SimplifiableLore lore =>
SimpleOps lore -> Prog lore -> PassM (Prog lore)
copyPropagateInProg SimpleOps SOACS
simpleSOACS (Prog SOACS -> PassM (Prog SOACS))
-> (Prog SOACS -> PassM (Prog SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Prog SOACS -> PassM (Prog SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Prog SOACS -> m (Prog SOACS)
aggInlineFunctions
       }

-- | @removeDeadFunctions prog@ removes the functions that are unreachable from
-- the main function from the program.
removeDeadFunctions :: Pass SOACS SOACS
removeDeadFunctions :: Pass SOACS SOACS
removeDeadFunctions =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"Remove dead functions"
       , passDescription :: String
passDescription = String
"Remove the functions that are unreachable from entry points"
       , passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = Prog SOACS -> PassM (Prog SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog SOACS -> PassM (Prog SOACS))
-> (Prog SOACS -> Prog SOACS) -> Prog SOACS -> PassM (Prog SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog SOACS -> Prog SOACS
pass
       }
  where pass :: Prog SOACS -> Prog SOACS
pass Prog SOACS
prog =
          let cg :: CallGraph
cg        = Prog SOACS -> CallGraph
buildCallGraph Prog SOACS
prog
              live_funs :: [FunDef SOACS]
live_funs = (FunDef SOACS -> Bool) -> [FunDef SOACS] -> [FunDef SOACS]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> CallGraph -> Bool
`isFunInCallGraph` CallGraph
cg) (Name -> Bool) -> (FunDef SOACS -> Name) -> FunDef SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName) ([FunDef SOACS] -> [FunDef SOACS])
-> [FunDef SOACS] -> [FunDef SOACS]
forall a b. (a -> b) -> a -> b
$
                          Prog SOACS -> [FunDef SOACS]
forall lore. Prog lore -> [FunDef lore]
progFuns Prog SOACS
prog
          in Prog SOACS
prog { progFuns :: [FunDef SOACS]
progFuns = [FunDef SOACS]
live_funs }