{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DataKinds #-}

-- | Provides the heuristics for when it's beneficial to lambda lift bindings.
-- Most significantly, this employs a cost model to estimate impact on heap
-- allocations, by looking at an STG expression's 'Skeleton'.
module StgLiftLams.Analysis (
    -- * #when# When to lift
    -- $when

    -- * #clogro# Estimating closure growth
    -- $clogro

    -- * AST annotation
    Skeleton(..), BinderInfo(..), binderInfoBndr,
    LlStgBinding, LlStgExpr, LlStgRhs, LlStgAlt, tagSkeletonTopBind,
    -- * Lifting decision
    goodToLift,
    closureGrowth -- Exported just for the docs
  ) where

import GhcPrelude

import BasicTypes
import Demand
import DynFlags
import Id
import SMRep ( WordOff )
import StgSyn
import qualified StgCmmArgRep
import qualified StgCmmClosure
import qualified StgCmmLayout
import Outputable
import Util
import VarSet

import Data.Maybe ( mapMaybe )

-- Note [When to lift]
-- ~~~~~~~~~~~~~~~~~~~
-- $when
-- The analysis proceeds in two steps:
--
--   1. It tags the syntax tree with analysis information in the form of
--      'BinderInfo' at each binder and 'Skeleton's at each let-binding
--      by 'tagSkeletonTopBind' and friends.
--   2. The resulting syntax tree is treated by the "StgLiftLams.Transformation"
--      module, calling out to 'goodToLift' to decide if a binding is worthwhile
--      to lift.
--      'goodToLift' consults argument occurrence information in 'BinderInfo'
--      and estimates 'closureGrowth', for which it needs the 'Skeleton'.
--
-- So the annotations from 'tagSkeletonTopBind' ultimately fuel 'goodToLift',
-- which employs a number of heuristics to identify and exclude lambda lifting
-- opportunities deemed non-beneficial:
--
--  [Top-level bindings] can't be lifted.
--  [Thunks] and data constructors shouldn't be lifted in order not to destroy
--    sharing.
--  [Argument occurrences] #arg_occs# of binders prohibit them to be lifted.
--    Doing the lift would re-introduce the very allocation at call sites that
--    we tried to get rid off in the first place. We capture analysis
--    information in 'BinderInfo'. Note that we also consider a nullary
--    application as argument occurrence, because it would turn into an n-ary
--    partial application created by a generic apply function. This occurs in
--    CPS-heavy code like the CS benchmark.
--  [Join points] should not be lifted, simply because there's no reduction in
--    allocation to be had.
--  [Abstracting over join points] destroys join points, because they end up as
--    arguments to the lifted function.
--  [Abstracting over known local functions] turns a known call into an unknown
--    call (e.g. some @stg_ap_*@), which is generally slower. Can be turned off
--    with @-fstg-lift-lams-known@.
--  [Calling convention] Don't lift when the resulting function would have a
--    higher arity than available argument registers for the calling convention.
--    Can be influenced with @-fstg-lift-(non)rec-args(-any)@.
--  [Closure growth] introduced when former free variables have to be available
--    at call sites may actually lead to an increase in overall allocations
--  resulting from a lift. Estimating closure growth is described in
--  "StgLiftLams.Analysis#clogro" and is what most of this module is ultimately
--  concerned with.
--
-- There's a <https://ghc.haskell.org/trac/ghc/wiki/LateLamLift wiki page> with
-- some more background and history.

-- Note [Estimating closure growth]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- $clogro
-- We estimate closure growth by abstracting the syntax tree into a 'Skeleton',
-- capturing only syntactic details relevant to 'closureGrowth', such as
--
--   * 'ClosureSk', representing closure allocation.
--   * 'RhsSk', representing a RHS of a binding and how many times it's called
--     by an appropriate 'DmdShell'.
--   * 'AltSk', 'BothSk' and 'NilSk' for choice, sequence and empty element.
--
-- This abstraction is mostly so that the main analysis function 'closureGrowth'
-- can stay simple and focused. Also, skeletons tend to be much smaller than
-- the syntax tree they abstract, so it makes sense to construct them once and
-- and operate on them instead of the actual syntax tree.
--
-- A more detailed treatment of computing closure growth, including examples,
-- can be found in the paper referenced from the
-- <https://ghc.haskell.org/trac/ghc/wiki/LateLamLift wiki page>.

llTrace :: String -> SDoc -> a -> a
llTrace :: String -> SDoc -> a -> a
llTrace _ _ c :: a
c = a
c
-- llTrace a b c = pprTrace a b c

type instance BinderP      'LiftLams = BinderInfo
type instance XRhsClosure  'LiftLams = DIdSet
type instance XLet         'LiftLams = Skeleton
type instance XLetNoEscape 'LiftLams = Skeleton

freeVarsOfRhs :: (XRhsClosure pass ~ DIdSet) => GenStgRhs pass -> DIdSet
freeVarsOfRhs :: GenStgRhs pass -> DIdSet
freeVarsOfRhs (StgRhsCon _ _ args :: [StgArg]
args) = [Var] -> DIdSet
mkDVarSet [ Var
id | StgVarArg id :: Var
id <- [StgArg]
args ]
freeVarsOfRhs (StgRhsClosure fvs :: XRhsClosure pass
fvs _ _ _ _) = DIdSet
XRhsClosure pass
fvs

-- | Captures details of the syntax tree relevant to the cost model, such as
-- closures, multi-shot lambdas and case expressions.
data Skeleton
  = ClosureSk !Id !DIdSet {- ^ free vars -} !Skeleton
  | RhsSk !DmdShell {- ^ how often the RHS was entered -} !Skeleton
  | AltSk !Skeleton !Skeleton
  | BothSk !Skeleton !Skeleton
  | NilSk

bothSk :: Skeleton -> Skeleton -> Skeleton
bothSk :: Skeleton -> Skeleton -> Skeleton
bothSk NilSk b :: Skeleton
b = Skeleton
b
bothSk a :: Skeleton
a NilSk = Skeleton
a
bothSk a :: Skeleton
a b :: Skeleton
b     = Skeleton -> Skeleton -> Skeleton
BothSk Skeleton
a Skeleton
b

altSk :: Skeleton -> Skeleton -> Skeleton
altSk :: Skeleton -> Skeleton -> Skeleton
altSk NilSk b :: Skeleton
b = Skeleton
b
altSk a :: Skeleton
a NilSk = Skeleton
a
altSk a :: Skeleton
a b :: Skeleton
b     = Skeleton -> Skeleton -> Skeleton
AltSk Skeleton
a Skeleton
b

rhsSk :: DmdShell -> Skeleton -> Skeleton
rhsSk :: DmdShell -> Skeleton -> Skeleton
rhsSk _        NilSk = Skeleton
NilSk
rhsSk body_dmd :: DmdShell
body_dmd skel :: Skeleton
skel  = DmdShell -> Skeleton -> Skeleton
RhsSk DmdShell
body_dmd Skeleton
skel

-- | The type used in binder positions in 'GenStgExpr's.
data BinderInfo
  = BindsClosure !Id !Bool -- ^ Let(-no-escape)-bound thing with a flag
                           --   indicating whether it occurs as an argument
                           --   or in a nullary application
                           --   (see "StgLiftLams.Analysis#arg_occs").
  | BoringBinder !Id       -- ^ Every other kind of binder

-- | Gets the bound 'Id' out a 'BinderInfo'.
binderInfoBndr :: BinderInfo -> Id
binderInfoBndr :: BinderInfo -> Var
binderInfoBndr (BoringBinder bndr :: Var
bndr)   = Var
bndr
binderInfoBndr (BindsClosure bndr :: Var
bndr _) = Var
bndr

-- | Returns 'Nothing' for 'BoringBinder's and 'Just' the flag indicating
-- occurrences as argument or in a nullary applications otherwise.
binderInfoOccursAsArg :: BinderInfo -> Maybe Bool
binderInfoOccursAsArg :: BinderInfo -> Maybe Bool
binderInfoOccursAsArg BoringBinder{}     = Maybe Bool
forall a. Maybe a
Nothing
binderInfoOccursAsArg (BindsClosure _ b :: Bool
b) = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
b

instance Outputable Skeleton where
  ppr :: Skeleton -> SDoc
ppr NilSk = String -> SDoc
text ""
  ppr (AltSk l :: Skeleton
l r :: Skeleton
r) = [SDoc] -> SDoc
vcat
    [ String -> SDoc
text "{ " SDoc -> SDoc -> SDoc
<+> Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
l
    , String -> SDoc
text "ALT"
    , String -> SDoc
text "  " SDoc -> SDoc -> SDoc
<+> Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
r
    , String -> SDoc
text "}"
    ]
  ppr (BothSk l :: Skeleton
l r :: Skeleton
r) = Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
l SDoc -> SDoc -> SDoc
$$ Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
r
  ppr (ClosureSk f :: Var
f fvs :: DIdSet
fvs body :: Skeleton
body) = Var -> SDoc
forall a. Outputable a => a -> SDoc
ppr Var
f SDoc -> SDoc -> SDoc
<+> DIdSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr DIdSet
fvs SDoc -> SDoc -> SDoc
$$ Int -> SDoc -> SDoc
nest 2 (Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
body)
  ppr (RhsSk body_dmd :: DmdShell
body_dmd body :: Skeleton
body) = [SDoc] -> SDoc
hcat
    [ String -> SDoc
text "λ["
    , Char -> SDoc
forall a. Outputable a => a -> SDoc
ppr Char
str
    , String -> SDoc
text ", "
    , Char -> SDoc
forall a. Outputable a => a -> SDoc
ppr Char
use
    , String -> SDoc
text "]. "
    , Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
body
    ]
    where
      str :: Char
str
        | DmdShell -> Bool
forall s u. JointDmd (Str s) (Use u) -> Bool
isStrictDmd DmdShell
body_dmd = '1'
        | Bool
otherwise = '0'
      use :: Char
use
        | DmdShell -> Bool
forall s u. JointDmd (Str s) (Use u) -> Bool
isAbsDmd DmdShell
body_dmd = '0'
        | DmdShell -> Bool
forall s u. JointDmd (Str s) (Use u) -> Bool
isUsedOnce DmdShell
body_dmd = '1'
        | Bool
otherwise = 'ω'

instance Outputable BinderInfo where
  ppr :: BinderInfo -> SDoc
ppr = Var -> SDoc
forall a. Outputable a => a -> SDoc
ppr (Var -> SDoc) -> (BinderInfo -> Var) -> BinderInfo -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderInfo -> Var
binderInfoBndr

instance OutputableBndr BinderInfo where
  pprBndr :: BindingSite -> BinderInfo -> SDoc
pprBndr b :: BindingSite
b = BindingSite -> Var -> SDoc
forall a. OutputableBndr a => BindingSite -> a -> SDoc
pprBndr BindingSite
b (Var -> SDoc) -> (BinderInfo -> Var) -> BinderInfo -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderInfo -> Var
binderInfoBndr
  pprPrefixOcc :: BinderInfo -> SDoc
pprPrefixOcc = Var -> SDoc
forall a. OutputableBndr a => a -> SDoc
pprPrefixOcc (Var -> SDoc) -> (BinderInfo -> Var) -> BinderInfo -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderInfo -> Var
binderInfoBndr
  pprInfixOcc :: BinderInfo -> SDoc
pprInfixOcc = Var -> SDoc
forall a. OutputableBndr a => a -> SDoc
pprInfixOcc (Var -> SDoc) -> (BinderInfo -> Var) -> BinderInfo -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderInfo -> Var
binderInfoBndr
  bndrIsJoin_maybe :: BinderInfo -> Maybe Int
bndrIsJoin_maybe = Var -> Maybe Int
forall a. OutputableBndr a => a -> Maybe Int
bndrIsJoin_maybe (Var -> Maybe Int)
-> (BinderInfo -> Var) -> BinderInfo -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinderInfo -> Var
binderInfoBndr

mkArgOccs :: [StgArg] -> IdSet
mkArgOccs :: [StgArg] -> IdSet
mkArgOccs = [Var] -> IdSet
mkVarSet ([Var] -> IdSet) -> ([StgArg] -> [Var]) -> [StgArg] -> IdSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StgArg -> Maybe Var) -> [StgArg] -> [Var]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe StgArg -> Maybe Var
stg_arg_var
  where
    stg_arg_var :: StgArg -> Maybe Var
