{-# LANGUAGE FlexibleContexts #-}
-- | This module implements a compiler pass for inlining functions,
-- then removing those that have become dead.
module Futhark.Optimise.InliningDeadFun
  ( inlineFunctions
  , removeDeadFunctions
  )
  where

import Control.Monad.Identity
import Data.List (partition)
import Data.Loc
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Futhark.Representation.SOACS
import Futhark.Representation.SOACS.Simplify
  (simpleSOACS, simplifyFun, simplifyConsts)
import Futhark.Optimise.CSE
import Futhark.Optimise.Simplify.Lore (addScopeWisdom)
import Futhark.Transform.CopyPropagate
  (copyPropagateInProg, copyPropagateInFun)
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Transform.Rename
import Futhark.Analysis.CallGraph
import Futhark.Binder
import Futhark.Pass

aggInlineFunctions :: MonadFreshNames m =>
                      CallGraph
                   -> (Stms SOACS, [FunDef SOACS])
                   -> m (Stms SOACS, [FunDef SOACS])
aggInlineFunctions :: CallGraph
-> (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
aggInlineFunctions CallGraph
cg =
  ((Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS]))
-> m (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([FunDef SOACS] -> [FunDef SOACS])
-> (Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((FunDef SOACS -> Bool) -> [FunDef SOACS] -> [FunDef SOACS]
forall a. (a -> Bool) -> [a] -> [a]
filter FunDef SOACS -> Bool
forall lore. FunDef lore -> Bool
keep)) (m (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS]))
-> ((Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS]))
-> (Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
forall (m :: * -> *).
MonadFreshNames m =>
Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
recurse Int
0 ((SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
 -> m (Stms SOACS, [FunDef SOACS]))
-> ((Stms SOACS, [FunDef SOACS])
    -> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS]))
-> (Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stms SOACS, [FunDef SOACS])
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
forall lore b c.
(Attributes lore, CanBeWise (Op lore), Scoped lore b) =>
(b, c) -> (SymbolTable (Wise lore), b, c)
addVtable
  where fdmap :: [FunDef lore] -> Map Name (FunDef lore)
fdmap [FunDef lore]
fds =
          [(Name, FunDef lore)] -> Map Name (FunDef lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, FunDef lore)] -> Map Name (FunDef lore))
-> [(Name, FunDef lore)] -> Map Name (FunDef lore)
forall a b. (a -> b) -> a -> b
$ [Name] -> [FunDef lore] -> [(Name, FunDef lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((FunDef lore -> Name) -> [FunDef lore] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName [FunDef lore]
fds) [FunDef lore]
fds

        addVtable :: (b, c) -> (SymbolTable (Wise lore), b, c)
addVtable (b
consts, c
funs) =
          (Scope (Wise lore) -> SymbolTable (Wise lore)
forall lore. Attributes lore => Scope lore -> SymbolTable lore
ST.fromScope (Scope lore -> Scope (Wise lore)
forall lore. Scope lore -> Scope (Wise lore)
addScopeWisdom (b -> Scope lore
forall lore a. Scoped lore a => a -> Scope lore
scopeOf b
consts)),
           b
consts, c
funs)

        noCallsTo :: Set Name -> FunDef lore -> Bool
noCallsTo Set Name
which FunDef lore
fundec =
          Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Name -> Bool) -> Set Name -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set Name
which) (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> CallGraph -> Set Name
allCalledBy (FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName FunDef lore
fundec) CallGraph
cg

        -- The inverse rate at which we perform full simplification
        -- after inlining.  For the other steps we just do copy
        -- propagation.  The rate here has been determined
        -- heuristically and is probably not optimal for any given
        -- program.
        simplifyRate :: Int
        simplifyRate :: Int
simplifyRate = Int
4

        -- We apply simplification after every round of inlining,
        -- because it is more efficient to shrink the program as soon
        -- as possible, rather than wait until it has balooned after
        -- full inlining.
        recurse :: Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
recurse Int
i (SymbolTable (Wise SOACS)
vtable, Stms SOACS
consts, [FunDef SOACS]
funs) = do
          let remaining :: Set Name
remaining = [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> Name) -> [FunDef SOACS] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName [FunDef SOACS]
funs
              ([FunDef SOACS]
to_be_inlined, [FunDef SOACS]
maybe_inline_in) =
                (FunDef SOACS -> Bool)
-> [FunDef SOACS] -> ([FunDef SOACS], [FunDef SOACS])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Set Name -> FunDef SOACS -> Bool
forall lore. Set Name -> FunDef lore -> Bool
noCallsTo Set Name
remaining) [FunDef SOACS]
funs
              ([FunDef SOACS]
not_to_inline_in, [FunDef SOACS]
to_inline_in) =
                (FunDef SOACS -> Bool)
