{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TcPluginM.Extra
(
newWanted
, newGiven
, newDerived
#if __GLASGOW_HASKELL__ < 711
, newWantedWithProvenance
#endif
, evByFiat
#if __GLASGOW_HASKELL__ < 711
, failWithProvenace
#endif
, lookupModule
, lookupName
, tracePlugin
, flattenGivens
, mkSubst
, mkSubst'
, substType
, substCt
)
where
#if __GLASGOW_HASKELL__ < 711
import Data.Maybe (mapMaybe)
#endif
#if __GLASGOW_HASKELL__ < 711
import BasicTypes (TopLevelFlag (..))
#endif
#if MIN_VERSION_ghc(8,5,0)
import CoreSyn (Expr(..))
#endif
import Coercion (Role (..), mkUnivCo)
import FastString (FastString, fsLit)
import Module (Module, ModuleName)
import Name (Name)
import OccName (OccName)
import Outputable (($$), (<+>), empty, ppr, text)
import Panic (panicDoc)
#if __GLASGOW_HASKELL__ >= 711
import TcEvidence (EvTerm (..))
#else
import TcEvidence (EvTerm (..), TcCoercion (..))
import TcMType (newEvVar)
#endif
#if __GLASGOW_HASKELL__ < 711
import TcPluginM (FindResult (..), TcPluginM, findImportedModule, lookupOrig,
tcPluginTrace, unsafeTcPluginTcM)
import TcRnTypes (Ct, CtEvidence (..), CtLoc, TcIdBinder (..), TcLclEnv (..),
TcPlugin (..), TcPluginResult (..), ctEvLoc,
ctLocEnv, setCtLocEnv)
#else
import TcPluginM (FindResult (..), TcPluginM, lookupOrig, tcPluginTrace)
import qualified TcPluginM
import qualified Finder
#if __GLASGOW_HASKELL__ < 809
import TcRnTypes (CtEvidence (..), CtLoc,
TcPlugin (..), TcPluginResult (..))
#else
import TcRnTypes (TcPlugin (..), TcPluginResult (..))
#endif
#endif
#if __GLASGOW_HASKELL__ < 802
import TcPluginM (tcPluginIO)
#endif
#if __GLASGOW_HASKELL__ >= 711
import TyCoRep (UnivCoProvenance (..))
import Type (PredType, Type)
#else
import Type (EqRel (..), PredTree (..), PredType, Type, classifyPredType)
import Var (varType)
#endif
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
#if __GLASGOW_HASKELL__ < 809
import TcRnTypes (Ct (..), ctLoc, ctEvId, mkNonCanonical)
#else
import Constraint
(Ct (..), CtEvidence (..), CtLoc, ctLoc, ctEvId, mkNonCanonical)
#endif
import TcType (TcTyVar, TcType)
#if __GLASGOW_HASKELL__ < 809
import Type (mkPrimEqPred)
#else
import Predicate (mkPrimEqPred)
#endif
#if __GLASGOW_HASKELL__ < 711
import TcRnTypes (ctEvTerm)
import TypeRep (Type (..))
#else
import Data.Maybe (mapMaybe)
import TyCoRep (Type (..))
#endif
#if __GLASGOW_HASKELL__ < 802
import Data.IORef (readIORef)
import Control.Monad (unless)
import StaticFlags (initStaticOpts, v_opt_C_ready)
#endif
#if __GLASGOW_HASKELL__ >= 711
pattern FoundModule :: Module -> FindResult
pattern FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod = id
#endif
#if __GLASGOW_HASKELL__ < 711
{-# DEPRECATED newWantedWithProvenance "No longer available in GHC 8.0+" #-}
newWantedWithProvenance :: CtEvidence
-> PredType
-> TcPluginM CtEvidence
newWantedWithProvenance ev@(CtWanted {}) p = do
let loc = ctEvLoc ev
env = ctLocEnv loc
id_ = ctEvId ev
env' = env {tcl_bndrs = (TcIdBndr id_ NotTopLevel):tcl_bndrs env}
loc' = setCtLocEnv loc env'
evVar <- unsafeTcPluginTcM $ newEvVar p
return CtWanted { ctev_pred = p
, ctev_evar = evVar
, ctev_loc = loc'}
newWantedWithProvenance ev _ =
panicDoc "newWantedWithProvenance: not a Wanted: " (ppr ev)
#endif
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newWanted = TcPluginM.newWanted
#else
newWanted loc pty = do
new_ev <- unsafeTcPluginTcM $ newEvVar pty
return CtWanted { ctev_pred = pty
, ctev_evar = new_ev
, ctev_loc = loc
}
#endif
newGiven :: CtLoc -> PredType -> EvTerm -> TcPluginM CtEvidence
#if MIN_VERSION_ghc(8,5,0)
newGiven loc pty (EvExpr ev) = TcPluginM.newGiven loc pty ev
newGiven _ _ ev = panicDoc "newGiven: not an EvExpr: " (ppr ev)
#elif __GLASGOW_HASKELL__ >= 711
newGiven = TcPluginM.newGiven
#else
newGiven loc pty evtm = return
CtGiven { ctev_pred = pty
, ctev_evtm = evtm
, ctev_loc = loc
}
#endif
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newDerived = TcPluginM.newDerived
#else
newDerived loc pty = return
CtDerived { ctev_pred = pty
, ctev_loc = loc
}
#endif
evByFiat :: String
-> Type
-> Type
-> EvTerm
evByFiat name t1 t2 =
#if MIN_VERSION_ghc(8,5,0)
EvExpr
$ Coercion
#else
EvCoercion
#if __GLASGOW_HASKELL__ < 711
$ TcCoercion
#endif
#endif
$ mkUnivCo
#if __GLASGOW_HASKELL__ >= 711
(PluginProv name)
#else
(fsLit name)
#endif
Nominal t1 t2
#if __GLASGOW_HASKELL__ < 711
{-# DEPRECATED failWithProvenace "No longer available in GHC 8.0+" #-}
failWithProvenace :: Ct -> TcPluginM TcPluginResult
failWithProvenace ct = return (TcPluginContradiction (ct : parents))
where
loc = ctLoc ct
lclbndrs = mapMaybe (\case {TcIdBndr id_ NotTopLevel -> Just id_
;_ -> Nothing })
$ tcl_bndrs (ctLocEnv loc)
eqBndrs = filter ((\x -> case x of { EqPred NomEq _ _ -> True
; _ -> False })
. classifyPredType . snd)
$ map (\ev -> (ev,varType ev)) lclbndrs
parents = map (\(id_,p) -> mkNonCanonical $ CtWanted p id_ loc) eqBndrs
#endif
lookupModule :: ModuleName
-> FastString
-> TcPluginM Module
lookupModule mod_nm _pkg = do
#if __GLASGOW_HASKELL__ >= 711
hsc_env <- TcPluginM.getTopEnv
found_module <- TcPluginM.tcPluginIO $ Finder.findPluginModule hsc_env mod_nm
#else
found_module <- findImportedModule mod_nm $ Just _pkg
#endif
case found_module of
#if __GLASGOW_HASKELL__ >= 711
FoundModule h -> return (fr_mod h)
#else
Found _ md -> return md
#endif
_ -> do
found_module' <- TcPluginM.findImportedModule mod_nm $ Just $ fsLit "this"
case found_module' of
#if __GLASGOW_HASKELL__ >= 711
FoundModule h -> return (fr_mod h)
#else
Found _ md -> return md
#endif
_ -> panicDoc "Unable to resolve module looked up by plugin: "
(ppr mod_nm)
lookupName :: Module -> OccName -> TcPluginM Name
lookupName md occ = lookupOrig md occ
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin s TcPlugin{..} = TcPlugin { tcPluginInit = traceInit
, tcPluginSolve = traceSolve
, tcPluginStop = traceStop
}
where
traceInit = do
initializeStaticFlags
tcPluginTrace ("tcPluginInit " ++ s) empty >> tcPluginInit
traceStop z = tcPluginTrace ("tcPluginStop " ++ s) empty >> tcPluginStop z
traceSolve z given derived wanted = do
tcPluginTrace ("tcPluginSolve start " ++ s)
(text "given =" <+> ppr given
$$ text "derived =" <+> ppr derived
$$ text "wanted =" <+> ppr wanted)
r <- tcPluginSolve z given derived wanted
case r of
TcPluginOk solved new -> tcPluginTrace ("tcPluginSolve ok " ++ s)
(text "solved =" <+> ppr solved
$$ text "new =" <+> ppr new)
TcPluginContradiction bad -> tcPluginTrace
("tcPluginSolve contradiction " ++ s)
(text "bad =" <+> ppr bad)
return r
initializeStaticFlags :: TcPluginM ()
#if __GLASGOW_HASKELL__ < 802
initializeStaticFlags = tcPluginIO $ do
r <- readIORef v_opt_C_ready
unless r initStaticOpts
#else
initializeStaticFlags = return ()
#endif
flattenGivens
:: [Ct]
-> [Ct]
flattenGivens givens =
mapMaybe flatToCt flat ++ map (substCt subst') givens
where
subst = mkSubst' givens
(flat,subst')
= second (map fst . concat)
$ partition ((>= 2) . length)
$ groupBy ((==) `on` (fst.fst))
$ sortOn (fst.fst) subst
flatToCt :: [((TcTyVar,TcType),Ct)] -> Maybe Ct
flatToCt [((_,lhs),ct),((_,rhs),_)]
= Just
$ mkNonCanonical
$ CtGiven (mkPrimEqPred lhs rhs)
#if MIN_VERSION_ghc(8,4,0)
(ctEvId ct)
#elif MIN_VERSION_ghc(8,0,0)
(ctEvId (cc_ev ct))
#else
(ctEvTerm (cc_ev ct))
#endif
(ctLoc ct)
flatToCt _ = Nothing
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' = foldr substSubst [] . mapMaybe mkSubst
where
substSubst :: ((TcTyVar,TcType),Ct)
-> [((TcTyVar,TcType),Ct)]
-> [((TcTyVar,TcType),Ct)]
substSubst ((tv,t),ct) s = ((tv,substType (map fst s) t),ct)
: map (first (second (substType [(tv,t)]))) s
mkSubst
:: Ct
-> Maybe ((TcTyVar, TcType),Ct)
mkSubst ct@(CTyEqCan {..}) = Just ((cc_tyvar,cc_rhs),ct)
mkSubst ct@(CFunEqCan {..}) = Just ((cc_fsk,TyConApp cc_fun cc_tyargs),ct)
mkSubst _ = Nothing
substCt
:: [(TcTyVar, TcType)]
-> Ct
-> Ct
substCt subst ct =
ct { cc_ev = (cc_ev ct) {ctev_pred = substType subst (ctev_pred (cc_ev ct))}
}
substType
:: [(TcTyVar, TcType)]
-> TcType
-> TcType
substType subst tv@(TyVarTy v) = case lookup v subst of
Just t -> t
Nothing -> tv
substType subst (AppTy t1 t2) =
AppTy (substType subst t1) (substType subst t2)
substType subst (TyConApp tc xs) =
TyConApp tc (map (substType subst) xs)
substType _subst t@(ForAllTy _tv _ty) =
t
#if __GLASGOW_HASKELL__ >= 809
substType subst (FunTy af t1 t2) =
FunTy af (substType subst t1) (substType subst t2)
#elif __GLASGOW_HASKELL__ >= 802
substType subst (FunTy t1 t2) =
FunTy (substType subst t1) (substType subst t2)
#elif __GLASGOW_HASKELL__ < 711
substType subst (FunTy t1 t2) =
FunTy (substType subst t1) (substType subst t2)
#endif
substType _ l@(LitTy _) = l
#if __GLASGOW_HASKELL__ > 711
substType subst (CastTy ty co) =
CastTy (substType subst ty) co
substType _ co@(CoercionTy _) = co
#endif