{-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE NoMonomorphismRestriction #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleContexts #-} -- | This module contains functions for recursively "rewriting" -- GHC core using "rules". module Language.Haskell.Liquid.Transforms.Rewrite ( -- * Top level rewrite function rewriteBinds -- * Low-level Rewriting Function -- , rewriteWith -- * Rewrite Rule -- , RewriteRule ) where import CoreSyn import Type import Language.Haskell.Liquid.GHC.TypeRep import TyCon import qualified CoreUtils import qualified Var import qualified MkCore import Data.Maybe (fromMaybe) import Control.Monad (msum) import Control.Monad.State hiding (lift) import Language.Fixpoint.Misc ({- mapFst, -} mapSnd) import qualified Language.Fixpoint.Types as F import Language.Haskell.Liquid.Misc (safeZipWithError, mapThd3, Nat) import Language.Haskell.Liquid.GHC.Play (substExpr) import Language.Haskell.Liquid.GHC.Resugar import Language.Haskell.Liquid.GHC.Misc (unTickExpr, isTupleId, showPpr, mkAlive) -- , showPpr, tracePpr) import Language.Haskell.Liquid.UX.Config (Config, noSimplifyCore) -- import Debug.Trace import qualified Data.List as L import qualified Data.HashMap.Strict as M -------------------------------------------------------------------------------- -- | Top-level rewriter -------------------------------------------------------- -------------------------------------------------------------------------------- rewriteBinds :: Config -> [CoreBind] -> [CoreBind] rewriteBinds cfg | simplifyCore cfg = fmap (normalizeTuples . rewriteBindWith tidyTuples . rewriteBindWith simplifyPatTuple) | otherwise = id simplifyCore :: Config -> Bool simplifyCore = not . noSimplifyCore tidyTuples :: RewriteRule tidyTuples e = Just $ evalState (go e) [] where go (Tick t e) = Tick t <$> go e go (Let (NonRec x ex) e) = do ex' <- go ex e' <- go e return $ Let (NonRec x ex') e' go (Let (Rec bes) e) = Let <$> (Rec <$> mapM goRec bes) <*> go e go (Case (Var v) x t alts) = Case (Var v) x t <$> mapM (goAltR v) alts go (Case e x t alts) = Case e x t <$> mapM goAlt alts go (App e1 e2) = App <$> go e1 <*> go e2 go (Lam x e) = Lam x <$> go e go (Cast e c) = (`Cast` c) <$> go e go e = return e goRec (x, e) = (x,) <$> go e goAlt (c, bs, e) = (c, bs,) <$> go e goAltR v (c, bs, e) = do m <- get case L.lookup (c,v) m of Just bs' -> return (c, bs', substTuple bs' bs e) Nothing -> do let bs' = mkAlive <$> bs modify (((c,v),bs'):) return $ (c, bs', e) normalizeTuples :: CoreBind -> CoreBind normalizeTuples b | NonRec x e <- b = NonRec x $ go e | Rec xes <- b = let (xs,es) = unzip xes in Rec $ zip xs (go <$> es) where go (Let (NonRec x ex) e) | Case _ _ _ alts <- unTickExpr ex , [(_, vs, Var z)] <- alts , z `elem` vs = Let (NonRec z (go ex)) (substTuple [z] [x] (go e)) go (Let (NonRec x ex) e) = Let (NonRec x (go ex)) (go e) go (Let (Rec xes) e) = Let (Rec (mapSnd go <$> xes)) (go e) go (App e1 e2) = App (go e1) (go e2) go (Lam x e) = Lam x (go e) go (Case e b t alt) = Case (go e) b t (mapThd3 go <$> alt) go (Cast e c) = Cast (go e) c go (Tick t e) = Tick t (go e) go (Type t) = Type t go (Coercion c) = Coercion c go (Lit l) = Lit l go (Var x) = Var x -------------------------------------------------------------------------------- -- | A @RewriteRule@ is a function that maps a CoreExpr to another -------------------------------------------------------------------------------- type RewriteRule = CoreExpr -> Maybe CoreExpr -------------------------------------------------------------------------------- -------------------------------------------------------------------------------- rewriteBindWith :: RewriteRule -> CoreBind -> CoreBind -------------------------------------------------------------------------------- rewriteBindWith r (NonRec x e) = NonRec x (rewriteWith r e) rewriteBindWith r (Rec xes) = Rec (mapSnd (rewriteWith r) <$> xes) -------------------------------------------------------------------------------- rewriteWith :: RewriteRule -> CoreExpr -> CoreExpr -------------------------------------------------------------------------------- rewriteWith tx = go where go = txTop . step txTop e = fromMaybe e (tx e) goB (Rec xes) = Rec (mapSnd go <$> xes) goB (NonRec x e) = NonRec x (go e) step (Let b e) = Let (goB b) (go e) step (App e e') = App (go e) (go e') step (Lam x e) = Lam x (go e) step (Cast e c) = Cast (go e) c step (Tick t e) = Tick t (go e) step (Case e x t cs) = Case (go e) x t (mapThd3 go <$> cs) step e@(Type _) = e step e@(Lit _) = e step e@(Var _) = e step e@(Coercion _) = e -------------------------------------------------------------------------------- -- | Rewriting Pattern-Match-Tuples -------------------------------------------- -------------------------------------------------------------------------------- {- let CrazyPat x1 ... xn = e in e' let t : (t1,...,tn) = "CrazyPat e ... (y1, ..., yn)" xn = Proj t n ... x1 = Proj t 1 in e' "crazy-pat" -} {- [NOTE] The following is the structure of a @PatMatchTup@ let x :: (t1,...,tn) = E[(x1,...,xn)] yn = case x of (..., yn) -> yn … y1 = case x of (y1, ...) -> y1 in E' GOAL: simplify the above to: E [ (x1,...,xn) := E' [y1 := x1,...,yn := xn] ] TODO: several tests (e.g. tests/pos/zipper000.hs) fail because the above changes the "type" the expression `E` and in "other branches" the new type may be different than the old, e.g. let (x::y::_) = e in x + y let t = case e of h1::t1 -> case t1 of (h2::t2) -> (h1, h2) DEFAULT -> error @ (Int, Int) DEFAULT -> error @ (Int, Int) x = case t of (h1, _) -> h1 y = case t of (_, h2) -> h2 in x + y is rewritten to: h1::t1 -> case t1 of (h2::t2) -> h1 + h2 DEFAULT -> error @ (Int, Int) DEFAULT -> error @ (Int, Int) case e of h1 :: h2 :: _ -> h1 + h2 DEFAULT -> error @ (Int, Int) which, alas, is ill formed. -} -------------------------------------------------------------------------------- -- simplifyPatTuple :: RewriteRule -- simplifyPatTuple e = -- case simplifyPatTuple' e of -- Just e' -> if CoreUtils.exprType e == CoreUtils.exprType e' -- then Just e' -- else Just (tracePpr ("YIKES: RWR " ++ showPpr e) e') -- Nothing -> Nothing _safeSimplifyPatTuple :: RewriteRule _safeSimplifyPatTuple e | Just e' <- simplifyPatTuple e , CoreUtils.exprType e' == CoreUtils.exprType e = Just e' | otherwise = Nothing -------------------------------------------------------------------------------- simplifyPatTuple :: RewriteRule -------------------------------------------------------------------------------- _tidyAlt :: Int -> Maybe CoreExpr -> Maybe CoreExpr _tidyAlt n (Just (Let (NonRec x e) rest)) | Just (yes, e') <- takeBinds n rest = Just $ Let (NonRec x e) $ foldl (\e (x, ex) -> Let (NonRec x ex) e) e' ((reverse $ go $ reverse yes)) where go xes@((_, e):_) = let bs = grapBinds e in mapSnd (replaceBinds bs) <$> xes go [] = [] replaceBinds bs (Case c x t alt) = Case c x t (replaceBindsAlt bs <$> alt) replaceBinds bs (Tick t e) = Tick t (replaceBinds bs e) replaceBinds _ e = e replaceBindsAlt bs (c, _, e) = (c, bs, e) grapBinds (Case _ _ _ alt) = grapBinds' alt grapBinds (Tick _ e) = grapBinds e grapBinds _ = [] grapBinds' [] = [] grapBinds' ((_,bs,_):_) = bs _tidyAlt _ e = e simplifyPatTuple (Let (NonRec x e) rest) | Just (n, ts ) <- varTuple x , 2 <= n , Just (yes, e') <- takeBinds n rest , let ys = fst <$> yes , Just _ <- hasTuple ys e , matchTypes yes ts = replaceTuple ys e e' simplifyPatTuple _ = Nothing varTuple :: Var -> Maybe (Int, [Type]) varTuple x | TyConApp c ts <- Var.varType x , isTupleTyCon c = Just (length ts, ts) | otherwise = Nothing takeBinds :: Nat -> CoreExpr -> Maybe ([(Var, CoreExpr)], CoreExpr) takeBinds n e | n < 2 = Nothing | otherwise = {- mapFst reverse <$> -} go n e where go 0 e = Just ([], e) go n (Let (NonRec x e) e') = do (xes, e'') <- go (n-1) e' Just ((x,e) : xes, e'') go _ _ = Nothing matchTypes :: [(Var, CoreExpr)] -> [Type] -> Bool matchTypes xes ts = xN == tN && all (uncurry eqType) (safeZipWithError msg xts ts) && all isProjection es where xN = length xes tN = length ts xts = Var.varType <$> xs (xs, es) = unzip xes msg = "RW:matchTypes" isProjection :: CoreExpr -> Bool isProjection e = case lift e of Just (PatProject {}) -> True _ -> False -------------------------------------------------------------------------------- -- | `hasTuple ys e` CHECKS if `e` contains a tuple that "looks like" (y1...yn) -------------------------------------------------------------------------------- hasTuple :: [Var] -> CoreExpr -> Maybe [Var] -------------------------------------------------------------------------------- hasTuple ys = stepE where stepE e | Just xs <- isVarTup ys e = Just xs | otherwise = go e stepA (DEFAULT,_,_) = Nothing stepA (_, _, e) = stepE e go (Let _ e) = stepE e go (Case _ _ _ cs) = msum (stepA <$> cs) go _ = Nothing -------------------------------------------------------------------------------- -- | `replaceTuple ys e e'` REPLACES tuples that "looks like" (y1...yn) with e' -------------------------------------------------------------------------------- replaceTuple :: [Var] -> CoreExpr -> CoreExpr -> Maybe CoreExpr replaceTuple ys e e' = stepE e where t' = CoreUtils.exprType e' stepE e | Just xs <- isVarTup ys e = Just $ substTuple xs ys e' | otherwise = go e stepA (DEFAULT, xs, err) = Just (DEFAULT, xs, replaceIrrefutPat t' err) stepA (c, xs, e) = (c, xs,) <$> stepE e go (Let b e) = Let b <$> stepE e go (Case e x t cs) = fixCase e x t <$> mapM stepA cs go _ = Nothing _showExpr :: CoreExpr -> String _showExpr e = show' e where show' (App e1 e2) = show' e1 ++ " " ++ show' e2 show' (Var x) = _showVar x show' (Let (NonRec x ex) e) = "Let " ++ _showVar x ++ " = " ++ show' ex ++ "\nIN " ++ show' e show' (Tick _ e) = show' e show' (Case e x _ alt) = "Case " ++ _showVar x ++ " = " ++ show' e ++ " OF " ++ unlines (showAlt' <$> alt) show' e = showPpr e showAlt' (c, bs, e) = showPpr c ++ unwords (_showVar <$> bs) ++ " -> " ++ show' e _showVar :: Var -> String _showVar = show . F.symbol _errorSkip :: String -> a -> b _errorSkip x _ = error x -- replaceTuple :: [Var] -> CoreExpr -> CoreExpr -> Maybe CoreExpr -- replaceTuple ys e e' = tracePpr msg (_replaceTuple ys e e') -- where -- msg = "replaceTuple: ys = " ++ showPpr ys ++ -- " e = " ++ showPpr e ++ -- " e' =" ++ showPpr e' -- | The substitution (`substTuple`) can change the type of the overall -- case-expression, so we must update the type of each `Case` with its -- new, possibly updated type. See: -- https://github.com/ucsd-progsys/liquidhaskell/pull/752#issuecomment-228946210 fixCase :: CoreExpr -> Var -> Type -> ListNE (Alt Var) -> CoreExpr fixCase e x _t cs' = Case e x t' cs' where t' = CoreUtils.exprType body (_,_,body) = c c:_ = cs' {-@ type ListNE a = {v:[a] | len v > 0} @-} type ListNE a = [a] replaceIrrefutPat :: Type -> CoreExpr -> CoreExpr replaceIrrefutPat t (App (Lam z e) eVoid) | Just e' <- replaceIrrefutPat' t e = App (Lam z e') eVoid replaceIrrefutPat t e | Just e' <- replaceIrrefutPat' t e = e' replaceIrrefutPat _ e = e replaceIrrefutPat' :: Type -> CoreExpr -> Maybe CoreExpr replaceIrrefutPat' t e | (Var x, rep:_:args) <- collectArgs e , isIrrefutErrorVar x = Just (MkCore.mkCoreApps (Var x) (rep : Type t : args)) | otherwise = Nothing isIrrefutErrorVar :: Var -> Bool isIrrefutErrorVar x = MkCore.iRREFUT_PAT_ERROR_ID == x -------------------------------------------------------------------------------- -- | `substTuple xs ys e'` returns e' [y1 := x1,...,yn := xn] -------------------------------------------------------------------------------- substTuple :: [Var] -> [Var] -> CoreExpr -> CoreExpr substTuple xs ys = substExpr (M.fromList $ zip ys xs) -------------------------------------------------------------------------------- -- | `isVarTup xs e` returns `Just ys` if e == (y1, ... , yn) and xi ~ yi -------------------------------------------------------------------------------- isVarTup :: [Var] -> CoreExpr -> Maybe [Var] isVarTup xs e | Just ys <- isTuple e , eqVars xs ys = Just ys isVarTup _ _ = Nothing eqVars :: [Var] -> [Var] -> Bool eqVars xs ys = {- F.tracepp ("eqVars: " ++ show xs' ++ show ys') -} xs' == ys' where xs' = {- F.symbol -} show <$> xs ys' = {- F.symbol -} show <$> ys isTuple :: CoreExpr -> Maybe [Var] isTuple e | (Var t, es) <- collectArgs e , isTupleId t , Just xs <- mapM isVar (secondHalf es) = Just xs | otherwise = Nothing isVar :: CoreExpr -> Maybe Var isVar (Var x) = Just x isVar _ = Nothing secondHalf :: [a] -> [a] secondHalf xs = drop (n `div` 2) xs where n = length xs