{-# LANGUAGE CPP #-}
module GHC.Stg.Lift
(
stgLiftLams
)
where
#include "HsVersions.h"
import GHC.Prelude
import GHC.Types.Basic
import GHC.Driver.Session
import GHC.Types.Id
import GHC.Stg.FVs ( annBindingFreeVars )
import GHC.Stg.Lift.Analysis
import GHC.Stg.Lift.Monad
import GHC.Stg.Syntax
import GHC.Utils.Outputable
import GHC.Types.Unique.Supply
import GHC.Utils.Misc
import GHC.Utils.Panic
import GHC.Types.Var.Set
import Control.Monad ( when )
import Data.Maybe ( isNothing )
stgLiftLams :: DynFlags -> UniqSupply -> [InStgTopBinding] -> [OutStgTopBinding]
stgLiftLams :: DynFlags -> UniqSupply -> [InStgTopBinding] -> [InStgTopBinding]
stgLiftLams DynFlags
dflags UniqSupply
us = DynFlags -> UniqSupply -> LiftM () -> [InStgTopBinding]
runLiftM DynFlags
dflags UniqSupply
us forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr InStgTopBinding -> LiftM () -> LiftM ()
liftTopLvl (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
liftTopLvl :: InStgTopBinding -> LiftM () -> LiftM ()
liftTopLvl :: InStgTopBinding -> LiftM () -> LiftM ()
liftTopLvl (StgTopStringLit Id
bndr ByteString
lit) LiftM ()
rest = forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr Id
bndr forall a b. (a -> b) -> a -> b
$ \Id
bndr' -> do
Id -> ByteString -> LiftM ()
addTopStringLit Id
bndr' ByteString
lit
LiftM ()
rest
liftTopLvl (StgTopLifted GenStgBinding 'Vanilla
bind) LiftM ()
rest = do
let is_rec :: Bool
is_rec = RecFlag -> Bool
isRec forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding GenStgBinding 'Vanilla
bind
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
is_rec LiftM ()
startBindingGroup
let bind_w_fvs :: CgStgBinding
bind_w_fvs = GenStgBinding 'Vanilla -> CgStgBinding
annBindingFreeVars GenStgBinding 'Vanilla
bind
forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
TopLevel (CgStgBinding -> LlStgBinding
tagSkeletonTopBind CgStgBinding
bind_w_fvs) Skeleton
NilSk forall a b. (a -> b) -> a -> b
$ \Maybe (GenStgBinding 'Vanilla)
mb_bind' -> do
case Maybe (GenStgBinding 'Vanilla)
mb_bind' of
Maybe (GenStgBinding 'Vanilla)
Nothing -> forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"StgLiftLams" (String -> SDoc
text String
"Lifted top-level binding")
Just GenStgBinding 'Vanilla
bind' -> GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding GenStgBinding 'Vanilla
bind'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
is_rec LiftM ()
endBindingGroup
LiftM ()
rest
withLiftedBind
:: TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe OutStgBinding -> LiftM a)
-> LiftM a
withLiftedBind :: forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
top_lvl LlStgBinding
bind Skeleton
scope Maybe (GenStgBinding 'Vanilla) -> LiftM a
k
= forall a.
TopLevelFlag
-> RecFlag
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> (Maybe [(Id, OutStgRhs)] -> LiftM a)
-> LiftM a
withLiftedBindPairs TopLevelFlag
top_lvl RecFlag
rec [(BinderP 'LiftLams, LlStgRhs)]
pairs Skeleton
scope (Maybe (GenStgBinding 'Vanilla) -> LiftM a
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
rec))
where
(RecFlag
rec, [(BinderP 'LiftLams, LlStgRhs)]
pairs) = forall (pass :: StgPass).
GenStgBinding pass -> (RecFlag, [(BinderP pass, GenStgRhs pass)])
decomposeStgBinding LlStgBinding
bind
withLiftedBindPairs
:: TopLevelFlag
-> RecFlag
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> (Maybe [(Id, OutStgRhs)] -> LiftM a)
-> LiftM a
withLiftedBindPairs :: forall a.
TopLevelFlag
-> RecFlag
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> (Maybe [(Id, OutStgRhs)] -> LiftM a)
-> LiftM a
withLiftedBindPairs TopLevelFlag
top RecFlag
rec [(BinderInfo, LlStgRhs)]
pairs Skeleton
scope Maybe [(Id, OutStgRhs)] -> LiftM a
k = do
let ([BinderInfo]
infos, [LlStgRhs]
rhss) = forall a b. [(a, b)] -> ([a], [b])
unzip [(BinderInfo, LlStgRhs)]
pairs
let bndrs :: [Id]
bndrs = forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderInfo]
infos
DIdSet -> DIdSet
expander <- LiftM (DIdSet -> DIdSet)
liftedIdsExpander
DynFlags
dflags <- forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
case DynFlags
-> TopLevelFlag
-> RecFlag
-> (DIdSet -> DIdSet)
-> [(BinderInfo, LlStgRhs)]
-> Skeleton
-> Maybe DIdSet
goodToLift DynFlags
dflags TopLevelFlag
top RecFlag
rec DIdSet -> DIdSet
expander [(BinderInfo, LlStgRhs)]
pairs Skeleton
scope of
Just DIdSet
abs_ids -> forall (f :: * -> *) a.
Traversable f =>
DIdSet -> f Id -> (f Id -> LiftM a) -> LiftM a
withLiftedBndrs DIdSet
abs_ids [Id]
bndrs forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RecFlag -> Bool
isRec RecFlag
rec) LiftM ()
startBindingGroup
[OutStgRhs]
rhss' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Maybe DIdSet -> LlStgRhs -> LiftM OutStgRhs
liftRhs (forall a. a -> Maybe a
Just DIdSet
abs_ids)) [LlStgRhs]
rhss
let pairs' :: [(Id, OutStgRhs)]
pairs' = forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bndrs' [OutStgRhs]
rhss'
GenStgBinding 'Vanilla -> LiftM ()
addLiftedBinding (forall (pass :: StgPass).
RecFlag -> [(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
mkStgBinding RecFlag
rec [(Id, OutStgRhs)]
pairs')
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (RecFlag -> Bool
isRec RecFlag
rec) LiftM ()
endBindingGroup
Maybe [(Id, OutStgRhs)] -> LiftM a
k forall a. Maybe a
Nothing
Maybe DIdSet
Nothing -> forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs [Id]
bndrs forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' -> do
[OutStgRhs]
rhss' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Maybe DIdSet -> LlStgRhs -> LiftM OutStgRhs
liftRhs forall a. Maybe a
Nothing) [LlStgRhs]
rhss
let pairs' :: [(Id, OutStgRhs)]
pairs' = forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
bndrs' [OutStgRhs]
rhss'
Maybe [(Id, OutStgRhs)] -> LiftM a
k (forall a. a -> Maybe a
Just [(Id, OutStgRhs)]
pairs')
liftRhs
:: Maybe (DIdSet)
-> LlStgRhs
-> LiftM OutStgRhs
liftRhs :: Maybe DIdSet -> LlStgRhs -> LiftM OutStgRhs
liftRhs Maybe DIdSet
mb_former_fvs rhs :: LlStgRhs
rhs@(StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mn [StgTickish]
ts [StgArg]
args)
= ASSERT2(isNothing mb_former_fvs, text "Should never lift a constructor" $$ pprStgRhs panicStgPprOpts rhs)
forall (pass :: StgPass).
CostCentreStack
-> DataCon
-> ConstructorNumber
-> [StgTickish]
-> [StgArg]
-> GenStgRhs pass
StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
mn [StgTickish]
ts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args
liftRhs Maybe DIdSet
Nothing (StgRhsClosure XRhsClosure 'LiftLams
_ CostCentreStack
ccs UpdateFlag
upd [BinderP 'LiftLams]
infos GenStgExpr 'LiftLams
body) =
forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs (forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderP 'LiftLams]
infos) forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' ->
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure NoExtFieldSilent
noExtFieldSilent CostCentreStack
ccs UpdateFlag
upd [Id]
bndrs' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
body
liftRhs (Just DIdSet
former_fvs) (StgRhsClosure XRhsClosure 'LiftLams
_ CostCentreStack
ccs UpdateFlag
upd [BinderP 'LiftLams]
infos GenStgExpr 'LiftLams
body) =
forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs (forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderP 'LiftLams]
infos) forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' -> do
let bndrs'' :: [Id]
bndrs'' = DIdSet -> [Id]
dVarSetElems DIdSet
former_fvs forall a. [a] -> [a] -> [a]
++ [Id]
bndrs'
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> GenStgRhs pass
StgRhsClosure NoExtFieldSilent
noExtFieldSilent CostCentreStack
ccs UpdateFlag
upd [Id]
bndrs'' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
body
liftArgs :: InStgArg -> LiftM OutStgArg
liftArgs :: StgArg -> LiftM StgArg
liftArgs a :: StgArg
a@(StgLitArg Literal
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure StgArg
a
liftArgs (StgVarArg Id
occ) = do
ASSERTM2( not <$> isLifted occ, text "StgArgs should never be lifted" $$ ppr occ )
Id -> StgArg
StgVarArg forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> LiftM Id
substOcc Id
occ
liftExpr :: LlStgExpr -> LiftM OutStgExpr
liftExpr :: GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr (StgLit Literal
lit) = forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
liftExpr (StgTick StgTickish
t GenStgExpr 'LiftLams
e) = forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
t forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
e
liftExpr (StgApp Id
f [StgArg]
args) = do
Id
f' <- Id -> LiftM Id
substOcc Id
f
[StgArg]
args' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args
[Id]
fvs' <- Id -> LiftM [Id]
formerFreeVars Id
f
let top_lvl_args :: [StgArg]
top_lvl_args = forall a b. (a -> b) -> [a] -> [b]
map Id -> StgArg
StgVarArg [Id]
fvs' forall a. [a] -> [a] -> [a]
++ [StgArg]
args'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f' [StgArg]
top_lvl_args)
liftExpr (StgConApp DataCon
con XConApp 'LiftLams
mn [StgArg]
args [Type]
tys) = forall (pass :: StgPass).
DataCon -> XConApp pass -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
con XConApp 'LiftLams
mn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
tys
liftExpr (StgOpApp StgOp
op [StgArg]
args Type
ty) = forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse StgArg -> LiftM StgArg
liftArgs [StgArg]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
ty
liftExpr (StgCase GenStgExpr 'LiftLams
scrut BinderP 'LiftLams
info AltType
ty [GenStgAlt 'LiftLams]
alts) = do
OutStgExpr
scrut' <- GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
scrut
forall a. Id -> (Id -> LiftM a) -> LiftM a
withSubstBndr (BinderInfo -> Id
binderInfoBndr BinderP 'LiftLams
info) forall a b. (a -> b) -> a -> b
$ \Id
bndr' -> do
[(AltCon, [Id], OutStgExpr)]
alts' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse GenStgAlt 'LiftLams -> LiftM OutStgAlt
liftAlt [GenStgAlt 'LiftLams]
alts
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase OutStgExpr
scrut' Id
bndr' AltType
ty [(AltCon, [Id], OutStgExpr)]
alts')
liftExpr (StgLet XLet 'LiftLams
scope LlStgBinding
bind GenStgExpr 'LiftLams
body)
= forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
NotTopLevel LlStgBinding
bind XLet 'LiftLams
scope forall a b. (a -> b) -> a -> b
$ \Maybe (GenStgBinding 'Vanilla)
mb_bind' -> do
OutStgExpr
body' <- GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
body
case Maybe (GenStgBinding 'Vanilla)
mb_bind' of
Maybe (GenStgBinding 'Vanilla)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure OutStgExpr
body'
Just GenStgBinding 'Vanilla
bind' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (pass :: StgPass).
XLet pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLet NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bind' OutStgExpr
body')
liftExpr (StgLetNoEscape XLetNoEscape 'LiftLams
scope LlStgBinding
bind GenStgExpr 'LiftLams
body)
= forall a.
TopLevelFlag
-> LlStgBinding
-> Skeleton
-> (Maybe (GenStgBinding 'Vanilla) -> LiftM a)
-> LiftM a
withLiftedBind TopLevelFlag
NotTopLevel LlStgBinding
bind XLetNoEscape 'LiftLams
scope forall a b. (a -> b) -> a -> b
$ \Maybe (GenStgBinding 'Vanilla)
mb_bind' -> do
OutStgExpr
body' <- GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
body
case Maybe (GenStgBinding 'Vanilla)
mb_bind' of
Maybe (GenStgBinding 'Vanilla)
Nothing -> forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"stgLiftLams" (String -> SDoc
text String
"Should never decide to lift LNEs")
Just GenStgBinding 'Vanilla
bind' -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (pass :: StgPass).
XLetNoEscape pass
-> GenStgBinding pass -> GenStgExpr pass -> GenStgExpr pass
StgLetNoEscape NoExtFieldSilent
noExtFieldSilent GenStgBinding 'Vanilla
bind' OutStgExpr
body')
liftAlt :: LlStgAlt -> LiftM OutStgAlt
liftAlt :: GenStgAlt 'LiftLams -> LiftM OutStgAlt
liftAlt (AltCon
con, [BinderP 'LiftLams]
infos, GenStgExpr 'LiftLams
rhs) = forall (f :: * -> *) a.
Traversable f =>
f Id -> (f Id -> LiftM a) -> LiftM a
withSubstBndrs (forall a b. (a -> b) -> [a] -> [b]
map BinderInfo -> Id
binderInfoBndr [BinderP 'LiftLams]
infos) forall a b. (a -> b) -> a -> b
$ \[Id]
bndrs' ->
(,,) AltCon
con [Id]
bndrs' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'LiftLams -> LiftM OutStgExpr
liftExpr GenStgExpr 'LiftLams
rhs