stg_arg_var (StgVarArg occ :: Var
occ) = Var -> Maybe Var
forall a. a -> Maybe a
Just Var
occ
    stg_arg_var _               = Maybe Var
forall a. Maybe a
Nothing

-- | Tags every binder with its 'BinderInfo' and let bindings with their
-- 'Skeleton's.
tagSkeletonTopBind :: CgStgBinding -> LlStgBinding
-- NilSk is OK when tagging top-level bindings. Also, top-level things are never
-- lambda-lifted, so no need to track their argument occurrences. They can also
-- never be let-no-escapes (thus we pass False).
tagSkeletonTopBind :: CgStgBinding -> LlStgBinding
tagSkeletonTopBind bind :: CgStgBinding
bind = LlStgBinding
bind'
  where
    (_, _, _, bind' :: LlStgBinding
bind') = Bool
-> Skeleton
-> IdSet
-> CgStgBinding
-> (Skeleton, IdSet, Skeleton, LlStgBinding)
tagSkeletonBinding Bool
False Skeleton
NilSk IdSet
emptyVarSet CgStgBinding
bind

-- | Tags binders of an 'StgExpr' with its 'BinderInfo' and let bindings with
-- their 'Skeleton's. Additionally, returns its 'Skeleton' and the set of binder
-- occurrences in argument and nullary application position
-- (cf. "StgLiftLams.Analysis#arg_occs").
tagSkeletonExpr :: CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr :: CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr (StgLit lit :: Literal
lit)
  = (Skeleton
NilSk, IdSet
emptyVarSet, Literal -> LlStgExpr
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
tagSkeletonExpr (StgConApp con :: DataCon
con args :: [StgArg]
args tys :: [Type]
tys)
  = (Skeleton
NilSk, [StgArg] -> IdSet
mkArgOccs [StgArg]
args, DataCon -> [StgArg] -> [Type] -> LlStgExpr
forall (pass :: StgPass).
DataCon -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con [StgArg]
args [Type]
tys)
tagSkeletonExpr (StgOpApp op :: StgOp
op args :: [StgArg]
args ty :: Type
ty)
  = (Skeleton
NilSk, [StgArg] -> IdSet
mkArgOccs [StgArg]
args, StgOp -> [StgArg] -> Type -> LlStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
args Type
ty)
tagSkeletonExpr (StgApp f :: Var
f args :: [StgArg]
args)
  = (Skeleton
NilSk, IdSet
arg_occs, Var -> [StgArg] -> LlStgExpr
forall (pass :: StgPass). Var -> [StgArg] -> GenStgExpr pass
StgApp Var
f [StgArg]
args)
  where
    arg_occs :: IdSet
arg_occs
      -- This checks for nullary applications, which we treat the same as
      -- argument occurrences, see "StgLiftLams.Analysis#arg_occs".
      | [StgArg] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [StgArg]
args = Var -> IdSet
unitVarSet Var
f
      | Bool
otherwise = [StgArg] -> IdSet
mkArgOccs [StgArg]
args
tagSkeletonExpr (StgLam _ _) = String -> SDoc -> (Skeleton, IdSet, LlStgExpr)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "stgLiftLams" (String -> SDoc
text "StgLam")
tagSkeletonExpr (StgCase scrut :: CgStgExpr
scrut bndr :: BinderP 'CodeGen
bndr ty :: AltType
ty alts :: [GenStgAlt 'CodeGen]
alts)
  = (Skeleton
skel, IdSet
arg_occs, LlStgExpr
-> BinderP 'LiftLams
-> AltType
-> [GenStgAlt 'LiftLams]
-> LlStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase LlStgExpr
scrut' BinderP 'LiftLams
BinderInfo
bndr' AltType
ty [GenStgAlt 'LiftLams]
[(AltCon, [BinderInfo], LlStgExpr)]
alts')
  where
    (scrut_skel :: Skeleton
scrut_skel, scrut_arg_occs :: IdSet
scrut_arg_occs, scrut' :: LlStgExpr
scrut') = CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr CgStgExpr
scrut
    (alt_skels :: [Skeleton]
alt_skels, alt_arg_occss :: [IdSet]
alt_arg_occss, alts' :: [(AltCon, [BinderInfo], LlStgExpr)]
alts') = ((AltCon, [Var], CgStgExpr)
 -> (Skeleton, IdSet, (AltCon, [BinderInfo], LlStgExpr)))
-> [(AltCon, [Var], CgStgExpr)]
-> ([Skeleton], [IdSet], [(AltCon, [BinderInfo], LlStgExpr)])
forall a b c d. (a -> (b, c, d)) -> [a] -> ([b], [c], [d])
mapAndUnzip3 (AltCon, [Var], CgStgExpr)
-> (Skeleton, IdSet, (AltCon, [BinderInfo], LlStgExpr))
GenStgAlt 'CodeGen -> (Skeleton, IdSet, GenStgAlt 'LiftLams)
tagSkeletonAlt [(AltCon, [Var], CgStgExpr)]
[GenStgAlt 'CodeGen]
alts
    skel :: Skeleton
skel = Skeleton -> Skeleton -> Skeleton
bothSk Skeleton
scrut_skel ((Skeleton -> Skeleton -> Skeleton)
-> Skeleton -> [Skeleton] -> Skeleton
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Skeleton -> Skeleton -> Skeleton
altSk Skeleton
NilSk [Skeleton]
alt_skels)
    arg_occs :: IdSet
arg_occs = [IdSet] -> IdSet
unionVarSets (IdSet
scrut_arg_occsIdSet -> [IdSet] -> [IdSet]
forall a. a -> [a] -> [a]
:[IdSet]
alt_arg_occss) IdSet -> Var -> IdSet
`delVarSet` Var
BinderP 'CodeGen
bndr
    bndr' :: BinderInfo
bndr' = Var -> BinderInfo
BoringBinder Var
BinderP 'CodeGen
bndr
tagSkeletonExpr (StgTick t :: Tickish Var
t e :: CgStgExpr
e)
  = (Skeleton
skel, IdSet
arg_occs, Tickish Var -> LlStgExpr -> LlStgExpr
forall (pass :: StgPass).
Tickish Var -> GenStgExpr pass -> GenStgExpr pass
StgTick Tickish Var
t LlStgExpr
e')
  where
    (skel :: Skeleton
skel, arg_occs :: IdSet
arg_occs, e' :: LlStgExpr
e') = CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr CgStgExpr
e
tagSkeletonExpr (StgLet _ bind :: CgStgBinding
bind body :: CgStgExpr
body) = Bool -> CgStgExpr -> CgStgBinding -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonLet Bool
False CgStgExpr
body CgStgBinding
bind
tagSkeletonExpr (StgLetNoEscape _ bind :: CgStgBinding
bind body :: CgStgExpr
body) = Bool -> CgStgExpr -> CgStgBinding -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonLet Bool
True CgStgExpr
body CgStgBinding
bind

mkLet :: Bool -> Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr
mkLet :: Bool -> Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr
mkLet True = Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr
forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape
mkLet _    = Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet

tagSkeletonLet
  :: Bool
  -- ^ Is the binding a let-no-escape?
  -> CgStgExpr
  -- ^ Let body
  -> CgStgBinding
  -- ^ Binding group
  -> (Skeleton, IdSet, LlStgExpr)
  -- ^ RHS skeletons, argument occurrences and annotated binding
tagSkeletonLet :: Bool -> CgStgExpr -> CgStgBinding -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonLet is_lne :: Bool
is_lne body :: CgStgExpr
body bind :: CgStgBinding
bind
  = (Skeleton
let_skel, IdSet
arg_occs, Bool -> Skeleton -> LlStgBinding -> LlStgExpr -> LlStgExpr
mkLet Bool
is_lne Skeleton
scope LlStgBinding
bind' LlStgExpr
body')
  where
    (body_skel :: Skeleton
body_skel, body_arg_occs :: IdSet
body_arg_occs, body' :: LlStgExpr
body') = CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr CgStgExpr
body
    (let_skel :: Skeleton
let_skel, arg_occs :: IdSet
arg_occs, scope :: Skeleton
scope, bind' :: LlStgBinding
bind')
      = Bool
-> Skeleton
-> IdSet
-> CgStgBinding
-> (Skeleton, IdSet, Skeleton, LlStgBinding)
tagSkeletonBinding Bool
is_lne Skeleton
body_skel IdSet
body_arg_occs CgStgBinding
bind

tagSkeletonBinding
  :: Bool
  -- ^ Is the binding a let-no-escape?
  -> Skeleton
  -- ^ Let body skeleton
  -> IdSet
  -- ^ Argument occurrences in the body
  -> CgStgBinding
  -- ^ Binding group
  -> (Skeleton, IdSet, Skeleton, LlStgBinding)
  -- ^ Let skeleton, argument occurrences, scope skeleton of binding and
  --   the annotated binding
tagSkeletonBinding :: Bool
-> Skeleton
-> IdSet
-> CgStgBinding
-> (Skeleton, IdSet, Skeleton, LlStgBinding)
tagSkeletonBinding is_lne :: Bool
is_lne body_skel :: Skeleton
body_skel body_arg_occs :: IdSet
body_arg_occs (StgNonRec bndr :: BinderP 'CodeGen
bndr rhs :: GenStgRhs 'CodeGen
rhs)
  = (Skeleton
let_skel, IdSet
arg_occs, Skeleton
scope, LlStgBinding
bind')
  where
    (rhs_skel :: Skeleton
rhs_skel, rhs_arg_occs :: IdSet
rhs_arg_occs, rhs' :: LlStgRhs
rhs') = Var -> GenStgRhs 'CodeGen -> (Skeleton, IdSet, LlStgRhs)
tagSkeletonRhs Var
BinderP 'CodeGen
bndr GenStgRhs 'CodeGen
rhs
    arg_occs :: IdSet
arg_occs = (IdSet
body_arg_occs IdSet -> IdSet -> IdSet
`unionVarSet` IdSet
rhs_arg_occs) IdSet -> Var -> IdSet
`delVarSet` Var
BinderP 'CodeGen
bndr
    bind_skel :: Skeleton
bind_skel
      | Bool
is_lne    = Skeleton
rhs_skel -- no closure is allocated for let-no-escapes
      | Bool
otherwise = Var -> DIdSet -> Skeleton -> Skeleton
ClosureSk Var
BinderP 'CodeGen
bndr (GenStgRhs 'CodeGen -> DIdSet
forall (pass :: StgPass).
(XRhsClosure pass ~ DIdSet) =>
GenStgRhs pass -> DIdSet
freeVarsOfRhs GenStgRhs 'CodeGen
rhs) Skeleton
rhs_skel
    let_skel :: Skeleton
let_skel = Skeleton -> Skeleton -> Skeleton
bothSk Skeleton
body_skel Skeleton
bind_skel
    occurs_as_arg :: Bool
occurs_as_arg = Var
BinderP 'CodeGen
bndr Var -> IdSet -> Bool
`elemVarSet` IdSet
body_arg_occs
    -- Compared to the recursive case, this exploits the fact that @bndr@ is
    -- never free in @rhs@.
    scope :: Skeleton
scope = Skeleton
body_skel
    bind' :: LlStgBinding
bind' = BinderP 'LiftLams -> LlStgRhs -> LlStgBinding
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec (Var -> Bool -> BinderInfo
BindsClosure Var
BinderP 'CodeGen
bndr Bool
occurs_as_arg) LlStgRhs
rhs'
tagSkeletonBinding is_lne :: Bool
is_lne body_skel :: Skeleton
body_skel body_arg_occs :: IdSet
body_arg_occs (StgRec pairs :: [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
pairs)
  = (Skeleton
let_skel, IdSet
arg_occs, Skeleton
scope, [(BinderP 'LiftLams, LlStgRhs)] -> LlStgBinding
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec [(BinderP 'LiftLams, LlStgRhs)]
[(BinderInfo, LlStgRhs)]
pairs')
  where
    (bndrs :: [Var]
bndrs, _) = [(Var, GenStgRhs 'CodeGen)] -> ([Var], [GenStgRhs 'CodeGen])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Var, GenStgRhs 'CodeGen)]
[(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
pairs
    -- Local recursive STG bindings also regard the defined binders as free
    -- vars. We want to delete those for our cost model, as these are known
    -- calls anyway when we add them to the same top-level recursive group as
    -- the top-level binding currently being analysed.
    skel_occs_rhss' :: [(Skeleton, IdSet, LlStgRhs)]
skel_occs_rhss' = ((Var, GenStgRhs 'CodeGen) -> (Skeleton, IdSet, LlStgRhs))
-> [(Var, GenStgRhs 'CodeGen)] -> [(Skeleton, IdSet, LlStgRhs)]
forall a b. (a -> b) -> [a] -> [b]
map ((Var -> GenStgRhs 'CodeGen -> (Skeleton, IdSet, LlStgRhs))
-> (Var, GenStgRhs 'CodeGen) -> (Skeleton, IdSet, LlStgRhs)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Var -> GenStgRhs 'CodeGen -> (Skeleton, IdSet, LlStgRhs)
tagSkeletonRhs) [(Var, GenStgRhs 'CodeGen)]
[(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
pairs
    rhss_arg_occs :: [IdSet]
rhss_arg_occs = ((Skeleton, IdSet, LlStgRhs) -> IdSet)
-> [(Skeleton, IdSet, LlStgRhs)] -> [IdSet]
forall a b. (a -> b) -> [a] -> [b]
map (Skeleton, IdSet, LlStgRhs) -> IdSet
forall a b c. (a, b, c) -> b
sndOf3 [(Skeleton, IdSet, LlStgRhs)]
skel_occs_rhss'
    scope_occs :: IdSet
scope_occs = [IdSet] -> IdSet
unionVarSets (IdSet
body_arg_occsIdSet -> [IdSet] -> [IdSet]
forall a. a -> [a] -> [a]
:[IdSet]
rhss_arg_occs)
    arg_occs :: IdSet
arg_occs = IdSet
scope_occs IdSet -> [Var] -> IdSet
`delVarSetList` [Var]
bndrs
    -- @skel_rhss@ aren't yet wrapped in closures. We'll do that in a moment,
    -- but we also need the un-wrapped skeletons for calculating the @scope@
    -- of the group, as the outer closures don't contribute to closure growth
    -- when we lift this specific binding.
    scope :: Skeleton
scope = ((Skeleton, IdSet, LlStgRhs) -> Skeleton -> Skeleton)
-> Skeleton -> [(Skeleton, IdSet, LlStgRhs)] -> Skeleton
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Skeleton -> Skeleton -> Skeleton
bothSk (Skeleton -> Skeleton -> Skeleton)
-> ((Skeleton, IdSet, LlStgRhs) -> Skeleton)
-> (Skeleton, IdSet, LlStgRhs)
-> Skeleton
-> Skeleton
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Skeleton, IdSet, LlStgRhs) -> Skeleton
forall a b c. (a, b, c) -> a
fstOf3) Skeleton
body_skel [(Skeleton, IdSet, LlStgRhs)]
skel_occs_rhss'
    -- Now we can build the actual Skeleton for the expression just by
    -- iterating over each bind pair.
    (bind_skels :: [Skeleton]
bind_skels, pairs' :: [(BinderInfo, LlStgRhs)]
pairs') = [(Skeleton, (BinderInfo, LlStgRhs))]
-> ([Skeleton], [(BinderInfo, LlStgRhs)])
forall a b. [(a, b)] -> ([a], [b])
unzip ((Var
 -> (Skeleton, IdSet, LlStgRhs)
 -> (Skeleton, (BinderInfo, LlStgRhs)))
-> [Var]
-> [(Skeleton, IdSet, LlStgRhs)]
-> [(Skeleton, (BinderInfo, LlStgRhs))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Var
-> (Skeleton, IdSet, LlStgRhs)
-> (Skeleton, (BinderInfo, LlStgRhs))
single_bind [Var]
bndrs [(Skeleton, IdSet, LlStgRhs)]
skel_occs_rhss')
    let_skel :: Skeleton
let_skel = (Skeleton -> Skeleton -> Skeleton)
-> Skeleton -> [Skeleton] -> Skeleton
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Skeleton -> Skeleton -> Skeleton
bothSk Skeleton
body_skel [Skeleton]
bind_skels
    single_bind :: Var
-> (Skeleton, IdSet, LlStgRhs)
-> (Skeleton, (BinderInfo, LlStgRhs))
single_bind bndr :: Var
bndr (skel_rhs :: Skeleton
skel_rhs, _, rhs' :: LlStgRhs
rhs') = (Skeleton
bind_skel, (BinderInfo
bndr', LlStgRhs
rhs'))
      where
        -- Here, we finally add the closure around each @skel_rhs@.
        bind_skel :: Skeleton
bind_skel
          | Bool
is_lne    = Skeleton
skel_rhs -- no closure is allocated for let-no-escapes
          | Bool
otherwise = Var -> DIdSet -> Skeleton -> Skeleton
ClosureSk Var
bndr DIdSet
fvs Skeleton
skel_rhs
        fvs :: DIdSet
fvs = LlStgRhs -> DIdSet
forall (pass :: StgPass).
(XRhsClosure pass ~ DIdSet) =>
GenStgRhs pass -> DIdSet
freeVarsOfRhs LlStgRhs
rhs' DIdSet -> IdSet -> DIdSet
`dVarSetMinusVarSet` [Var] -> IdSet
mkVarSet [Var]
bndrs
        bndr' :: BinderInfo
bndr' = Var -> Bool -> BinderInfo
BindsClosure Var
bndr (Var
bndr Var -> IdSet -> Bool
`elemVarSet` IdSet
scope_occs)

tagSkeletonRhs :: Id -> CgStgRhs -> (Skeleton, IdSet, LlStgRhs)
tagSkeletonRhs :: Var -> GenStgRhs 'CodeGen -> (Skeleton, IdSet, LlStgRhs)
tagSkeletonRhs _ (StgRhsCon ccs :: CostCentreStack
ccs dc :: DataCon
dc args :: [StgArg]
args)
  = (Skeleton
NilSk, [StgArg] -> IdSet
mkArgOccs [StgArg]
args, CostCentreStack -> DataCon -> [StgArg] -> LlStgRhs
forall (pass :: StgPass).
CostCentreStack -> DataCon -> [StgArg] -> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
dc [StgArg]
args)
tagSkeletonRhs bndr :: Var
bndr (StgRhsClosure fvs :: XRhsClosure 'CodeGen
fvs ccs :: CostCentreStack
ccs upd :: UpdateFlag
upd bndrs :: [BinderP 'CodeGen]
bndrs body :: CgStgExpr
body)
  = (Skeleton
rhs_skel, IdSet
body_arg_occs, XRhsClosure 'LiftLams
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'LiftLams]
-> LlStgExpr
-> LlStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'LiftLams
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
upd [BinderP 'LiftLams]
[BinderInfo]
bndrs' LlStgExpr
body')
  where
    bndrs' :: [BinderInfo]
bndrs' = (Var -> BinderInfo) -> [Var] -> [BinderInfo]
forall a b. (a -> b) -> [a] -> [b]
map Var -> BinderInfo
BoringBinder [Var]
[BinderP 'CodeGen]
bndrs
    (body_skel :: Skeleton
body_skel, body_arg_occs :: IdSet
body_arg_occs, body' :: LlStgExpr
body') = CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr CgStgExpr
body
    rhs_skel :: Skeleton
rhs_skel = DmdShell -> Skeleton -> Skeleton
rhsSk (Var -> DmdShell
rhsDmdShell Var
bndr) Skeleton
body_skel

-- | How many times will the lambda body of the RHS bound to the given
-- identifier be evaluated, relative to its defining context? This function
-- computes the answer in form of a 'DmdShell'.
rhsDmdShell :: Id -> DmdShell
rhsDmdShell :: Var -> DmdShell
rhsDmdShell bndr :: Var
bndr
  | Bool
is_thunk = DmdShell -> DmdShell
forall s u. JointDmd s (Use u) -> JointDmd s (Use u)
oneifyDmd DmdShell
ds
  | Bool
otherwise = Int -> CleanDemand -> DmdShell
peelManyCalls (Var -> Int
idArity Var
bndr) CleanDemand
cd
  where
    is_thunk :: Bool
is_thunk = Var -> Int
idArity Var
bndr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0
    -- Let's pray idDemandInfo is still OK after unarise...
    (ds :: DmdShell
ds, cd :: CleanDemand
cd) = Demand -> (DmdShell, CleanDemand)
toCleanDmd (Var -> Demand
idDemandInfo Var
bndr)

tagSkeletonAlt :: CgStgAlt -> (Skeleton, IdSet, LlStgAlt)
tagSkeletonAlt :: GenStgAlt 'CodeGen -> (Skeleton, IdSet, GenStgAlt 'LiftLams)
tagSkeletonAlt (con :: AltCon
con, bndrs :: [BinderP 'CodeGen]
bndrs, rhs :: CgStgExpr
rhs)
  = (Skeleton
alt_skel, IdSet
arg_occs, (AltCon
con, (Var -> BinderInfo) -> [Var] -> [BinderInfo]
forall a b. (a -> b) -> [a] -> [b]
map Var -> BinderInfo
BoringBinder [Var]
[BinderP 'CodeGen]
bndrs, LlStgExpr
rhs'))
  where
    (alt_skel :: Skeleton
alt_skel, alt_arg_occs :: IdSet
alt_arg_occs, rhs' :: LlStgExpr
rhs') = CgStgExpr -> (Skeleton, IdSet, LlStgExpr)
tagSkeletonExpr CgStgExpr
rhs
    arg_occs :: IdSet
arg_occs = IdSet
alt_arg_occs IdSet -> [Var] -> IdSet
`delVarSetList` [Var]
[BinderP 'CodeGen]
bndrs

-- | Combines several heuristics to decide whether to lambda-lift a given
-- @let@-binding to top-level. See "StgLiftLams.Analysis#when" for details.
goodToLift
  :: DynFlags
  -> TopLevelFlag
  -> RecFlag
  -> (DIdSet -> DIdSet) -- ^ An expander function, turning 'InId's into
                        -- 'OutId's. See 'StgLiftLams.LiftM.liftedIdsExpander'.
  -> [(BinderInfo, LlStgRhs)]
  -> Skeleton
  -> Maybe DIdSet       -- ^ @Just abs_ids@ <=> This binding is beneficial to
                        -- lift and @abs_ids@ are the variables it would
                        -- abstract over
goodToLift :: DynFlags
-> TopLevelFlag
-> RecFlag
-> (DIdSet -> DIdSet)
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> Maybe DIdSet
goodToLift dflags :: DynFlags
dflags top_lvl :: TopLevelFlag
top_lvl rec_flag :: RecFlag
rec_flag expander :: DIdSet -> DIdSet
expander pairs :: [(BinderInfo, LlStgRhs)]
pairs scope :: Skeleton
scope = [(String, Bool)] -> Maybe DIdSet
decide
  [ ("top-level", TopLevelFlag -> Bool
isTopLevel TopLevelFlag
top_lvl) -- keep in sync with Note [When to lift]
  , ("memoized", Bool
any_memoized)
  , ("argument occurrences", Bool
arg_occs)
  , ("join point", Bool
is_join_point)
  , ("abstracts join points", Bool
abstracts_join_ids)
  , ("abstracts known local function", Bool
abstracts_known_local_fun)
  , ("args spill on stack", Bool
args_spill_on_stack)
  , ("increases allocation", Bool
inc_allocs)
  ] where
      decide :: [(String, Bool)] -> Maybe DIdSet
decide deciders :: [(String, Bool)]
deciders
        | Bool -> Bool
not ([(String, Bool)] -> Bool
fancy_or [(String, Bool)]
deciders)
        = String -> SDoc -> Maybe DIdSet -> Maybe DIdSet
forall a. String -> SDoc -> a -> a
llTrace "stgLiftLams:lifting"
                  ([Var] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Var]
bndrs SDoc -> SDoc -> SDoc
<+> DIdSet -> SDoc
forall a. Outputable a => a -> SDoc
ppr DIdSet
abs_ids SDoc -> SDoc -> SDoc
$$
                   IntWithInf -> SDoc
forall a. Outputable a => a -> SDoc
ppr IntWithInf
allocs SDoc -> SDoc -> SDoc
$$
                   Skeleton -> SDoc
forall a. Outputable a => a -> SDoc
ppr Skeleton
scope) (Maybe DIdSet -> Maybe DIdSet) -> Maybe DIdSet -> Maybe DIdSet
forall a b. (a -> b) -> a -> b
$
          DIdSet -> Maybe DIdSet
forall a. a -> Maybe a
Just DIdSet
abs_ids
        | Bool
otherwise
        = Maybe DIdSet
forall a. Maybe a
Nothing
      ppr_deciders :: [(String, Bool)] -> SDoc
ppr_deciders = [SDoc] -> SDoc
vcat ([SDoc] -> SDoc)
-> ([(String, Bool)] -> [SDoc]) -> [(String, Bool)] -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Bool) -> SDoc) -> [(String, Bool)] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map (String -> SDoc
text (String -> SDoc)
-> ((String, Bool) -> String) -> (String, Bool) -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String, Bool) -> String
forall a b. (a, b) -> a
fst) ([(String, Bool)] -> [SDoc])
-> ([(String, Bool)] -> [(String, Bool)])
-> [(String, Bool)]
-> [SDoc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Bool) -> Bool) -> [(String, Bool)] -> [(String, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (String, Bool) -> Bool
forall a b. (a, b) -> b
snd
      fancy_or :: [(String, Bool)] -> Bool
fancy_or deciders :: [(String, Bool)]
deciders
        = String -> SDoc -> Bool -> Bool
forall a. String -> SDoc -> a -> a
llTrace "stgLiftLams:goodToLift" ([Var] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Var]
bndrs SDoc -> SDoc -> SDoc
$$ [(String, Bool)] -> SDoc
ppr_deciders [(String, Bool)]
deciders) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
          ((String, Bool) -> Bool) -> [(String, Bool)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (String, Bool) -> Bool
forall a b. (a, b) -> b
snd [(String, Bool)]
deciders

      bndrs :: [Var]
bndrs = ((BinderInfo, LlStgRhs) -> Var)
-> [(BinderInfo, LlStgRhs)] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
map (BinderInfo -> Var
binderInfoBndr (BinderInfo -> Var)
-> ((BinderInfo, LlStgRhs) -> BinderInfo)
-> (BinderInfo, LlStgRhs)
-> Var
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BinderInfo, LlStgRhs) -> BinderInfo
forall a b. (a, b) -> a
fst) [(BinderInfo, LlStgRhs)]
pairs
      bndrs_set :: IdSet
bndrs_set = [Var] -> IdSet
mkVarSet [Var]
bndrs
      rhss :: [LlStgRhs]
rhss = ((BinderInfo, LlStgRhs) -> LlStgRhs)
-> [(BinderInfo, LlStgRhs)] -> [LlStgRhs]
forall a b. (a -> b) -> [a] -> [b]
map (BinderInfo, LlStgRhs) -> LlStgRhs
forall a b. (a, b) -> b
snd [(BinderInfo, LlStgRhs)]
pairs

      -- First objective: Calculate @abs_ids@, e.g. the former free variables
      -- the lifted binding would abstract over. We have to merge the free
      -- variables of all RHS to get the set of variables that will have to be
      -- passed through parameters.
      fvs :: DIdSet
fvs = [DIdSet] -> DIdSet
unionDVarSets ((LlStgRhs -> DIdSet) -> [LlStgRhs] -> [DIdSet]
forall a b. (a -> b) -> [a] -> [b]
map LlStgRhs -> DIdSet
forall (pass :: StgPass).
(XRhsClosure pass ~ DIdSet) =>
GenStgRhs pass -> DIdSet
freeVarsOfRhs [LlStgRhs]
rhss)
      -- To lift the binding to top-level, we want to delete the lifted binders
      -- themselves from the free var set. Local let bindings track recursive
      -- occurrences in their free variable set. We neither want to apply our
      -- cost model to them (see 'tagSkeletonRhs'), nor pass them as parameters
      -- when lifted, as these are known calls. We call the resulting set the
      -- identifiers we abstract over, thus @abs_ids@. These are all 'OutId's.
      -- We will save the set in 'LiftM.e_expansions' for each of the variables
      -- if we perform the lift.
      abs_ids :: DIdSet
abs_ids = DIdSet -> DIdSet
expander (DIdSet -> [Var] -> DIdSet
delDVarSetList DIdSet
fvs [Var]
bndrs)

      -- We don't lift updatable thunks or constructors
      any_memoized :: Bool
any_memoized = (LlStgRhs -> Bool) -> [LlStgRhs] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any LlStgRhs -> Bool
forall (pass :: StgPass). GenStgRhs pass -> Bool
is_memoized_rhs [LlStgRhs]
rhss
      is_memoized_rhs :: GenStgRhs pass -> Bool
is_memoized_rhs StgRhsCon{} = Bool
True
      is_memoized_rhs (StgRhsClosure _ _ upd :: UpdateFlag
upd _ _) = UpdateFlag -> Bool
isUpdatable UpdateFlag
upd

      -- Don't lift binders occuring as arguments. This would result in complex
      -- argument expressions which would have to be given a name, reintroducing
      -- the very allocation at each call site that we wanted to get rid off in
      -- the first place.
      arg_occs :: Bool
arg_occs = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or (((BinderInfo, LlStgRhs) -> Maybe Bool)
-> [(BinderInfo, LlStgRhs)] -> [Bool]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (BinderInfo -> Maybe Bool
binderInfoOccursAsArg (BinderInfo -> Maybe Bool)
-> ((BinderInfo, LlStgRhs) -> BinderInfo)
-> (BinderInfo, LlStgRhs)
-> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BinderInfo, LlStgRhs) -> BinderInfo
forall a b. (a, b) -> a
fst) [(BinderInfo, LlStgRhs)]
pairs)

      -- These don't allocate anyway.
      is_join_point :: Bool
is_join_point = (Var -> Bool) -> [Var] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Var -> Bool
isJoinId [Var]
bndrs

      -- Abstracting over join points/let-no-escapes spoils them.
      abstracts_join_ids :: Bool
abstracts_join_ids = (Var -> Bool) -> [Var] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Var -> Bool
isJoinId (DIdSet -> [Var]
dVarSetElems DIdSet
abs_ids)

      -- Abstracting over known local functions that aren't floated themselves
      -- turns a known, fast call into an unknown, slow call:
      --
      --    let f x = ...
      --        g y = ... f x ... -- this was a known call
      --    in g 4
      --
      -- After lifting @g@, but not @f@:
      --
      --    l_g f y = ... f y ... -- this is now an unknown call
      --    let f x = ...
      --    in l_g f 4
      --
      -- We can abuse the results of arity analysis for this:
      -- idArity f > 0 ==> known
      known_fun :: Var -> Bool
known_fun id :: Var
id = Var -> Int
idArity Var
id Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0
      abstracts_known_local_fun :: Bool
abstracts_known_local_fun
        = Bool -> Bool
not (DynFlags -> Bool
liftLamsKnown DynFlags
dflags) Bool -> Bool -> Bool
&& (Var -> Bool) -> [Var] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Var -> Bool
known_fun (DIdSet -> [Var]
dVarSetElems DIdSet
abs_ids)

      -- Number of arguments of a RHS in the current binding group if we decide
      -- to lift it
      n_args :: LlStgRhs -> Int
n_args
        = [NonVoid Var] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length
        ([NonVoid Var] -> Int)
-> (LlStgRhs -> [NonVoid Var]) -> LlStgRhs -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var] -> [NonVoid Var]
StgCmmClosure.nonVoidIds -- void parameters don't appear in Cmm
        ([Var] -> [NonVoid Var])
-> (LlStgRhs -> [Var]) -> LlStgRhs -> [NonVoid Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DIdSet -> [Var]
dVarSetElems DIdSet
abs_ids [Var] -> [Var] -> [Var]
forall a. [a] -> [a] -> [a]
++)
        ([Var] -> [Var]) -> (LlStgRhs -> [Var]) -> LlStgRhs -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LlStgRhs -> [Var]
rhsLambdaBndrs
      max_n_args :: Maybe Int
max_n_args
        | RecFlag -> Bool
isRec RecFlag
rec_flag = DynFlags -> Maybe Int
liftLamsRecArgs DynFlags
dflags
        | Bool
otherwise      = DynFlags -> Maybe Int
liftLamsNonRecArgs DynFlags
dflags
      -- We have 5 hardware registers on x86_64 to pass arguments in. Any excess
      -- args are passed on the stack, which means slow memory accesses
      args_spill_on_stack :: Bool
args_spill_on_stack
        | Just n :: Int
n <- Maybe Int
max_n_args = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ((LlStgRhs -> Int) -> [LlStgRhs] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LlStgRhs -> Int
n_args [LlStgRhs]
rhss) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n
        | Bool
otherwise = Bool
False

      -- We only perform the lift if allocations didn't increase.
      -- Note that @clo_growth@ will be 'infinity' if there was positive growth
      -- under a multi-shot lambda.
      -- Also, abstracting over LNEs is unacceptable. LNEs might return
      -- unlifted tuples, which idClosureFootprint can't cope with.
      inc_allocs :: Bool
inc_allocs = Bool
abstracts_join_ids Bool -> Bool -> Bool
|| IntWithInf
allocs IntWithInf -> IntWithInf -> Bool
forall a. Ord a => a -> a -> Bool
> 0
      allocs :: IntWithInf
allocs = IntWithInf
clo_growth IntWithInf -> IntWithInf -> IntWithInf
forall a. Num a => a -> a -> a
+ Int -> IntWithInf
mkIntWithInf (Int -> Int
forall a. Num a => a -> a
negate Int
closuresSize)
      -- We calculate and then add up the size of each binding's closure.
      -- GHC does not currently share closure environments, and we either lift
      -- the entire recursive binding group or none of it.
      closuresSize :: Int
closuresSize = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ ((LlStgRhs -> Int) -> [LlStgRhs] -> [Int])
-> [LlStgRhs] -> (LlStgRhs -> Int) -> [Int]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (LlStgRhs -> Int) -> [LlStgRhs] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [LlStgRhs]
rhss ((LlStgRhs -> Int) -> [Int]) -> (LlStgRhs -> Int) -> [Int]
forall a b. (a -> b) -> a -> b
$ \rhs :: LlStgRhs
rhs ->
        DynFlags -> [Var] -> Int
closureSize DynFlags
dflags
        ([Var] -> Int) -> (DIdSet -> [Var]) -> DIdSet -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DIdSet -> [Var]
dVarSetElems
        (DIdSet -> [Var]) -> (DIdSet -> DIdSet) -> DIdSet -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DIdSet -> DIdSet
expander
        (DIdSet -> DIdSet) -> (DIdSet -> DIdSet) -> DIdSet -> DIdSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DIdSet -> IdSet -> DIdSet) -> IdSet -> DIdSet -> DIdSet
forall a b c. (a -> b -> c) -> b -> a -> c
flip DIdSet -> IdSet -> DIdSet
dVarSetMinusVarSet IdSet
bndrs_set
        (DIdSet -> Int) -> DIdSet -> Int
forall a b. (a -> b) -> a -> b
$ LlStgRhs -> DIdSet
forall (pass :: StgPass).
(XRhsClosure pass ~ DIdSet) =>
GenStgRhs pass -> DIdSet
freeVarsOfRhs LlStgRhs
rhs
      clo_growth :: IntWithInf
clo_growth = (DIdSet -> DIdSet)
-> (Var -> Int) -> IdSet -> DIdSet -> Skeleton -> IntWithInf
closureGrowth DIdSet -> DIdSet
expander (DynFlags -> Var -> Int
idClosureFootprint DynFlags
dflags) IdSet
bndrs_set DIdSet
abs_ids Skeleton
scope

rhsLambdaBndrs :: LlStgRhs -> [Id]
rhsLambdaBndrs :: LlStgRhs -> [Var]
rhsLambdaBndrs StgRhsCon{} = []
rhsLambdaBndrs (StgRhsClosure _ _ _ bndrs :: [BinderP 'LiftLams]
bndrs _) = (BinderInfo -> Var) -> [BinderInfo] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Var
binderInfoBndr [BinderP 'LiftLams]
[BinderInfo]
bndrs

-- | The size in words of a function closure closing over the given 'Id's,
-- including the header.
closureSize :: DynFlags -> [Id] -> WordOff
closureSize :: DynFlags -> [Var] -> Int
closureSize dflags :: DynFlags
dflags ids :: [Var]
ids = Int
words
  where
    (words :: Int
words, _, _)
      -- Functions have a StdHeader (as opposed to ThunkHeader).
      -- Note that mkVirtHeadOffsets will account for profiling headers, so
      -- lifting decisions vary if we begin to profile stuff. Maybe we shouldn't
      -- do this or deactivate profiling in @dflags@?
      = DynFlags
-> ClosureHeader
-> [NonVoid (PrimRep, Var)]
-> (Int, Int, [(NonVoid Var, Int)])
forall a.
DynFlags
-> ClosureHeader
-> [NonVoid (PrimRep, a)]
-> (Int, Int, [(NonVoid a, Int)])
StgCmmLayout.mkVirtHeapOffsets DynFlags
dflags ClosureHeader
StgCmmLayout.StdHeader
      ([NonVoid (PrimRep, Var)] -> (Int, Int, [(NonVoid Var, Int)]))
-> ([Var] -> [NonVoid (PrimRep, Var)])
-> [Var]
-> (Int, Int, [(NonVoid Var, Int)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [NonVoid Var] -> [NonVoid (PrimRep, Var)]
StgCmmClosure.addIdReps
      ([NonVoid Var] -> [NonVoid (PrimRep, Var)])
-> ([Var] -> [NonVoid Var]) -> [Var] -> [NonVoid (PrimRep, Var)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var] -> [NonVoid Var]
StgCmmClosure.nonVoidIds
      ([Var] -> (Int, Int, [(NonVoid Var, Int)]))
-> [Var] -> (Int, Int, [(NonVoid Var, Int)])
forall a b. (a -> b) -> a -> b
$ [Var]
ids

-- | The number of words a single 'Id' adds to a closure's size.
-- Note that this can't handle unboxed tuples (which may still be present in
-- let-no-escapes, even after Unarise), in which case
-- @'StgCmmClosure.idPrimRep'@ will crash.
idClosureFootprint:: DynFlags -> Id -> WordOff
idClosureFootprint :: DynFlags -> Var -> Int
idClosureFootprint dflags :: DynFlags
dflags
  = DynFlags -> ArgRep -> Int
StgCmmArgRep.argRepSizeW DynFlags
dflags
  (ArgRep -> Int) -> (Var -> ArgRep) -> Var -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Var -> ArgRep
StgCmmArgRep.idArgRep

-- | @closureGrowth expander sizer f fvs@ computes the closure growth in words
-- as a result of lifting @f@ to top-level. If there was any growing closure
-- under a multi-shot lambda, the result will be 'infinity'.
-- Also see "StgLiftLams.Analysis#clogro".
closureGrowth
  :: (DIdSet -> DIdSet)
  -- ^ Expands outer free ids that were lifted to their free vars
  -> (Id -> Int)
  -- ^ Computes the closure footprint of an identifier
  -> IdSet
  -- ^ Binding group for which lifting is to be decided
  -> DIdSet
  -- ^ Free vars of the whole binding group prior to lifting it. These must be
  --   available at call sites if we decide to lift the binding group.
  -> Skeleton
  -- ^ Abstraction of the scope of the function
  -> IntWithInf
  -- ^ Closure growth. 'infinity' indicates there was growth under a
  --   (multi-shot) lambda.
closureGrowth :: (DIdSet -> DIdSet)
-> (Var -> Int) -> IdSet -> DIdSet -> Skeleton -> IntWithInf
closureGrowth expander :: DIdSet -> DIdSet
expander sizer :: Var -> Int
sizer group :: IdSet
group abs_ids :: DIdSet
abs_ids = Skeleton -> IntWithInf
go
  where
    go :: Skeleton -> IntWithInf
go NilSk = 0
    go (BothSk a :: Skeleton
a b :: Skeleton
b) = Skeleton -> IntWithInf
go Skeleton
a IntWithInf -> IntWithInf -> IntWithInf
forall a. Num a => a -> a -> a
+ Skeleton -> IntWithInf
go Skeleton
b
    go (AltSk a :: Skeleton
a b :: Skeleton
b) = IntWithInf -> IntWithInf -> IntWithInf
forall a. Ord a => a -> a -> a
max (Skeleton -> IntWithInf
go Skeleton
a) (Skeleton -> IntWithInf
go Skeleton
b)
    go (ClosureSk _ clo_fvs :: DIdSet
clo_fvs rhs :: Skeleton
rhs)
      -- If no binder of the @group@ occurs free in the closure, the lifting
      -- won't have any effect on it and we can omit the recursive call.
      | Int
n_occs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== 0 = 0
      -- Otherwise, we account the cost of allocating the closure and add it to
      -- the closure growth of its RHS.
      | Bool
otherwise   = Int -> IntWithInf
mkIntWithInf Int
cost IntWithInf -> IntWithInf -> IntWithInf
forall a. Num a => a -> a -> a
+ Skeleton -> IntWithInf
go Skeleton
rhs
      where
        n_occs :: Int
n_occs = DIdSet -> Int
sizeDVarSet (DIdSet
clo_fvs' DIdSet -> IdSet -> DIdSet
`dVarSetIntersectVarSet` IdSet
group)
        -- What we close over considering prior lifting decisions
        clo_fvs' :: DIdSet
clo_fvs' = DIdSet -> DIdSet
expander DIdSet
clo_fvs
        -- Variables that would additionally occur free in the closure body if
        -- we lift @f@
        newbies :: DIdSet
newbies = DIdSet
abs_ids DIdSet -> DIdSet -> DIdSet
`minusDVarSet` DIdSet
clo_fvs'
        -- Lifting @f@ removes @f@ from the closure but adds all @newbies@
        cost :: Int
cost = (Var -> Int -> Int) -> Int -> DIdSet -> Int
forall a. (Var -> a -> a) -> a -> DIdSet -> a
foldDVarSet (\id :: Var
id size :: Int
size -> Var -> Int
sizer Var
id Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
size) 0 DIdSet
newbies Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_occs
    go (RhsSk body_dmd :: DmdShell
body_dmd body :: Skeleton
body)
      -- The conservative assumption would be that
      --   1. Every RHS with positive growth would be called multiple times,
      --      modulo thunks.
      --   2. Every RHS with negative growth wouldn't be called at all.
      --
      -- In the first case, we'd have to return 'infinity', while in the
      -- second case, we'd have to return 0. But we can do far better
      -- considering information from the demand analyser, which provides us
      -- with conservative estimates on minimum and maximum evaluation
      -- cardinality. The @body_dmd@ part of 'RhsSk' is the result of
      -- 'rhsDmdShell' and accurately captures the cardinality of the RHSs body
      -- relative to its defining context.
      | DmdShell -> Bool
forall s u. JointDmd (Str s) (Use u) -> Bool
isAbsDmd DmdShell
body_dmd   = 0
      | IntWithInf
cg IntWithInf -> IntWithInf -> Bool
forall a. Ord a => a -> a -> Bool
<= 0             = if DmdShell -> Bool
forall s u. JointDmd (Str s) (Use u) -> Bool
isStrictDmd DmdShell
body_dmd then IntWithInf
cg else 0
      | DmdShell -> Bool
forall s u. JointDmd (Str s) (Use u) -> Bool
isUsedOnce DmdShell
body_dmd = IntWithInf
cg
      | Bool
otherwise           = IntWithInf
infinity
      where
        cg :: IntWithInf
cg = Skeleton -> IntWithInf
go Skeleton
body