{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module PinnedWarnings
  ( plugin
  ) where

import           Control.Concurrent.MVar
import           Control.Monad
import           Control.Monad.IO.Class
import           Data.IORef
import           Data.List
import qualified Data.Map.Strict as M
import           Data.Maybe
import qualified Data.Set as S
import           Data.String (fromString)
import qualified System.Directory as Dir
import           System.IO.Unsafe (unsafePerformIO)

import qualified Internal.FixWarnings as FW
import qualified Internal.GhcFacade as Ghc
import           Internal.Types

-- | A mutable global variable used to track warnings during and after
-- compilations.
globalState :: MVar (M.Map ModuleFile WarningsWithModDate)
globalState :: MVar (Map ModuleFile WarningsWithModDate)
globalState = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. a -> IO (MVar a)
newMVar forall a. Monoid a => a
mempty
{-# NOINLINE globalState #-}

--------------------------------------------------------------------------------
-- Plugin
--------------------------------------------------------------------------------

-- dynFlagsPlugin is being removed in future GHC. There is instead a way to
-- modify the HscEnv and there is a Logger type on HscEnv that should allow
-- for hooking into messages.
plugin :: Ghc.Plugin
plugin :: Plugin
plugin =
  Plugin
Ghc.defaultPlugin
    { tcPlugin :: TcPlugin
Ghc.tcPlugin           = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just TcPlugin
tcPlugin
    , parsedResultAction :: [ModuleFile] -> ModSummary -> HsParsedModule -> Hsc HsParsedModule
Ghc.parsedResultAction = forall a b. a -> b -> a
const ModSummary -> HsParsedModule -> Hsc HsParsedModule
resetPinnedWarnsForMod
    , driverPlugin :: [ModuleFile] -> HscEnv -> IO HscEnv
Ghc.driverPlugin       = forall a b. a -> b -> a
const (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. HscEnv -> HscEnv
addWarningCapture)
    , pluginRecompile :: [ModuleFile] -> IO PluginRecompile
Ghc.pluginRecompile    = [ModuleFile] -> IO PluginRecompile
Ghc.purePlugin
    }

tcPlugin :: Ghc.TcPlugin
tcPlugin :: TcPlugin
tcPlugin =
  Ghc.TcPlugin
    { tcPluginInit :: TcPluginM PluginState
Ghc.tcPluginInit  = TcPluginM PluginState
initTcPlugin
    , tcPluginSolve :: PluginState -> TcPluginSolver
Ghc.tcPluginSolve = \PluginState
pluginState [Ct]
_ [Ct]
_ -> PluginState -> [Ct] -> TcPluginM TcPluginResult'
checkWanteds PluginState
pluginState
    , tcPluginStop :: PluginState -> TcPluginM ()
Ghc.tcPluginStop  = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#if MIN_VERSION_ghc(9,4,0)
    , Ghc.tcPluginRewrite = mempty
#endif
    }

data PluginState =
  MkPluginState
    { PluginState -> TyCon
showWarningsClass  :: Ghc.TyCon
    , PluginState -> TyCon
fixWarningsClass   :: Ghc.TyCon
    , PluginState -> TyCon
clearWarningsClass :: Ghc.TyCon
    , PluginState -> IORef Int
counterRef         :: IORef Int
    }

initTcPlugin :: Ghc.TcPluginM PluginState
initTcPlugin :: TcPluginM PluginState
initTcPlugin =
  TyCon -> TyCon -> TyCon -> IORef Int -> PluginState
MkPluginState
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ModuleFile -> TcPluginM TyCon
lookupClass ModuleFile
"ShowWarnings"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ModuleFile -> TcPluginM TyCon
lookupClass ModuleFile
"FixWarnings"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ModuleFile -> TcPluginM TyCon
lookupClass ModuleFile
"ClearWarnings"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (forall a. a -> IO (IORef a)
newIORef Int
0)

-- | Get a reference to a class from the @ShowWarnings@ module
lookupClass :: String -> Ghc.TcPluginM Ghc.TyCon
lookupClass :: ModuleFile -> TcPluginM TyCon
lookupClass ModuleFile
className = do
  FindResult
result <- ModuleName -> Maybe FastString -> TcPluginM FindResult
Ghc.findImportedModule
              (ModuleFile -> ModuleName
Ghc.mkModuleName ModuleFile
"ShowWarnings")
#if MIN_VERSION_ghc(9,4,0)
              Ghc.NoPkgQual
#else
              (forall a. a -> Maybe a
Just FastString
"pinned-warnings")
#endif

  case FindResult
result of
    Ghc.Found ModLocation
_ Module
mod' -> do
      Name
name <- Module -> OccName -> TcPluginM Name
Ghc.lookupOrig Module
mod' forall a b. (a -> b) -> a -> b
$ ModuleFile -> OccName
Ghc.mkTcOcc ModuleFile
className
      Class -> TyCon
Ghc.classTyCon forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> TcPluginM Class
Ghc.tcLookupClass Name
name

    FindResult
_ -> forall a. HasCallStack => ModuleFile -> a
error ModuleFile
"ShowWarnings module not found"

-- | If any wanted constraints are for 'ShowWarnings', then inject the pinned
-- warnings into GHC.
checkWanteds :: PluginState
             -> [Ghc.Ct]
             -> Ghc.TcPluginM Ghc.TcPluginResult'
checkWanteds :: PluginState -> [Ct] -> TcPluginM TcPluginResult'
checkWanteds PluginState
pluginState
    = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> b -> a -> c
flip [(EvTerm, Ct)] -> [Ct] -> TcPluginResult'
Ghc.TcPluginOk [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ct -> TcPluginM (Maybe (EvTerm, Ct))
go
  where
    go :: Ct -> TcPluginM (Maybe (EvTerm, Ct))
go ct :: Ct
ct@(Ghc.CDictCan' CtEvidence
_ Class
cls [Xi]
_)
      | Class -> TyCon
Ghc.classTyCon Class
cls forall a. Eq a => a -> a -> Bool
== PluginState -> TyCon
showWarningsClass PluginState
pluginState = do
          Int
counter <- forall a. IO a -> TcPluginM a
Ghc.tcPluginIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef (PluginState -> IORef Int
counterRef PluginState
pluginState)

          -- for some reason warnings only appear if they are added on
          -- particular iterations.
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
counter forall a. Eq a => a -> a -> Bool
== Int
2) TcPluginM ()
addWarningsToContext
          TcPluginM ()
incrementCounter

          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (EvExpr -> EvTerm
Ghc.EvExpr EvExpr
Ghc.unitExpr, Ct
ct)

      | Class -> TyCon
Ghc.classTyCon Class
cls forall a. Eq a => a -> a -> Bool
== PluginState -> TyCon
fixWarningsClass PluginState
pluginState = do
          Int
counter <- forall a. IO a -> TcPluginM a
Ghc.tcPluginIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef (PluginState -> IORef Int
counterRef PluginState
pluginState)

          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
counter forall a. Eq a => a -> a -> Bool
== Int
0) (forall a. IO a -> TcPluginM a
Ghc.tcPluginIO IO ()
fixWarnings)
          TcPluginM ()
incrementCounter

          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (EvExpr -> EvTerm
Ghc.EvExpr EvExpr
Ghc.unitExpr, Ct
ct)

      | Class -> TyCon
Ghc.classTyCon Class
cls forall a. Eq a => a -> a -> Bool
== PluginState -> TyCon
clearWarningsClass PluginState
pluginState = do
          Int
counter <- forall a. IO a -> TcPluginM a
Ghc.tcPluginIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
readIORef (PluginState -> IORef Int
counterRef PluginState
pluginState)

          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
counter forall a. Eq a => a -> a -> Bool
== Int
0) (forall a. IO a -> TcPluginM a
Ghc.tcPluginIO IO ()
clearWarnings)
          TcPluginM ()
