{-# LANGUAGE CPP #-}
module LiberateCase ( liberateCase ) where
#include "HsVersions.h"
import GhcPrelude
import DynFlags
import CoreSyn
import CoreUnfold ( couldBeSmallEnoughToInline )
import TysWiredIn ( unitDataConId )
import Id
import VarEnv
import Util ( notNull )
liberateCase :: DynFlags -> CoreProgram -> CoreProgram
liberateCase :: DynFlags -> CoreProgram -> CoreProgram
liberateCase dflags :: DynFlags
dflags binds :: CoreProgram
binds = LibCaseEnv -> CoreProgram -> CoreProgram
do_prog (DynFlags -> LibCaseEnv
initEnv DynFlags
dflags) CoreProgram
binds
where
do_prog :: LibCaseEnv -> CoreProgram -> CoreProgram
do_prog _ [] = []
do_prog env :: LibCaseEnv
env (bind :: CoreBind
bind:binds :: CoreProgram
binds) = CoreBind
bind' CoreBind -> CoreProgram -> CoreProgram
forall a. a -> [a] -> [a]
: LibCaseEnv -> CoreProgram -> CoreProgram
do_prog LibCaseEnv
env' CoreProgram
binds
where
(env' :: LibCaseEnv
env', bind' :: CoreBind
bind') = LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env CoreBind
bind
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind env :: LibCaseEnv
env (NonRec binder :: CoreBndr
binder rhs :: Expr CoreBndr
rhs)
= (LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders LibCaseEnv
env [CoreBndr
binder], CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
binder (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env Expr CoreBndr
rhs))
libCaseBind env :: LibCaseEnv
env (Rec pairs :: [(CoreBndr, Expr CoreBndr)]
pairs)
= (LibCaseEnv
env_body, [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
pairs')
where
binders :: [CoreBndr]
binders = ((CoreBndr, Expr CoreBndr) -> CoreBndr)
-> [(CoreBndr, Expr CoreBndr)] -> [CoreBndr]
forall a b. (a -> b) -> [a] -> [b]
map (CoreBndr, Expr CoreBndr) -> CoreBndr
forall a b. (a, b) -> a
fst [(CoreBndr, Expr CoreBndr)]
pairs
env_body :: LibCaseEnv
env_body = LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders LibCaseEnv
env [CoreBndr]
binders
pairs' :: [(CoreBndr, Expr CoreBndr)]
pairs' = [(CoreBndr
binder, LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env_rhs Expr CoreBndr
rhs) | (binder :: CoreBndr
binder,rhs :: Expr CoreBndr
rhs) <- [(CoreBndr, Expr CoreBndr)]
pairs]
env_rhs :: LibCaseEnv
env_rhs | Bool
is_dupable_bind = LibCaseEnv -> [(CoreBndr, Expr CoreBndr)] -> LibCaseEnv
addRecBinds LibCaseEnv
env [(CoreBndr, Expr CoreBndr)]
dup_pairs
| Bool
otherwise = LibCaseEnv
env
dup_pairs :: [(CoreBndr, Expr CoreBndr)]
dup_pairs = [ (CoreBndr -> CoreBndr
localiseId CoreBndr
binder, LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env_body Expr CoreBndr
rhs)
| (binder :: CoreBndr
binder, rhs :: Expr CoreBndr
rhs) <- [(CoreBndr, Expr CoreBndr)]
pairs ]
is_dupable_bind :: Bool
is_dupable_bind = Bool
small_enough Bool -> Bool -> Bool
&& ((CoreBndr, Expr CoreBndr) -> Bool)
-> [(CoreBndr, Expr CoreBndr)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (CoreBndr, Expr CoreBndr) -> Bool
forall b. (CoreBndr, b) -> Bool
ok_pair [(CoreBndr, Expr CoreBndr)]
pairs
small_enough :: Bool
small_enough = case LibCaseEnv -> Maybe Int
bombOutSize LibCaseEnv
env of
Nothing -> Bool
True
Just size :: Int
size -> DynFlags -> Int -> Expr CoreBndr -> Bool
couldBeSmallEnoughToInline (LibCaseEnv -> DynFlags
lc_dflags LibCaseEnv
env) Int
size (Expr CoreBndr -> Bool) -> Expr CoreBndr -> Bool
forall a b. (a -> b) -> a -> b
$
CoreBind -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let ([(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
dup_pairs) (CoreBndr -> Expr CoreBndr
forall b. CoreBndr -> Expr b
Var CoreBndr
unitDataConId)
ok_pair :: (CoreBndr, b) -> Bool
ok_pair (id :: CoreBndr
id,_)
= CoreBndr -> Int
idArity CoreBndr
id Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> 0
Bool -> Bool -> Bool
&& Bool -> Bool
not (CoreBndr -> Bool
isBottomingId CoreBndr
id)
libCase :: LibCaseEnv
-> CoreExpr
-> CoreExpr
libCase :: LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase env :: LibCaseEnv
env (Var v :: CoreBndr
v) = LibCaseEnv -> CoreBndr -> [Expr CoreBndr] -> Expr CoreBndr
libCaseApp LibCaseEnv
env CoreBndr
v []
libCase _ (Lit lit :: Literal
lit) = Literal -> Expr CoreBndr
forall b. Literal -> Expr b
Lit Literal
lit
libCase _ (Type ty :: Type
ty) = Type -> Expr CoreBndr
forall b. Type -> Expr b
Type Type
ty
libCase _ (Coercion co :: Coercion
co) = Coercion -> Expr CoreBndr
forall b. Coercion -> Expr b
Coercion Coercion
co
libCase env :: LibCaseEnv
env e :: Expr CoreBndr
e@(App {}) | let (fun :: Expr CoreBndr
fun, args :: [Expr CoreBndr]
args) = Expr CoreBndr -> (Expr CoreBndr, [Expr CoreBndr])
forall b. Expr b -> (Expr b, [Expr b])
collectArgs Expr CoreBndr
e
, Var v :: CoreBndr
v <- Expr CoreBndr
fun
= LibCaseEnv -> CoreBndr -> [Expr CoreBndr] -> Expr CoreBndr
libCaseApp LibCaseEnv
env CoreBndr
v [Expr CoreBndr]
args
libCase env :: LibCaseEnv
env (App fun :: Expr CoreBndr
fun arg :: Expr CoreBndr
arg) = Expr CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Expr b -> Expr b -> Expr b
App (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env Expr CoreBndr
fun) (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env Expr CoreBndr
arg)
libCase env :: LibCaseEnv
env (Tick tickish :: Tickish CoreBndr
tickish body :: Expr CoreBndr
body) = Tickish CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. Tickish CoreBndr -> Expr b -> Expr b
Tick Tickish CoreBndr
tickish (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env Expr CoreBndr
body)
libCase env :: LibCaseEnv
env (Cast e :: Expr CoreBndr
e co :: Coercion
co) = Expr CoreBndr -> Coercion -> Expr CoreBndr
forall b. Expr b -> Coercion -> Expr b
Cast (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env Expr CoreBndr
e) Coercion
co
libCase env :: LibCaseEnv
env (Lam binder :: CoreBndr
binder body :: Expr CoreBndr
body)
= CoreBndr -> Expr CoreBndr -> Expr CoreBndr
forall b. b -> Expr b -> Expr b
Lam CoreBndr
binder (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase (LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders LibCaseEnv
env [CoreBndr
binder]) Expr CoreBndr
body)
libCase env :: LibCaseEnv
env (Let bind :: CoreBind
bind body :: Expr CoreBndr
body)
= CoreBind -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
bind' (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env_body Expr CoreBndr
body)
where
(env_body :: LibCaseEnv
env_body, bind' :: CoreBind
bind') = LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
libCaseBind LibCaseEnv
env CoreBind
bind
libCase env :: LibCaseEnv
env (Case scrut :: Expr CoreBndr
scrut bndr :: CoreBndr
bndr ty :: Type
ty alts :: [Alt CoreBndr]
alts)
= Expr CoreBndr
-> CoreBndr -> Type -> [Alt CoreBndr] -> Expr CoreBndr
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env Expr CoreBndr
scrut) CoreBndr
bndr Type
ty ((Alt CoreBndr -> Alt CoreBndr) -> [Alt CoreBndr] -> [Alt CoreBndr]
forall a b. (a -> b) -> [a] -> [b]
map (LibCaseEnv -> Alt CoreBndr -> Alt CoreBndr
libCaseAlt LibCaseEnv
env_alts) [Alt CoreBndr]
alts)
where
env_alts :: LibCaseEnv
env_alts = LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders (Expr CoreBndr -> LibCaseEnv
forall b. Expr b -> LibCaseEnv
mk_alt_env Expr CoreBndr
scrut) [CoreBndr
bndr]
mk_alt_env :: Expr b -> LibCaseEnv
mk_alt_env (Var scrut_var :: CoreBndr
scrut_var) = LibCaseEnv -> CoreBndr -> LibCaseEnv
addScrutedVar LibCaseEnv
env CoreBndr
scrut_var
mk_alt_env (Cast scrut :: Expr b
scrut _) = Expr b -> LibCaseEnv
mk_alt_env Expr b
scrut
mk_alt_env _ = LibCaseEnv
env
libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
-> (AltCon, [CoreBndr], CoreExpr)
libCaseAlt :: LibCaseEnv -> Alt CoreBndr -> Alt CoreBndr
libCaseAlt env :: LibCaseEnv
env (con :: AltCon
con,args :: [CoreBndr]
args,rhs :: Expr CoreBndr
rhs) = (AltCon
con, [CoreBndr]
args, LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase (LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders LibCaseEnv
env [CoreBndr]
args) Expr CoreBndr
rhs)
libCaseApp :: LibCaseEnv -> Id -> [CoreExpr] -> CoreExpr
libCaseApp :: LibCaseEnv -> CoreBndr -> [Expr CoreBndr] -> Expr CoreBndr
libCaseApp env :: LibCaseEnv
env v :: CoreBndr
v args :: [Expr CoreBndr]
args
| Just the_bind :: CoreBind
the_bind <- LibCaseEnv -> CoreBndr -> Maybe CoreBind
lookupRecId LibCaseEnv
env CoreBndr
v
, [CoreBndr] -> Bool
forall a. [a] -> Bool
notNull [CoreBndr]
free_scruts
= CoreBind -> Expr CoreBndr -> Expr CoreBndr
forall b. Bind b -> Expr b -> Expr b
Let CoreBind
the_bind Expr CoreBndr
expr'
| Bool
otherwise
= Expr CoreBndr
expr'
where
rec_id_level :: Int
rec_id_level = LibCaseEnv -> CoreBndr -> Int
lookupLevel LibCaseEnv
env CoreBndr
v
free_scruts :: [CoreBndr]
free_scruts = LibCaseEnv -> Int -> [CoreBndr]
freeScruts LibCaseEnv
env Int
rec_id_level
expr' :: Expr CoreBndr
expr' = Expr CoreBndr -> [Expr CoreBndr] -> Expr CoreBndr
forall b. Expr b -> [Expr b] -> Expr b
mkApps (CoreBndr -> Expr CoreBndr
forall b. CoreBndr -> Expr b
Var CoreBndr
v) ((Expr CoreBndr -> Expr CoreBndr)
-> [Expr CoreBndr] -> [Expr CoreBndr]
forall a b. (a -> b) -> [a] -> [b]
map (LibCaseEnv -> Expr CoreBndr -> Expr CoreBndr
libCase LibCaseEnv
env) [Expr CoreBndr]
args)
freeScruts :: LibCaseEnv
-> LibCaseLevel
-> [Id]
freeScruts :: LibCaseEnv -> Int -> [CoreBndr]
freeScruts env :: LibCaseEnv
env rec_bind_lvl :: Int
rec_bind_lvl
= [CoreBndr
v | (v :: CoreBndr
v, scrut_bind_lvl :: Int
scrut_bind_lvl, scrut_at_lvl :: Int
scrut_at_lvl) <- LibCaseEnv -> [(CoreBndr, Int, Int)]
lc_scruts LibCaseEnv
env
, Int
scrut_bind_lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
rec_bind_lvl
, Int
scrut_at_lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
rec_bind_lvl]
addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
addBinders env :: LibCaseEnv
env@(LibCaseEnv { lc_lvl :: LibCaseEnv -> Int
lc_lvl = Int
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv Int
lc_lvl_env = IdEnv Int
lvl_env }) binders :: [CoreBndr]
binders
= LibCaseEnv
env { lc_lvl_env :: IdEnv Int
lc_lvl_env = IdEnv Int
lvl_env' }
where
lvl_env' :: IdEnv Int
lvl_env' = IdEnv Int -> [(CoreBndr, Int)] -> IdEnv Int
forall a. VarEnv a -> [(CoreBndr, a)] -> VarEnv a
extendVarEnvList IdEnv Int
lvl_env ([CoreBndr]
binders [CoreBndr] -> [Int] -> [(CoreBndr, Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Int -> [Int]
forall a. a -> [a]
repeat Int
lvl)
addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
addRecBinds :: LibCaseEnv -> [(CoreBndr, Expr CoreBndr)] -> LibCaseEnv
addRecBinds env :: LibCaseEnv
env@(LibCaseEnv {lc_lvl :: LibCaseEnv -> Int
lc_lvl = Int
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv Int
lc_lvl_env = IdEnv Int
lvl_env,
lc_rec_env :: LibCaseEnv -> IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
rec_env}) pairs :: [(CoreBndr, Expr CoreBndr)]
pairs
= LibCaseEnv
env { lc_lvl :: Int
lc_lvl = Int
lvl', lc_lvl_env :: IdEnv Int
lc_lvl_env = IdEnv Int
lvl_env', lc_rec_env :: IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
rec_env' }
where
lvl' :: Int
lvl' = Int
lvl Int -> Int -> Int
forall a. Num a => a -> a -> a
+ 1
lvl_env' :: IdEnv Int
lvl_env' = IdEnv Int -> [(CoreBndr, Int)] -> IdEnv Int
forall a. VarEnv a -> [(CoreBndr, a)] -> VarEnv a
extendVarEnvList IdEnv Int
lvl_env [(CoreBndr
binder,Int
lvl) | (binder :: CoreBndr
binder,_) <- [(CoreBndr, Expr CoreBndr)]
pairs]
rec_env' :: IdEnv CoreBind
rec_env' = IdEnv CoreBind -> [(CoreBndr, CoreBind)] -> IdEnv CoreBind
forall a. VarEnv a -> [(CoreBndr, a)] -> VarEnv a
extendVarEnvList IdEnv CoreBind
rec_env [(CoreBndr
binder, [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(CoreBndr, Expr CoreBndr)]
pairs) | (binder :: CoreBndr
binder,_) <- [(CoreBndr, Expr CoreBndr)]
pairs]
addScrutedVar :: LibCaseEnv
-> Id
-> LibCaseEnv
addScrutedVar :: LibCaseEnv -> CoreBndr -> LibCaseEnv
addScrutedVar env :: LibCaseEnv
env@(LibCaseEnv { lc_lvl :: LibCaseEnv -> Int
lc_lvl = Int
lvl, lc_lvl_env :: LibCaseEnv -> IdEnv Int
lc_lvl_env = IdEnv Int
lvl_env,
lc_scruts :: LibCaseEnv -> [(CoreBndr, Int, Int)]
lc_scruts = [(CoreBndr, Int, Int)]
scruts }) scrut_var :: CoreBndr
scrut_var
| Int
bind_lvl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lvl
= LibCaseEnv
env { lc_scruts :: [(CoreBndr, Int, Int)]
lc_scruts = [(CoreBndr, Int, Int)]
scruts' }
| Bool
otherwise = LibCaseEnv
env
where
scruts' :: [(CoreBndr, Int, Int)]
scruts' = (CoreBndr
scrut_var, Int
bind_lvl, Int
lvl) (CoreBndr, Int, Int)
-> [(CoreBndr, Int, Int)] -> [(CoreBndr, Int, Int)]
forall a. a -> [a] -> [a]
: [(CoreBndr, Int, Int)]
scruts
bind_lvl :: Int
bind_lvl = case IdEnv Int -> CoreBndr -> Maybe Int
forall a. VarEnv a -> CoreBndr -> Maybe a
lookupVarEnv IdEnv Int
lvl_env CoreBndr
scrut_var of
Just lvl :: Int
lvl -> Int
lvl
Nothing -> Int
topLevel
lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
lookupRecId :: LibCaseEnv -> CoreBndr -> Maybe CoreBind
lookupRecId env :: LibCaseEnv
env id :: CoreBndr
id = IdEnv CoreBind -> CoreBndr -> Maybe CoreBind
forall a. VarEnv a -> CoreBndr -> Maybe a
lookupVarEnv (LibCaseEnv -> IdEnv CoreBind
lc_rec_env LibCaseEnv
env) CoreBndr
id
lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
lookupLevel :: LibCaseEnv -> CoreBndr -> Int
lookupLevel env :: LibCaseEnv
env id :: CoreBndr
id
= case IdEnv Int -> CoreBndr -> Maybe Int
forall a. VarEnv a -> CoreBndr -> Maybe a
lookupVarEnv (LibCaseEnv -> IdEnv Int
lc_lvl_env LibCaseEnv
env) CoreBndr
id of
Just lvl :: Int
lvl -> Int
lvl
Nothing -> Int
topLevel
type LibCaseLevel = Int
topLevel :: LibCaseLevel
topLevel :: Int
topLevel = 0
data LibCaseEnv
= LibCaseEnv {
LibCaseEnv -> DynFlags
lc_dflags :: DynFlags,
LibCaseEnv -> Int
lc_lvl :: LibCaseLevel,
LibCaseEnv -> IdEnv Int
lc_lvl_env :: IdEnv LibCaseLevel,
LibCaseEnv -> IdEnv CoreBind
lc_rec_env :: IdEnv CoreBind,
LibCaseEnv -> [(CoreBndr, Int, Int)]
lc_scruts :: [(Id, LibCaseLevel, LibCaseLevel)]
}
initEnv :: DynFlags -> LibCaseEnv
initEnv :: DynFlags -> LibCaseEnv
initEnv dflags :: DynFlags
dflags
= LibCaseEnv :: DynFlags
-> Int
-> IdEnv Int
-> IdEnv CoreBind
-> [(CoreBndr, Int, Int)]
-> LibCaseEnv
LibCaseEnv { lc_dflags :: DynFlags
lc_dflags = DynFlags
dflags,
lc_lvl :: Int
lc_lvl = 0,
lc_lvl_env :: IdEnv Int
lc_lvl_env = IdEnv Int
forall a. VarEnv a
emptyVarEnv,
lc_rec_env :: IdEnv CoreBind
lc_rec_env = IdEnv CoreBind
forall a. VarEnv a
emptyVarEnv,
lc_scruts :: [(CoreBndr, Int, Int)]
lc_scruts = [] }
bombOutSize :: LibCaseEnv -> Maybe Int
bombOutSize :: LibCaseEnv -> Maybe Int
bombOutSize = DynFlags -> Maybe Int
liberateCaseThreshold (DynFlags -> Maybe Int)
-> (LibCaseEnv -> DynFlags) -> LibCaseEnv -> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LibCaseEnv -> DynFlags
lc_dflags