{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module AsyncRattus.Plugin.Strictify
  (strictifyExpr, SCxt (..)) where
import Prelude hiding ((<>))
import AsyncRattus.Plugin.Utils

import GHC.Plugins
import GHC.Types.Tickish

data SCxt = SCxt {SCxt -> SrcSpan
srcSpan :: SrcSpan, SCxt -> Bool
checkStrictData :: Bool}

-- | Transforms all functions into strict functions. If the
-- 'checkStrictData' field of the 'SCxt' argument is set to @True@,
-- then this function also checks for use of non-strict data types and
-- produces warnings if it finds any.
strictifyExpr :: SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr :: SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss (Let (NonRec Id
b CoreExpr
e1) CoreExpr
e2) = do
  CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e1
  CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e2
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e1' Id
b (CoreExpr -> Type
exprType CoreExpr
e2) [forall {b}. AltCon -> [b] -> Expr b -> Alt b
mkAlt AltCon
DEFAULT [] CoreExpr
e2' ])
strictifyExpr SCxt
ss (Case CoreExpr
e Id
b Type
t [Alt Id]
alts) = do
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
  [Alt Id]
alts' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((\(AltCon
c,[Id]
args,CoreExpr
e) -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\CoreExpr
e' -> forall {b}. AltCon -> [b] -> Expr b -> Alt b
mkAlt AltCon
c [Id]
args CoreExpr
e' ) (SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {b}. Alt b -> (AltCon, [b], Expr b)
getAlt) [Alt Id]
alts
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case CoreExpr
e' Id
b Type
t [Alt Id]
alts')
strictifyExpr SCxt
ss (Let (Rec [(Id, CoreExpr)]
es) CoreExpr
e) = do
  [(Id, CoreExpr)]
es' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ (Id
b,CoreExpr
e) -> SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \CoreExpr
e'-> forall (m :: * -> *) a. Monad m => a -> m a
return (Id
b,CoreExpr
e')) [(Id, CoreExpr)]
es
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Bind b -> Expr b -> Expr b
Let (forall b. [(b, Expr b)] -> Bind b
Rec [(Id, CoreExpr)]
es') CoreExpr
e')
strictifyExpr SCxt
ss (Lam Id
b CoreExpr
e)
   | Bool -> Bool
not (Id -> Bool
isCoVar Id
b) Bool -> Bool -> Bool
&& Bool -> Bool
not (Id -> Bool
isTyVar Id
b) Bool -> Bool -> Bool
&& Type -> Bool
tcIsLiftedTypeKind(HasDebugCallStack => Type -> Type
typeKind (Id -> Type
varType Id
b))
    = do
       CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
       Id
b' <- forall (m :: * -> *). MonadUnique m => FastString -> Id -> m Id
mkSysLocalFromVar (String -> FastString
fsLit String
"strict") Id
b
       forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. b -> Expr b -> Expr b
Lam Id
b' (forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case (forall b. Id -> Expr b
varToCoreExpr Id
b') Id
b (CoreExpr -> Type
exprType CoreExpr
e) [forall {b}. AltCon -> [b] -> Expr b -> Alt b
mkAlt AltCon
DEFAULT [] CoreExpr
e' ]))
   | Bool
otherwise = do
       CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
       forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. b -> Expr b -> Expr b
Lam Id
b CoreExpr
e')
strictifyExpr SCxt
ss (Cast CoreExpr
e CoercionR
c) = do
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> CoercionR -> Expr b
Cast CoreExpr
e' CoercionR
c)
strictifyExpr SCxt
ss (Tick t :: CoreTickish
t@(SourceNote RealSrcSpan
span String
_) CoreExpr
e) = do
  CoreExpr
e' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr (SCxt
ss{srcSpan :: SrcSpan
srcSpan = RealSrcSpan -> SrcSpan
fromRealSrcSpan RealSrcSpan
span}) CoreExpr
e
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
t CoreExpr
e')
strictifyExpr SCxt
ss (App CoreExpr
e1 e2 :: CoreExpr
e2@Lit{}) =
  do CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e1
     forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2)
strictifyExpr SCxt
ss (App CoreExpr
e1 CoreExpr
e2)
  | (SCxt -> Bool
checkStrictData SCxt
ss Bool -> Bool -> Bool
&& Bool -> Bool
not (forall {b}. Expr b -> Bool
isType CoreExpr
e2) Bool -> Bool -> Bool
&& Type -> Bool
tcIsLiftedTypeKind(HasDebugCallStack => Type -> Type
typeKind (CoreExpr -> Type
exprType CoreExpr
e2))
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Type -> Bool
isStrict (CoreExpr -> Type
exprType CoreExpr
e2))) = 
      if CoreExpr -> Bool
isDeepseqForce CoreExpr
e2 Bool -> Bool -> Bool
|| CoreExpr -> Bool
isLit CoreExpr
e2 then
        do CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e1
           CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e2
           forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
      else
        do (forall (m :: * -> *).
(HasDynFlags m, MonadIO m, HasLogger m) =>
Severity -> SrcSpan -> SDoc -> m ()
printMessage Severity
SevWarning (SCxt -> SrcSpan
srcSpan SCxt
ss)
               (String -> SDoc
text String
"The use of lazy type " SDoc -> SDoc -> SDoc
<> forall a. Outputable a => a -> SDoc
ppr (CoreExpr -> Type
exprType CoreExpr
e2) SDoc -> SDoc -> SDoc
<> SDoc
" may lead to memory leaks. Use Control.DeepSeq.force on lazy types."))
           CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss{checkStrictData :: Bool
checkStrictData = Bool
False} CoreExpr
e1
           CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss{checkStrictData :: Bool
checkStrictData = Bool
False} CoreExpr
e2
           forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
  | Bool
otherwise = do
      CoreExpr
e1' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e1
      CoreExpr
e2' <- SCxt -> CoreExpr -> CoreM CoreExpr
strictifyExpr SCxt
ss CoreExpr
e2
      forall (m :: * -> *) a. Monad m => a -> m a
return (forall b. Expr b -> Expr b -> Expr b
App CoreExpr
e1' CoreExpr
e2')
strictifyExpr SCxt
_ss CoreExpr
e = forall (m :: * -> *) a. Monad m => a -> m a
return CoreExpr
e

isLit :: CoreExpr -> Bool
isLit :: CoreExpr -> Bool
isLit Lit{} = Bool
True
isLit (App (Var Id
v) Lit{}) 
  | Just (FastString
name,FastString
mod) <- forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Id
v = FastString
mod forall a. Eq a => a -> a -> Bool
== FastString
"GHC.CString" Bool -> Bool -> Bool
&& FastString
name forall a. Eq a => a -> a -> Bool
== FastString
"unpackCString#"
isLit CoreExpr
_ = Bool
False


isDeepseqForce :: CoreExpr -> Bool
isDeepseqForce :: CoreExpr -> Bool
isDeepseqForce (App (App (App (Var Id
v) CoreExpr
_) CoreExpr
_) CoreExpr
_) =
  case forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Id
v of
    Just (FastString
name, FastString
mod) -> FastString
mod forall a. Eq a => a -> a -> Bool
== FastString
"Control.DeepSeq" Bool -> Bool -> Bool
&& FastString
name forall a. Eq a => a -> a -> Bool
== FastString
"force"
    Maybe (FastString, FastString)
_ -> Bool
False
isDeepseqForce CoreExpr
_ = Bool
False