module GHC.TcPluginM.Extra
(
newWanted
, newGiven
, newDerived
#if __GLASGOW_HASKELL__ < 711
, newWantedWithProvenance
#endif
, evByFiat
#if __GLASGOW_HASKELL__ < 711
, failWithProvenace
#endif
, lookupModule
, lookupName
, tracePlugin
)
where
#if __GLASGOW_HASKELL__ < 711
import Data.Maybe (mapMaybe)
#endif
#if __GLASGOW_HASKELL__ < 711
import BasicTypes (TopLevelFlag (..))
#endif
import Coercion (Role (..), mkUnivCo)
import FastString (FastString, fsLit)
import Module (Module, ModuleName)
import Name (Name)
import OccName (OccName)
import Outputable (($$), (<+>), empty, ppr, text)
import Panic (panicDoc)
#if __GLASGOW_HASKELL__ >= 711
import TcEvidence (EvTerm (..))
#else
import TcEvidence (EvTerm (..), TcCoercion (..))
import TcMType (newEvVar)
#endif
#if __GLASGOW_HASKELL__ < 711
import TcPluginM (FindResult (..), TcPluginM, findImportedModule, lookupOrig,
tcPluginTrace, unsafeTcPluginTcM)
import TcRnTypes (Ct, CtEvidence (..), CtLoc, TcIdBinder (..), TcLclEnv (..),
TcPlugin (..), TcPluginResult (..), ctEvId, ctEvLoc, ctLoc,
ctLocEnv, mkNonCanonical, setCtLocEnv)
#else
import TcPluginM (FindResult (..), TcPluginM, findImportedModule, lookupOrig,
tcPluginTrace)
import qualified TcPluginM
import TcRnTypes (CtEvidence (..), CtLoc,
TcPlugin (..), TcPluginResult (..))
#endif
#if __GLASGOW_HASKELL__ < 802
import TcPluginM (tcPluginIO)
#endif
#if __GLASGOW_HASKELL__ >= 711
import TyCoRep (UnivCoProvenance (..))
import Type (PredType, Type)
#else
import Type (EqRel (..), PredTree (..), PredType, Type, classifyPredType)
import Var (varType)
#endif
#if __GLASGOW_HASKELL__ < 802
import Data.IORef (readIORef)
import Control.Monad (unless)
import StaticFlags (initStaticOpts, v_opt_C_ready)
#endif
#if __GLASGOW_HASKELL__ >= 711
pattern FoundModule :: Module -> FindResult
pattern FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod = id
#endif
#if __GLASGOW_HASKELL__ < 711
newWantedWithProvenance :: CtEvidence
-> PredType
-> TcPluginM CtEvidence
newWantedWithProvenance ev@(CtWanted {}) p = do
let loc = ctEvLoc ev
env = ctLocEnv loc
id_ = ctEvId ev
env' = env {tcl_bndrs = (TcIdBndr id_ NotTopLevel):tcl_bndrs env}
loc' = setCtLocEnv loc env'
evVar <- unsafeTcPluginTcM $ newEvVar p
return CtWanted { ctev_pred = p
, ctev_evar = evVar
, ctev_loc = loc'}
newWantedWithProvenance ev _ =
panicDoc "newWantedWithProvenance: not a Wanted: " (ppr ev)
#endif
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newWanted = TcPluginM.newWanted
#else
newWanted loc pty = do
new_ev <- unsafeTcPluginTcM $ newEvVar pty
return CtWanted { ctev_pred = pty
, ctev_evar = new_ev
, ctev_loc = loc
}
#endif
newGiven :: CtLoc -> PredType -> EvTerm -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newGiven = TcPluginM.newGiven
#else
newGiven loc pty evtm = return
CtGiven { ctev_pred = pty
, ctev_evtm = evtm
, ctev_loc = loc
}
#endif
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newDerived = TcPluginM.newDerived
#else
newDerived loc pty = return
CtDerived { ctev_pred = pty
, ctev_loc = loc
}
#endif
evByFiat :: String
-> Type
-> Type
-> EvTerm
evByFiat name t1 t2 = EvCoercion
#if __GLASGOW_HASKELL__ < 711
$ TcCoercion
#endif
$ mkUnivCo
#if __GLASGOW_HASKELL__ >= 711
(PluginProv name)
#else
(fsLit name)
#endif
Nominal t1 t2
#if __GLASGOW_HASKELL__ < 711
failWithProvenace :: Ct -> TcPluginM TcPluginResult
failWithProvenace ct = return (TcPluginContradiction (ct : parents))
where
loc = ctLoc ct
lclbndrs = mapMaybe (\case {TcIdBndr id_ NotTopLevel -> Just id_
;_ -> Nothing })
$ tcl_bndrs (ctLocEnv loc)
eqBndrs = filter ((\x -> case x of { EqPred NomEq _ _ -> True
; _ -> False })
. classifyPredType . snd)
$ map (\ev -> (ev,varType ev)) lclbndrs
parents = map (\(id_,p) -> mkNonCanonical $ CtWanted p id_ loc) eqBndrs
#endif
lookupModule :: ModuleName
-> FastString
-> TcPluginM Module
lookupModule mod_nm pkg = do
found_module <- findImportedModule mod_nm $ Just pkg
case found_module of
#if __GLASGOW_HASKELL__ >= 711
FoundModule h -> return (fr_mod h)
#else
Found _ md -> return md
#endif
_ -> do
found_module' <- findImportedModule mod_nm $ Just $ fsLit "this"
case found_module' of
#if __GLASGOW_HASKELL__ >= 711
FoundModule h -> return (fr_mod h)
#else
Found _ md -> return md
#endif
_ -> panicDoc "Unable to resolve module looked up by plugin: "
(ppr mod_nm)
lookupName :: Module -> OccName -> TcPluginM Name
lookupName md occ = lookupOrig md occ
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin s TcPlugin{..} = TcPlugin { tcPluginInit = traceInit
, tcPluginSolve = traceSolve
, tcPluginStop = traceStop
}
where
traceInit = do
initializeStaticFlags
tcPluginTrace ("tcPluginInit " ++ s) empty >> tcPluginInit
traceStop z = tcPluginTrace ("tcPluginStop " ++ s) empty >> tcPluginStop z
traceSolve z given derived wanted = do
tcPluginTrace ("tcPluginSolve start " ++ s)
(text "given =" <+> ppr given
$$ text "derived =" <+> ppr derived
$$ text "wanted =" <+> ppr wanted)
r <- tcPluginSolve z given derived wanted
case r of
TcPluginOk solved new -> tcPluginTrace ("tcPluginSolve ok " ++ s)
(text "solved =" <+> ppr solved
$$ text "new =" <+> ppr new)
TcPluginContradiction bad -> tcPluginTrace
("tcPluginSolve contradiction " ++ s)
(text "bad =" <+> ppr bad)
return r
initializeStaticFlags :: TcPluginM ()
#if __GLASGOW_HASKELL__ < 802
initializeStaticFlags = tcPluginIO $ do
r <- readIORef v_opt_C_ready
unless r initStaticOpts
#else
initializeStaticFlags = return ()
#endif