-> [FunDef SOACS] -> ([FunDef SOACS], [FunDef SOACS])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Set Name -> FunDef SOACS -> Bool
forall lore. Set Name -> FunDef lore -> Bool
noCallsTo
                           ([Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> Name) -> [FunDef SOACS] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName [FunDef SOACS]
to_be_inlined))
                [FunDef SOACS]
maybe_inline_in
              ([FunDef SOACS]
not_actually_inlined, [FunDef SOACS]
to_be_inlined') =
                (FunDef SOACS -> Bool)
-> [FunDef SOACS] -> ([FunDef SOACS], [FunDef SOACS])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition FunDef SOACS -> Bool
forall lore. FunDef lore -> Bool
keep [FunDef SOACS]
to_be_inlined
          if [FunDef SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FunDef SOACS]
to_be_inlined
            then (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS
consts, [FunDef SOACS]
funs)
            else do

            (SymbolTable (Wise SOACS)
vtable', Stms SOACS
consts') <-
              if (FunDef SOACS -> Bool) -> [FunDef SOACS] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Name -> CallGraph -> Bool
`calledByConsts` CallGraph
cg) (Name -> Bool) -> (FunDef SOACS -> Name) -> FunDef SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName) [FunDef SOACS]
to_be_inlined'
              then Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyConsts (Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS))
-> m (Stms SOACS) -> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                   Bool -> Stms SOACS -> Stms SOACS
forall lore.
(Attributes lore, CanBeAliased (Op lore),
 CSEInOp (OpWithAliases (Op lore))) =>
