{-# LANGUAGE TypeFamilies #-}
module StgCse (stgCse) where
import DataCon
import Id
import StgSyn
import Outputable
import VarEnv
import CoreSyn (AltCon(..))
import Data.List (mapAccumL)
import Data.Maybe (fromMaybe)
import TrieMap
import NameEnv
import Control.Monad( (>=>) )
data StgArgMap a = SAM
{ sam_var :: DVarEnv a
, sam_lit :: LiteralMap a
}
instance TrieMap StgArgMap where
type Key StgArgMap = StgArg
emptyTM = SAM { sam_var = emptyTM
, sam_lit = emptyTM }
lookupTM (StgVarArg var) = sam_var >.> lkDFreeVar var
lookupTM (StgLitArg lit) = sam_lit >.> lookupTM lit
alterTM (StgVarArg var) f m = m { sam_var = sam_var m |> xtDFreeVar var f }
alterTM (StgLitArg lit) f m = m { sam_lit = sam_lit m |> alterTM lit f }
foldTM k m = foldTM k (sam_var m) . foldTM k (sam_lit m)
mapTM f (SAM {sam_var = varm, sam_lit = litm}) =
SAM { sam_var = mapTM f varm, sam_lit = mapTM f litm }
newtype ConAppMap a = CAM { un_cam :: DNameEnv (ListMap StgArgMap a) }
instance TrieMap ConAppMap where
type Key ConAppMap = (DataCon, [StgArg])
emptyTM = CAM emptyTM
lookupTM (dataCon, args) = un_cam >.> lkDNamed dataCon >=> lookupTM args
alterTM (dataCon, args) f m =
m { un_cam = un_cam m |> xtDNamed dataCon |>> alterTM args f }
foldTM k = un_cam >.> foldTM (foldTM k)
mapTM f = un_cam >.> mapTM (mapTM f) >.> CAM
data CseEnv = CseEnv
{ ce_conAppMap :: ConAppMap OutId
, ce_subst :: IdEnv OutId
, ce_bndrMap :: IdEnv OutId
, ce_in_scope :: InScopeSet
}
initEnv :: InScopeSet -> CseEnv
initEnv in_scope = CseEnv
{ ce_conAppMap = emptyTM
, ce_subst = emptyVarEnv
, ce_bndrMap = emptyVarEnv
, ce_in_scope = in_scope
}
envLookup :: DataCon -> [OutStgArg] -> CseEnv -> Maybe OutId
envLookup dataCon args env = lookupTM (dataCon, args') (ce_conAppMap env)
where args' = map go args
go (StgVarArg v ) = StgVarArg (fromMaybe v $ lookupVarEnv (ce_bndrMap env) v)
go (StgLitArg lit) = StgLitArg lit
addDataCon :: OutId -> DataCon -> [OutStgArg] -> CseEnv -> CseEnv
addDataCon _ _ [] env = env
addDataCon bndr dataCon args env = env { ce_conAppMap = new_env }
where
new_env = insertTM (dataCon, args) bndr (ce_conAppMap env)
forgetCse :: CseEnv -> CseEnv
forgetCse env = env { ce_conAppMap = emptyTM }
addSubst :: OutId -> OutId -> CseEnv -> CseEnv
addSubst from to env
= env { ce_subst = extendVarEnv (ce_subst env) from to }
addTrivCaseBndr :: OutId -> OutId -> CseEnv -> CseEnv
addTrivCaseBndr from to env
= env { ce_bndrMap = extendVarEnv (ce_bndrMap env) from to }
substArgs :: CseEnv -> [InStgArg] -> [OutStgArg]
substArgs env = map (substArg env)
substArg :: CseEnv -> InStgArg -> OutStgArg
substArg env (StgVarArg from) = StgVarArg (substVar env from)
substArg _ (StgLitArg lit) = StgLitArg lit
substVars :: CseEnv -> [InId] -> [OutId]
substVars env = map (substVar env)
substVar :: CseEnv -> InId -> OutId
substVar env id = fromMaybe id $ lookupVarEnv (ce_subst env) id
substBndr :: CseEnv -> InId -> (CseEnv, OutId)
substBndr env old_id
= (new_env, new_id)
where
new_id = uniqAway (ce_in_scope env) old_id
no_change = new_id == old_id
env' = env { ce_in_scope = ce_in_scope env `extendInScopeSet` new_id }
new_env | no_change = env' { ce_subst = extendVarEnv (ce_subst env) old_id new_id }
| otherwise = env'
substBndrs :: CseEnv -> [InVar] -> (CseEnv, [OutVar])
substBndrs env bndrs = mapAccumL substBndr env bndrs
substPairs :: CseEnv -> [(InVar, a)] -> (CseEnv, [(OutVar, a)])
substPairs env bndrs = mapAccumL go env bndrs
where go env (id, x) = let (env', id') = substBndr env id
in (env', (id', x))
stgCse :: [InStgTopBinding] -> [OutStgTopBinding]
stgCse binds = snd $ mapAccumL stgCseTopLvl emptyInScopeSet binds
stgCseTopLvl :: InScopeSet -> InStgTopBinding -> (InScopeSet, OutStgTopBinding)
stgCseTopLvl in_scope t@(StgTopStringLit _ _) = (in_scope, t)
stgCseTopLvl in_scope (StgTopLifted (StgNonRec bndr rhs))
= (in_scope'
, StgTopLifted (StgNonRec bndr (stgCseTopLvlRhs in_scope rhs)))
where in_scope' = in_scope `extendInScopeSet` bndr
stgCseTopLvl in_scope (StgTopLifted (StgRec eqs))
= ( in_scope'
, StgTopLifted (StgRec [ (bndr, stgCseTopLvlRhs in_scope' rhs) | (bndr, rhs) <- eqs ]))
where in_scope' = in_scope `extendInScopeSetList` [ bndr | (bndr, _) <- eqs ]
stgCseTopLvlRhs :: InScopeSet -> InStgRhs -> OutStgRhs
stgCseTopLvlRhs in_scope (StgRhsClosure ccs info occs upd args body)
= let body' = stgCseExpr (initEnv in_scope) body
in StgRhsClosure ccs info occs upd args body'
stgCseTopLvlRhs _ (StgRhsCon ccs dataCon args)
= StgRhsCon ccs dataCon args
stgCseExpr :: CseEnv -> InStgExpr -> OutStgExpr
stgCseExpr env (StgApp fun args)
= StgApp fun' args'
where fun' = substVar env fun
args' = substArgs env args
stgCseExpr _ (StgLit lit)
= StgLit lit
stgCseExpr env (StgOpApp op args tys)
= StgOpApp op args' tys
where args' = substArgs env args
stgCseExpr _ (StgLam _ _)
= pprPanic "stgCseExp" (text "StgLam")
stgCseExpr env (StgTick tick body)
= let body' = stgCseExpr env body
in StgTick tick body'
stgCseExpr env (StgCase scrut bndr ty alts)
= StgCase scrut' bndr' ty alts'
where
scrut' = stgCseExpr env scrut
(env1, bndr') = substBndr env bndr
env2 | StgApp trivial_scrut [] <- scrut' = addTrivCaseBndr bndr trivial_scrut env1
| otherwise = env1
alts' = map (stgCseAlt env2 bndr') alts
stgCseExpr env (StgConApp dataCon args tys)
| Just bndr' <- envLookup dataCon args' env
= StgApp bndr' []
| otherwise
= StgConApp dataCon args' tys
where args' = substArgs env args
stgCseExpr env (StgLet binds body)
= let (binds', env') = stgCseBind env binds
body' = stgCseExpr env' body
in mkStgLet StgLet binds' body'
stgCseExpr env (StgLetNoEscape binds body)
= let (binds', env') = stgCseBind env binds
body' = stgCseExpr env' body
in mkStgLet StgLetNoEscape binds' body'
stgCseAlt :: CseEnv -> OutId -> InStgAlt -> OutStgAlt
stgCseAlt env case_bndr (DataAlt dataCon, args, rhs)
= let (env1, args') = substBndrs env args
env2 = addDataCon case_bndr dataCon (map StgVarArg args') env1
rhs' = stgCseExpr env2 rhs
in (DataAlt dataCon, args', rhs')
stgCseAlt env _ (altCon, args, rhs)
= let (env1, args') = substBndrs env args
rhs' = stgCseExpr env1 rhs
in (altCon, args', rhs')
stgCseBind :: CseEnv -> InStgBinding -> (Maybe OutStgBinding, CseEnv)
stgCseBind env (StgNonRec b e)
= let (env1, b') = substBndr env b
in case stgCseRhs env1 b' e of
(Nothing, env2) -> (Nothing, env2)
(Just (b2,e'), env2) -> (Just (StgNonRec b2 e'), env2)
stgCseBind env (StgRec pairs)
= let (env1, pairs1) = substPairs env pairs
in case stgCsePairs env1 pairs1 of
([], env2) -> (Nothing, env2)
(pairs2, env2) -> (Just (StgRec pairs2), env2)
stgCsePairs :: CseEnv -> [(OutId, InStgRhs)] -> ([(OutId, OutStgRhs)], CseEnv)
stgCsePairs env [] = ([], env)
stgCsePairs env0 ((b,e):pairs)
= let (pairMB, env1) = stgCseRhs env0 b e
(pairs', env2) = stgCsePairs env1 pairs
in (pairMB `mbCons` pairs', env2)
where
mbCons = maybe id (:)
stgCseRhs :: CseEnv -> OutId -> InStgRhs -> (Maybe (OutId, OutStgRhs), CseEnv)
stgCseRhs env bndr (StgRhsCon ccs dataCon args)
| Just other_bndr <- envLookup dataCon args' env
= let env' = addSubst bndr other_bndr env
in (Nothing, env')
| otherwise
= let env' = addDataCon bndr dataCon args' env
pair = (bndr, StgRhsCon ccs dataCon args')
in (Just pair, env')
where args' = substArgs env args
stgCseRhs env bndr (StgRhsClosure ccs info occs upd args body)
= let (env1, args') = substBndrs env args
env2 = forgetCse env1
body' = stgCseExpr env2 body
in (Just (substVar env bndr, StgRhsClosure ccs info occs' upd args' body'), env)
where occs' = substVars env occs
mkStgLet :: (a -> b -> b) -> Maybe a -> b -> b
mkStgLet _ Nothing body = body
mkStgLet stgLet (Just binds) body = stgLet binds body