{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternSynonyms #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module Internal
(
newWanted
, newGiven
, newDerived
, evByFiat
, lookupModule
, lookupName
, tracePlugin
, flattenGivens
, mkSubst
, mkSubst'
, substType
, substCt
)
where
import GHC.Tc.Plugin (TcPluginM, lookupOrig, tcPluginTrace)
import qualified GHC.Tc.Plugin as TcPlugin
(newDerived, newWanted, getTopEnv, tcPluginIO, findImportedModule)
import GHC.Tc.Types (TcPlugin(..), TcPluginResult(..))
import Control.Arrow (first, second)
import Data.Function (on)
import Data.List (groupBy, partition, sortOn)
import GHC.Tc.Utils.TcType (TcType)
import Data.Maybe (mapMaybe)
import GhcApi.Constraint (Ct(..), CtEvidence(..), CtLoc)
import GhcApi.GhcPlugins
import Internal.Type (substType)
import Internal.Constraint (newGiven, flatToCt, mkSubst, overEvidencePredType)
import Internal.Evidence (evByFiat)
{-# ANN fr_mod "HLint: ignore Use camelCase" #-}
pattern FoundModule :: Module -> FindResult
pattern $mFoundModule :: forall {r}. FindResult -> (Module -> r) -> ((# #) -> r) -> r
FoundModule a <- Found _ a
fr_mod :: a -> a
fr_mod :: forall a. a -> a
fr_mod = forall a. a -> a
id
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted :: CtLoc -> PredType -> TcPluginM CtEvidence
newWanted = CtLoc -> PredType -> TcPluginM CtEvidence
TcPlugin.newWanted
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived :: CtLoc -> PredType -> TcPluginM CtEvidence
newDerived = CtLoc -> PredType -> TcPluginM CtEvidence
TcPlugin.newDerived
lookupModule :: ModuleName
-> FastString
-> TcPluginM Module
lookupModule :: ModuleName -> FastString -> TcPluginM Module
lookupModule ModuleName
mod_nm FastString
_pkg = do
HscEnv
hsc_env <- TcPluginM HscEnv
TcPlugin.getTopEnv
FindResult
found_module <- forall a. IO a -> TcPluginM a
TcPlugin.tcPluginIO forall a b. (a -> b) -> a -> b
$ HscEnv -> ModuleName -> IO FindResult
findPluginModule HscEnv
hsc_env ModuleName
mod_nm
case FindResult
found_module of
FoundModule Module
h -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> a
fr_mod Module
h)
FindResult
_ -> do
FindResult
found_module' <- ModuleName -> Maybe FastString -> TcPluginM FindResult
TcPlugin.findImportedModule ModuleName
mod_nm forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String -> FastString
fsLit String
"this"
case FindResult
found_module' of
FoundModule Module
h -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> a
fr_mod Module
h)
FindResult
_ -> forall a. String -> SDoc -> a
panicDoc String
"Couldn't find module" (forall a. Outputable a => a -> SDoc
ppr ModuleName
mod_nm)
lookupName :: Module -> OccName -> TcPluginM Name
lookupName :: Module -> OccName -> TcPluginM Name
lookupName = Module -> OccName -> TcPluginM Name
lookupOrig
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin :: String -> TcPlugin -> TcPlugin
tracePlugin String
s TcPlugin{TcPluginM s
s -> TcPluginM ()
s -> TcPluginSolver
tcPluginInit :: ()
tcPluginSolve :: ()
tcPluginStop :: ()
tcPluginStop :: s -> TcPluginM ()
tcPluginSolve :: s -> TcPluginSolver
tcPluginInit :: TcPluginM s
..} = TcPlugin { tcPluginInit :: TcPluginM s
tcPluginInit = TcPluginM s
traceInit
, tcPluginSolve :: s -> TcPluginSolver
tcPluginSolve = s -> TcPluginSolver
traceSolve
, tcPluginStop :: s -> TcPluginM ()
tcPluginStop = s -> TcPluginM ()
traceStop
}
where
traceInit :: TcPluginM s
traceInit = do
TcPluginM ()
initializeStaticFlags
String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginInit " forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TcPluginM s
tcPluginInit
traceStop :: s -> TcPluginM ()
traceStop s
z = String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginStop " forall a. [a] -> [a] -> [a]
++ String
s) SDoc
empty forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> s -> TcPluginM ()
tcPluginStop s
z
traceSolve :: s -> TcPluginSolver
traceSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted = do
String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve start " forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"given =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
given
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"derived =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
derived
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"wanted =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
TcPluginResult
r <- s -> TcPluginSolver
tcPluginSolve s
z [Ct]
given [Ct]
derived [Ct]
wanted
case TcPluginResult
r of
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new -> String -> SDoc -> TcPluginM ()
tcPluginTrace (String
"tcPluginSolve ok " forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"solved =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [(EvTerm, Ct)]
solved
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"new =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
new)
TcPluginContradiction [Ct]
bad -> String -> SDoc -> TcPluginM ()
tcPluginTrace
(String
"tcPluginSolve contradiction " forall a. [a] -> [a] -> [a]
++ String
s)
(String -> SDoc
text String
"bad =" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr [Ct]
bad)
forall (m :: * -> *) a. Monad m => a -> m a
return TcPluginResult
r
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags :: TcPluginM ()
initializeStaticFlags = forall (m :: * -> *) a. Monad m => a -> m a
return ()
flattenGivens :: [Ct] -> [Ct]
flattenGivens :: [Ct] -> [Ct]
flattenGivens [Ct]
givens =
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [((TcTyVar, PredType), Ct)] -> Maybe Ct
flatToCt [[((TcTyVar, PredType), Ct)]]
flat forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map ([(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst') [Ct]
givens
where
subst :: [((TcTyVar, PredType), Ct)]
subst = [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' [Ct]
givens
([[((TcTyVar, PredType), Ct)]]
flat,[(TcTyVar, PredType)]
subst')
= forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat)
forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((forall a. Ord a => a -> a -> Bool
>= Int
2) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length)
forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (forall a. Eq a => a -> a -> Bool
(==) forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst))
forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (forall a b. (a, b) -> a
fstforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) [((TcTyVar, PredType), Ct)]
subst
mkSubst' :: [Ct] -> [((TcTyVar,TcType),Ct)]
mkSubst' :: [Ct] -> [((TcTyVar, PredType), Ct)]
mkSubst' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ct -> Maybe ((TcTyVar, PredType), Ct)
mkSubst
where
substSubst :: ((TcTyVar,TcType),Ct)
-> [((TcTyVar,TcType),Ct)]
-> [((TcTyVar,TcType),Ct)]
substSubst :: ((TcTyVar, PredType), Ct)
-> [((TcTyVar, PredType), Ct)] -> [((TcTyVar, PredType), Ct)]
substSubst ((TcTyVar
tv,PredType
t),Ct
ct) [((TcTyVar, PredType), Ct)]
s = ((TcTyVar
tv,[(TcTyVar, PredType)] -> PredType -> PredType
substType (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [((TcTyVar, PredType), Ct)]
s) PredType
t),Ct
ct)
forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar
tv,PredType
t)]))) [((TcTyVar, PredType), Ct)]
s
substCt :: [(TcTyVar, TcType)] -> Ct -> Ct
substCt :: [(TcTyVar, PredType)] -> Ct -> Ct
substCt [(TcTyVar, PredType)]
subst = (PredType -> PredType) -> Ct -> Ct
overEvidencePredType ([(TcTyVar, PredType)] -> PredType -> PredType
substType [(TcTyVar, PredType)]
subst)