incrementCounter

          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (EvExpr -> EvTerm
Ghc.EvExpr EvExpr
Ghc.unitExpr, Ct
ct)

    go Ct
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

    incrementCounter :: TcPluginM ()
incrementCounter =
      forall a. IO a -> TcPluginM a
Ghc.tcPluginIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (PluginState -> IORef Int
counterRef PluginState
pluginState) forall a. Enum a => a -> a
succ

-- | Add warnings from the global state back into the GHC context
addWarningsToContext :: Ghc.TcPluginM ()
addWarningsToContext :: TcPluginM ()
addWarningsToContext = do
  TcRef (Messages DecoratedSDoc)
errsRef <- TcLclEnv -> TcRef (Messages DecoratedSDoc)
Ghc.tcl_errs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcPluginM (TcGblEnv, TcLclEnv)
Ghc.getEnvs

  forall a. IO a -> TcPluginM a
Ghc.tcPluginIO IO ()
pruneDeleted
  Bag (MsgEnvelope DecoratedSDoc)
pinnedWarns <- forall a. [a] -> Bag a
Ghc.listToBag forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Warning -> MsgEnvelope DecoratedSDoc
unWarning
               forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. Set a -> [a]
S.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. WarningsWithModDate -> MonoidMap SrcSpanKey (Set Warning)
warningsMap)
             forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (forall a. MVar a -> IO a
readMVar MVar (Map ModuleFile WarningsWithModDate)
globalState)

  forall a. IO a -> TcPluginM a
