module GHC.Core.Opt.LiberateCase ( liberateCase ) where
import GHC.Prelude
import GHC.Driver.Session
import GHC.Core
import GHC.Core.Unfold
import GHC.Builtin.Types ( unitDataConId )
import GHC.Types.Id
import GHC.Types.Var.Env
import GHC.Utils.Misc ( notNull )
liberateCase :: DynFlags -> CoreProgram -> CoreProgram
liberateCase :: DynFlags -> CoreProgram -> CoreProgram
liberateCase DynFlags
dflags CoreProgram
binds = LibCaseEnv -> CoreProgram -> CoreProgram
do_prog (DynFlags -> LibCaseEnv
initLiberateCaseEnv DynFlags
dflags) CoreProgram
binds
where
do_prog :: LibCaseEnv -> CoreProgram -> CoreProgram
do_prog LibCaseEnv
_ [] = []
do_prog LibCaseEnv
env (CoreBind
bind:CoreProgram
binds) = CoreBind
bind' CoreBind -> CoreProgram -> CoreProgram
forall a. a -> [a] -> [a]
: LibCaseEnv -> CoreProgram -> CoreProgram
do_prog LibCaseEnv
env' CoreProgram
binds
where
(LibCaseEnv
env', CoreBind
bind') = LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env CoreBind
bind
initLiberateCaseEnv :: DynFlags -> LibCaseEnv
initLiberateCaseEnv :: DynFlags -> LibCaseEnv
initLiberateCaseEnv DynFlags
dflags = LibCaseEnv
{ lc_threshold :: Maybe LibCaseLevel
lc_threshold = DynFlags -> Maybe LibCaseLevel
liberateCaseThreshold DynFlags
dflags
, lc_uf_opts :: UnfoldingOpts
lc_uf_opts = DynFlags -> UnfoldingOpts
unfoldingOpts DynFlags
dflags
, lc_lvl :: LibCaseLevel
lc_lvl = LibCaseLevel
0
, lc_lvl_env :: IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
forall a. VarEnv a
emptyVarEnv
, lc_rec_env :: IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
forall a. VarEnv a
emptyVarEnv
, lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts = []
}
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env (NonRec Id
binder Expr Id
rhs)
= (LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id
binder], Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
rhs))
libCaseBind LibCaseEnv
env (Rec [(Id, Expr Id)]
pairs)
= (LibCaseEnv
env_body, [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
pairs')
where
binders :: [Id]
binders = ((Id, Expr Id) -> Id) -> [(Id, Expr Id)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, Expr Id) -> Id
forall a b. (a, b) -> a
fst [(Id, Expr Id)]
pairs
env_body :: LibCaseEnv
env_body = LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id]
binders
pairs' :: [(Id, Expr Id)]
pairs' = [(Id
binder, LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env_rhs Expr Id
rhs) | (Id
binder,Expr Id
rhs) <- [(Id, Expr Id)]
pairs]
env_rhs :: LibCaseEnv
env_rhs | Bool
is_dupable_bind = LibCaseEnv -> [(Id, Expr Id)] -> LibCaseEnv
addRecBinds LibCaseEnv
env [(Id, Expr Id)]
dup_pairs
| Bool
otherwise = LibCaseEnv
env
dup_pairs :: [(Id, Expr Id)]
dup_pairs = [ (Id -> Id
localiseId Id
binder, LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env_body Expr Id
rhs)
| (Id
binder, Expr Id
rhs) <- [(Id, Expr Id)]
pairs ]
is_dupable_bind :: Bool
is_dupable_bind = Bool
small_enough Bool -> Bool -> Bool
&& ((Id, Expr Id) -> Bool) -> [(Id, Expr Id)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Id, Expr Id) -> Bool
forall {b}. (Id, b) -> Bool
ok_pair [(Id, Expr Id)]
pairs
small_enough :: Bool
small_enough = case LibCaseEnv -> Maybe LibCaseLevel
lc_threshold LibCaseEnv
env of
Maybe LibCaseLevel
Nothing -> Bool
True
Just LibCaseLevel
size -> UnfoldingOpts -> LibCaseLevel -> Expr Id -> Bool
couldBeSmallEnoughToInline (LibCaseEnv -> UnfoldingOpts
lc_uf_opts LibCaseEnv
env) LibCaseLevel
size (Expr Id -> Bool) -> Expr Id -> Bool
forall a b. (a -> b) -> a -> b
$
CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
dup_pairs) (Id -> Expr Id
forall b. Id -> Expr b
Var Id
unitDataConId)
ok_pair :: (Id, b) -> Bool
ok_pair (Id
id,b
_)
= Id -> LibCaseLevel
idArity Id
id LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
> LibCaseLevel
0
Bool -> Bool -> Bool
&& Bool -> Bool
not (Id -> Bool
isDeadEndId Id
id)
libCase :: LibCaseEnv
-> CoreExpr
-> CoreExpr
libCase :: LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env (Var Id
v) = LibCaseEnv -> Id -> [Expr Id] -> Expr Id
libCaseApp LibCaseEnv
env Id
v []
libCase LibCaseEnv
_ (Lit Literal
lit) = Literal -> Expr Id
forall b. Literal -> Expr b
Lit Literal
lit
libCase LibCaseEnv
_ (Type Type
ty) = Type -> Expr Id
forall b. Type -> Expr b
Type Type
ty
libCase LibCaseEnv
_ (Coercion Coercion
co) = Coercion -> Expr Id
forall b. Coercion -> Expr b
Coercion Coercion
co
libCase LibCaseEnv
env e :: Expr Id
e@(App {}) | let (Expr Id
fun, [Expr Id]
args) = Expr Id -> (Expr Id, [Expr Id])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs Expr Id
e
, Var Id
v <- Expr Id
fun
= LibCaseEnv -> Id -> [Expr Id] -> Expr Id
libCaseApp LibCaseEnv
env Id
v [Expr Id]
args
libCase LibCaseEnv
env (App Expr Id
fun Expr Id
arg) = Expr Id -> Expr Id -> Expr Id
forall b. Expr b -> Expr b -> Expr b
App (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
fun) (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
arg)
libCase LibCaseEnv
env (Tick CoreTickish
tickish Expr Id
body) = CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
body)
libCase LibCaseEnv
env (Cast Expr Id
e Coercion
co) = Expr Id -> Coercion -> Expr Id
forall b. Expr b -> Coercion -> Expr b
Cast (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
e) Coercion
co
libCase LibCaseEnv
env (Lam Id
binder Expr Id
body)
= Id -> Expr Id -> Expr Id
forall b. b -> Expr b -> Expr b
Lam Id
binder (LibCaseEnv -> Expr Id -> Expr Id
libCase (LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id
binder]) Expr Id
body)
libCase LibCaseEnv
env (Let CoreBind
bind Expr Id
body)
= CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
bind' (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env_body Expr Id
body)
where
(LibCaseEnv
env_body, CoreBind
bind') = LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env CoreBind
bind
libCase LibCaseEnv
env (Case Expr Id
scrut Id
bndr Type
ty [Alt Id]
alts)
= Expr Id -> Id -> Type -> [Alt Id] -> Expr Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env Expr Id
scrut) Id
bndr Type
ty ((Alt Id -> Alt Id) -> [Alt Id] -> [Alt Id]
forall a b. (a -> b) -> [a] -> [b]
map (LibCaseEnv -> Alt Id -> Alt Id
libCaseAlt LibCaseEnv
env_alts) [Alt Id]
alts)
where
env_alts :: LibCaseEnv
env_alts = LibCaseEnv -> [Id] -> LibCaseEnv
addBinders (Expr Id -> LibCaseEnv
mk_alt_env Expr Id
scrut) [Id
bndr]
mk_alt_env :: Expr Id -> LibCaseEnv
mk_alt_env (Var Id
scrut_var) = LibCaseEnv -> Id -> LibCaseEnv
addScrutedVar LibCaseEnv
env Id
scrut_var
mk_alt_env (Cast Expr Id
scrut Coercion
_) = Expr Id -> LibCaseEnv
mk_alt_env Expr Id
scrut
mk_alt_env Expr Id
_ = LibCaseEnv
env
libCaseAlt :: LibCaseEnv -> Alt CoreBndr -> Alt CoreBndr
libCaseAlt :: LibCaseEnv -> Alt Id -> Alt Id
libCaseAlt LibCaseEnv
env (Alt AltCon
con [Id]
args Expr Id
rhs) = AltCon -> [Id] -> Expr Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
con [Id]
args (LibCaseEnv -> Expr Id -> Expr Id
libCase (LibCaseEnv -> [Id] -> LibCaseEnv
addBinders LibCaseEnv
env [Id]
args) Expr Id
rhs)
libCaseApp :: LibCaseEnv -> Id -> [CoreExpr] -> CoreExpr
libCaseApp :: LibCaseEnv -> Id -> [Expr Id] -> Expr Id
libCaseApp LibCaseEnv
env Id
v [Expr Id]
args
| Just CoreBind
the_bind <- LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId LibCaseEnv
env Id
v
, [Id] -> Bool
forall (f :: * -> *) a. Foldable f => f a -> Bool
notNull [Id]
free_scruts
= CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
the_bind Expr Id
expr'
| Bool
otherwise
= Expr Id
expr'
where
rec_id_level :: LibCaseLevel
rec_id_level = LibCaseEnv -> Id -> LibCaseLevel
lookupLevel LibCaseEnv
env Id
v
free_scruts :: [Id]
free_scruts = LibCaseEnv -> LibCaseLevel -> [Id]
freeScruts LibCaseEnv
env LibCaseLevel
rec_id_level
expr' :: Expr Id
expr' = Expr Id -> [Expr Id] -> Expr Id
forall b. Expr b -> [Expr b] -> Expr b
mkApps (Id -> Expr Id
forall b. Id -> Expr b
Var Id
v) ((Expr Id -> Expr Id) -> [Expr Id] -> [Expr Id]
forall a b. (a -> b) -> [a] -> [b]
map (LibCaseEnv -> Expr Id -> Expr Id
libCase LibCaseEnv
env) [Expr Id]
args)
freeScruts :: LibCaseEnv
-> LibCaseLevel
-> [Id]
freeScruts :: LibCaseEnv -> LibCaseLevel -> [Id]
freeScruts LibCaseEnv
env LibCaseLevel
rec_bind_lvl
= [Id
v | (Id
v, LibCaseLevel
scrut_bind_lvl, LibCaseLevel
scrut_at_lvl) <- LibCaseEnv -> [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts LibCaseEnv
env
, LibCaseLevel
scrut_bind_lvl LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
<= LibCaseLevel
rec_bind_lvl
, LibCaseLevel
scrut_at_lvl LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
> LibCaseLevel
rec_bind_lvl]
addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders :: LibCaseEnv -> [Id] -> LibCaseEnv
addBinders env :: LibCaseEnv
env@(LibCaseEnv { lc_lvl :: LibCaseEnv -> LibCaseLevel
lc_lvl = LibCaseLevel
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env }) [Id]
binders
= LibCaseEnv
env { lc_lvl_env :: IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env' }
where
lvl_env' :: IdEnv LibCaseLevel
lvl_env' = IdEnv LibCaseLevel -> [(Id, LibCaseLevel)] -> IdEnv LibCaseLevel
forall a. VarEnv a -> [(Id, a)] -> VarEnv a
extendVarEnvList IdEnv LibCaseLevel
lvl_env ([Id]
binders [Id] -> [LibCaseLevel] -> [(Id, LibCaseLevel)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` LibCaseLevel -> [LibCaseLevel]
forall a. a -> [a]
repeat LibCaseLevel
lvl)
addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
addRecBinds :: LibCaseEnv -> [(Id, Expr Id)] -> LibCaseEnv
addRecBinds env :: LibCaseEnv
env@(LibCaseEnv {lc_lvl :: LibCaseEnv -> LibCaseLevel
lc_lvl = LibCaseLevel
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env,
lc_rec_env :: LibCaseEnv -> IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
rec_env}) [(Id, Expr Id)]
pairs
= LibCaseEnv
env { lc_lvl :: LibCaseLevel
lc_lvl = LibCaseLevel
lvl', lc_lvl_env :: IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env', lc_rec_env :: IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
rec_env' }
where
lvl' :: LibCaseLevel
lvl' = LibCaseLevel
lvl LibCaseLevel -> LibCaseLevel -> LibCaseLevel
forall a. Num a => a -> a -> a
+ LibCaseLevel
1
lvl_env' :: IdEnv LibCaseLevel
lvl_env' = IdEnv LibCaseLevel -> [(Id, LibCaseLevel)] -> IdEnv LibCaseLevel
forall a. VarEnv a -> [(Id, a)] -> VarEnv a
extendVarEnvList IdEnv LibCaseLevel
lvl_env [(Id
binder,LibCaseLevel
lvl) | (Id
binder,Expr Id
_) <- [(Id, Expr Id)]
pairs]
rec_env' :: IdEnv CoreBind
rec_env' = IdEnv CoreBind -> [(Id, CoreBind)] -> IdEnv CoreBind
forall a. VarEnv a -> [(Id, a)] -> VarEnv a
extendVarEnvList IdEnv CoreBind
rec_env [(Id
binder, [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
pairs) | (Id
binder,Expr Id
_) <- [(Id, Expr Id)]
pairs]
addScrutedVar :: LibCaseEnv
-> Id
-> LibCaseEnv
addScrutedVar :: LibCaseEnv -> Id -> LibCaseEnv
addScrutedVar env :: LibCaseEnv
env@(LibCaseEnv { lc_lvl :: LibCaseEnv -> LibCaseLevel
lc_lvl = LibCaseLevel
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env = IdEnv LibCaseLevel
lvl_env,
lc_scruts :: LibCaseEnv -> [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts = [(Id, LibCaseLevel, LibCaseLevel)]
scruts }) Id
scrut_var
| LibCaseLevel
bind_lvl LibCaseLevel -> LibCaseLevel -> Bool
forall a. Ord a => a -> a -> Bool
< LibCaseLevel
lvl
= LibCaseEnv
env { lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts = [(Id, LibCaseLevel, LibCaseLevel)]
scruts' }
| Bool
otherwise = LibCaseEnv
env
where
scruts' :: [(Id, LibCaseLevel, LibCaseLevel)]
scruts' = (Id
scrut_var, LibCaseLevel
bind_lvl, LibCaseLevel
lvl) (Id, LibCaseLevel, LibCaseLevel)
-> [(Id, LibCaseLevel, LibCaseLevel)]
-> [(Id, LibCaseLevel, LibCaseLevel)]
forall a. a -> [a] -> [a]
: [(Id, LibCaseLevel, LibCaseLevel)]
scruts
bind_lvl :: LibCaseLevel
bind_lvl = case IdEnv LibCaseLevel -> Id -> Maybe LibCaseLevel
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv IdEnv LibCaseLevel
lvl_env Id
scrut_var of
Just LibCaseLevel
lvl -> LibCaseLevel
lvl
Maybe LibCaseLevel
Nothing -> LibCaseLevel
topLevel
lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId LibCaseEnv
env Id
id = IdEnv CoreBind -> Id -> Maybe CoreBind
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv (LibCaseEnv -> IdEnv CoreBind
lc_rec_env LibCaseEnv
env) Id
id
lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
lookupLevel LibCaseEnv
env Id
id
= case IdEnv LibCaseLevel -> Id -> Maybe LibCaseLevel
forall a. VarEnv a -> Id -> Maybe a
lookupVarEnv (LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env LibCaseEnv
env) Id
id of
Just LibCaseLevel
lvl -> LibCaseLevel
lvl
Maybe LibCaseLevel
Nothing -> LibCaseLevel
topLevel
type LibCaseLevel = Int
topLevel :: LibCaseLevel
topLevel :: LibCaseLevel
topLevel = LibCaseLevel
0
data LibCaseEnv
= LibCaseEnv {
LibCaseEnv -> Maybe LibCaseLevel
lc_threshold :: Maybe Int,
LibCaseEnv -> UnfoldingOpts
lc_uf_opts :: UnfoldingOpts,
LibCaseEnv -> LibCaseLevel
lc_lvl :: LibCaseLevel,
LibCaseEnv -> IdEnv LibCaseLevel
lc_lvl_env :: IdEnv LibCaseLevel,
LibCaseEnv -> IdEnv CoreBind
lc_rec_env :: IdEnv CoreBind,
LibCaseEnv -> [(Id, LibCaseLevel, LibCaseLevel)]
lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
}