{-# LANGUAGE OverloadedStrings #-}
module Rattus.Plugin.Strictify where
import Prelude hiding ((<>))
import Rattus.Plugin.Utils
import GhcPlugins
data SCxt = SCxt {srcSpan :: SrcSpan, checkStrictData :: Bool}
strictifyExpr :: SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr ss (Let (NonRec b e1) e2) = do
e1' <- strictifyExpr ss e1
e2' <- strictifyExpr ss e2
return (Case e1' b (exprType e2) [(DEFAULT, [], e2')])
strictifyExpr ss (Case e b t alts) = do
e' <- strictifyExpr ss e
alts' <- mapM (\(c,args,e) -> fmap (\e' -> (c,args,e')) (strictifyExpr ss e)) alts
return (Case e' b t alts')
strictifyExpr ss (Let (Rec es) e) = do
es' <- mapM (\ (b,e) -> strictifyExpr ss e >>= \e'-> return (b,e')) es
e' <- strictifyExpr ss e
return (Let (Rec es') e')
strictifyExpr ss (Lam b e)
| not (isCoVar b) && not (isTyVar b) && tcIsLiftedTypeKind(typeKind (varType b))
= do
e' <- strictifyExpr ss e
b' <- mkSysLocalM (fsLit "strict") (varType b)
return (Lam b' (Case (varToCoreExpr b') b (exprType e) [(DEFAULT,[],e')]))
| otherwise = do
e' <- strictifyExpr ss e
return (Lam b e')
strictifyExpr ss (Cast e c) = do
e' <- strictifyExpr ss e
return (Cast e' c)
strictifyExpr ss (Tick t@(SourceNote span _) e) = do
e' <- strictifyExpr (ss{srcSpan = RealSrcSpan span}) e
return (Tick t e')
strictifyExpr ss (App e1 e2)
| (checkStrictData ss && not (isType e2) && tcIsLiftedTypeKind(typeKind (exprType e2))
&& not (isStrict (exprType e2))) = do
(printMessage SevWarning (srcSpan ss)
(text "The use of lazy type " <> ppr (exprType e2) <> " may lead to memory leaks"))
e1' <- strictifyExpr ss{checkStrictData = False} e1
e2' <- strictifyExpr ss{checkStrictData = False} e2
return (App e1' e2')
| otherwise = do
e1' <- strictifyExpr ss e1
e2' <- strictifyExpr ss e2
return (App e1' e2')
strictifyExpr _ss e = return e
isDelayApp (App e _) = isDelayApp e
isDelayApp (Cast e _) = isDelayApp e
isDelayApp (Tick _ e) = isDelayApp e
isDelayApp (Var v) = isDelayVar v
isDelayApp _ = False
isDelayVar :: Var -> Bool
isDelayVar v = maybe False id $ do
let name = varName v
mod <- nameModule_maybe name
let occ = getOccString name
return ((occ == "Delay" || occ == "delay") || (occ == "Box" || occ == "delay")
&& ((moduleNameString (moduleName mod) == "Rattus.Internal") ||
moduleNameString (moduleName mod) == "Rattus.Primitives"))
isCase Case{} = True
isCase (Tick _ e) = isCase e
isCase (Cast e _) = isCase e
isCase Lam {} = True
isCase e = isType e
isTophandler (App e1 e2) = isTophandler e1 || isTophandler e2
isTophandler (Cast e _) = isTophandler e
isTophandler (Tick _ e) = isTophandler e
isTophandler e = showSDocUnsafe (ppr e) == "GHC.TopHandler.runMainIO"