Bool -> Stms lore -> Stms lore
performCSEOnStms Bool
True (Stms SOACS -> Stms SOACS) -> m (Stms SOACS) -> m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                   Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
inlineInStms ([FunDef SOACS] -> Map Name (FunDef SOACS)
forall lore. [FunDef lore] -> Map Name (FunDef lore)
fdmap [FunDef SOACS]
to_be_inlined') Stms SOACS
consts
              else (SymbolTable (Wise SOACS), Stms SOACS)
-> m (SymbolTable (Wise SOACS), Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SymbolTable (Wise SOACS)
vtable, Stms SOACS
consts)

            let simplifyFun' :: FunDef SOACS -> m (FunDef SOACS)
simplifyFun' FunDef SOACS
fd
                  | Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`rem` Int
simplifyRate Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
                      SimpleOps SOACS
-> SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> SymbolTable (Wise lore) -> FunDef lore -> m (FunDef lore)
copyPropagateInFun SimpleOps SOACS
simpleSOACS SymbolTable (Wise SOACS)
vtable' (FunDef SOACS -> m (FunDef SOACS))
-> m (FunDef SOACS) -> m (FunDef SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
                      Bool -> FunDef SOACS -> FunDef SOACS
forall lore.
(Attributes lore, CanBeAliased (Op lore),
 CSEInOp (OpWithAliases (Op lore))) =>
Bool -> FunDef lore -> FunDef lore
performCSEOnFunDef Bool
True (FunDef SOACS -> FunDef SOACS)
-> m (FunDef SOACS) -> m (FunDef SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                      SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
simplifyFun SymbolTable (Wise SOACS)
vtable' FunDef SOACS
fd
                  | Bool
otherwise =
                      SimpleOps SOACS
-> SymbolTable (Wise SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, SimplifiableLore lore) =>
SimpleOps lore
-> SymbolTable (Wise lore) -> FunDef lore -> m (FunDef lore)
copyPropagateInFun SimpleOps SOACS
simpleSOACS SymbolTable (Wise SOACS)
vtable' FunDef SOACS
fd

            let onFun :: FunDef SOACS -> m (FunDef SOACS)
onFun = FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
FunDef SOACS -> m (FunDef SOACS)
simplifyFun' (FunDef SOACS -> m (FunDef SOACS))
-> (FunDef SOACS -> m (FunDef SOACS))
-> FunDef SOACS
-> m (FunDef SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=<
                        Map Name (FunDef SOACS) -> FunDef SOACS -> m (FunDef SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> FunDef SOACS -> m (FunDef SOACS)
inlineInFunDef ([FunDef SOACS] -> Map Name (FunDef SOACS)
forall lore. [FunDef lore] -> Map Name (FunDef lore)
fdmap [FunDef SOACS]
to_be_inlined')
            [FunDef SOACS]
to_inline_in' <- (FunDef SOACS -> m (FunDef SOACS))
-> [FunDef SOACS] -> m [FunDef SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef SOACS -> m (FunDef SOACS)
onFun [FunDef SOACS]
to_inline_in
            ([FunDef SOACS] -> [FunDef SOACS])
-> (Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([FunDef SOACS]
not_actually_inlined[FunDef SOACS] -> [FunDef SOACS] -> [FunDef SOACS]
forall a. Semigroup a => a -> a -> a
<>) ((Stms SOACS, [FunDef SOACS]) -> (Stms SOACS, [FunDef SOACS]))
-> m (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
              Int
-> (SymbolTable (Wise SOACS), Stms SOACS, [FunDef SOACS])
-> m (Stms SOACS, [FunDef SOACS])
recurse (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
              (SymbolTable (Wise SOACS)
vtable', Stms SOACS
consts', [FunDef SOACS]
not_to_inline_in [FunDef SOACS] -> [FunDef SOACS] -> [FunDef SOACS]
forall a. Semigroup a => a -> a -> a
<> [FunDef SOACS]
to_inline_in')

        keep :: FunDef lore -> Bool
keep FunDef lore
fd =
          Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust (FunDef lore -> Maybe EntryPoint
forall lore. FunDef lore -> Maybe EntryPoint
funDefEntryPoint FunDef lore
fd) Bool -> Bool -> Bool
|| FunDef lore -> Bool
forall lore. FunDef lore -> Bool
callsRecursive FunDef lore
fd

        callsRecursive :: FunDef lore -> Bool
callsRecursive FunDef lore
fd = (Name -> Bool) -> Set Name -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Name -> Bool
recursive (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> CallGraph -> Set Name
allCalledBy (FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName FunDef lore
fd) CallGraph
cg
        recursive :: Name -> Bool
recursive Name
fname = Name -> Name -> CallGraph -> Bool
calls Name
fname Name
fname CallGraph
cg

-- | @inlineInFunDef constf fdmap caller@ inlines in @calleer@ the
-- functions in @fdmap@ that are called as @constf@. At this point the
-- preconditions are that if @fdmap@ is not empty, and, more
-- importantly, the functions in @fdmap@ do not call any other
-- functions.
inlineInFunDef :: MonadFreshNames m =>
                  M.Map Name (FunDef SOACS) -> FunDef SOACS
               -> m (FunDef SOACS)
inlineInFunDef :: Map Name (FunDef SOACS) -> FunDef SOACS -> m (FunDef SOACS)
inlineInFunDef Map Name (FunDef SOACS)
fdmap (FunDef Maybe EntryPoint
entry Name
name [RetType SOACS]
rtp [FParam SOACS]
args BodyT SOACS
body) =
  Maybe EntryPoint
-> Name
-> [RetType SOACS]
-> [FParam SOACS]
-> BodyT SOACS
-> FunDef SOACS
forall lore.
Maybe EntryPoint
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Name
name [RetType SOACS]
rtp [FParam SOACS]
args (BodyT SOACS -> FunDef SOACS)
-> m (BodyT SOACS) -> m (FunDef SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
inlineInBody Map Name (FunDef SOACS)
fdmap BodyT SOACS
body

inlineFunction :: MonadFreshNames m =>
                  Pattern
               -> StmAux attr
               -> [(SubExp, Diet)]
               -> (Safety, SrcLoc, [SrcLoc])
               -> FunDef SOACS
               -> m [Stm]
inlineFunction :: Pattern
-> StmAux attr
-> [(SubExp, Diet)]
-> (Safety, SrcLoc, [SrcLoc])
-> FunDef SOACS
-> m [Stm]
inlineFunction Pattern
pat StmAux attr
aux [(SubExp, Diet)]
args (Safety
safety,SrcLoc
loc,[SrcLoc]
locs) FunDef SOACS
fun = do
  Body BodyAttr SOACS
_ Stms SOACS
stms Result
res <-
    BodyT SOACS -> m (BodyT SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (BodyT SOACS -> m (BodyT SOACS)) -> BodyT SOACS -> m (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> Result -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody
    ([Stm] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm]
param_stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> [Stm] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm]
body_stms)
    (BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun))
  let res_stms :: [Stm]
res_stms =
        Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux attr -> Certificates
forall attr. StmAux attr -> Certificates
stmAuxCerts StmAux attr
aux) (Stm -> Stm) -> [Stm] -> [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
        (Ident -> SubExp -> Stm) -> [Ident] -> Result -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([VName] -> Ident -> SubExp -> Stm
forall (t :: * -> *) lore.
(Foldable t, Bindable lore) =>
t VName -> Ident -> SubExp -> Stm lore
reshapeIfNecessary (PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern
pat))
        (PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternIdents PatternT Type
Pattern
pat) Result
res
  [Stm] -> m [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm] -> m [Stm]) -> [Stm] -> m [Stm]
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms [Stm] -> [Stm] -> [Stm]
forall a. Semigroup a => a -> a -> a
<> [Stm]
res_stms
  where param_names :: [VName]
param_names =
          (Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall attr. Param attr -> VName
paramName ([Param DeclType] -> [VName]) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> [FParam SOACS]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SOACS
fun

        param_stms :: [Stm]
param_stms =
          (Ident -> SubExp -> Stm) -> [Ident] -> Result -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([VName] -> Ident -> SubExp -> Stm
forall (t :: * -> *) lore.
(Foldable t, Bindable lore) =>
t VName -> Ident -> SubExp -> Stm lore
reshapeIfNecessary [VName]
param_names)
          ((Param DeclType -> Ident) -> [Param DeclType] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> Ident
forall attr. Typed attr => Param attr -> Ident
paramIdent ([Param DeclType] -> [Ident]) -> [Param DeclType] -> [Ident]
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> [FParam SOACS]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SOACS
fun) (((SubExp, Diet) -> SubExp) -> [(SubExp, Diet)] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)

        body_stms :: [Stm]
body_stms =
          Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm]) -> Stms SOACS -> [Stm]
forall a b. (a -> b) -> a -> b
$
          Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations Safety
safety ((SrcLoc -> Bool) -> [SrcLoc] -> [SrcLoc]
forall a. (a -> Bool) -> [a] -> [a]
filter SrcLoc -> Bool
notNoLoc (SrcLoc
locSrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
:[SrcLoc]
locs)) (Stms SOACS -> Stms SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
          BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun

        reshapeIfNecessary :: t VName -> Ident -> SubExp -> Stm lore
reshapeIfNecessary t VName
dim_names Ident
ident SubExp
se
          | t :: Type
t@Array{} <- Ident -> Type
identType Ident
ident,
            (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> t VName -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t VName
dim_names) (Result -> [VName]
subExpVars (Result -> [VName]) -> Result -> [VName]
forall a b. (a -> b) -> a -> b
$ Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t),
            Var VName
v <- SubExp
se =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
ident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Result -> VName -> Exp lore
forall lore. Result -> VName -> Exp lore
shapeCoerce (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t) VName
v
          | Bool
otherwise =
              [Ident] -> [Ident] -> Exp lore -> Stm lore
forall lore.
Bindable lore =>
[Ident] -> [Ident] -> Exp lore -> Stm lore
mkLet [] [Ident
ident] (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ BasicOp lore -> Exp lore
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp lore -> Exp lore) -> BasicOp lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp lore
forall lore. SubExp -> BasicOp lore
SubExp SubExp
se

        notNoLoc :: SrcLoc -> Bool
notNoLoc = (Loc -> Loc -> Bool
forall a. Eq a => a -> a -> Bool
/=Loc
NoLoc) (Loc -> Bool) -> (SrcLoc -> Loc) -> SrcLoc -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Loc
forall a. Located a => a -> Loc
locOf

inlineInStms :: MonadFreshNames m =>
                M.Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
inlineInStms :: Map Name (FunDef SOACS) -> Stms SOACS -> m (Stms SOACS)
inlineInStms Map Name (FunDef SOACS)
fdmap Stms SOACS
stms =
  BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> m (BodyT SOACS) -> m (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
inlineInBody Map Name (FunDef SOACS)
fdmap (Stms SOACS -> Result -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms SOACS
stms [])

inlineInBody :: MonadFreshNames m =>
                M.Map Name (FunDef SOACS) -> Body -> m Body
inlineInBody :: Map Name (FunDef SOACS) -> BodyT SOACS -> m (BodyT SOACS)
inlineInBody Map Name (FunDef SOACS)
fdmap = BodyT SOACS -> m (BodyT SOACS)
onBody
  where inline :: [Stm] -> m [Stm]
inline (Let Pattern
pat StmAux (ExpAttr SOACS)
aux (Apply Name
fname [(SubExp, Diet)]
args [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
what) : [Stm]
rest)
          | Just FunDef SOACS
fd <- Name -> Map Name (FunDef SOACS) -> Maybe (FunDef SOACS)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
fname Map Name (FunDef SOACS)
fdmap =
              [Stm] -> [Stm] -> [Stm]
forall a. Semigroup a => a -> a -> a
(<>) ([Stm] -> [Stm] -> [Stm]) -> m [Stm] -> m ([Stm] -> [Stm])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern
-> StmAux ()
-> [(SubExp, Diet)]
-> (Safety, SrcLoc, [SrcLoc])
-> FunDef SOACS
-> m [Stm]
forall (m :: * -> *) attr.
MonadFreshNames m =>
Pattern
-> StmAux attr
-> [(SubExp, Diet)]
-> (Safety, SrcLoc, [SrcLoc])
-> FunDef SOACS
-> m [Stm]
inlineFunction Pattern
pat StmAux ()
StmAux (ExpAttr SOACS)
aux [(SubExp, Diet)]
args (Safety, SrcLoc, [SrcLoc])
what FunDef SOACS
fd m ([Stm] -> [Stm]) -> m [Stm] -> m [Stm]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Stm] -> m [Stm]
inline [Stm]
rest

        inline (Stm
stm : [Stm]
rest) =
          (:) (Stm -> [Stm] -> [Stm]) -> m Stm -> m ([Stm] -> [Stm])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stm -> m Stm
onStm Stm
stm m ([Stm] -> [Stm]) -> m [Stm] -> m [Stm]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Stm] -> m [Stm]
inline [Stm]
rest
        inline [] =
          [Stm] -> m [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Stm]
forall a. Monoid a => a
mempty

        onBody :: BodyT SOACS -> m (BodyT SOACS)
onBody (Body BodyAttr SOACS
attr Stms SOACS
stms Result
res) =
          BodyAttr SOACS -> Stms SOACS -> Result -> BodyT SOACS
forall lore. BodyAttr lore -> Stms lore -> Result -> BodyT lore
Body BodyAttr SOACS
attr (Stms SOACS -> Result -> BodyT SOACS)
-> ([Stm] -> Stms SOACS) -> [Stm] -> Result -> BodyT SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm] -> Result -> BodyT SOACS)
-> m [Stm] -> m (Result -> BodyT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Stm] -> m [Stm]
inline (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms) m (Result -> BodyT SOACS) -> m Result -> m (BodyT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

        onStm :: Stm -> m Stm
onStm (Let Pattern
pat StmAux (ExpAttr SOACS)
aux ExpT SOACS
e) =
          Pattern -> StmAux (ExpAttr SOACS) -> ExpT SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux (ExpAttr SOACS)
aux (ExpT SOACS -> Stm) -> m (ExpT SOACS) -> m Stm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper SOACS SOACS m -> ExpT SOACS -> m (ExpT SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS SOACS m
inliner ExpT SOACS
e

        inliner :: Mapper SOACS SOACS m
inliner =
          Mapper SOACS SOACS m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope SOACS -> BodyT SOACS -> m (BodyT SOACS)
mapOnBody = (BodyT SOACS -> m (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> m (BodyT SOACS)
forall a b. a -> b -> a
const BodyT SOACS -> m (BodyT SOACS)
onBody
                         , mapOnOp :: Op SOACS -> m (Op SOACS)
mapOnOp = Op SOACS -> m (Op SOACS)
SOAC SOACS -> m (SOAC SOACS)
onSOAC
                         }

        onSOAC :: SOAC SOACS -> m (SOAC SOACS)
onSOAC =
          SOACMapper SOACS SOACS m -> SOAC SOACS -> m (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper
          { mapOnSOACLambda :: Lambda SOACS -> m (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> m (Lambda SOACS)
onLambda }

        onLambda :: Lambda SOACS -> m (Lambda SOACS)
onLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
ret) =
          [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
params (BodyT SOACS -> [Type] -> Lambda SOACS)
-> m (BodyT SOACS) -> m ([Type] -> Lambda SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT SOACS -> m (BodyT SOACS)
onBody BodyT SOACS
body m ([Type] -> Lambda SOACS) -> m [Type] -> m (Lambda SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret

addLocations :: Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations :: Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations Safety
caller_safety [SrcLoc]
more_locs = (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Stm -> Stm
onStm
  where onStm :: Stm -> Stm
onStm Stm
stm = Stm
stm { stmExp :: ExpT SOACS
stmExp = ExpT SOACS -> ExpT SOACS
onExp (ExpT SOACS -> ExpT SOACS) -> ExpT SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ Stm -> ExpT SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm
stm }
        onExp :: ExpT SOACS -> ExpT SOACS
onExp (Apply Name
fname [(SubExp, Diet)]
args [RetType SOACS]
t (Safety
safety, SrcLoc
loc,[SrcLoc]
locs)) =
          Name
-> [(SubExp, Diet)]
-> [RetType SOACS]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT SOACS
forall lore.
Name
-> [(SubExp, Diet)]
-> [RetType lore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT lore
Apply Name
fname [(SubExp, Diet)]
args [RetType SOACS]
t (Safety -> Safety -> Safety
forall a. Ord a => a -> a -> a
min Safety
caller_safety Safety
safety, SrcLoc
loc,[SrcLoc]
locs[SrcLoc] -> [SrcLoc] -> [SrcLoc]
forall a. [a] -> [a] -> [a]
++[SrcLoc]
more_locs)
        onExp (BasicOp (Assert SubExp
cond ErrorMsg SubExp
desc (SrcLoc
loc,[SrcLoc]
locs))) =
          case Safety
caller_safety of
            Safety
Safe -> BasicOp SOACS -> ExpT SOACS
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp SOACS -> ExpT SOACS) -> BasicOp SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp SOACS
forall lore.
SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp lore
Assert SubExp
cond ErrorMsg SubExp
desc (SrcLoc
loc,[SrcLoc]
locs[SrcLoc] -> [SrcLoc] -> [SrcLoc]
forall a. [a] -> [a] -> [a]
++[SrcLoc]
more_locs)
            Safety
Unsafe -> BasicOp SOACS -> ExpT SOACS
forall lore. BasicOp lore -> ExpT lore
BasicOp (BasicOp SOACS -> ExpT SOACS) -> BasicOp SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp SOACS
forall lore. SubExp -> BasicOp lore
SubExp (SubExp -> BasicOp SOACS) -> SubExp -> BasicOp SOACS
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
Checked
        onExp (Op Op SOACS
soac) = Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ Identity (SOAC SOACS) -> SOAC SOACS
forall a. Identity a -> a
runIdentity (Identity (SOAC SOACS) -> SOAC SOACS)
-> Identity (SOAC SOACS) -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ SOACMapper SOACS SOACS Identity
-> SOAC SOACS -> Identity (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM
                          SOACMapper Any Any Identity
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper { mapOnSOACLambda :: Lambda SOACS -> Identity (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> Identity (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda SOACS -> Identity (Lambda SOACS))
-> (Lambda SOACS -> Lambda SOACS)
-> Lambda SOACS
-> Identity (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda SOACS
onLambda
                                             } Op SOACS
SOAC SOACS
soac
        onExp ExpT SOACS
e = Mapper SOACS SOACS Identity -> ExpT SOACS -> ExpT SOACS
forall flore tlore.
Mapper flore tlore Identity -> Exp flore -> Exp tlore
mapExp Mapper SOACS SOACS Identity
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope SOACS -> BodyT SOACS -> Identity (BodyT SOACS)
mapOnBody = (BodyT SOACS -> Identity (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> Identity (BodyT SOACS)
forall a b. a -> b -> a
const ((BodyT SOACS -> Identity (BodyT SOACS))
 -> Scope SOACS -> BodyT SOACS -> Identity (BodyT SOACS))
-> (BodyT SOACS -> Identity (BodyT SOACS))
-> Scope SOACS
-> BodyT SOACS
-> Identity (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Identity (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT SOACS -> Identity (BodyT SOACS))
-> (BodyT SOACS -> BodyT SOACS)
-> BodyT SOACS
-> Identity (BodyT SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT SOACS -> BodyT SOACS
onBody
                                        } ExpT SOACS
e
        onBody :: BodyT SOACS -> BodyT SOACS
onBody BodyT SOACS
body =
          BodyT SOACS
body { bodyStms :: Stms SOACS
bodyStms = Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations Safety
caller_safety [SrcLoc]
more_locs (Stms SOACS -> Stms SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body }
        onLambda :: Lambda -> Lambda
        onLambda :: Lambda SOACS -> Lambda SOACS
onLambda Lambda SOACS
lam = Lambda SOACS
lam { lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS -> BodyT SOACS
onBody (BodyT SOACS -> BodyT SOACS) -> BodyT SOACS -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam }

-- | Inline 'NotConstFun' functions and remove the resulting dead functions.
inlineFunctions :: Pass SOACS SOACS
inlineFunctions :: Pass SOACS SOACS
inlineFunctions =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"Inline functions"
       , passDescription :: String
passDescription = String
"Inline and remove resulting dead functions."
       , passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = Prog SOACS -> PassM (Prog SOACS)
pass
       }
  where pass :: Prog SOACS -> PassM (Prog SOACS)
pass prog :: Prog SOACS
prog@(Prog Stms SOACS
consts [FunDef SOACS]
funs) = do
          let cg :: CallGraph
cg = Prog SOACS -> CallGraph
buildCallGraph Prog SOACS
prog
          (Stms SOACS
consts', [FunDef SOACS]
funs') <- CallGraph
-> (Stms SOACS, [FunDef SOACS])
-> PassM (Stms SOACS, [FunDef SOACS])
forall (m :: * -> *).
MonadFreshNames m =>
CallGraph
-> (Stms SOACS, [FunDef SOACS]) -> m (Stms SOACS, [FunDef SOACS])
aggInlineFunctions CallGraph
cg (Stms SOACS
consts, [FunDef SOACS]
funs)
          SimpleOps SOACS -> Prog SOACS -> PassM (Prog SOACS)
forall lore.
SimplifiableLore lore =>
SimpleOps lore -> Prog lore -> PassM (Prog lore)
copyPropagateInProg SimpleOps SOACS
simpleSOACS (Prog SOACS -> PassM (Prog SOACS))
-> Prog SOACS -> PassM (Prog SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [FunDef SOACS] -> Prog SOACS
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms SOACS
consts' [FunDef SOACS]
funs'

-- | @removeDeadFunctions prog@ removes the functions that are unreachable from
-- the main function from the program.
removeDeadFunctions :: Pass SOACS SOACS
removeDeadFunctions :: Pass SOACS SOACS
removeDeadFunctions =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass { passName :: String
passName = String
"Remove dead functions"
       , passDescription :: String
passDescription = String
"Remove the functions that are unreachable from entry points"
       , passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = Prog SOACS -> PassM (Prog SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog SOACS -> PassM (Prog SOACS))
-> (Prog SOACS -> Prog SOACS) -> Prog SOACS -> PassM (Prog SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prog SOACS -> Prog SOACS
pass
       }
  where pass :: Prog SOACS -> Prog SOACS
pass Prog SOACS
prog =
          let cg :: CallGraph
cg        = Prog SOACS -> CallGraph
buildCallGraph Prog SOACS
prog
              live_funs :: [FunDef SOACS]
live_funs = (FunDef SOACS -> Bool) -> [FunDef SOACS] -> [FunDef SOACS]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> CallGraph -> Bool
`isFunInCallGraph` CallGraph
cg) (Name -> Bool) -> (FunDef SOACS -> Name) -> FunDef SOACS -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FunDef SOACS -> Name
forall lore. FunDef lore -> Name
funDefName) ([FunDef SOACS] -> [FunDef SOACS])
-> [FunDef SOACS] -> [FunDef SOACS]
forall a b. (a -> b) -> a -> b
$
                          Prog SOACS -> [FunDef SOACS]
forall lore. Prog lore -> [FunDef lore]
progFuns Prog SOACS
prog
          in Prog SOACS
prog { progFuns :: [FunDef SOACS]
progFuns = [FunDef SOACS]
live_funs }