{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE CPP             #-}
module Overloaded.Plugin.TcPlugin.Ctx where

import qualified GHC.Compat.All  as GHC
import           GHC.Compat.Expr

#if MIN_VERSION_ghc(9,0,0)
import qualified GHC.Tc.Plugin as Plugins
#else
import qualified TcPluginM as Plugins
#endif

import Overloaded.Plugin.Diagnostics
import Overloaded.Plugin.Names

data PluginCtx = PluginCtx
    { PluginCtx -> Class
hasPolyFieldCls :: GHC.Class
    , PluginCtx -> Class
hasPolyConCls   :: GHC.Class
    }

tcPluginInit :: GHC.TcPluginM PluginCtx
tcPluginInit :: TcPluginM PluginCtx
tcPluginInit = do
    DynFlags
dflags <- TcM DynFlags -> TcPluginM DynFlags
forall a. TcM a -> TcPluginM a
GHC.unsafeTcPluginTcM TcM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
GHC.getDynFlags

    let findModule :: GHC.ModuleName -> Plugins.TcPluginM GHC.Module
        findModule :: ModuleName -> TcPluginM Module
findModule ModuleName
m = do
            FindResult
im <- ModuleName -> Maybe FastString -> TcPluginM FindResult
Plugins.findImportedModule ModuleName
m Maybe FastString
forall a. Maybe a
Nothing
            case FindResult
im of
                GHC.Found ModLocation
_ Module
md -> Module -> TcPluginM Module
forall (m :: * -> *) a. Monad m => a -> m a
return Module
md
                FindResult
_              -> do
                    IO () -> TcPluginM ()
forall a. IO a -> TcPluginM a
Plugins.tcPluginIO (IO () -> TcPluginM ()) -> IO () -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ DynFlags -> SrcSpan -> SDoc -> IO ()
forall (m :: * -> *).
MonadIO m =>
DynFlags -> SrcSpan -> SDoc -> m ()
putError DynFlags
dflags SrcSpan
noSrcSpan  (SDoc -> IO ()) -> SDoc -> IO ()
forall a b. (a -> b) -> a -> b
$
                        String -> SDoc
GHC.text String
"Cannot find module" SDoc -> SDoc -> SDoc
GHC.<+> ModuleName -> SDoc
forall a. Outputable a => a -> SDoc
GHC.ppr ModuleName
m 
                    String -> TcPluginM Module
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"panic!"

    Class
hasPolyFieldCls <- do
        Module
md <- ModuleName -> TcPluginM Module
findModule ModuleName
ghcRecordsCompatMN
        Name -> TcPluginM Class
Plugins.tcLookupClass (Name -> TcPluginM Class) -> TcPluginM Name -> TcPluginM Class
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
Plugins.lookupOrig Module
md (String -> OccName
GHC.mkTcOcc String
"HasField")

    Class
hasPolyConCls <- do
        Module
md <- ModuleName -> TcPluginM Module
findModule ModuleName
overloadedConstructorsMN
        Name -> TcPluginM Class
Plugins.tcLookupClass (Name -> TcPluginM Class) -> TcPluginM Name -> TcPluginM Class
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Module -> OccName -> TcPluginM Name
Plugins.lookupOrig Module
md (String -> OccName
GHC.mkTcOcc String
"HasConstructor")

    PluginCtx -> TcPluginM PluginCtx
forall (m :: * -> *) a. Monad m => a -> m a
return PluginCtx :: Class -> Class -> PluginCtx
PluginCtx {Class
hasPolyConCls :: Class
hasPolyFieldCls :: Class
hasPolyConCls :: Class
hasPolyFieldCls :: Class
..}