{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module AsyncRattus.Plugin.Strictify
  (checkStrictData, SCxt (..)) where
import Prelude hiding ((<>))
import Control.Monad
import AsyncRattus.Plugin.Utils

import GHC.Plugins
import GHC.Types.Tickish

data SCxt = SCxt {SCxt -> SrcSpan
srcSpan :: SrcSpan}

-- | Checks whether the given expression uses non-strict data types
-- and issues a warning if it finds any such use.
checkStrictData :: SCxt -> CoreExpr -> CoreM ()
checkStrictData :: SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss (Let (NonRec Id
_ CoreExpr
e1) CoreExpr
e2) = 
  SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e1 CoreM () -> CoreM () -> CoreM ()
forall a b. CoreM a -> CoreM b -> CoreM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e2
checkStrictData SCxt
ss (Case CoreExpr
e Id
_ Type
_ [Alt Id]
alts) = do
  SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e
  (Alt Id -> CoreM ()) -> [Alt Id] -> CoreM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((\(AltCon
_,[Id]
_,CoreExpr
e) ->  SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e) ((AltCon, [Id], CoreExpr) -> CoreM ())
-> (Alt Id -> (AltCon, [Id], CoreExpr)) -> Alt Id -> CoreM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alt Id -> (AltCon, [Id], CoreExpr)
forall {b}. Alt b -> (AltCon, [b], Expr b)
getAlt) [Alt Id]
alts
checkStrictData SCxt
ss (Let (Rec [(Id, CoreExpr)]
es) CoreExpr
e) = do
  ((Id, CoreExpr) -> CoreM ()) -> [(Id, CoreExpr)] -> CoreM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\ (Id
_,CoreExpr
e) -> SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e) [(Id, CoreExpr)]
es
  SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e
checkStrictData SCxt
ss (Lam Id
_ CoreExpr
e) = SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e
checkStrictData SCxt
ss (Cast CoreExpr
e CoercionR
_) = SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e
checkStrictData SCxt
ss (Tick (SourceNote RealSrcSpan
span String
_) CoreExpr
e) = 
  SCxt -> CoreExpr -> CoreM ()
checkStrictData (SCxt
ss{srcSpan = fromRealSrcSpan span}) CoreExpr
e
checkStrictData SCxt
ss (App CoreExpr
e1 CoreExpr
e2)
  | CoreExpr -> Bool
isPushCallStack CoreExpr
e1 = () -> CoreM ()
forall a. a -> CoreM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise = do 
    Bool -> CoreM () -> CoreM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (CoreExpr -> Bool
forall {b}. Expr b -> Bool
isType CoreExpr
e2) Bool -> Bool -> Bool
&& Type -> Bool
tcIsLiftedTypeKind((() :: Constraint) => Type -> Type
Type -> Type
typeKind ((() :: Constraint) => CoreExpr -> Type
CoreExpr -> Type
exprType CoreExpr
e2))
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Type -> Bool
isStrict ((() :: Constraint) => CoreExpr -> Type
CoreExpr -> Type
exprType CoreExpr
e2)) Bool -> Bool -> Bool
&& Bool -> Bool
not (CoreExpr -> Bool
isDeepseqForce CoreExpr
e2) Bool -> Bool -> Bool
&& Bool -> Bool
not (CoreExpr -> Bool
isLit CoreExpr
e2))
          (Severity -> SrcSpan -> SDoc -> CoreM ()
forall (m :: * -> *).
(HasDynFlags m, MonadIO m, HasLogger m) =>
Severity -> SrcSpan -> SDoc -> m ()
printMessage Severity
SevWarning (SCxt -> SrcSpan
srcSpan SCxt
ss)
               (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"The use of lazy type " SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr ((() :: Constraint) => CoreExpr -> Type
CoreExpr -> Type
exprType CoreExpr
e2) SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> SDoc
" may lead to memory leaks. Use Control.DeepSeq.force on lazy types."))
    SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e1
    SCxt -> CoreExpr -> CoreM ()
checkStrictData SCxt
ss CoreExpr
e2
checkStrictData SCxt
_ss CoreExpr
_ = () -> CoreM ()
forall a. a -> CoreM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

isLit :: CoreExpr -> Bool
isLit :: CoreExpr -> Bool
isLit Lit{} = Bool
True
isLit (App (Var Id
v) Lit{}) 
  | Just (FastString
name,FastString
mod) <- Id -> Maybe (FastString, FastString)
forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Id
v = FastString
mod FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"GHC.CString" Bool -> Bool -> Bool
&& FastString
name FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"unpackCString#"
isLit CoreExpr
_ = Bool
False


isPushCallStack :: CoreExpr -> Bool
isPushCallStack :: CoreExpr -> Bool
isPushCallStack (Var Id
v) =
  case Id -> Maybe (FastString, FastString)
forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Id
v of
    Just (FastString
name, FastString
mod) -> FastString
mod FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"GHC.Stack.Types" Bool -> Bool -> Bool
&& FastString
name FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"pushCallStack"
    Maybe (FastString, FastString)
_ -> Bool
False
isPushCallStack (App CoreExpr
x CoreExpr
_) = CoreExpr -> Bool
isPushCallStack CoreExpr
x
isPushCallStack CoreExpr
_ = Bool
False

isDeepseqForce :: CoreExpr -> Bool
isDeepseqForce :: CoreExpr -> Bool
isDeepseqForce (App (App (App (Var Id
v) CoreExpr
_) CoreExpr
_) CoreExpr
_) =
  case Id -> Maybe (FastString, FastString)
forall a. NamedThing a => a -> Maybe (FastString, FastString)
getNameModule Id
v of
    Just (FastString
name, FastString
mod) -> FastString
mod FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"Control.DeepSeq" Bool -> Bool -> Bool
&& FastString
name FastString -> FastString -> Bool
forall a. Eq a => a -> a -> Bool
== FastString
"force"
    Maybe (FastString, FastString)
_ -> Bool
False
isDeepseqForce CoreExpr
_ = Bool
False