module Language.Haskell.Liquid.Transforms.Rewrite
(
rewriteBinds
) where
import CoreSyn
import Type
import Language.Haskell.Liquid.GHC.TypeRep
import TyCon
import qualified CoreSubst
import qualified Outputable
import qualified CoreUtils
import qualified Var
import qualified MkCore
import Data.Maybe (fromMaybe)
import Control.Monad (msum)
import Control.Monad.State hiding (lift)
import Language.Fixpoint.Misc ( mapSnd)
import qualified Language.Fixpoint.Types as F
import Language.Haskell.Liquid.Misc (safeZipWithError, mapThd3, Nat)
import Language.Haskell.Liquid.GHC.Resugar
import Language.Haskell.Liquid.GHC.Misc (isTupleId, showPpr, mkAlive)
import Language.Haskell.Liquid.UX.Config (Config, noSimplifyCore)
import qualified Data.List as L
rewriteBinds :: Config -> [CoreBind] -> [CoreBind]
rewriteBinds cfg
| simplifyCore cfg = fmap (rewriteBindWith tidyTuples . rewriteBindWith simplifyPatTuple)
| otherwise = id
simplifyCore :: Config -> Bool
simplifyCore = not . noSimplifyCore
tidyTuples :: CoreExpr -> Maybe CoreExpr
tidyTuples e = Just $ evalState (go e) []
where
go (Tick t e)
= Tick t <$> go e
go (Let (NonRec x ex) e)
= do ex' <- go ex
e' <- go e
return $ Let (NonRec x ex') e'
go (Let (Rec bes) e)
= Let <$> (Rec <$> mapM goRec bes) <*> go e
go (Case (Var v) x t alts)
= Case (Var v) x t <$> mapM (goAltR v) alts
go (Case e x t alts)
= Case e x t <$> mapM goAlt alts
go (App e1 e2)
= App <$> go e1 <*> go e2
go (Lam x e)
= Lam x <$> go e
go (Cast e c)
= (`Cast` c) <$> go e
go e
= return e
goRec (x, e)
= (x,) <$> go e
goAlt (c, bs, e)
= (c, bs,) <$> go e
goAltR v (c, bs, e)
= do m <- get
case L.lookup (c,v) m of
Just bs' -> return (c, bs', substTuple bs' bs e)
Nothing -> do let bs' = mkAlive <$> bs
modify (((c,v),bs'):)
return $ (c,bs', e)
type RewriteRule = CoreExpr -> Maybe CoreExpr
rewriteBindWith :: RewriteRule -> CoreBind -> CoreBind
rewriteBindWith r (NonRec x e) = NonRec x (rewriteWith r e)
rewriteBindWith r (Rec xes) = Rec (mapSnd (rewriteWith r) <$> xes)
rewriteWith :: RewriteRule -> CoreExpr -> CoreExpr
rewriteWith tx = go
where
go = txTop . step
txTop e = fromMaybe e (tx e)
goB (Rec xes) = Rec (mapSnd go <$> xes)
goB (NonRec x e) = NonRec x (go e)
step (Let b e) = Let (goB b) (go e)
step (App e e') = App (go e) (go e')
step (Lam x e) = Lam x (go e)
step (Cast e c) = Cast (go e) c
step (Tick t e) = Tick t (go e)
step (Case e x t cs) = Case (go e) x t (mapThd3 go <$> cs)
step e@(Type _) = e
step e@(Lit _) = e
step e@(Var _) = e
step e@(Coercion _) = e
_safeSimplifyPatTuple :: RewriteRule
_safeSimplifyPatTuple e
| Just e' <- simplifyPatTuple e
, CoreUtils.exprType e' == CoreUtils.exprType e
= Just e'
| otherwise
= Nothing
simplifyPatTuple :: RewriteRule
_tidyAlt :: Int -> Maybe CoreExpr -> Maybe CoreExpr
_tidyAlt n (Just (Let (NonRec x e) rest))
| Just (yes, e') <- takeBinds n rest
= Just $ Let (NonRec x e) $ foldl (\e (x, ex) -> Let (NonRec x ex) e) e' ((reverse $ go $ reverse yes))
where
go xes@((_, e):_) = let bs = grapBinds e in mapSnd (replaceBinds bs) <$> xes
go [] = []
replaceBinds bs (Case c x t alt) = Case c x t (replaceBindsAlt bs <$> alt)
replaceBinds bs (Tick t e) = Tick t (replaceBinds bs e)
replaceBinds _ e = e
replaceBindsAlt bs (c, _, e) = (c, bs, e)
grapBinds (Case _ _ _ alt) = grapBinds' alt
grapBinds (Tick _ e) = grapBinds e
grapBinds _ = []
grapBinds' [] = []
grapBinds' ((_,bs,_):_) = bs
_tidyAlt _ e
= e
simplifyPatTuple (Let (NonRec x e) rest)
| Just (n, ts ) <- varTuple x
, 2 <= n
, Just (yes, e') <- takeBinds n rest
, let ys = fst <$> yes
, Just _ <- hasTuple ys e
, matchTypes yes ts
= replaceTuple ys e e'
simplifyPatTuple _
= Nothing
varTuple :: Var -> Maybe (Int, [Type])
varTuple x
| TyConApp c ts <- Var.varType x
, isTupleTyCon c
= Just (length ts, ts)
| otherwise
= Nothing
takeBinds :: Nat -> CoreExpr -> Maybe ([(Var, CoreExpr)], CoreExpr)
takeBinds n e
| n < 2 = Nothing
| otherwise = go n e
where
go 0 e = Just ([], e)
go n (Let (NonRec x e) e') = do (xes, e'') <- go (n1) e'
Just ((x,e) : xes, e'')
go _ _ = Nothing
matchTypes :: [(Var, CoreExpr)] -> [Type] -> Bool
matchTypes xes ts = xN == tN
&& all (uncurry eqType) (safeZipWithError msg xts ts)
&& all isProjection es
where
xN = length xes
tN = length ts
xts = Var.varType <$> xs
(xs, es) = unzip xes
msg = "RW:matchTypes"
isProjection :: CoreExpr -> Bool
isProjection e = case lift e of
Just (PatProject {}) -> True
_ -> False
hasTuple :: [Var] -> CoreExpr -> Maybe [Var]
hasTuple ys = stepE
where
stepE e
| Just xs <- isVarTup ys e = Just xs
| otherwise = go e
stepA (DEFAULT,_,_) = Nothing
stepA (_, _, e) = stepE e
go (Let _ e) = stepE e
go (Case _ _ _ cs) = msum (stepA <$> cs)
go _ = Nothing
replaceTuple :: [Var] -> CoreExpr -> CoreExpr -> Maybe CoreExpr
replaceTuple ys e e' = stepE e
where
t' = CoreUtils.exprType e'
stepE e
| Just xs <- isVarTup ys e = Just $ substTuple xs ys e'
| otherwise = go e
stepA (DEFAULT, xs, err) = Just (DEFAULT, xs, replaceIrrefutPat t' err)
stepA (c, xs, e) = (c, xs,) <$> stepE e
go (Let b e) = Let b <$> stepE e
go (Case e x t cs) = fixCase e x t <$> mapM stepA cs
go _ = Nothing
_showExpr :: CoreExpr -> String
_showExpr e = show' e
where
show' (App e1 e2) = show' e1 ++ " " ++ show' e2
show' (Var x) = _showVar x
show' (Let (NonRec x ex) e) = "Let " ++ _showVar x ++ " = " ++ show' ex ++ "\nIN " ++ show' e
show' (Tick _ e) = show' e
show' (Case e x _ alt) = "Case " ++ _showVar x ++ " = " ++ show' e ++ " OF " ++ unlines (showAlt' <$> alt)
show' e = showPpr e
showAlt' (c, bs, e) = showPpr c ++ unwords (_showVar <$> bs) ++ " -> " ++ show' e
_showVar :: Var -> String
_showVar = show . F.symbol
_errorSkip :: String -> a -> b
_errorSkip x _ = error x
fixCase :: CoreExpr -> Var -> Type -> ListNE (Alt Var) -> CoreExpr
fixCase e x _t cs' = Case e x t' cs'
where
t' = CoreUtils.exprType body
(_,_,body) = c
c:_ = cs'
type ListNE a = [a]
replaceIrrefutPat :: Type -> CoreExpr -> CoreExpr
replaceIrrefutPat t (App (Lam z e) eVoid)
| Just e' <- replaceIrrefutPat' t e
= App (Lam z e') eVoid
replaceIrrefutPat t e
| Just e' <- replaceIrrefutPat' t e
= e'
replaceIrrefutPat _ e
= e
replaceIrrefutPat' :: Type -> CoreExpr -> Maybe CoreExpr
replaceIrrefutPat' t e
| (Var x, rep:_:args) <- collectArgs e
, isIrrefutErrorVar x
= Just (MkCore.mkCoreApps (Var x) (rep : Type t : args))
| otherwise
= Nothing
isIrrefutErrorVar :: Var -> Bool
isIrrefutErrorVar x = MkCore.iRREFUT_PAT_ERROR_ID == x
substTuple :: [Var] -> [Var] -> CoreExpr -> CoreExpr
substTuple xs ys = CoreSubst.substExpr Outputable.empty (mkSubst ys xs)
mkSubst :: [Var] -> [Var] -> CoreSubst.Subst
mkSubst ys xs = CoreSubst.extendIdSubstList CoreSubst.emptySubst yxs
where
yxs = safeZipWithError "RW:mkSubst" ys (Var <$> xs)
isVarTup :: [Var] -> CoreExpr -> Maybe [Var]
isVarTup xs e
| Just ys <- isTuple e
, eqVars xs ys = Just ys
isVarTup _ _ = Nothing
eqVars :: [Var] -> [Var] -> Bool
eqVars xs ys = xs' == ys'
where
xs' = show <$> xs
ys' = show <$> ys
isTuple :: CoreExpr -> Maybe [Var]
isTuple e
| (Var t, es) <- collectArgs e
, isTupleId t
, Just xs <- mapM isVar (secondHalf es)
= Just xs
| otherwise
= Nothing
isVar :: CoreExpr -> Maybe Var
isVar (Var x) = Just x
isVar _ = Nothing
secondHalf :: [a] -> [a]
secondHalf xs = drop (n `div` 2) xs
where
n = length xs