{-# LANGUAGE OverloadedStrings #-}
module PinnedWarnings
  ( plugin
  ) where

import           Control.Concurrent.MVar
import           Control.Monad
import           Control.Monad.IO.Class
import qualified Data.ByteString.Char8 as BS
import           Data.IORef
import           Data.List
import qualified Data.Map.Strict as M
import           Data.Maybe
import qualified Data.Set as S
import qualified System.Directory as Dir
import           System.IO.Unsafe (unsafePerformIO)

import qualified GhcFacade as Ghc

type ModuleFile = BS.ByteString

-- The infamous mutable global trick.
-- Needed to track the pinned warnings during and after compilation.
globalState :: MVar (M.Map ModuleFile Ghc.WarningMessages)
globalState :: MVar (Map ModuleFile WarningMessages)
globalState = IO (MVar (Map ModuleFile WarningMessages))
-> MVar (Map ModuleFile WarningMessages)
forall a. IO a -> a
unsafePerformIO (IO (MVar (Map ModuleFile WarningMessages))
 -> MVar (Map ModuleFile WarningMessages))
-> IO (MVar (Map ModuleFile WarningMessages))
-> MVar (Map ModuleFile WarningMessages)
forall a b. (a -> b) -> a -> b
$ Map ModuleFile WarningMessages
-> IO (MVar (Map ModuleFile WarningMessages))
forall a. a -> IO (MVar a)
newMVar Map ModuleFile WarningMessages
forall a. Monoid a => a
mempty
{-# NOINLINE globalState #-}

plugin :: Ghc.Plugin
plugin :: Plugin
plugin =
  Plugin
Ghc.defaultPlugin
    { tcPlugin :: TcPlugin
Ghc.tcPlugin = Maybe TcPlugin -> TcPlugin
forall a b. a -> b -> a
const (Maybe TcPlugin -> TcPlugin) -> Maybe TcPlugin -> TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just TcPlugin
tcPlugin
    , typeCheckResultAction :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv
Ghc.typeCheckResultAction = (ModSummary -> TcGblEnv -> TcM TcGblEnv)
-> [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv
forall a b. a -> b -> a
const ModSummary -> TcGblEnv -> TcM TcGblEnv
insertModuleWarnings
    }

tcPlugin :: Ghc.TcPlugin
tcPlugin :: TcPlugin
tcPlugin =
  TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
Ghc.TcPlugin
    { tcPluginInit :: TcPluginM (TyCon, IORef Int)
Ghc.tcPluginInit  = TcPluginM (TyCon, IORef Int)
initTcPlugin
    , tcPluginSolve :: (TyCon, IORef Int) -> TcPluginSolver
Ghc.tcPluginSolve = \(TyCon
sw, IORef Int
counterRef) [Ct]
_ [Ct]
_ [Ct]
wanteds ->
        TyCon -> IORef Int -> [Ct] -> TcPluginM TcPluginResult
checkWanteds TyCon
sw IORef Int
counterRef [Ct]
wanteds
    , tcPluginStop :: (TyCon, IORef Int) -> TcPluginM ()
Ghc.tcPluginStop  = TcPluginM () -> (TyCon, IORef Int) -> TcPluginM ()
forall a b. a -> b -> a
const (TcPluginM () -> (TyCon, IORef Int) -> TcPluginM ())
-> TcPluginM () -> (TyCon, IORef Int) -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ () -> TcPluginM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    }

initTcPlugin :: Ghc.TcPluginM (Ghc.TyCon, IORef Int)
initTcPlugin :: TcPluginM (TyCon, IORef Int)
initTcPlugin =
  (,) (TyCon -> IORef Int -> (TyCon, IORef Int))
-> TcPluginM TyCon -> TcPluginM (IORef Int -> (TyCon, IORef Int))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcPluginM TyCon
lookupShowWarnings
      TcPluginM (IORef Int -> (TyCon, IORef Int))
-> TcPluginM (IORef Int) -> TcPluginM (TyCon, IORef Int)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (IORef Int) -> TcPluginM (IORef Int)
forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0)

-- | Gets a reference to the 'ShowWarnings' constraint
lookupShowWarnings :: Ghc.TcPluginM Ghc.TyCon
lookupShowWarnings :: TcPluginM TyCon
lookupShowWarnings = do
  FindResult
result <- ModuleName -> Maybe FastString -> TcPluginM FindResult
Ghc.findImportedModule
              (CommandLineOption -> ModuleName
Ghc.mkModuleName CommandLineOption
"ShowWarnings")
              (FastString -> Maybe FastString
forall a. a -> Maybe a
Just  FastString
"pinned-warnings")

  case FindResult
result of
    Ghc.Found ModLocation
_ Module
mod -> do
      Name
name <- Module -> OccName -> TcPluginM Name
Ghc.lookupOrig Module
mod (OccName -> TcPluginM Name) -> OccName -> TcPluginM Name
forall a b. (a -> b) -> a -> b
$ CommandLineOption -> OccName
Ghc.mkTcOcc CommandLineOption
"ShowWarnings"
      Class -> TyCon
Ghc.classTyCon (Class -> TyCon) -> TcPluginM Class -> TcPluginM TyCon
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Name -> TcPluginM Class
Ghc.tcLookupClass Name
name

    FindResult
_ -> CommandLineOption -> TcPluginM TyCon
forall a. HasCallStack => CommandLineOption -> a
error CommandLineOption
"ShowWarnings module not found"

-- | If any wanted constraints are for 'ShowWarnings', then inject any pinned
-- warnings into GHC.
checkWanteds :: Ghc.TyCon
             -> IORef Int
             -> [Ghc.Ct]
             -> Ghc.TcPluginM Ghc.TcPluginResult
checkWanteds :: TyCon -> IORef Int -> [Ct] -> TcPluginM TcPluginResult
checkWanteds TyCon
sw IORef Int
counterRef
    = ([Maybe (EvTerm, Ct)] -> TcPluginResult)
-> TcPluginM [Maybe (EvTerm, Ct)] -> TcPluginM TcPluginResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([(EvTerm, Ct)] -> [Ct] -> TcPluginResult)
-> [Ct] -> [(EvTerm, Ct)] -> TcPluginResult
forall a b c. (a -> b -> c) -> b -> a -> c
flip [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
Ghc.TcPluginOk [] ([(EvTerm, Ct)] -> TcPluginResult)
-> ([Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)])
-> [Maybe (EvTerm, Ct)]
-> TcPluginResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. [Maybe a] -> [a]
catMaybes)
    (TcPluginM [Maybe (EvTerm, Ct)] -> TcPluginM TcPluginResult)
-> ([Ct] -> TcPluginM [Maybe (EvTerm, Ct)])
-> [Ct]
-> TcPluginM TcPluginResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> TcPluginM (Maybe (EvTerm, Ct)))
-> [Ct] -> TcPluginM [Maybe (EvTerm, Ct)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ct -> TcPluginM (Maybe (EvTerm, Ct))
go
  where
    go :: Ct -> TcPluginM (Maybe (EvTerm, Ct))
go ct :: Ct
ct@Ghc.CDictCan { cc_class :: Ct -> Class
Ghc.cc_class = Class
cls }
      | Class -> TyCon
Ghc.classTyCon Class
cls TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
sw = do
          Int
counter <- IO Int -> TcPluginM Int
forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (IO Int -> TcPluginM Int) -> IO Int -> TcPluginM Int
forall a b. (a -> b) -> a -> b
$ IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
counterRef

          -- for some reason warnings only appear if they are added on
          -- particular iterations.
          Bool -> TcPluginM () -> TcPluginM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
counter Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2) TcPluginM ()
addWarningsToContext

          IO () -> TcPluginM ()
forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (IO () -> TcPluginM ()) -> IO () -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ IORef Int -> (Int -> Int) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Int
counterRef Int -> Int
forall a. Enum a => a -> a
succ

          Maybe (EvTerm, Ct) -> TcPluginM (Maybe (EvTerm, Ct))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (EvTerm, Ct) -> TcPluginM (Maybe (EvTerm, Ct)))
-> Maybe (EvTerm, Ct) -> TcPluginM (Maybe (EvTerm, Ct))
forall a b. (a -> b) -> a -> b
$ (EvTerm, Ct) -> Maybe (EvTerm, Ct)
forall a. a -> Maybe a
Just (EvExpr -> EvTerm
Ghc.EvExpr EvExpr
Ghc.unitExpr, Ct
ct)

    go Ct
_ = Maybe (EvTerm, Ct) -> TcPluginM (Maybe (EvTerm, Ct))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (EvTerm, Ct)
forall a. Maybe a
Nothing

-- | Add warnings from the global state back into the GHC context
addWarningsToContext :: Ghc.TcPluginM ()
addWarningsToContext :: TcPluginM ()
addWarningsToContext = do
  TcRef Messages
errsRef <- TcLclEnv -> TcRef Messages
Ghc.tcl_errs (TcLclEnv -> TcRef Messages)
-> ((TcGblEnv, TcLclEnv) -> TcLclEnv)
-> (TcGblEnv, TcLclEnv)
-> TcRef Messages
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TcGblEnv, TcLclEnv) -> TcLclEnv
forall a b. (a, b) -> b
snd ((TcGblEnv, TcLclEnv) -> TcRef Messages)
-> TcPluginM (TcGblEnv, TcLclEnv) -> TcPluginM (TcRef Messages)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TcPluginM (TcGblEnv, TcLclEnv)
Ghc.getEnvs

  TcPluginM ()