Ghc.tcPluginIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' TcRef (Messages DecoratedSDoc)
errsRef
#if MIN_VERSION_ghc(9,6,0)
    $ \messages ->
        (Ghc.mkMessages ((fmap . fmap) Ghc.mkTcRnUnknownMessage pinnedWarns)
          `Ghc.unionMessages` messages, ())
#elif MIN_VERSION_ghc(9,4,0)
    $ \messages ->
        (Ghc.mkMessages ((fmap . fmap) Ghc.TcRnUnknownMessage pinnedWarns)
          `Ghc.unionMessages` messages, ())
#elif MIN_VERSION_ghc(9,2,0)
    forall a b. (a -> b) -> a -> b
$ \Messages DecoratedSDoc
messages ->
        (forall e. Bag (MsgEnvelope e) -> Messages e
Ghc.mkMessages Bag (MsgEnvelope DecoratedSDoc)
pinnedWarns forall e. Messages e -> Messages e -> Messages e
`Ghc.unionMessages` Messages DecoratedSDoc
messages, ())
#endif

-- | Remove warnings for modules that no longer exist
pruneDeleted :: IO ()
pruneDeleted :: IO ()
pruneDeleted = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map ModuleFile WarningsWithModDate)
globalState forall a b. (a -> b) -> a -> b
$ \Map ModuleFile WarningsWithModDate
warns -> do
  -- remove keys that have no warnings
  let warns' :: Map ModuleFile WarningsWithModDate
warns' = forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall b c a. (b -> c) -> (a -> b) -> a -> c
. WarningsWithModDate -> MonoidMap SrcSpanKey (Set Warning)
warningsMap) Map ModuleFile WarningsWithModDate
warns
      mods :: [ModuleFile]
mods = forall k a. Map k a -> [k]
M.keys Map ModuleFile WarningsWithModDate
warns'

  [ModuleFile]
deletedMods <-
    forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleFile -> IO Bool
Dir.doesFileExist)
            [ModuleFile]
mods

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Map k a
M.delete) Map ModuleFile WarningsWithModDate
warns' [ModuleFile]
deletedMods

-- | Removes currently pinned warnings for a module and updates the timestamp.
-- This occurs before any new warnings are captured for the module.
resetPinnedWarnsForMod
  :: Ghc.ModSummary
#if MIN_VERSION_ghc(9,4,0)
  -> Ghc.ParsedResult
  -> Ghc.Hsc Ghc.ParsedResult
#else
  -> Ghc.HsParsedModule
  -> Ghc.Hsc Ghc.HsParsedModule
#endif
resetPinnedWarnsForMod :: ModSummary -> HsParsedModule -> Hsc HsParsedModule
resetPinnedWarnsForMod ModSummary
modSummary HsParsedModule
parsedModule = do
  let modFile :: ModuleFile
modFile = forall a. IsString a => ModuleFile -> a
fromString forall a b. (a -> b) -> a -> b
$ ModSummary -> ModuleFile
Ghc.ms_hspp_file ModSummary
modSummary

  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map ModuleFile WarningsWithModDate)
globalState
    forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Map k a
M.delete ModuleFile
modFile

  forall (f :: * -> *) a. Applicative f => a -> f a
pure HsParsedModule
parsedModule

-- | Taps into the log action to capture the warnings that GHC emits.
#if MIN_VERSION_ghc(9,4,0)
addWarningCapture :: Ghc.HscEnv -> Ghc.HscEnv
addWarningCapture hscEnv =
  hscEnv
    { Ghc.hsc_logger = Ghc.pushLogHook warningsHook (Ghc.hsc_logger hscEnv)
    }
  where
    warningsHook :: Ghc.LogAction -> Ghc.LogAction
    warningsHook logAction dynFlags messageClass srcSpan sdoc = do
      case messageClass of
#if MIN_VERSION_ghc(9,6,0)
        Ghc.MCDiagnostic Ghc.SevWarning _ _
#else
        Ghc.MCDiagnostic Ghc.SevWarning _
#endif
          | Ghc.RealSrcLoc start _ <- Ghc.srcSpanStart srcSpan
          , Ghc.RealSrcLoc end _ <- Ghc.srcSpanEnd srcSpan
          , Just modFile <- Ghc.srcSpanFileName_maybe srcSpan
          -> do
            let diag =
                  Ghc.DiagnosticMessage
                    { Ghc.diagMessage = Ghc.mkSimpleDecorated sdoc
                    , Ghc.diagReason = Ghc.WarningWithoutFlag
                    , Ghc.diagHints = []
                    }
                diagOpts = Ghc.initDiagOpts $ Ghc.hsc_dflags hscEnv
                warn = Warning $
                  Ghc.mkMsgEnvelope diagOpts srcSpan Ghc.neverQualify diag
            addWarningToGlobalState start end modFile warn
        _ -> pure ()
      logAction dynFlags messageClass srcSpan sdoc
