-- | Free variable analysis on STG terms.
module StgFVs (
    annTopBindingsFreeVars,
    annBindingFreeVars
  ) where

import GhcPrelude

import StgSyn
import Id
import VarSet
import CoreSyn    ( Tickish(Breakpoint) )
import Outputable
import Util

import Data.Maybe ( mapMaybe )

newtype Env
  = Env
  { Env -> IdSet
locals :: IdSet
  }

emptyEnv :: Env
emptyEnv :: Env
emptyEnv = IdSet -> Env
Env IdSet
emptyVarSet

addLocals :: [Id] -> Env -> Env
addLocals :: [Id] -> Env -> Env
addLocals bndrs :: [Id]
bndrs env :: Env
env
  = Env
env { locals :: IdSet
locals = IdSet -> [Id] -> IdSet
extendVarSetList (Env -> IdSet
locals Env
env) [Id]
bndrs }

-- | Annotates a top-level STG binding group with its free variables.
annTopBindingsFreeVars :: [StgTopBinding] -> [CgStgTopBinding]
annTopBindingsFreeVars :: [StgTopBinding] -> [CgStgTopBinding]
annTopBindingsFreeVars = (StgTopBinding -> CgStgTopBinding)
-> [StgTopBinding] -> [CgStgTopBinding]
forall a b. (a -> b) -> [a] -> [b]
map StgTopBinding -> CgStgTopBinding
go
  where
    go :: StgTopBinding -> CgStgTopBinding
go (StgTopStringLit id :: Id
id bs :: ByteString
bs) = Id -> ByteString -> CgStgTopBinding
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
id ByteString
bs
    go (StgTopLifted bind :: GenStgBinding 'Vanilla
bind)
      = GenStgBinding 'CodeGen -> CgStgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted (GenStgBinding 'Vanilla -> GenStgBinding 'CodeGen
annBindingFreeVars GenStgBinding 'Vanilla
bind)