pruneDeleted
  WarningMessages
pinnedWarns <- [WarnMsg] -> WarningMessages
forall a. [a] -> Bag a
Ghc.listToBag
               ([WarnMsg] -> WarningMessages)
-> (Map ModuleFile WarningMessages -> [WarnMsg])
-> Map ModuleFile WarningMessages
-> WarningMessages
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (WarningMessages -> [WarnMsg])
-> Map ModuleFile WarningMessages -> [WarnMsg]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap WarningMessages -> [WarnMsg]
forall a. Bag a -> [a]
Ghc.bagToList
             (Map ModuleFile WarningMessages -> WarningMessages)
-> TcPluginM (Map ModuleFile WarningMessages)
-> TcPluginM WarningMessages
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (Map ModuleFile WarningMessages)
-> TcPluginM (Map ModuleFile WarningMessages)
forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (MVar (Map ModuleFile WarningMessages)
-> IO (Map ModuleFile WarningMessages)
forall a. MVar a -> IO a
readMVar MVar (Map ModuleFile WarningMessages)
globalState)

  IO () -> TcPluginM ()
forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (IO () -> TcPluginM ())
-> ((Messages -> (Messages, ())) -> IO ())
-> (Messages -> (Messages, ()))
-> TcPluginM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcRef Messages -> (Messages -> (Messages, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' TcRef Messages
errsRef
    ((Messages -> (Messages, ())) -> TcPluginM ())
-> (Messages -> (Messages, ())) -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \(WarningMessages
warnings, WarningMessages
errors) ->
        ((WarningMessages -> WarningMessages -> WarningMessages
forall a. Bag a -> Bag a -> Bag a
Ghc.unionBags WarningMessages
pinnedWarns WarningMessages
warnings, WarningMessages
errors), ())

-- | Remove warnings for modules that no longer exist
pruneDeleted :: Ghc.TcPluginM ()
pruneDeleted :: TcPluginM ()
pruneDeleted = IO () -> TcPluginM ()
forall a. IO a -> TcPluginM a
Ghc.tcPluginIO (IO () -> TcPluginM ())
-> ((Map ModuleFile WarningMessages
     -> IO (Map ModuleFile WarningMessages))
    -> IO ())
-> (Map ModuleFile WarningMessages
    -> IO (Map ModuleFile WarningMessages))
-> TcPluginM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar (Map ModuleFile WarningMessages)
-> (Map ModuleFile WarningMessages
    -> IO (Map ModuleFile WarningMessages))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map ModuleFile WarningMessages)
globalState ((Map ModuleFile WarningMessages
  -> IO (Map ModuleFile WarningMessages))
 -> TcPluginM ())
-> (Map ModuleFile WarningMessages
    -> IO (Map ModuleFile WarningMessages))
-> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ \Map ModuleFile WarningMessages
warns -> do
  let mods :: [ModuleFile]
mods = Map ModuleFile WarningMessages -> [ModuleFile]
forall k a. Map k a -> [k]
M.keys Map ModuleFile WarningMessages
warns

  [ModuleFile]
deletedMods <-
    (ModuleFile -> IO Bool) -> [ModuleFile] -> IO [ModuleFile]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ((Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Bool
not (IO Bool -> IO Bool)
-> (ModuleFile -> IO Bool) -> ModuleFile -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CommandLineOption -> IO Bool
Dir.doesFileExist (CommandLineOption -> IO Bool)
-> (ModuleFile -> CommandLineOption) -> ModuleFile -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleFile -> CommandLineOption
BS.unpack)
            [ModuleFile]
mods

  Map ModuleFile WarningMessages
-> IO (Map ModuleFile WarningMessages)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map ModuleFile WarningMessages
 -> IO (Map ModuleFile WarningMessages))
-> Map ModuleFile WarningMessages
-> IO (Map ModuleFile WarningMessages)
forall a b. (a -> b) -> a -> b
$ (Map ModuleFile WarningMessages
 -> ModuleFile -> Map ModuleFile WarningMessages)
-> Map ModuleFile WarningMessages
-> [ModuleFile]
-> Map ModuleFile WarningMessages
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((ModuleFile
 -> Map ModuleFile WarningMessages
 -> Map ModuleFile WarningMessages)
-> Map ModuleFile WarningMessages
-> ModuleFile
-> Map ModuleFile WarningMessages
forall a b c. (a -> b -> c) -> b -> a -> c
flip ModuleFile
-> Map ModuleFile WarningMessages -> Map ModuleFile WarningMessages
forall k a. Ord k => k -> Map k a -> Map k a
M.delete) Map ModuleFile WarningMessages
warns [ModuleFile]
deletedMods

-- | After type checking a module, pin any warnings pertaining to it.
insertModuleWarnings :: Ghc.ModSummary -> Ghc.TcGblEnv -> Ghc.TcM Ghc.TcGblEnv
insertModuleWarnings :: ModSummary -> TcGblEnv -> TcM TcGblEnv
insertModuleWarnings ModSummary
modSummary TcGblEnv
tcGblEnv = do
  TcRef Messages
lclErrsRef <- TcLclEnv -> TcRef Messages
Ghc.tcl_errs (TcLclEnv -> TcRef Messages)
-> (Env TcGblEnv TcLclEnv -> TcLclEnv)
-> Env TcGblEnv TcLclEnv
-> TcRef Messages
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env TcGblEnv TcLclEnv -> TcLclEnv
forall gbl lcl. Env gbl lcl -> lcl
Ghc.env_lcl (Env TcGblEnv TcLclEnv -> TcRef Messages)
-> IOEnv (Env TcGblEnv TcLclEnv) (Env TcGblEnv TcLclEnv)
-> IOEnv (Env TcGblEnv TcLclEnv) (TcRef Messages)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IOEnv (Env TcGblEnv TcLclEnv) (Env TcGblEnv TcLclEnv)
forall env. IOEnv env env
Ghc.getEnv
  (WarningMessages
warns, WarningMessages
_) <- IO Messages -> IOEnv (Env TcGblEnv TcLclEnv) Messages
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Messages -> IOEnv (Env TcGblEnv TcLclEnv) Messages)
-> IO Messages -> IOEnv (Env TcGblEnv TcLclEnv) Messages
forall a b. (a -> b) -> a -> b
$ TcRef Messages -> IO Messages
forall a. IORef a -> IO a
readIORef TcRef Messages
lclErrsRef

  let modFile :: ModuleFile