#elif MIN_VERSION_ghc(9,2,0)
addWarningCapture :: Ghc.HscEnv -> Ghc.HscEnv
addWarningCapture :: HscEnv -> HscEnv
addWarningCapture HscEnv
hscEnv =
  HscEnv
hscEnv
    { hsc_logger :: Logger
Ghc.hsc_logger = (LogAction -> LogAction) -> Logger -> Logger
Ghc.pushLogHook LogAction -> LogAction
warningsHook (HscEnv -> Logger
Ghc.hsc_logger HscEnv
hscEnv)
    }
  where
    warningsHook :: Ghc.LogAction -> Ghc.LogAction
    warningsHook :: LogAction -> LogAction
warningsHook LogAction
logAction DynFlags
dynFlags WarnReason
warnReason Severity
severity SrcSpan
srcSpan SDoc
sdoc = do
      case Severity
severity of
        Severity
Ghc.SevWarning
          | Ghc.RealSrcLoc RealSrcLoc
start Maybe BufPos
_ <- SrcSpan -> SrcLoc
Ghc.srcSpanStart SrcSpan
srcSpan
          , Ghc.RealSrcLoc RealSrcLoc
end Maybe BufPos
_ <- SrcSpan -> SrcLoc
Ghc.srcSpanEnd SrcSpan
srcSpan
          , Just FastString
modFile <- SrcSpan -> Maybe FastString
Ghc.srcSpanFileName_maybe SrcSpan
srcSpan
          -> do
            let warn :: Warning
warn = MsgEnvelope DecoratedSDoc -> Warning
Warning forall a b. (a -> b) -> a -> b
$ SrcSpan -> SDoc -> MsgEnvelope DecoratedSDoc
Ghc.mkPlainWarnMsg SrcSpan
srcSpan SDoc
sdoc
            RealSrcLoc -> RealSrcLoc -> FastString -> Warning -> IO ()
addWarningToGlobalState RealSrcLoc
start RealSrcLoc
end FastString
modFile Warning
warn
        Severity
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      LogAction
logAction DynFlags
dynFlags WarnReason
warnReason Severity
severity SrcSpan
srcSpan SDoc
sdoc
#endif

-- | Adds a warning to the global state variable
addWarningToGlobalState
  :: Ghc.RealSrcLoc -- ^ start location
  -> Ghc.RealSrcLoc -- ^ end location
  -> Ghc.FastString -- ^ module name
  -> Warning
  -> IO ()
addWarningToGlobalState :: RealSrcLoc -> RealSrcLoc -> FastString -> Warning -> IO ()
addWarningToGlobalState RealSrcLoc
start RealSrcLoc
end FastString
modFile Warning
warn = do
  let wrappedWarn :: Map SrcSpanKey (Set Warning)
wrappedWarn = forall k a. k -> a -> Map k a
M.singleton (RealSrcLoc
start, RealSrcLoc
end)
                  forall a b. (a -> b) -> a -> b
$ forall a. a -> Set a
S.singleton Warning
warn
      file :: ModuleFile
file = FastString -> ModuleFile
Ghc.unpackFS FastString
modFile
  Bool
exists <- ModuleFile -> IO Bool
Dir.doesFileExist ModuleFile
file
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
exists forall a b. (a -> b) -> a -> b
$ do
    UTCTime
fileModifiedAt <- ModuleFile -> IO UTCTime
Dir.getModificationTime ModuleFile
file
    forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map ModuleFile WarningsWithModDate)
globalState
      forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith forall a. Semigroup a => a -> a -> a
(<>) ModuleFile
file
          MkWarningsWithModDate
            { lastUpdated :: UTCTime
lastUpdated = UTCTime
fileModifiedAt
            , warningsMap :: MonoidMap SrcSpanKey (Set Warning)
warningsMap = forall k a. Map k a -> MonoidMap k a
MonoidMap Map SrcSpanKey (Set Warning)
wrappedWarn
            }

fixWarnings :: IO ()
fixWarnings :: IO ()
fixWarnings = do
  IO ()
pruneDeleted

  forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map ModuleFile WarningsWithModDate)
globalState forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
M.traverseWithKey ModuleFile -> WarningsWithModDate -> IO WarningsWithModDate
FW.fixWarning

clearWarnings :: IO ()
clearWarnings :: IO ()
clearWarnings =
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO a
swapMVar MVar (Map ModuleFile WarningsWithModDate)
globalState forall k a. Map k a
M.empty