{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module GHC.TcPlugin.API.Names
( ResolveNames, resolveNames
, Wear, QualifiedName(..), NameResolution(..)
, Promoted
, Lookupable(..)
, Generically1(..)
) where
import Prelude
hiding ( lookup )
import Data.Coerce
( Coercible, coerce )
import Data.Kind
( Type, Constraint )
import GHC.Generics
( Generic(..)
#if MIN_VERSION_base(4,17,0)
, Generically1(..)
#endif
, (:+:)(..), (:*:)(..)
, K1(K1), M1(M1), U1(..), V1, Rec0
)
import GHC.TypeLits
( TypeError, ErrorMessage(..) )
import Control.Monad.Trans.State.Strict
( StateT, evalStateT, get, put )
import Control.Monad.Trans.Class
( MonadTrans(lift) )
#if MIN_VERSION_ghc(9,3,0)
import GHC.Iface.Errors
( cannotFindModule )
#elif MIN_VERSION_ghc(9,2,0)
import GHC.Iface.Load
( cannotFindModule )
#else
import GHC.Driver.Types
( hsc_dflags )
import GHC.Driver.Finder
( cannotFindModule )
import GHC.Driver.Session
( DynFlags )
#endif
import GHC.Utils.Panic
( pgmErrorDoc )
import GHC.Unit.Module.Name
( moduleNameString )
import GHC.Tc.Plugin
( getTopEnv )
import GHC.Types.Unique.FM
( addToUFM, addToUFM_C, lookupUFM, plusUFM, unitUFM )
import GHC.TcPlugin.API
hiding ( Type )
import GHC.TcPlugin.API.Internal
( MonadTcPlugin(liftTcPluginM) )
type QualifiedName :: k -> Type
data QualifiedName thing
= Qualified
{
name :: String
, module' :: ModuleName
, package :: Maybe FastString
}
data NameResolution = Named | Resolved
type Promoted :: k -> Type
data Promoted thing
type Wear :: forall k. NameResolution -> k -> Type
type family Wear n thing where
Wear Named thing = QualifiedName thing
Wear Resolved (Promoted DataCon) = TyCon
Wear Resolved (Promoted a)
= TypeError
( Text "Cannot promote " :<>: ShowType a :<>: Text "."
:$$: Text "Can only promote 'DataCon's."
)
Wear Resolved thing = thing
type UnwearNamed :: Type -> Type
type family UnwearNamed loc where
UnwearNamed (QualifiedName thing) = thing
#if MIN_VERSION_ghc(9,0,0)
type Lookupable :: forall {k}. k -> Constraint
#endif
class Lookupable a where
mkOccName :: String -> OccName
lookup :: MonadTcPlugin m => Name -> m (Wear Resolved a)
instance Lookupable TyCon where
mkOccName = mkTcOcc
lookup = tcLookupTyCon
instance Lookupable DataCon where
mkOccName = mkDataOcc
lookup = tcLookupDataCon
instance Lookupable Class where
mkOccName = mkClsOcc
lookup = tcLookupClass
instance Lookupable (Promoted DataCon) where
mkOccName = mkDataOcc
lookup = fmap promoteDataCon . tcLookupDataCon
type ResolveNames :: ( NameResolution -> Type ) -> Constraint
class ResolveNames f where
resolve_names :: ( Coercible res ( f Resolved ), MonadTcPlugin m )
=> f Named -> m res
resolveNames :: ( MonadTcPlugin m, ResolveNames f )
=> f Named -> m ( f Resolved )
resolveNames = resolve_names
instance ( Generic (f Named)
, Generic (f Resolved)
, GTraversableC ResolveName (Rep (f Named)) (Rep (f Resolved))
)
=> ResolveNames (Generically1 f) where
resolve_names
:: forall
#if MIN_VERSION_ghc(9,0,0)
{m}
#else
m
#endif
res
. ( Coercible res ( Generically1 f Resolved ), MonadTcPlugin m )
=> Generically1 f Named -> m res
resolve_names dat
= ( `evalStateT` emptyModules )
$ coerce . to @(f Resolved)
<$> gtraverseC @ResolveName resolveName ( from dat )
type ResolveName :: Type -> Type -> Constraint
class ( a ~ Wear Named ( UnwearNamed a )
, b ~ Wear Resolved ( UnwearNamed a )
, Lookupable ( UnwearNamed a )
)
=> ResolveName a b
instance ( a ~ Wear Named ( UnwearNamed a )
, b ~ Wear Resolved ( UnwearNamed a )
, Lookupable ( UnwearNamed a )
)
=> ResolveName a b
resolveName :: forall thing m
. ResolveName ( Wear Named thing ) ( Wear Resolved thing )
=> MonadTcPlugin m
=> Wear Named thing
-> StateT ImportedModules m ( Wear Resolved thing )
resolveName (Qualified str mod_name mb_pkg) = do
md <- lookupModule mb_pkg mod_name
nm <- lift $ lookupOrig md (mkOccName @thing str)
lift $ lookup @thing nm
data ImportedModules
= ImportedModules
{ home_modules :: UniqFM ModuleName Module
, pkg_modules :: UniqFM FastString ( UniqFM ModuleName Module )
}
emptyModules :: ImportedModules
emptyModules = ImportedModules emptyUFM emptyUFM
lookupCachedModule :: Monad m => Maybe FastString -> ModuleName -> StateT ImportedModules m (Maybe Module)
lookupCachedModule Nothing mod_name
= ( `lookupUFM` mod_name )
. home_modules
<$> get
lookupCachedModule (Just pkg) mod_name
= ( ( `lookupUFM` mod_name ) =<< )
. ( `lookupUFM` pkg )
. pkg_modules
<$> get
insertCachedModule :: Monad m => Maybe FastString -> ModuleName -> Module -> StateT ImportedModules m ()
insertCachedModule Nothing mod_name md = do
mods@( ImportedModules { home_modules = prev } ) <- get
put $ mods { home_modules = addToUFM prev mod_name md }
insertCachedModule (Just pkg) mod_name md = do
mods@( ImportedModules { pkg_modules = prev } ) <- get
put $ mods { pkg_modules = addToUFM_C plusUFM prev pkg (unitUFM mod_name md) }
lookupModule :: MonadTcPlugin m => Maybe FastString -> ModuleName -> StateT ImportedModules m Module
lookupModule mb_pkg mod_name = do
cachedResult <- lookupCachedModule mb_pkg mod_name
case cachedResult of
Just res -> do
insertCachedModule mb_pkg mod_name res
pure res
Nothing -> do
findResult <- lift $ findImportedModule mod_name mb_pkg
case findResult of
Found _ res
-> pure res
other -> do
hsc_env <- lift . liftTcPluginM $ getTopEnv
let
err_doc :: SDoc
#if MIN_VERSION_ghc(9,2,0)
err_doc = cannotFindModule hsc_env mod_name other
#else
err_doc = cannotFindModule dflags mod_name other
dflags :: DynFlags
dflags = hsc_dflags hsc_env
#endif
pgmErrorDoc
( "GHC.TcPlugin.API: could not find module " <> mod_str <> " in " <> pkg_name )
err_doc
where
pkg_name, mod_str :: String
pkg_name = case mb_pkg of
Just pkg -> "package " <> show pkg
Nothing -> "home package"
mod_str = moduleNameString mod_name
type TraversalC :: ( Type -> Type -> Constraint ) -> Type -> Type -> Type
type TraversalC c s t
= forall f. ( Applicative f )
=> ( forall a b. c a b => a -> f b ) -> s -> f t
type GTraversableC :: ( Type -> Type -> Constraint )
-> ( Type -> Type )
-> ( Type -> Type )
-> Constraint
class GTraversableC c s t where
gtraverseC :: TraversalC c (s x) (t x)
instance
( GTraversableC c l l'
, GTraversableC c r r'
) => GTraversableC c (l :*: r) (l' :*: r') where
gtraverseC f (l :*: r)
= (:*:) <$> gtraverseC @c f l <*> gtraverseC @c f r
instance
( GTraversableC c l l'
, GTraversableC c r r'
) => GTraversableC c (l :+: r) (l' :+: r') where
gtraverseC f (L1 l) = L1 <$> gtraverseC @c f l
gtraverseC f (R1 r) = R1 <$> gtraverseC @c f r
instance GTraversableC c s t
=> GTraversableC c (M1 i m s) (M1 i m t) where
gtraverseC f (M1 x) = M1 <$> gtraverseC @c f x
instance GTraversableC c U1 U1 where
gtraverseC _ _ = pure U1
instance GTraversableC c V1 V1 where
gtraverseC _ = pure
instance c a b => GTraversableC c (Rec0 a) (Rec0 b) where
gtraverseC f (K1 a) = K1 <$> f a
#if !MIN_VERSION_base(4,17,0)
type Generically1 :: ( k -> Type ) -> ( k -> Type )
newtype Generically1 f a = Generically1 ( f a )
deriving newtype Generic
#endif