{-# LANGUAGE CPP #-}
{-# LANGUAGE RecordWildCards #-}
module Overloaded.Plugin.TcPlugin (
tcPlugin,
) where
import Data.Maybe (mapMaybe)
import qualified GHC.Compat.All as GHC
#if MIN_VERSION_ghc(9,0,0)
import qualified GHC.Tc.Plugin as Plugins
#else
import qualified TcPluginM as Plugins
#endif
import Overloaded.Plugin.TcPlugin.Ctx
import Overloaded.Plugin.HasField
import Overloaded.Plugin.HasConstructor
tcPlugin :: GHC.TcPlugin
tcPlugin :: TcPlugin
tcPlugin = TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
GHC.TcPlugin
{ tcPluginInit :: TcPluginM PluginCtx
GHC.tcPluginInit = TcPluginM PluginCtx
tcPluginInit
, tcPluginSolve :: PluginCtx -> TcPluginSolver
GHC.tcPluginSolve = PluginCtx -> TcPluginSolver
tcPluginSolve
, tcPluginStop :: PluginCtx -> TcPluginM ()
GHC.tcPluginStop = TcPluginM () -> PluginCtx -> TcPluginM ()
forall a b. a -> b -> a
const (() -> TcPluginM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
}
tcPluginSolve :: PluginCtx -> GHC.TcPluginSolver
tcPluginSolve :: PluginCtx -> TcPluginSolver
tcPluginSolve PluginCtx
ctx [Ct]
_ [Ct]
_ [Ct]
wanteds = do
DynFlags
dflags <- TcM DynFlags -> TcPluginM DynFlags
forall a. TcM a -> TcPluginM a
Plugins.unsafeTcPluginTcM TcM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
GHC.getDynFlags
(FamInstEnv, FamInstEnv)
famInstEnvs <- TcPluginM (FamInstEnv, FamInstEnv)
Plugins.getFamInstEnvs
GlobalRdrEnv
rdrEnv <- TcM GlobalRdrEnv -> TcPluginM GlobalRdrEnv
forall a. TcM a -> TcPluginM a
Plugins.unsafeTcPluginTcM TcM GlobalRdrEnv
GHC.getGlobalRdrEnv
[(Maybe (EvTerm, [Ct]), Ct)]
solvedHasField <- PluginCtx
-> DynFlags
-> (FamInstEnv, FamInstEnv)
-> GlobalRdrEnv
-> [Ct]
-> TcPluginM [(Maybe (EvTerm, [Ct]), Ct)]
solveHasField PluginCtx
ctx DynFlags
dflags (FamInstEnv, FamInstEnv)
famInstEnvs GlobalRdrEnv
rdrEnv [Ct]
wanteds
[(Maybe (EvTerm, [Ct]), Ct)]
solvedHasConstructor <- PluginCtx
-> DynFlags
-> (FamInstEnv, FamInstEnv)
-> GlobalRdrEnv
-> [Ct]
-> TcPluginM [(Maybe (EvTerm, [Ct]), Ct)]
solveHasConstructor PluginCtx
ctx DynFlags
dflags (FamInstEnv, FamInstEnv)
famInstEnvs GlobalRdrEnv
rdrEnv [Ct]
wanteds
let solved :: [(Maybe (EvTerm, [Ct]), Ct)]
solved = [(Maybe (EvTerm, [Ct]), Ct)]
solvedHasField [(Maybe (EvTerm, [Ct]), Ct)]
-> [(Maybe (EvTerm, [Ct]), Ct)] -> [(Maybe (EvTerm, [Ct]), Ct)]
forall a. [a] -> [a] -> [a]
++ [(Maybe (EvTerm, [Ct]), Ct)]
solvedHasConstructor
TcPluginResult -> TcPluginM TcPluginResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
GHC.TcPluginOk (((Maybe (EvTerm, [Ct]), Ct) -> Maybe (EvTerm, Ct))
-> [(Maybe (EvTerm, [Ct]), Ct)] -> [(EvTerm, Ct)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Maybe (EvTerm, [Ct]), Ct) -> Maybe (EvTerm, Ct)
forall a b b. (Maybe (a, b), b) -> Maybe (a, b)
extractA [(Maybe (EvTerm, [Ct]), Ct)]
solved) ([[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Ct]] -> [Ct]) -> [[Ct]] -> [Ct]
forall a b. (a -> b) -> a -> b
$ ((Maybe (EvTerm, [Ct]), Ct) -> Maybe [Ct])
-> [(Maybe (EvTerm, [Ct]), Ct)] -> [[Ct]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Maybe (EvTerm, [Ct]), Ct) -> Maybe [Ct]
forall a a b. (Maybe (a, a), b) -> Maybe a
extractB [(Maybe (EvTerm, [Ct]), Ct)]
solved)
where
extractA :: (Maybe (a, b), b) -> Maybe (a, b)
extractA (Maybe (a, b)
Nothing, b
_) = Maybe (a, b)
forall a. Maybe a
Nothing
extractA (Just (a
a, b
_), b
b) = (a, b) -> Maybe (a, b)
forall a. a -> Maybe a
Just (a
a, b
b)
extractB :: (Maybe (a, a), b) -> Maybe a
extractB (Maybe (a, a)
Nothing, b
_) = Maybe a
forall a. Maybe a
Nothing
extractB (Just (a
_, a
ct), b
_) = a -> Maybe a
forall a. a -> Maybe a
Just a
ct