modFile = CommandLineOption -> ModuleFile
BS.pack (CommandLineOption -> ModuleFile)
-> CommandLineOption -> ModuleFile
forall a b. (a -> b) -> a -> b
$ ModSummary -> CommandLineOption
Ghc.ms_hspp_file ModSummary
modSummary
      onlyThisMod :: WarnMsg -> Bool
onlyThisMod WarnMsg
w =
        case WarnMsg -> SrcSpan
Ghc.errMsgSpan WarnMsg
w of
          Ghc.RealSrcSpan' RealSrcSpan
span ->
            FastString -> ModuleFile
Ghc.bytesFS' (RealSrcSpan -> FastString
Ghc.srcSpanFile RealSrcSpan
span) ModuleFile -> ModuleFile -> Bool
forall a. Eq a => a -> a -> Bool
== ModuleFile
modFile
          SrcSpan
_ -> Bool
False

      warnsForMod :: WarningMessages
warnsForMod = (WarnMsg -> Bool) -> WarningMessages -> WarningMessages
forall a. (a -> Bool) -> Bag a -> Bag a
Ghc.filterBag WarnMsg -> Bool
onlyThisMod WarningMessages
warns

  -- Replace any existing pinned warnings with new ones for this module
  IO () -> IOEnv (Env TcGblEnv TcLclEnv) ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IOEnv (Env TcGblEnv TcLclEnv) ())
-> ((Map ModuleFile WarningMessages
     -> IO (Map ModuleFile WarningMessages))
    -> IO ())
-> (Map ModuleFile WarningMessages
    -> IO (Map ModuleFile WarningMessages))
-> IOEnv (Env TcGblEnv TcLclEnv) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar (Map ModuleFile WarningMessages)
-> (Map ModuleFile WarningMessages
    -> IO (Map ModuleFile WarningMessages))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map ModuleFile WarningMessages)
globalState
    ((Map ModuleFile WarningMessages
  -> IO (Map ModuleFile WarningMessages))
 -> IOEnv (Env TcGblEnv TcLclEnv) ())
-> (Map ModuleFile WarningMessages
    -> IO (Map ModuleFile WarningMessages))
-> IOEnv (Env TcGblEnv TcLclEnv) ()
forall a b. (a -> b) -> a -> b
$ Map ModuleFile WarningMessages
-> IO (Map ModuleFile WarningMessages)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map ModuleFile WarningMessages
 -> IO (Map ModuleFile WarningMessages))
-> (Map ModuleFile WarningMessages
    -> Map ModuleFile WarningMessages)
-> Map ModuleFile WarningMessages
-> IO (Map ModuleFile WarningMessages)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleFile
-> WarningMessages
-> Map ModuleFile WarningMessages
-> Map ModuleFile WarningMessages
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert ModuleFile
modFile WarningMessages
warnsForMod

  TcGblEnv -> TcM TcGblEnv
forall (f :: * -> *) a. Applicative f => a -> f a
pure TcGblEnv
tcGblEnv