module Language.HERMIT.Primitive.Local.Case
(
externals
, caseFloatApp
, caseFloatArg
, caseFloatCase
, caseFloatCast
, caseFloatLet
, caseFloat
, caseUnfloat
, caseUnfloatApp
, caseUnfloatArgs
, caseReduce
, caseReduceDatacon
, caseReduceLiteral
, caseSplit
, caseSplitInline
)
where
import GhcPlugins
import Data.List
import Data.Monoid
import Control.Arrow
import Language.HERMIT.Core
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.GHC
import Language.HERMIT.External
import Language.HERMIT.Primitive.Common
import Language.HERMIT.Primitive.GHC hiding (externals)
import Language.HERMIT.Primitive.Inline hiding (externals)
import Language.HERMIT.Primitive.AlphaConversion hiding (externals)
import qualified Language.Haskell.TH as TH
externals :: [External]
externals =
[
external "case-float-app" (promoteExprR caseFloatApp :: RewriteH Core)
[ "(case ec of alt -> e) v ==> case ec of alt -> e v" ] .+ Commute .+ Shallow .+ Bash
, external "case-float-arg" (promoteExprR caseFloatArg :: RewriteH Core)
[ "f (case s of alt -> e) ==> case s of alt -> f e" ] .+ Commute .+ Shallow .+ PreCondition
, external "case-float-case" (promoteExprR caseFloatCase :: RewriteH Core)
[ "case (case ec of alt1 -> e1) of alta -> ea ==> case ec of alt1 -> case e1 of alta -> ea" ] .+ Commute .+ Eval .+ Bash
, external "case-float-cast" (promoteExprR caseFloatCast)
[ "cast (case s of p -> e) co ==> case s of p -> cast e co" ] .+ Shallow .+ Commute .+ Bash
, external "case-float-let" (promoteExprR caseFloatLet :: RewriteH Core)
[ "let v = case ec of alt1 -> e1 in e ==> case ec of alt1 -> let v = e1 in e" ] .+ Commute .+ Shallow .+ Bash
, external "case-float" (promoteExprR caseFloat :: RewriteH Core)
[ "Float a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-unfloat" (promoteExprR caseUnfloat :: RewriteH Core)
[ "Unfloat a Case whatever the context." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-unfloat-args" (promoteExprR caseUnfloatArgs :: RewriteH Core)
[ "Unfloat a Case whose alternatives are parallel applications of the same function." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-unfloat-app" (promoteExprR caseUnfloatApp :: RewriteH Core)
[ "Unfloat a Case whole alternatives are applications of different functions with the same arguments." ] .+ Commute .+ Shallow .+ PreCondition
, external "case-reduce" (promoteExprR caseReduce :: RewriteH Core)
[ "case-reduce-datacon <+ case-reduce-literal" ] .+ Shallow .+ Eval .+ Bash
, external "case-reduce-datacon" (promoteExprR caseReduceDatacon :: RewriteH Core)
[ "case-of-known-constructor"
, "case C v1..vn of C w1..wn -> e ==> let { w1 = v1 ; .. ; wn = vn } in e" ] .+ Shallow .+ Eval
, external "case-reduce-literal" (promoteExprR caseReduceLiteral :: RewriteH Core)
[ "case L of L -> e ==> e" ] .+ Shallow .+ Eval
, external "case-split" (promoteExprR . caseSplit :: TH.Name -> RewriteH Core)
[ "case-split 'x"
, "e ==> case x of C1 vs -> e; C2 vs -> e, where x is free in e" ]
, external "case-split-inline" (caseSplitInline :: TH.Name -> RewriteH Core)
[ "Like case-split, but additionally inlines the matched constructor "
, "applications for all occurances of the named variable." ]
]
caseFloatApp :: RewriteH CoreExpr
caseFloatApp = prefixFailMsg "Case floating from App function failed: " $
do
captures <- appT caseAltVarsT freeVarsT (flip (map . intersect))
wildCapture <- appT caseWildIdT freeVarsT elem
appT ((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
idR
(\(Case s b _ty alts) v -> let newTy = exprType (App (case head alts of (_,_,f) -> f) v)
in Case s b newTy $ mapAlts (flip App v) alts)
caseFloatArg :: RewriteH CoreExpr
caseFloatArg = prefixFailMsg "Case floating from App argument failed: " $
do
captures <- appT freeVarsT caseAltVarsT (map . intersect)
wildCapture <- appT freeVarsT caseWildIdT (flip elem)
appT idR
((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
(\f (Case s b _ty alts) -> let newTy = exprType (App f (case head alts of (_,_,e) -> e))
in Case s b newTy $ mapAlts (App f) alts)
caseFloatCase :: RewriteH CoreExpr
caseFloatCase = prefixFailMsg "Case floating from Case failed: " $
do
captures <- caseT caseAltVarsT (const altFreeVarsExclWildT) (\ vss bndr _ fs -> map (intersect (concatMap ($ bndr) fs)) vss)
wildCapture <- caseT caseWildIdT (const altFreeVarsExclWildT) (\ innerBndr bndr _ fvs -> innerBndr `elem` concatMap ($ bndr) fvs)
caseT ((if not wildCapture then idR else alphaCaseBinder Nothing)
>>> caseAllR idR (\i -> if null (captures !! i) then idR else alphaAlt)
)
(const idR)
(\ (Case s1 b1 _ alts1) b2 ty alts2 -> Case s1 b1 ty $ mapAlts (\s -> Case s b2 ty alts2) alts1)
caseFloatLet :: RewriteH CoreExpr
caseFloatLet = prefixFailMsg "Case floating from Let failed: " $
do vs <- letNonRecT caseAltVarsT idR (\ letVar caseVars _ -> elem letVar $ concat caseVars)
let bdsAction = if not vs then idR else nonRecR alphaCase
letT bdsAction idR $ \ (NonRec v (Case s b ty alts)) e -> Case s b ty $ mapAlts (flip Let e . NonRec v) alts
caseFloatCast :: RewriteH CoreExpr
caseFloatCast = prefixFailMsg "Case float from cast failed: " $
withPatFailMsg (wrongExprForm "Cast (Case s bnd ty alts) co") $
do Cast (Case s bnd _ alts) co <- idR
let alts' = mapAlts (flip Cast co) alts
return $ Case s bnd (coreAltsType alts') alts'
caseFloat :: RewriteH CoreExpr
caseFloat = setFailMsg "Unsuitable expression for Case floating." $
caseFloatApp <+ caseFloatArg <+ caseFloatCase <+ caseFloatLet <+ caseFloatCast
caseUnfloat :: RewriteH CoreExpr
caseUnfloat = setFailMsg "Case unfloating failed." $
caseUnfloatApp <+ caseUnfloatArgs
caseUnfloatApp :: RewriteH CoreExpr
caseUnfloatApp = fail "caseUnfloatApp: TODO"
caseUnfloatArgs :: RewriteH CoreExpr
caseUnfloatArgs = prefixFailMsg "Case unfloating into arguments failed: " $
withPatFailMsg (wrongExprForm "Case s v t alts") $
do Case s wild _ty alts <- idR
(vss, fs, argss) <- caseT mempty (\_ -> altT callT $ \ _ac vs (fn, args) -> (vs, fn, args))
(\ () _ _ alts -> unzip3 [ (wild:vs, fn, args) | (vs,fn,args) <- alts ])
guardMsg (exprsEqual fs) "alternatives are not parallel in function call."
guardMsg (all null $ zipWith intersect (map coreExprFreeVars fs) vss) "function bound by case binders."
return $ mkCoreApps (head fs) [ if isTyCoArg (head args)
then head args
else let alts' = [ (ac, vs, arg) | ((ac,vs,_),arg) <- zip alts args ]
in Case s wild (coreAltsType alts') alts'
| args <- transpose argss ]
caseReduce :: RewriteH CoreExpr
caseReduce = setFailMsg "Unsuitable expression for Case reduction." $
caseReduceDatacon <+ caseReduceLiteral
caseReduceLiteral :: RewriteH CoreExpr
caseReduceLiteral = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case (Lit l) v t alts") $
do Case s wild _ alts <- idR
case exprIsLiteral_maybe idUnfolding s of
Nothing -> fail "scrutinee is not a literal."
Just l -> do guardMsg (not (litIsLifted l)) "cannot case-reduce lifted literals"
case findAlt (LitAlt l) alts of
Nothing -> fail "no matching alternative."
Just (_, _, rhs) -> return $ mkCoreLet (NonRec wild (Lit l)) rhs
caseReduceDatacon :: RewriteH CoreExpr
caseReduceDatacon = prefixFailMsg "Case reduction failed: " $
withPatFailMsg (wrongExprForm "Case e v t alts")
go
where
go :: RewriteH CoreExpr
go = do Case e wild _ alts <- idR
case exprIsConApp_maybe idUnfolding e of
Nothing -> fail "head of scrutinee is not a data constructor."
Just (dc, univTys, es) -> case findAlt (DataAlt dc) alts of
Nothing -> fail "no matching alternative."
Just (dc', vs, rhs) ->
let fvss = map coreExprFreeVars $ map Type univTys ++ es
shadows = [ v | (v,n) <- zip vs [1..], any (elem v) (drop n fvss) ]
in if | any (elem wild) fvss -> alphaCaseBinder Nothing >>> go
| not (null shadows) -> caseOneR (fail "scrutinee") (\ _ -> acceptR (\ (dc'',_,_) -> dc'' == dc') >>> alphaAltVars shadows) >>> go
| null shadows -> return $ flip mkCoreLets rhs $ zipWith NonRec (wild : vs) (e : es)
caseSplit :: TH.Name -> RewriteH CoreExpr
caseSplit nm = do
frees <- freeIdsT
contextfreeT $ \ e ->
case [ i | i <- frees, cmpTHName2Var nm i ] of
[] -> fail "caseSplit: provided name is not free"
(i:_) -> do
let (tycon, tys) = splitTyConApp (idType i)
dcs = tyConDataCons tycon
aNms = map (:[]) $ cycle ['a'..'z']
dcsAndVars <- mapM (\dc -> do
as <- sequence [ newIdH a ty | (a,ty) <- zip aNms $ dataConInstArgTys dc tys ]
return (dc,as)) dcs
return $ Case (Var i) i (exprType e) [ (DataAlt dc, as, e) | (dc,as) <- dcsAndVars ]
caseSplitInline :: TH.Name -> RewriteH Core
caseSplitInline nm = promoteR (caseSplit nm) >>> anybuR (promoteExprR $ inlineName nm)