-- | Annotates an STG binding with its free variables.
annBindingFreeVars :: StgBinding -> CgStgBinding
annBindingFreeVars :: GenStgBinding 'Vanilla -> GenStgBinding 'CodeGen
annBindingFreeVars = (GenStgBinding 'CodeGen, DIdSet) -> GenStgBinding 'CodeGen
forall a b. (a, b) -> a
fst ((GenStgBinding 'CodeGen, DIdSet) -> GenStgBinding 'CodeGen)
-> (GenStgBinding 'Vanilla -> (GenStgBinding 'CodeGen, DIdSet))
-> GenStgBinding 'Vanilla
-> GenStgBinding 'CodeGen
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env
-> DIdSet
-> GenStgBinding 'Vanilla
-> (GenStgBinding 'CodeGen, DIdSet)
binding Env
emptyEnv DIdSet
emptyDVarSet

boundIds :: StgBinding -> [Id]
boundIds :: GenStgBinding 'Vanilla -> [Id]
boundIds (StgNonRec b :: BinderP 'Vanilla
b _) = [Id
BinderP 'Vanilla
b]
boundIds (StgRec pairs :: [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs)  = ((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs

-- Note [Tracking local binders]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- 'locals' contains non-toplevel, non-imported binders.
-- We maintain the set in 'expr', 'alt' and 'rhs', which are the only
-- places where new local binders are introduced.
-- Why do it there rather than in 'binding'? Two reasons:
--
--   1. We call 'binding' from 'annTopBindingsFreeVars', which would
--      add top-level bindings to the 'locals' set.
--   2. In the let(-no-escape) case, we need to extend the environment
--      prior to analysing the body, but we also need the fvs from the
--      body to analyse the RHSs. No way to do this without some
--      knot-tying.

-- | This makes sure that only local, non-global free vars make it into the set.
mkFreeVarSet :: Env -> [Id] -> DIdSet
mkFreeVarSet :: Env -> [Id] -> DIdSet
mkFreeVarSet env :: Env
env = [Id] -> DIdSet
mkDVarSet ([Id] -> DIdSet) -> ([Id] -> [Id]) -> [Id] -> DIdSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Id -> IdSet -> Bool
`elemVarSet` Env -> IdSet
locals Env
env)

args :: Env -> [StgArg] -> DIdSet
args :: Env -> [StgArg] -> DIdSet
args env :: Env
env = Env -> [Id] -> DIdSet
mkFreeVarSet Env
env ([Id] -> DIdSet) -> ([StgArg] -> [Id]) -> [StgArg] -> DIdSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StgArg -> Maybe Id) -> [StgArg] -> [Id]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe StgArg -> Maybe Id
f
  where
    f :: StgArg -> Maybe Id
f (StgVarArg occ :: Id
occ) = Id -> Maybe Id
forall a. a -> Maybe a
Just Id
occ
    f _               = Maybe Id
forall a. Maybe a
Nothing

binding :: Env -> DIdSet -> StgBinding -> (CgStgBinding, DIdSet)
binding :: Env
-> DIdSet
-> GenStgBinding 'Vanilla
-> (GenStgBinding 'CodeGen, DIdSet)
binding env :: Env
env body_fv :: DIdSet
body_fv (StgNonRec bndr :: BinderP 'Vanilla
bndr r :: GenStgRhs 'Vanilla
r) = (BinderP 'CodeGen -> GenStgRhs 'CodeGen -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
BinderP pass -> GenStgRhs pass -> GenStgBinding pass
StgNonRec BinderP 'Vanilla
BinderP 'CodeGen
bndr GenStgRhs 'CodeGen
r', DIdSet
fvs)
  where
    -- See Note [Tacking local binders]
    (r' :: GenStgRhs 'CodeGen
r', rhs_fvs :: DIdSet
rhs_fvs) = Env -> GenStgRhs 'Vanilla -> (GenStgRhs 'CodeGen, DIdSet)
rhs Env
env GenStgRhs 'Vanilla
r
    fvs :: DIdSet
fvs = DIdSet -> Id -> DIdSet
delDVarSet DIdSet
body_fv Id
BinderP 'Vanilla
bndr DIdSet -> DIdSet -> DIdSet
`unionDVarSet` DIdSet
rhs_fvs
binding env :: Env
env body_fv :: DIdSet
body_fv (StgRec pairs :: [(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs) = ([(BinderP 'CodeGen, GenStgRhs 'CodeGen)] -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec [(Id, GenStgRhs 'CodeGen)]
[(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
pairs', DIdSet
fvs)
  where
    -- See Note [Tacking local binders]
    bndrs :: [Id]
bndrs = ((Id, GenStgRhs 'Vanilla) -> Id)
-> [(Id, GenStgRhs 'Vanilla)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, GenStgRhs 'Vanilla) -> Id
forall a b. (a, b) -> a
fst [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs
    (rhss :: [GenStgRhs 'CodeGen]
rhss, rhs_fvss :: [DIdSet]
rhs_fvss) = ((Id, GenStgRhs 'Vanilla) -> (GenStgRhs 'CodeGen, DIdSet))
-> [(Id, GenStgRhs 'Vanilla)] -> ([GenStgRhs 'CodeGen], [DIdSet])
forall a b c. (a -> (b, c)) -> [a] -> ([b], [c])
mapAndUnzip (Env -> GenStgRhs 'Vanilla -> (GenStgRhs 'CodeGen, DIdSet)
rhs Env
env (GenStgRhs 'Vanilla -> (GenStgRhs 'CodeGen, DIdSet))
-> ((Id, GenStgRhs 'Vanilla) -> GenStgRhs 'Vanilla)
-> (Id, GenStgRhs 'Vanilla)
-> (GenStgRhs 'CodeGen, DIdSet)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, GenStgRhs 'Vanilla) -> GenStgRhs 'Vanilla
forall a b. (a, b) -> b
snd) [(Id, GenStgRhs 'Vanilla)]
[(BinderP 'Vanilla, GenStgRhs 'Vanilla)]
pairs
    pairs' :: [(Id, GenStgRhs 'CodeGen)]
pairs' = [Id] -> [GenStgRhs 'CodeGen] -> [(Id, GenStgRhs 'CodeGen)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bndrs [GenStgRhs 'CodeGen]
rhss
    fvs :: DIdSet
fvs = DIdSet -> [Id] -> DIdSet
delDVarSetList ([DIdSet] -> DIdSet
unionDVarSets (DIdSet
body_fvDIdSet -> [DIdSet] -> [DIdSet]
forall a. a -> [a] -> [a]
:[DIdSet]
rhs_fvss)) [Id]
bndrs

expr :: Env -> StgExpr -> (CgStgExpr, DIdSet)
expr :: Env -> StgExpr -> (CgStgExpr, DIdSet)
expr env :: Env
env = StgExpr -> (CgStgExpr, DIdSet)
go
  where
    go :: StgExpr -> (CgStgExpr, DIdSet)
go (StgApp occ :: Id
occ as :: [StgArg]
as)
      = (Id -> [StgArg] -> CgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
occ [StgArg]
as, DIdSet -> DIdSet -> DIdSet
unionDVarSet (Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as) (Env -> [Id] -> DIdSet
mkFreeVarSet Env
env [Id
occ]))
    go (StgLit lit :: Literal
lit) = (Literal -> CgStgExpr
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit, DIdSet
emptyDVarSet)
    go (StgConApp dc :: DataCon
dc as :: [StgArg]
as tys :: [Type]
tys) = (DataCon -> [StgArg] -> [Type] -> CgStgExpr
forall (pass :: StgPass).
DataCon -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
dc [StgArg]
as [Type]
tys, Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as)
    go (StgOpApp op :: StgOp
op as :: [StgArg]
as ty :: Type
ty) = (StgOp -> [StgArg] -> Type -> CgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
as Type
ty, Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as)
    go StgLam{} = String -> SDoc -> (CgStgExpr, DIdSet)
forall a. HasCallStack => String -> SDoc -> a
pprPanic "StgFVs: StgLam" SDoc
empty
    go (StgCase scrut :: StgExpr
scrut bndr :: BinderP 'Vanilla
bndr ty :: AltType
ty alts :: [GenStgAlt 'Vanilla]
alts) = (CgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> CgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase CgStgExpr
scrut' BinderP 'Vanilla
BinderP 'CodeGen
bndr AltType
ty [(AltCon, [Id], CgStgExpr)]
[GenStgAlt 'CodeGen]
alts', DIdSet
fvs)
      where
        (scrut' :: CgStgExpr
scrut', scrut_fvs :: DIdSet
scrut_fvs) = StgExpr -> (CgStgExpr, DIdSet)
go StgExpr
scrut
        -- See Note [Tacking local binders]
        (alts' :: [(AltCon, [Id], CgStgExpr)]
alts', alt_fvss :: [DIdSet]
alt_fvss) = ((AltCon, [Id], StgExpr) -> ((AltCon, [Id], CgStgExpr), DIdSet))
-> [(AltCon, [Id], StgExpr)]
-> ([(AltCon, [Id], CgStgExpr)], [DIdSet])
forall a b c. (a -> (b, c)) -> [a] -> ([b], [c])
mapAndUnzip (Env -> GenStgAlt 'Vanilla -> (GenStgAlt 'CodeGen, DIdSet)
alt ([Id] -> Env -> Env
addLocals [Id
BinderP 'Vanilla
bndr] Env
env)) [(AltCon, [Id], StgExpr)]
[GenStgAlt 'Vanilla]
alts
        alt_fvs :: DIdSet
alt_fvs = [DIdSet] -> DIdSet
unionDVarSets [DIdSet]
alt_fvss
        fvs :: DIdSet
fvs = DIdSet -> Id -> DIdSet
delDVarSet (DIdSet -> DIdSet -> DIdSet
unionDVarSet DIdSet
scrut_fvs DIdSet
alt_fvs) Id
BinderP 'Vanilla
bndr
    go (StgLet ext :: XLet 'Vanilla
ext bind :: GenStgBinding 'Vanilla
bind body :: StgExpr
body) = (GenStgBinding 'CodeGen -> CgStgExpr -> CgStgExpr)
-> GenStgBinding 'Vanilla -> StgExpr -> (CgStgExpr, DIdSet)
forall a.
(GenStgBinding 'CodeGen -> CgStgExpr -> a)
-> GenStgBinding 'Vanilla -> StgExpr -> (a, DIdSet)
go_bind (XLet 'CodeGen -> GenStgBinding 'CodeGen -> CgStgExpr -> CgStgExpr
forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet XLet 'Vanilla
XLet 'CodeGen
ext) GenStgBinding 'Vanilla
bind StgExpr
body
    go (StgLetNoEscape ext :: XLetNoEscape 'Vanilla
ext bind :: GenStgBinding 'Vanilla
bind body :: StgExpr
body) = (GenStgBinding 'CodeGen -> CgStgExpr -> CgStgExpr)
-> GenStgBinding 'Vanilla -> StgExpr -> (CgStgExpr, DIdSet)
forall a.
(GenStgBinding 'CodeGen -> CgStgExpr -> a)
-> GenStgBinding 'Vanilla -> StgExpr -> (a, DIdSet)
go_bind (XLetNoEscape 'CodeGen
-> GenStgBinding 'CodeGen -> CgStgExpr -> CgStgExpr
forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape XLetNoEscape 'Vanilla
XLetNoEscape 'CodeGen
ext) GenStgBinding 'Vanilla
bind StgExpr
body
    go (StgTick tick :: Tickish Id
tick e :: StgExpr
e) = (Tickish Id -> CgStgExpr -> CgStgExpr
forall (pass :: StgPass).
Tickish Id -> GenStgExpr pass -> GenStgExpr pass
StgTick Tickish Id
tick CgStgExpr
e', DIdSet
fvs')
      where
        (e' :: CgStgExpr
e', fvs :: DIdSet
fvs) = StgExpr -> (CgStgExpr, DIdSet)
go StgExpr
e
        fvs' :: DIdSet
fvs' = DIdSet -> DIdSet -> DIdSet
unionDVarSet (Tickish Id -> DIdSet
tickish Tickish Id
tick) DIdSet
fvs
        tickish :: Tickish Id -> DIdSet
tickish (Breakpoint _ ids :: [Id]
ids) = [Id] -> DIdSet
mkDVarSet [Id]
ids
        tickish _                  = DIdSet
emptyDVarSet

    go_bind :: (GenStgBinding 'CodeGen -> CgStgExpr -> a)
-> GenStgBinding 'Vanilla -> StgExpr -> (a, DIdSet)
go_bind dc :: GenStgBinding 'CodeGen -> CgStgExpr -> a
dc bind :: GenStgBinding 'Vanilla
bind body :: StgExpr
body = (GenStgBinding 'CodeGen -> CgStgExpr -> a
dc GenStgBinding 'CodeGen
bind' CgStgExpr
body', DIdSet
fvs)
      where
        -- See Note [Tacking local binders]
        env' :: Env
env' = [Id] -> Env -> Env
addLocals (GenStgBinding 'Vanilla -> [Id]
boundIds GenStgBinding 'Vanilla
bind) Env
env
        (body' :: CgStgExpr
body', body_fvs :: DIdSet
body_fvs) = Env -> StgExpr -> (CgStgExpr, DIdSet)
expr Env
env' StgExpr
body
        (bind' :: GenStgBinding 'CodeGen
bind', fvs :: DIdSet
fvs) = Env
-> DIdSet
-> GenStgBinding 'Vanilla
-> (GenStgBinding 'CodeGen, DIdSet)
binding Env
env' DIdSet
body_fvs GenStgBinding 'Vanilla
bind

rhs :: Env -> StgRhs -> (CgStgRhs, DIdSet)
rhs :: Env -> GenStgRhs 'Vanilla -> (GenStgRhs 'CodeGen, DIdSet)
rhs env :: Env
env (StgRhsClosure _ ccs :: CostCentreStack
ccs uf :: UpdateFlag
uf bndrs :: [BinderP 'Vanilla]
bndrs body :: StgExpr
body)
  = (XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> CgStgExpr
-> GenStgRhs 'CodeGen
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure DIdSet
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
uf [BinderP 'Vanilla]
[BinderP 'CodeGen]
bndrs CgStgExpr
body', DIdSet
fvs)
  where
    -- See Note [Tacking local binders]
    (body' :: CgStgExpr
body', body_fvs :: DIdSet
body_fvs) = Env -> StgExpr -> (CgStgExpr, DIdSet)
expr ([Id] -> Env -> Env
addLocals [Id]
[BinderP 'Vanilla]
bndrs Env
env) StgExpr
body
    fvs :: DIdSet
fvs = DIdSet -> [Id] -> DIdSet
delDVarSetList DIdSet
body_fvs [Id]
[BinderP 'Vanilla]
bndrs
rhs env :: Env
env (StgRhsCon ccs :: CostCentreStack
ccs dc :: DataCon
dc as :: [StgArg]
as) = (CostCentreStack -> DataCon -> [StgArg] -> GenStgRhs 'CodeGen
forall (pass :: StgPass).
CostCentreStack -> DataCon -> [StgArg] -> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
dc [StgArg]
as, Env -> [StgArg] -> DIdSet
args Env
env [StgArg]
as)

alt :: Env -> StgAlt -> (CgStgAlt, DIdSet)
alt :: Env -> GenStgAlt 'Vanilla -> (GenStgAlt 'CodeGen, DIdSet)
alt env :: Env
env (con :: AltCon
con, bndrs :: [BinderP 'Vanilla]
bndrs, e :: StgExpr
e) = ((AltCon
con, [BinderP 'Vanilla]
[BinderP 'CodeGen]
bndrs, CgStgExpr
e'), DIdSet
fvs)
  where
    -- See Note [Tacking local binders]
    (e' :: CgStgExpr
e', rhs_fvs :: DIdSet
rhs_fvs) = Env -> StgExpr -> (CgStgExpr, DIdSet)
expr ([Id] -> Env -> Env
addLocals [Id]
[BinderP 'Vanilla]
bndrs Env
env) StgExpr
e
    fvs :: DIdSet
fvs = DIdSet -> [Id] -> DIdSet
delDVarSetList DIdSet
rhs_fvs [Id]
[BinderP 'Vanilla]
bndrs