{-# LANGUAGE FlexibleContexts #-}
module Futhark.Optimise.InliningDeadFun
( inlineAndRemoveDeadFunctions
, removeDeadFunctions
)
where
import Control.Monad.Identity
import Data.List
import Data.Loc
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Representation.SOACS
import Futhark.Transform.Rename
import Futhark.Analysis.CallGraph
import Futhark.Binder
import Futhark.Pass
aggInlining :: CallGraph -> [FunDef] -> [FunDef]
aggInlining cg = filter keep . recurse
where noInterestingCalls :: S.Set Name -> FunDef -> Bool
noInterestingCalls interesting fundec =
case M.lookup (funDefName fundec) cg of
Just calls | not $ any (`elem` interesting') calls -> True
_ -> False
where interesting' = funDefName fundec `S.insert` interesting
recurse funs =
let interesting = S.fromList $ map funDefName funs
(to_be_inlined, to_inline_in) =
partition (noInterestingCalls interesting) funs
inlined_but_entry_points =
filter (isJust . funDefEntryPoint) to_be_inlined
in if null to_be_inlined then funs
else inlined_but_entry_points ++
recurse (map (`doInlineInCaller` to_be_inlined) to_inline_in)
keep fundec = isJust (funDefEntryPoint fundec) || callsRecursive fundec
callsRecursive fundec = maybe False (any recursive) $
M.lookup (funDefName fundec) cg
recursive fname = case M.lookup fname cg of
Just calls -> fname `elem` calls
Nothing -> False
doInlineInCaller :: FunDef -> [FunDef] -> FunDef
doInlineInCaller (FunDef entry name rtp args body) inlcallees =
let body' = inlineInBody inlcallees body
in FunDef entry name rtp args body'
inlineInBody :: [FunDef] -> Body -> Body
inlineInBody inlcallees (Body attr stms res) = Body attr stms' res
where stms' = stmsFromList (concatMap inline $ stmsToList stms)
inline (Let pat _ (Apply fname args _ (safety,loc,locs)))
| fun:_ <- filter ((== fname) . funDefName) inlcallees =
let param_stms = zipWith reshapeIfNecessary (map paramIdent $ funDefParams fun) (map fst args)
body_stms = stmsToList $ addLocations safety
(filter notNoLoc (loc:locs)) $ bodyStms $ funDefBody fun
res_stms = zipWith reshapeIfNecessary (patternIdents pat)
(bodyResult $ funDefBody fun)
in param_stms ++ body_stms ++ res_stms
inline stm = [inlineInStm inlcallees stm]
reshapeIfNecessary ident se
| t@Array{} <- identType ident,
Var v <- se =
mkLet [] [ident] $ shapeCoerce (arrayDims t) v
| otherwise =
mkLet [] [ident] $ BasicOp $ SubExp se
notNoLoc :: SrcLoc -> Bool
notNoLoc = (/=NoLoc) . locOf
inliner :: Monad m => [FunDef] -> Mapper SOACS SOACS m
inliner funs = identityMapper { mapOnBody = const $ return . inlineInBody funs
, mapOnOp = return . inlineInSOAC funs
}
inlineInSOAC :: [FunDef] -> SOAC SOACS -> SOAC SOACS
inlineInSOAC inlcallees = runIdentity . mapSOACM identitySOACMapper
{ mapOnSOACLambda = return . inlineInLambda inlcallees
}
inlineInStm :: [FunDef] -> Stm -> Stm
inlineInStm inlcallees (Let pat aux e) =
Let pat aux $ mapExp (inliner inlcallees) e
inlineInLambda :: [FunDef] -> Lambda -> Lambda
inlineInLambda inlcallees (Lambda params body ret) =
Lambda params (inlineInBody inlcallees body) ret
addLocations :: Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS
addLocations caller_safety more_locs = fmap onStm
where onStm stm = stm { stmExp = onExp $ stmExp stm }
onExp (Apply fname args t (safety, loc,locs)) =
Apply fname args t (min caller_safety safety, loc,locs++more_locs)
onExp (BasicOp (Assert cond desc (loc,locs))) =
case caller_safety of
Safe -> BasicOp $ Assert cond desc (loc,locs++more_locs)
Unsafe -> BasicOp $ SubExp $ Constant Checked
onExp (Op soac) = Op $ runIdentity $ mapSOACM
identitySOACMapper { mapOnSOACLambda = return . onLambda
} soac
onExp e = mapExp identityMapper { mapOnBody = const $ return . onBody
} e
onBody body =
body { bodyStms = addLocations caller_safety more_locs $ bodyStms body }
onLambda :: Lambda -> Lambda
onLambda lam = lam { lambdaBody = onBody $ lambdaBody lam }
inlineAndRemoveDeadFunctions :: Pass SOACS SOACS
inlineAndRemoveDeadFunctions =
Pass { passName = "Inline and remove dead functions"
, passDescription = "Inline and remove resulting dead functions."
, passFunction = pass
}
where pass prog = do
let cg = buildCallGraph prog
renameProg . Prog . aggInlining cg . progFunctions =<< renameProg prog
removeDeadFunctions :: Pass SOACS SOACS
removeDeadFunctions =
Pass { passName = "Remove dead functions"
, passDescription = "Remove the functions that are unreachable from the main function"
, passFunction = return . pass
}
where pass prog =
let cg = buildCallGraph prog
live_funs = filter (isFunInCallGraph cg) (progFunctions prog)
in Prog live_funs
isFunInCallGraph cg fundec = isJust $ M.lookup (funDefName fundec) cg