module GHC.Core.Opt.Exitify ( exitifyProgram ) where
import GHC.Prelude
import GHC.Types.Var
import GHC.Types.Id
import GHC.Types.Id.Info
import GHC.Core
import GHC.Core.Utils
import GHC.Utils.Monad.State
import GHC.Builtin.Uniques
import GHC.Types.Var.Set
import GHC.Types.Var.Env
import GHC.Core.FVs
import GHC.Data.FastString
import GHC.Core.Type
import GHC.Utils.Misc( mapSnd )
import Data.Bifunctor
import Control.Monad
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram :: CoreProgram -> CoreProgram
exitifyProgram CoreProgram
binds = forall a b. (a -> b) -> [a] -> [b]
map Bind JoinId -> Bind JoinId
goTopLvl CoreProgram
binds
where
goTopLvl :: Bind JoinId -> Bind JoinId
goTopLvl (NonRec JoinId
v CoreExpr
e) = forall b. b -> Expr b -> Bind b
NonRec JoinId
v (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope_toplvl CoreExpr
e)
goTopLvl (Rec [(JoinId, CoreExpr)]
pairs) = forall b. [(b, Expr b)] -> Bind b
Rec (forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope_toplvl)) [(JoinId, CoreExpr)]
pairs)
in_scope_toplvl :: InScopeSet
in_scope_toplvl = InScopeSet
emptyInScopeSet InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` forall b. [Bind b] -> [b]
bindersOfBinds CoreProgram
binds
go :: InScopeSet -> CoreExpr -> CoreExpr
go :: InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
_ e :: CoreExpr
e@(Var{}) = CoreExpr
e
go InScopeSet
_ e :: CoreExpr
e@(Lit {}) = CoreExpr
e
go InScopeSet
_ e :: CoreExpr
e@(Type {}) = CoreExpr
e
go InScopeSet
_ e :: CoreExpr
e@(Coercion {}) = CoreExpr
e
go InScopeSet
in_scope (Cast CoreExpr
e' CoercionR
c) = forall b. Expr b -> CoercionR -> Expr b
Cast (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e') CoercionR
c
go InScopeSet
in_scope (Tick CoreTickish
t CoreExpr
e') = forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e')
go InScopeSet
in_scope (App CoreExpr
e1 CoreExpr
e2) = forall b. Expr b -> Expr b -> Expr b
App (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e1) (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
e2)
go InScopeSet
in_scope (Lam JoinId
v CoreExpr
e')
= forall b. b -> Expr b -> Expr b
Lam JoinId
v (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
e')
where in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
v
go InScopeSet
in_scope (Case CoreExpr
scrut JoinId
bndr Type
ty [Alt JoinId]
alts)
= forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
scrut) JoinId
bndr Type
ty (forall a b. (a -> b) -> [a] -> [b]
map Alt JoinId -> Alt JoinId
go_alt [Alt JoinId]
alts)
where
in_scope1 :: InScopeSet
in_scope1 = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
bndr
go_alt :: Alt JoinId -> Alt JoinId
go_alt (Alt AltCon
dc [JoinId]
pats CoreExpr
rhs) = forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
dc [JoinId]
pats (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
rhs)
where in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope1 InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` [JoinId]
pats
go InScopeSet
in_scope (Let (NonRec JoinId
bndr CoreExpr
rhs) CoreExpr
body)
= forall b. Bind b -> Expr b -> Expr b
Let (forall b. b -> Expr b -> Bind b
NonRec JoinId
bndr (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope CoreExpr
rhs)) (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
body)
where
in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
bndr
go InScopeSet
in_scope (Let (Rec [(JoinId, CoreExpr)]
pairs) CoreExpr
body)
| Bool
is_join_rec = forall b. [Bind b] -> Expr b -> Expr b
mkLets (InScopeSet -> [(JoinId, CoreExpr)] -> CoreProgram
exitifyRec InScopeSet
in_scope' [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
| Bool
otherwise = forall b. Bind b -> Expr b -> Expr b
Let (forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
where
is_join_rec :: Bool
is_join_rec = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (JoinId -> Bool
isJoinId forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(JoinId, CoreExpr)]
pairs
in_scope' :: InScopeSet
in_scope' = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` forall b. Bind b -> [b]
bindersOf (forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs)
pairs' :: [(JoinId, CoreExpr)]
pairs' = forall b c a. (b -> c) -> [(a, b)] -> [(a, c)]
mapSnd (InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope') [(JoinId, CoreExpr)]
pairs
body' :: CoreExpr
body' = InScopeSet -> CoreExpr -> CoreExpr
go InScopeSet
in_scope' CoreExpr
body
type ExitifyM = State [(JoinId, CoreExpr)]
exitifyRec :: InScopeSet -> [(Var,CoreExpr)] -> [CoreBind]
exitifyRec :: InScopeSet -> [(JoinId, CoreExpr)] -> CoreProgram
exitifyRec InScopeSet
in_scope [(JoinId, CoreExpr)]
pairs
= [ forall b. b -> Expr b -> Bind b
NonRec JoinId
xid CoreExpr
rhs | (JoinId
xid,CoreExpr
rhs) <- [(JoinId, CoreExpr)]
exits ] forall a. [a] -> [a] -> [a]
++ [forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs']
where
ann_pairs :: [(JoinId, CoreExprWithFVs)]
ann_pairs = forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second CoreExpr -> CoreExprWithFVs
freeVars) [(JoinId, CoreExpr)]
pairs
recursive_calls :: VarSet
recursive_calls = [JoinId] -> VarSet
mkVarSet forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(JoinId, CoreExpr)]
pairs
([(JoinId, CoreExpr)]
pairs',[(JoinId, CoreExpr)]
exits) = (forall s a. State s a -> s -> (a, s)
`runState` []) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(JoinId, CoreExprWithFVs)]
ann_pairs forall a b. (a -> b) -> a -> b
$ \(JoinId
x,CoreExprWithFVs
rhs) -> do
let ([JoinId]
args, CoreExprWithFVs
body) = forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs (JoinId -> Int
idJoinArity JoinId
x) CoreExprWithFVs
rhs
CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go [JoinId]
args CoreExprWithFVs
body
let rhs' :: CoreExpr
rhs' = forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
args CoreExpr
body'
forall (m :: * -> *) a. Monad m => a -> m a
return (JoinId
x, CoreExpr
rhs')
go :: [Var]
-> CoreExprWithFVs
-> ExitifyM CoreExpr
go :: [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go [JoinId]
captured CoreExprWithFVs
ann_e
|
let fvs :: VarSet
fvs = DVarSet -> VarSet
dVarSetToVarSet (CoreExprWithFVs -> DVarSet
freeVarsOf CoreExprWithFVs
ann_e)
, VarSet -> VarSet -> Bool
disjointVarSet VarSet
fvs VarSet
recursive_calls
= [JoinId] -> CoreExpr -> VarSet -> ExitifyM CoreExpr
go_exit [JoinId]
captured (forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
ann_e) VarSet
fvs
go [JoinId]
captured (DVarSet
_, AnnCase CoreExprWithFVs
scrut JoinId
bndr Type
ty [AnnAlt JoinId DVarSet]
alts) = do
[Alt JoinId]
alts' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [AnnAlt JoinId DVarSet]
alts forall a b. (a -> b) -> a -> b
$ \(AnnAlt AltCon
dc [JoinId]
pats CoreExprWithFVs
rhs) -> do
CoreExpr
rhs' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured forall a. [a] -> [a] -> [a]
++ [JoinId
bndr] forall a. [a] -> [a] -> [a]
++ [JoinId]
pats) CoreExprWithFVs
rhs
forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
dc [JoinId]
pats CoreExpr
rhs')
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
scrut) JoinId
bndr Type
ty [Alt JoinId]
alts'
go [JoinId]
captured (DVarSet
_, AnnLet AnnBind JoinId DVarSet
ann_bind CoreExprWithFVs
body)
| AnnNonRec JoinId
j CoreExprWithFVs
rhs <- AnnBind JoinId DVarSet
ann_bind
, Just Int
join_arity <- JoinId -> Maybe Int
isJoinId_maybe JoinId
j
= do let ([JoinId]
params, CoreExprWithFVs
join_body) = forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs Int
join_arity CoreExprWithFVs
rhs
CoreExpr
join_body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured forall a. [a] -> [a] -> [a]
++ [JoinId]
params) CoreExprWithFVs
join_body
let rhs' :: CoreExpr
rhs' = forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
params CoreExpr
join_body'
CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured forall a. [a] -> [a] -> [a]
++ [JoinId
j]) CoreExprWithFVs
body
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall b. Bind b -> Expr b -> Expr b
Let (forall b. b -> Expr b -> Bind b
NonRec JoinId
j CoreExpr
rhs') CoreExpr
body'
| AnnRec [(JoinId, CoreExprWithFVs)]
pairs <- AnnBind JoinId DVarSet
ann_bind
, JoinId -> Bool
isJoinId (forall a b. (a, b) -> a
fst (forall a. [a] -> a
head [(JoinId, CoreExprWithFVs)]
pairs))
= do let js :: [JoinId]
js = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(JoinId, CoreExprWithFVs)]
pairs
[(JoinId, CoreExpr)]
pairs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(JoinId, CoreExprWithFVs)]
pairs forall a b. (a -> b) -> a -> b
$ \(JoinId
j,CoreExprWithFVs
rhs) -> do
let join_arity :: Int
join_arity = JoinId -> Int
idJoinArity JoinId
j
([JoinId]
params, CoreExprWithFVs
join_body) = forall bndr annot.
Int -> AnnExpr bndr annot -> ([bndr], AnnExpr bndr annot)
collectNAnnBndrs Int
join_arity CoreExprWithFVs
rhs
CoreExpr
join_body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured forall a. [a] -> [a] -> [a]
++ [JoinId]
js forall a. [a] -> [a] -> [a]
++ [JoinId]
params) CoreExprWithFVs
join_body
let rhs' :: CoreExpr
rhs' = forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
params CoreExpr
join_body'
forall (m :: * -> *) a. Monad m => a -> m a
return (JoinId
j, CoreExpr
rhs')
CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured forall a. [a] -> [a] -> [a]
++ [JoinId]
js) CoreExprWithFVs
body
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall b. Bind b -> Expr b -> Expr b
Let (forall b. [(b, Expr b)] -> Bind b
Rec [(JoinId, CoreExpr)]
pairs') CoreExpr
body'
| Bool
otherwise
= do CoreExpr
body' <- [JoinId] -> CoreExprWithFVs -> ExitifyM CoreExpr
go ([JoinId]
captured forall a. [a] -> [a] -> [a]
++ forall b. Bind b -> [b]
bindersOf Bind JoinId
bind ) CoreExprWithFVs
body
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall b. Bind b -> Expr b -> Expr b
Let Bind JoinId
bind CoreExpr
body'
where bind :: Bind JoinId
bind = forall b annot. AnnBind b annot -> Bind b
deAnnBind AnnBind JoinId DVarSet
ann_bind
go [JoinId]
_ CoreExprWithFVs
ann_e = forall (m :: * -> *) a. Monad m => a -> m a
return (forall bndr annot. AnnExpr bndr annot -> Expr bndr
deAnnotate CoreExprWithFVs
ann_e)
go_exit :: [Var]
-> CoreExpr
-> VarSet
-> ExitifyM CoreExpr
go_exit :: [JoinId] -> CoreExpr -> VarSet -> ExitifyM CoreExpr
go_exit [JoinId]
captured CoreExpr
e VarSet
fvs
| (Var JoinId
f, [CoreExpr]
args) <- forall b. Expr b -> (Expr b, [Expr b])
collectArgs CoreExpr
e
, JoinId -> Bool
isJoinId JoinId
f
, forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all CoreExpr -> Bool
isCapturedVarArg [CoreExpr]
args
= forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
| Bool -> Bool
not Bool
is_interesting
= forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
| Bool
captures_join_points
= forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e
| Bool
otherwise
= do {
let rhs :: CoreExpr
rhs = forall b. [b] -> Expr b -> Expr b
mkLams [JoinId]
abs_vars CoreExpr
e
avoid :: InScopeSet
avoid = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` [JoinId]
captured
; JoinId
v <- InScopeSet -> Int -> CoreExpr -> ExitifyM JoinId
addExit InScopeSet
avoid (forall (t :: * -> *) a. Foldable t => t a -> Int
length [JoinId]
abs_vars) CoreExpr
rhs
; forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall b. Expr b -> [JoinId] -> Expr b
mkVarApps (forall b. JoinId -> Expr b
Var JoinId
v) [JoinId]
abs_vars }
where
isCapturedVarArg :: CoreExpr -> Bool
isCapturedVarArg (Var JoinId
v) = JoinId
v forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [JoinId]
captured
isCapturedVarArg CoreExpr
_ = Bool
False
is_interesting :: Bool
is_interesting = (JoinId -> Bool) -> VarSet -> Bool
anyVarSet JoinId -> Bool
isLocalId forall a b. (a -> b) -> a -> b
$
VarSet
fvs VarSet -> VarSet -> VarSet
`minusVarSet` [JoinId] -> VarSet
mkVarSet [JoinId]
captured
abs_vars :: [JoinId]
abs_vars = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId])
pick (VarSet
fvs, []) [JoinId]
captured
where
pick :: JoinId -> (VarSet, [JoinId]) -> (VarSet, [JoinId])
pick JoinId
v (VarSet
fvs', [JoinId]
acc) | JoinId
v JoinId -> VarSet -> Bool
`elemVarSet` VarSet
fvs' = (VarSet
fvs' VarSet -> JoinId -> VarSet
`delVarSet` JoinId
v, JoinId -> JoinId
zap JoinId
v forall a. a -> [a] -> [a]
: [JoinId]
acc)
| Bool
otherwise = (VarSet
fvs', [JoinId]
acc)
zap :: JoinId -> JoinId
zap JoinId
v | JoinId -> Bool
isId JoinId
v = JoinId -> IdInfo -> JoinId
setIdInfo JoinId
v IdInfo
vanillaIdInfo
| Bool
otherwise = JoinId
v
captures_join_points :: Bool
captures_join_points = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any JoinId -> Bool
isJoinId [JoinId]
abs_vars
mkExitJoinId :: InScopeSet -> Type -> JoinArity -> ExitifyM JoinId
mkExitJoinId :: InScopeSet -> Type -> Int -> ExitifyM JoinId
mkExitJoinId InScopeSet
in_scope Type
ty Int
join_arity = do
[(JoinId, CoreExpr)]
fs <- forall s. State s s
get
let avoid :: InScopeSet
avoid = InScopeSet
in_scope InScopeSet -> [JoinId] -> InScopeSet
`extendInScopeSetList` (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(JoinId, CoreExpr)]
fs)
InScopeSet -> JoinId -> InScopeSet
`extendInScopeSet` JoinId
exit_id_tmpl
forall (m :: * -> *) a. Monad m => a -> m a
return (InScopeSet -> JoinId -> JoinId
uniqAway InScopeSet
avoid JoinId
exit_id_tmpl)
where
exit_id_tmpl :: JoinId
exit_id_tmpl = FastString -> Unique -> Type -> Type -> JoinId
mkSysLocal (String -> FastString
fsLit String
"exit") Unique
initExitJoinUnique Type
Many Type
ty
JoinId -> Int -> JoinId
`asJoinId` Int
join_arity
addExit :: InScopeSet -> JoinArity -> CoreExpr -> ExitifyM JoinId
addExit :: InScopeSet -> Int -> CoreExpr -> ExitifyM JoinId
addExit InScopeSet
in_scope Int
join_arity CoreExpr
rhs = do
let ty :: Type
ty = CoreExpr -> Type
exprType CoreExpr
rhs
JoinId
v <- InScopeSet -> Type -> Int -> ExitifyM JoinId
mkExitJoinId InScopeSet
in_scope Type
ty Int
join_arity
[(JoinId, CoreExpr)]
fs <- forall s. State s s
get
forall s. s -> State s ()
put ((JoinId
v,CoreExpr
rhs)forall a. a -> [a] -> [a]
:[(JoinId, CoreExpr)]
fs)
forall (m :: * -> *) a. Monad m => a -> m a
return JoinId
v