module Language.Haskell.Liquid.Transforms.Rewrite
(
rewriteBinds
) where
import CoreSyn
import Type
import Language.Haskell.Liquid.GHC.TypeRep
import TyCon
import qualified CoreUtils
import qualified Var
import qualified MkCore
import Data.Maybe (fromMaybe)
import Control.Monad (msum)
import Control.Monad.State hiding (lift)
import Language.Fixpoint.Misc ( mapSnd)
import qualified Language.Fixpoint.Types as F
import Language.Haskell.Liquid.Misc (safeZipWithError, mapThd3, Nat)
import Language.Haskell.Liquid.GHC.Play (substExpr)
import Language.Haskell.Liquid.GHC.Resugar
import Language.Haskell.Liquid.GHC.Misc (unTickExpr, isTupleId, showPpr, mkAlive)
import Language.Haskell.Liquid.UX.Config (Config, noSimplifyCore)
import qualified Data.List as L
import qualified Data.HashMap.Strict as M
rewriteBinds :: Config -> [CoreBind] -> [CoreBind]
rewriteBinds cfg
| simplifyCore cfg
= fmap (normalizeTuples . rewriteBindWith tidyTuples . rewriteBindWith simplifyPatTuple)
| otherwise
= id
simplifyCore :: Config -> Bool
simplifyCore = not . noSimplifyCore
tidyTuples :: RewriteRule
tidyTuples e = Just $ evalState (go e) []
where
go (Tick t e)
= Tick t <$> go e
go (Let (NonRec x ex) e)
= do ex' <- go ex
e' <- go e
return $ Let (NonRec x ex') e'
go (Let (Rec bes) e)
= Let <$> (Rec <$> mapM goRec bes) <*> go e
go (Case (Var v) x t alts)
= Case (Var v) x t <$> mapM (goAltR v) alts
go (Case e x t alts)
= Case e x t <$> mapM goAlt alts
go (App e1 e2)
= App <$> go e1 <*> go e2
go (Lam x e)
= Lam x <$> go e
go (Cast e c)
= (`Cast` c) <$> go e
go e
= return e
goRec (x, e)
= (x,) <$> go e
goAlt (c, bs, e)
= (c, bs,) <$> go e
goAltR v (c, bs, e)
= do m <- get
case L.lookup (c,v) m of
Just bs' -> return (c, bs', substTuple bs' bs e)
Nothing -> do let bs' = mkAlive <$> bs
modify (((c,v),bs'):)
return $ (c, bs', e)
normalizeTuples :: CoreBind -> CoreBind
normalizeTuples b
| NonRec x e <- b
= NonRec x $ go e
| Rec xes <- b
= let (xs,es) = unzip xes in
Rec $ zip xs (go <$> es)
where
go (Let (NonRec x ex) e)
| Case _ _ _ alts <- unTickExpr ex
, [(_, vs, Var z)] <- alts
, z `elem` vs
= Let (NonRec z (go ex)) (substTuple [z] [x] (go e))
go (Let (NonRec x ex) e)
= Let (NonRec x (go ex)) (go e)
go (Let (Rec xes) e)
= Let (Rec (mapSnd go <$> xes)) (go e)
go (App e1 e2)
= App (go e1) (go e2)
go (Lam x e)
= Lam x (go e)
go (Case e b t alt)
= Case (go e) b t (mapThd3 go <$> alt)
go (Cast e c)
= Cast (go e) c
go (Tick t e)
= Tick t (go e)
go (Type t)
= Type t
go (Coercion c)
= Coercion c
go (Lit l)
= Lit l
go (Var x)
= Var x
type RewriteRule = CoreExpr -> Maybe CoreExpr
rewriteBindWith :: RewriteRule -> CoreBind -> CoreBind
rewriteBindWith r (NonRec x e) = NonRec x (rewriteWith r e)
rewriteBindWith r (Rec xes) = Rec (mapSnd (rewriteWith r) <$> xes)
rewriteWith :: RewriteRule -> CoreExpr -> CoreExpr
rewriteWith tx = go
where
go = txTop . step
txTop e = fromMaybe e (tx e)
goB (Rec xes) = Rec (mapSnd go <$> xes)
goB (NonRec x e) = NonRec x (go e)
step (Let b e) = Let (goB b) (go e)
step (App e e') = App (go e) (go e')
step (Lam x e) = Lam x (go e)
step (Cast e c) = Cast (go e) c
step (Tick t e) = Tick t (go e)
step (Case e x t cs) = Case (go e) x t (mapThd3 go <$> cs)
step e@(Type _) = e
step e@(Lit _) = e
step e@(Var _) = e
step e@(Coercion _) = e
_safeSimplifyPatTuple :: RewriteRule
_safeSimplifyPatTuple e
| Just e' <- simplifyPatTuple e
, CoreUtils.exprType e' == CoreUtils.exprType e
= Just e'
| otherwise
= Nothing
simplifyPatTuple :: RewriteRule
_tidyAlt :: Int -> Maybe CoreExpr -> Maybe CoreExpr
_tidyAlt n (Just (Let (NonRec x e) rest))
| Just (yes, e') <- takeBinds n rest
= Just $ Let (NonRec x e) $ foldl (\e (x, ex) -> Let (NonRec x ex) e) e' ((reverse $ go $ reverse yes))
where
go xes@((_, e):_) = let bs = grapBinds e in mapSnd (replaceBinds bs) <$> xes
go [] = []
replaceBinds bs (Case c x t alt) = Case c x t (replaceBindsAlt bs <$> alt)
replaceBinds bs (Tick t e) = Tick t (replaceBinds bs e)
replaceBinds _ e = e
replaceBindsAlt bs (c, _, e) = (c, bs, e)
grapBinds (Case _ _ _ alt) = grapBinds' alt
grapBinds (Tick _ e) = grapBinds e
grapBinds _ = []
grapBinds' [] = []
grapBinds' ((_,bs,_):_) = bs
_tidyAlt _ e
= e
simplifyPatTuple (Let (NonRec x e) rest)
| Just (n, ts ) <- varTuple x
, 2 <= n
, Just (yes, e') <- takeBinds n rest
, let ys = fst <$> yes
, Just _ <- hasTuple ys e
, matchTypes yes ts
= replaceTuple ys e e'
simplifyPatTuple _
= Nothing
varTuple :: Var -> Maybe (Int, [Type])
varTuple x
| TyConApp c ts <- Var.varType x
, isTupleTyCon c
= Just (length ts, ts)
| otherwise
= Nothing
takeBinds :: Nat -> CoreExpr -> Maybe ([(Var, CoreExpr)], CoreExpr)
takeBinds n e
| n < 2 = Nothing
| otherwise = go n e
where
go 0 e = Just ([], e)
go n (Let (NonRec x e) e') = do (xes, e'') <- go (n1) e'
Just ((x,e) : xes, e'')
go _ _ = Nothing
matchTypes :: [(Var, CoreExpr)] -> [Type] -> Bool
matchTypes xes ts = xN == tN
&& all (uncurry eqType) (safeZipWithError msg xts ts)
&& all isProjection es
where
xN = length xes
tN = length ts
xts = Var.varType <$> xs
(xs, es) = unzip xes
msg = "RW:matchTypes"
isProjection :: CoreExpr -> Bool
isProjection e = case lift e of
Just (PatProject {}) -> True
_ -> False
hasTuple :: [Var] -> CoreExpr -> Maybe [Var]
hasTuple ys = stepE
where
stepE e
| Just xs <- isVarTup ys e = Just xs
| otherwise = go e
stepA (DEFAULT,_,_) = Nothing
stepA (_, _, e) = stepE e
go (Let _ e) = stepE e
go (Case _ _ _ cs) = msum (stepA <$> cs)
go _ = Nothing
replaceTuple :: [Var] -> CoreExpr -> CoreExpr -> Maybe CoreExpr
replaceTuple ys e e' = stepE e
where
t' = CoreUtils.exprType e'
stepE e
| Just xs <- isVarTup ys e = Just $ substTuple xs ys e'
| otherwise = go e
stepA (DEFAULT, xs, err) = Just (DEFAULT, xs, replaceIrrefutPat t' err)
stepA (c, xs, e) = (c, xs,) <$> stepE e
go (Let b e) = Let b <$> stepE e
go (Case e x t cs) = fixCase e x t <$> mapM stepA cs
go _ = Nothing
_showExpr :: CoreExpr -> String
_showExpr e = show' e
where
show' (App e1 e2) = show' e1 ++ " " ++ show' e2
show' (Var x) = _showVar x
show' (Let (NonRec x ex) e) = "Let " ++ _showVar x ++ " = " ++ show' ex ++ "\nIN " ++ show' e
show' (Tick _ e) = show' e
show' (Case e x _ alt) = "Case " ++ _showVar x ++ " = " ++ show' e ++ " OF " ++ unlines (showAlt' <$> alt)
show' e = showPpr e
showAlt' (c, bs, e) = showPpr c ++ unwords (_showVar <$> bs) ++ " -> " ++ show' e
_showVar :: Var -> String
_showVar = show . F.symbol
_errorSkip :: String -> a -> b
_errorSkip x _ = error x
fixCase :: CoreExpr -> Var -> Type -> ListNE (Alt Var) -> CoreExpr
fixCase e x _t cs' = Case e x t' cs'
where
t' = CoreUtils.exprType body
(_,_,body) = c
c:_ = cs'
type ListNE a = [a]
replaceIrrefutPat :: Type -> CoreExpr -> CoreExpr
replaceIrrefutPat t (App (Lam z e) eVoid)
| Just e' <- replaceIrrefutPat' t e
= App (Lam z e') eVoid
replaceIrrefutPat t e
| Just e' <- replaceIrrefutPat' t e
= e'
replaceIrrefutPat _ e
= e
replaceIrrefutPat' :: Type -> CoreExpr -> Maybe CoreExpr
replaceIrrefutPat' t e
| (Var x, rep:_:args) <- collectArgs e
, isIrrefutErrorVar x
= Just (MkCore.mkCoreApps (Var x) (rep : Type t : args))
| otherwise
= Nothing
isIrrefutErrorVar :: Var -> Bool
isIrrefutErrorVar x = MkCore.iRREFUT_PAT_ERROR_ID == x
substTuple :: [Var] -> [Var] -> CoreExpr -> CoreExpr
substTuple xs ys = substExpr (M.fromList $ zip ys xs)
isVarTup :: [Var] -> CoreExpr -> Maybe [Var]
isVarTup xs e
| Just ys <- isTuple e
, eqVars xs ys = Just ys
isVarTup _ _ = Nothing
eqVars :: [Var] -> [Var] -> Bool
eqVars xs ys = xs' == ys'
where
xs' = show <$> xs
ys' = show <$> ys
isTuple :: CoreExpr -> Maybe [Var]
isTuple e
| (Var t, es) <- collectArgs e
, isTupleId t
, Just xs <- mapM isVar (secondHalf es)
= Just xs
| otherwise
= Nothing
isVar :: CoreExpr -> Maybe Var
isVar (Var x) = Just x
isVar _ = Nothing
secondHalf :: [a] -> [a]
secondHalf xs = drop (n `div` 2) xs
where
n = length xs