module Language.HERMIT.Primitive.Local.Case
(
externals
, caseElim
, caseFloatApp
, caseFloatArg
, caseFloatCase
, caseFloatCast
, caseFloatLet
, caseFloat
, caseUnfloat
, caseUnfloatApp
, caseUnfloatArgs
, caseReduce
, caseReduceDatacon
, caseReduceLiteral
, caseSplit
, caseSplitInline
) where
import GhcPlugins
import Data.List
import Data.Monoid
import Data.Set (intersection, fromList, toList, member, unions)
import qualified Data.Set as S
import Control.Arrow
import Control.Monad (liftM)
import Language.HERMIT.Core
import Language.HERMIT.Context
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.GHC
import Language.HERMIT.External
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC hiding (externals)
import Language.HERMIT.Primitive.Inline hiding (externals)
import Language.HERMIT.Primitive.AlphaConversion hiding (externals)
import qualified Language.Haskell.TH as TH
externals :: [External]
externals =
[ external "case-elim" (promoteExprR caseElim :: RewriteH Core)
[ "case s of w; C vs -> e ==> e if w and vs are not free in e" ] .+ Shallow
, external "case-float-app" (promoteExprR caseFloatApp :: RewriteH Core)
[ "(case ec of alt -> e) v ==> case ec of alt -> e v" ] .+ Commute .+ Shallow .+ Bash
, external "case-float-arg" (promoteExprR caseFloatArg :: RewriteH Core)
[ "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow .+ PreCondition
, external "case-float-case" (promoteExprR caseFloatCase :: RewriteH Core)
[ "case (case ec of alt1 -> e1) of alta -> ea ==> case ec of alt1 -> case e1 of alta -> ea" ] .+ Commute .+ Eval .+ Bash
, external "case-float-cast" (promoteExprR caseFloatCast :: RewriteH Core)
[ "cast (case s of p -> e) co ==> case s of p -> cast e co" ] .+ Shallow .+ Commute .+ Bash
, external "case-float-let" (promoteExprR caseFloatLet :: RewriteH Core)
[ "let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e" ] .+ Commute .+ Shallow .+ Bash
, external "case-float" (promoteExprR caseFloat :: RewriteH Core)
[ "Float a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-unfloat" (promoteExprR caseUnfloat :: RewriteH Core)
[ "Unfloat a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-unfloat-args" (promoteExprR caseUnfloatArgs :: RewriteH Core)
[ "Unfloat a Case whose alternatives are parallel applications of the same function." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-reduce" (promoteExprR caseReduce :: RewriteH Core)
[ "case-reduce-datacon <+ case-reduce-literal" ] .+ Shallow .+ Eval .+ Bash
, external "case-reduce-datacon" (promoteExprR caseReduceDatacon :: RewriteH Core)
[ "case-of-known-constructor"
, "case C v1..vn of C w1..wn -> e ==> let { w1 = v1 ; .. ; wn = vn } in e" ] .+ Shallow .+ Eval
, external "case-reduce-literal" (promoteExprR caseReduceLiteral :: RewriteH Core)
[ "case L of L -> e ==> e" ] .+ Shallow .+ Eval
, external "case-split" (promoteExprR . caseSplit :: TH.Name -> RewriteH Core)
[ "case-split 'x"
, "e ==> case x of C1 vs -> e; C2 vs -> e, where x is free in e" ]
, external "case-split-inline" (caseSplitInline :: TH.Name -> RewriteH Core)
[ "Like case-split, but additionally inlines the matched constructor "
, "applications for all occurances of the named variable." ]
]
caseElim :: Rewrite c HermitM CoreExpr
caseElim = prefixFailMsg "Case elimination failed: " $
withPatFailMsg (wrongExprForm "Case s bnd ty alts") $ do
Case _ bnd _ alts <- idR
case alts of
[(_, vs, e)] -> do fvs <- applyInContextT freeVarsT e
guardMsg (S.null $ intersection (fromList (bnd:vs)) fvs) "wildcard or pattern binders free in RHS."
return e
_ -> fail "more than one case alternative."
caseFloatApp :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseFloatApp = prefixFailMsg "Case floating from App function failed: " $
do
captures <- appT (liftM (map fromList) caseAltVarsT) freeVarsT (flip (map . intersection))
wildCapture <- appT caseWildIdT freeVarsT member
appT ((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR idR idR (\i -> if S.null (captures !! i) then idR else alphaAlt)
)
idR
(\(Case s b _ty alts) v -> let newTy = exprType (App (case head alts of (_,_,f) -> f) v)
in Case s b newTy $ mapAlts (flip App v) alts)
caseFloatArg :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseFloatArg = prefixFailMsg "Case floating from App argument failed: " $
do
captures <- appT freeVarsT (liftM (map fromList) caseAltVarsT) (map . intersection)
wildCapture <- appT freeVarsT caseWildIdT (flip member)
appT idR
((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR idR idR (\i -> if S.null (captures !! i) then idR else alphaAlt)
)
(\f (Case s b _ty alts) -> let newTy = exprType (App f (case head alts of (_,_,e) -> e))
in Case s b newTy $ mapAlts (App f) alts)
caseFloatCase :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseFloatCase = prefixFailMsg "Case floating from Case failed: " $
do
captures <- caseT (liftM (map fromList) caseAltVarsT) idR mempty (const altFreeVarsExclWildT) (\ vss bndr () fs -> map (intersection (unions $ map ($ bndr) fs)) vss)
wildCapture <- caseT caseWildIdT idR mempty (const altFreeVarsExclWildT) (\ innerBndr bndr () fvs -> innerBndr `member` unions (map ($ bndr) fvs))
caseT ((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR idR idR (\i -> if S.null (captures !! i) then idR else alphaAlt)
)
idR
idR
(const idR)
(\ (Case s1 b1 _ alts1) b2 ty alts2 -> Case s1 b1 ty $ mapAlts (\s -> Case s b2 ty alts2) alts1)
caseFloatLet :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseFloatLet = prefixFailMsg "Case floating from Let failed: " $
do vs <- letNonRecT idR caseAltVarsT mempty (\ letVar caseVars () -> letVar `elem` concat caseVars)
let bdsAction = if not vs then idR else nonRecAllR idR alphaCase
letT bdsAction idR $ \ (NonRec v (Case s b _ alts)) e -> Case s b (exprType e) $ mapAlts (flip Let e . NonRec v) alts
caseFloatCast :: MonadCatch m => Rewrite c m CoreExpr
caseFloatCast = prefixFailMsg "Case float from cast failed: " $
withPatFailMsg (wrongExprForm "Cast (Case s bnd ty alts) co") $
do Cast (Case s bnd _ alts) co <- idR
let alts' = mapAlts (flip Cast co) alts
return $ Case s bnd (coreAltsType alts') alts'
caseFloat :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseFloat = setFailMsg "Unsuitable expression for Case floating." $
caseFloatApp <+ caseFloatArg <+ caseFloatCase <+ caseFloatLet <+ caseFloatCast
caseUnfloat :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Rewrite c m CoreExpr
caseUnfloat = setFailMsg "Case unfloating failed." $
caseUnfloatApp <+ caseUnfloatArgs
caseUnfloatApp :: Monad m => Rewrite c m CoreExpr
caseUnfloatApp = fail "caseUnfloatApp: TODO"
caseUnfloatArgs :: (ExtendPath c Crumb, AddBindings c, MonadCatch m) => Rewrite c m CoreExpr
caseUnfloatArgs = prefixFailMsg "Case unfloating into arguments failed: " $
withPatFailMsg (wrongExprForm "Case s v t alts") $
do Case s wild _ty alts <- idR
(vss, fs, argss) <- caseT mempty mempty mempty (\ _ -> altT mempty (\ _ -> idR) callT $ \ () vs (fn, args) -> (vs, fn, args))
(\ () () () alts' -> unzip3 [ (wild:vs, fn, args) | (vs,fn,args) <- alts' ])
guardMsg (exprsEqual fs) "alternatives are not parallel in function call."
guardMsg (all null $ zipWith intersect (map (toList.coreExprFreeVars) fs) vss) "function bound by case binders."
let argss' = transpose argss
guardMsg (all exprsEqual $ filter (isTyCoArg . head) argss') "function applied at different types."
return $ mkCoreApps (head fs) [ if isTyCoArg (head args)
then head args
else let alts' = [ (ac, vs, arg) | ((ac,vs,_),arg) <- zip alts args ]
in Case s wild (coreAltsType alts') alts'
| args <- argss' ]
caseReduce :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseReduce = setFailMsg "Unsuitable expression for Case reduction." $
caseReduceDatacon <+ caseReduceLiteral
caseReduceLiteral :: MonadCatch m => Rewrite c m CoreExpr
caseReduceLiteral = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case (Lit l) v t alts") $
do Case s wild _ alts <- idR
#if __GLASGOW_HASKELL__ > 706
let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- S.toList (coreExprFreeVars s) ])
case exprIsLiteral_maybe (in_scope, idUnfolding) s of
#else
case exprIsLiteral_maybe idUnfolding s of
#endif
Nothing -> fail "scrutinee is not a literal."
Just l -> do guardMsg (not (litIsLifted l)) "cannot case-reduce lifted literals"
case findAlt (LitAlt l) alts of
Nothing -> fail "no matching alternative."
Just (_, _, rhs) -> return $ mkCoreLet (NonRec wild (Lit l)) rhs
caseReduceDatacon :: forall c. (ExtendPath c Crumb, AddBindings c, ReadBindings c) => Rewrite c HermitM CoreExpr
caseReduceDatacon = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case e v t alts")
go
where
go :: Rewrite c HermitM CoreExpr
go = do Case e wild _ alts <- idR
#if __GLASGOW_HASKELL__ > 706
let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- S.toList (coreExprFreeVars e) ])
case exprIsConApp_maybe (in_scope, idUnfolding) e of
#else
case exprIsConApp_maybe idUnfolding e of
#endif
Nothing -> fail "head of scrutinee is not a data constructor."
Just (dc, univTys, es) -> case findAlt (DataAlt dc) alts of
Nothing -> fail "no matching alternative."
Just (dc', vs, rhs) ->
let fvss = map coreExprFreeVars $ map Type univTys ++ es
shadows = [ v | (v,n) <- zip vs [1..], any (member v) (drop n fvss) ]
in if | any (member wild) fvss -> alphaCaseBinder Nothing >>> go
| not (null shadows) -> caseOneR (fail "scrutinee") (fail "binder") (fail "type") (\ _ -> acceptR (\ (dc'',_,_) -> dc'' == dc') >>> alphaAltVars shadows) >>> go
| null shadows -> return $ flip mkCoreLets rhs $ zipWith NonRec (wild : vs) (e : es)
caseSplit :: TH.Name -> Rewrite c HermitM CoreExpr
caseSplit nm = do
frees <- freeIdsT
contextfreeT $ \ e ->
case filter (cmpTHName2Var nm) (toList frees) of
[] -> fail "caseSplit: provided name is not free"
(i:_) -> do
let (tycon, tys) = splitTyConApp (idType i)
dcs = tyConDataCons tycon
aNms = map (:[]) $ cycle ['a'..'z']
dcsAndVars <- mapM (\dc -> do
as <- sequence [ newIdH a ty | (a,ty) <- zip aNms $ dataConInstArgTys dc tys ]
return (dc,as)) dcs
return $ Case (Var i) i (exprType e) [ (DataAlt dc, as, e) | (dc,as) <- dcsAndVars ]
caseSplitInline :: (ExtendPath c Crumb, AddBindings c, ReadBindings c) => TH.Name -> Rewrite c HermitM Core
caseSplitInline nm = promoteR (caseSplit nm) >>> anybuR (promoteExprR $ inlineName nm)