{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TcPluginM.Extra
(
newWanted
, newGiven
, newDerived
#if __GLASGOW_HASKELL__ < 711
, newWantedWithProvenance
#endif
, evByFiat
#if __GLASGOW_HASKELL__ < 711
, failWithProvenace
#endif
, lookupModule
, lookupName
, tracePlugin
#if __GLASGOW_HASKELL__ >= 804
, flattenGivens
, mkSubst
, mkSubst'
, substType
, substCt
#endif
)
where
#if __GLASGOW_HASKELL__ < 711
import Data.Maybe (mapMaybe)
#endif
#if __GLASGOW_HASKELL__ < 711
import BasicTypes (TopLevelFlag (..))
#endif
#if MIN_VERSION_ghc(8,5,0)
import CoreSyn (Expr(..))
#endif
import Coercion (Role (..), mkUnivCo)
import FastString (FastString, fsLit)
import Module (Module, ModuleName)
import Name (Name)
import OccName (OccName)
import Outputable (($$), (<+>), empty, ppr, text)
import Panic (panicDoc)
#if __GLASGOW_HASKELL__ >= 711
import TcEvidence (EvTerm (..))
#else
import TcEvidence (EvTerm (..), TcCoercion (..))
import TcMType (newEvVar)
#endif
#if __GLASGOW_HASKELL__ < 711
import TcPluginM (FindResult (..), TcPluginM, findImportedModule, lookupOrig,
tcPluginTrace, unsafeTcPluginTcM)
import TcRnTypes (Ct, CtEvidence (..), CtLoc, TcIdBinder (..), TcLclEnv (..),
TcPlugin (..), TcPluginResult (..), ctEvId, ctEvLoc, ctLoc,
ctLocEnv, mkNonCanonical, setCtLocEnv)
#else
import TcPluginM (FindResult (..), TcPluginM, findImportedModule, lookupOrig,
tcPluginTrace)
import qualified TcPluginM
import TcRnTypes (CtEvidence (..), CtLoc,
TcPlugin (..), TcPluginResult (..))
#endif
#if __GLASGOW_HASKELL__ < 802
import TcPluginM (tcPluginIO)
#endif
#if __GLASGOW_HASKELL__ >= 711
import TyCoRep (UnivCoProvenance (..))
import Type (PredType, Type)
#else
import Type (EqRel (..), PredTree (..), PredType, Type, classifyPredType)
import Var (varType)
#endif
#if __GLASGOW_HASKELL__ >= 804
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import Data.Maybe (mapMaybe)
import TcRnTypes (Ct (..), ctLoc, ctEvId, mkNonCanonical)
import TcType (TcTyVar, TcType)
import Type (mkPrimEqPred)
import TyCoRep (Type (..))
#endif
#if __GLASGOW_HASKELL__ < 802
import Data.IORef (readIORef)
import Control.Monad (unless)
import StaticFlags (initStaticOpts, v_opt_C_ready)
#endif
#if __GLASGOW_HASKELL__ >= 711
pattern FoundModule :: Module -> FindResult
pattern FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod = id
#endif
#if __GLASGOW_HASKELL__ < 711
{-# DEPRECATED newWantedWithProvenance "No longer available in GHC 8.0+" #-}
newWantedWithProvenance :: CtEvidence
-> PredType
-> TcPluginM CtEvidence
newWantedWithProvenance ev@(CtWanted {}) p = do
let loc = ctEvLoc ev
env = ctLocEnv loc
id_ = ctEvId ev
env' = env {tcl_bndrs = (TcIdBndr id_ NotTopLevel):tcl_bndrs env}
loc' = setCtLocEnv loc env'
evVar <- unsafeTcPluginTcM $ newEvVar p
return CtWanted { ctev_pred = p
, ctev_evar = evVar
, ctev_loc = loc'}
newWantedWithProvenance ev _ =
panicDoc "newWantedWithProvenance: not a Wanted: " (ppr ev)
#endif
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newWanted = TcPluginM.newWanted
#else
newWanted loc pty = do
new_ev <- unsafeTcPluginTcM $ newEvVar pty
return CtWanted { ctev_pred = pty
, ctev_evar = new_ev
, ctev_loc = loc
}
#endif
newGiven :: CtLoc -> PredType -> EvTerm -> TcPluginM CtEvidence
#if MIN_VERSION_ghc(8,5,0)
newGiven loc pty (EvExpr ev) = TcPluginM.newGiven loc pty ev
newGiven _ _ ev = panicDoc "newGiven: not an EvExpr: " (ppr ev)
#elif __GLASGOW_HASKELL__ >= 711
newGiven = TcPluginM.newGiven
#else
newGiven loc pty evtm = return
CtGiven { ctev_pred = pty
, ctev_evtm = evtm
, ctev_loc = loc
}
#endif
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
#if __GLASGOW_HASKELL__ >= 711
newDerived = TcPluginM.newDerived
#else
newDerived loc pty = return
CtDerived { ctev_pred = pty
, ctev_loc = loc
}
#endif
evByFiat :: String
-> Type
-> Type
-> EvTerm
evByFiat name t1 t2 =
#if MIN_VERSION_ghc(8,5,0)
EvExpr
$ Coercion
#else
EvCoercion
#if __GLASGOW_HASKELL__ < 711
$ TcCoercion
#endif
#endif
$ mkUnivCo
#if __GLASGOW_HASKELL__ >= 711
(PluginProv name)
#else
(fsLit name)
#endif
Nominal t1 t2
#if __GLASGOW_HASKELL__ < 711
{-# DEPRECATED failWithProvenace "No longer available in GHC 8.0+" #-}
failWithProvenace :: Ct -> TcPluginM TcPluginResult
failWithProvenace ct = return (TcPluginContradiction (ct : parents))
where
loc = ctLoc ct
lclbndrs = mapMaybe (\case {TcIdBndr id_ NotTopLevel -> Just id_
;_ -> Nothing })
$ tcl_bndrs (ctLocEnv loc)
eqBndrs = filter ((\x -> case x of { EqPred NomEq _ _ -> True
; _ -> False })
. classifyPredType . snd)
$ map (\ev -> (ev,varType ev)) lclbndrs
parents = map (\(id_,p) -> mkNonCanonical $ CtWanted p id_ loc) eqBndrs
#endif
lookupModule :: ModuleName
-> FastString
-> TcPluginM Module
lookupModule mod_nm pkg = do
found_module <- findImportedModule mod_nm $ Just pkg
case found_module of
#if __GLASGOW_HASKELL__ >= 711
FoundModule h -> return (fr_mod h)
#else
Found _ md -> return md
#endif
_ -> do
found_module' <- findImportedModule mod_nm $ Just $ fsLit "this"
case found_module' of
#if __GLASGOW_HASKELL__ >= 711
FoundModule h -> return (fr_mod h)
#else
Found _ md -> return md
#endif
_ -> panicDoc "Unable to resolve module looked up by plugin: "
(ppr mod_nm)
lookupName :: Module -> OccName -> TcPluginM Name
lookupName md occ = lookupOrig md occ
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin s TcPlugin{..} = TcPlugin { tcPluginInit = traceInit
, tcPluginSolve = traceSolve
, tcPluginStop = traceStop
}
where
traceInit = do
initializeStaticFlags
tcPluginTrace ("tcPluginInit " ++ s) empty >> tcPluginInit
traceStop z = tcPluginTrace ("tcPluginStop " ++ s) empty >> tcPluginStop z
traceSolve z given derived wanted = do
tcPluginTrace ("tcPluginSolve start " ++ s)
(text "given =" <+> ppr given
$$ text "derived =" <+> ppr derived
$$ text "wanted =" <+> ppr wanted)
r <- tcPluginSolve z given derived wanted
case r of
TcPluginOk solved new -> tcPluginTrace ("tcPluginSolve ok " ++ s)
(text "solved =" <+> ppr solved
$$ text "new =" <+> ppr new)
TcPluginContradiction bad -> tcPluginTrace
("tcPluginSolve contradiction " ++ s)
(text "bad =" <+> ppr bad)
return r
initializeStaticFlags :: TcPluginM ()
#if __GLASGOW_HASKELL__ < 802
initializeStaticFlags = tcPluginIO $ do
r <- readIORef v_opt_C_ready
unless r initStaticOpts
#else
initializeStaticFlags = return ()
#endif
#if __GLASGOW_HASKELL__ >= 804
flattenGivens
:: [Ct]
-> [Ct]
flattenGivens givens =
mapMaybe flatToCt flat ++ map (substCt subst') givens
where
subst = mkSubst' givens
(flat,subst')
= second (map fst . concat)
$ partition ((>= 2) . length)
$ groupBy ((==) `on` (fst.fst))
$ sortOn (fst.fst) subst
flatToCt [((_,lhs),ct),((_,rhs),_)]
= Just
$ mkNonCanonical
$ CtGiven (mkPrimEqPred lhs rhs) (ctEvId ct) (ctLoc ct)
flatToCt _ = Nothing
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' = foldr substSubst [] . mapMaybe mkSubst
where
substSubst :: ((TcTyVar,TcType),Ct)
-> [((TcTyVar,TcType),Ct)]
-> [((TcTyVar,TcType),Ct)]
substSubst ((tv,t),ct) s = ((tv,substType (map fst s) t),ct)
: map (first (second (substType [(tv,t)]))) s
mkSubst
:: Ct
-> Maybe ((TcTyVar, TcType),Ct)
mkSubst ct@(CTyEqCan {..}) = Just ((cc_tyvar,cc_rhs),ct)
mkSubst ct@(CFunEqCan {..}) = Just ((cc_fsk,TyConApp cc_fun cc_tyargs),ct)
mkSubst _ = Nothing
substCt
:: [(TcTyVar, TcType)]
-> Ct
-> Ct
substCt subst ct =
ct { cc_ev = (cc_ev ct) {ctev_pred = substType subst (ctev_pred (cc_ev ct))}
}
substType
:: [(TcTyVar, TcType)]
-> TcType
-> TcType
substType subst tv@(TyVarTy v) = case lookup v subst of
Just t -> t
Nothing -> tv
substType subst (AppTy t1 t2) =
AppTy (substType subst t1) (substType subst t2)
substType subst (TyConApp tc xs) =
TyConApp tc (map (substType subst) xs)
substType _subst t@(ForAllTy _tv _ty) =
t
substType subst (FunTy t1 t2) =
FunTy (substType subst t1) (substType subst t2)
substType _ l@(LitTy _) = l
substType subst (CastTy ty co) =
CastTy (substType subst ty) co
substType _ co@(CoercionTy _) = co
#endif