-- | Functions and utilities to detect the importent modules, classes -- and types of the plugin. module Control.Supermonad.Plugin.Detect ( -- * Supermonad Class Detection supermonadModuleName , bindClassName, returnClassName , findSupermonadModule , isBindClass, isReturnClass , isSupermonadModule , findBindClass, findReturnClass , findSupermonads , checkSupermonadInstances -- * Identity Type Detection , identityModuleName , identityTyConName , findIdentityModule , findIdentityTyCon , functorClassName, functorModuleName -- * General detection utilities , findInstancesInScope ) where import Data.List ( find ) import Data.Maybe ( catMaybes, listToMaybe ) import qualified Data.Set as S import qualified Data.Map as M import Control.Monad ( forM, liftM ) import BasicTypes ( Arity ) import TcRnTypes ( TcGblEnv(..) , TcTyThing(..) , ImportAvails( imp_mods ) ) import Type ( TyThing(..) ) import TyCon ( TyCon ) import TcPluginM ( TcPluginM , getEnvs, getInstEnvs , tcLookup ) import Name ( nameModule , getOccName ) import OccName ( OccName , occNameString, mkTcOcc ) import RdrName ( GlobalRdrElt(..) , Parent( NoParent ) , lookupGlobalRdrEnv ) import Module ( Module, ModuleName , moduleName , moduleEnvKeys , mkModuleName ) import Class ( Class(..) , className, classArity ) import InstEnv ( ClsInst(..) , instEnvElts , ie_global , classInstances ) import PrelNames ( mAIN_NAME ) import Outputable ( SDoc, ($$), (<>), text, vcat, ppr ) --import qualified Outputable as O --import Control.Supermonad.Plugin.Log ( printObj, printObjTrace ) import Control.Supermonad.Plugin.Wrapper ( UnitId, baseUnitId, moduleUnitId, isImportedFrom ) import Control.Supermonad.Plugin.Instance ( instanceTyArgs , isMonoTyConInstance , isPolyTyConInstance ) import Control.Supermonad.Plugin.Utils ( collectTopTyCons ) -- ----------------------------------------------------------------------------- -- Constant Names (Magic Numbers...) -- ----------------------------------------------------------------------------- -- | Name of the "Control.Supermonad" module. supermonadModuleName :: String supermonadModuleName = "Control.Supermonad" -- | Name of the "Control.Supermonad.Constrained" module. supermonadCtModuleName :: String supermonadCtModuleName = "Control.Supermonad.Constrained" -- | Name of the 'Control.Supermonad.Bind' type class. bindClassName :: String bindClassName = "Bind" -- | Name of the 'Control.Supermonad.Bind' type class. returnClassName :: String returnClassName = "Return" -- | Name of the "Data.Functor.Identity" module. identityModuleName :: String identityModuleName = "Data.Functor.Identity" -- | Name of the 'Data.Functor.Identity.Identity' type constructor. identityTyConName :: String identityTyConName = "Identity" -- | Name of the "Control.Supermonad.Prelude" module. supermonadPreludeModuleName :: String supermonadPreludeModuleName = "Control.Supermonad.Prelude" -- | Name of the "Control.Supermonad.Constrained.Prelude" module. supermonadCtPreludeModuleName :: String supermonadCtPreludeModuleName = "Control.Supermonad.Constrained.Prelude" -- | Name of the 'Data.Functor.Functor' class. functorClassName :: String functorClassName = "Functor" -- | Name of the "Data.Functor" module. functorModuleName :: String functorModuleName = "Data.Functor" -- ----------------------------------------------------------------------------- -- Polymonad Class Detection -- ----------------------------------------------------------------------------- -- | Checks if a module providing the supermonad classes is imported. findSupermonadModule :: TcPluginM (Either SDoc Module) findSupermonadModule = do eSmUnCtMdl <- findSupermonadUnCtModule eSmCtMdl <- findSupermonadCtModule case (eSmUnCtMdl, eSmCtMdl) of (Right _ , Left _errCt) -> return eSmUnCtMdl (Left _err, Right _ ) -> return eSmCtMdl (Left err, Left errCt) -> return $ Left $ text "Could not find supermonad or constrained supermonad modules!" $$ err $$ errCt (Right _, Right _) -> return $ Left $ text "Found unconstrained and constrained supermonad modules!" -- | Checks if the module "Control.Supermonad" or "Control.Supermonad.Prelude" -- is imported and, if so, returns either. findSupermonadUnCtModule :: TcPluginM (Either SDoc Module) findSupermonadUnCtModule = do eMdl <- getModule Nothing supermonadModuleName case eMdl of Left _err -> getModule Nothing supermonadPreludeModuleName Right _ -> return eMdl -- | Checks if the module "Control.Supermonad.Constrained" or "Control.Supermonad.Constrained.Prelude" -- is imported and, if so, returns either. findSupermonadCtModule :: TcPluginM (Either SDoc Module) findSupermonadCtModule = do eCtMdl <- getModule Nothing supermonadCtModuleName case eCtMdl of Left _err -> getModule Nothing supermonadCtPreludeModuleName Right _ -> return eCtMdl -- | Check if the given module is the supermonad module. isSupermonadModule :: Module -> Bool isSupermonadModule mdl = mdlName `elem` [smMdlName, smPrelName, smCtMdlName, smCtPrelName, mAIN_NAME] where mdlName = moduleName mdl smMdlName = mkModuleName supermonadModuleName smPrelName = mkModuleName supermonadPreludeModuleName smCtMdlName = mkModuleName supermonadCtModuleName smCtPrelName = mkModuleName supermonadCtPreludeModuleName -- | Checks if the given class matches the shape of the 'Control.Supermonad.Bind' -- type class and is defined in the right module. isBindClass :: Class -> Bool isBindClass cls = isClass cls isSupermonadModule bindClassName 3 -- | Checks if the given class matches the shape of the 'Control.Supermonad.Return' -- type class and is defined in the right module. isReturnClass :: Class -> Bool isReturnClass cls = isClass cls isSupermonadModule returnClassName 1 -- | Checks if a type class matching the shape and name of the -- 'Control.Supermonad.Bind' type class is in scope. findBindClass :: TcPluginM (Maybe Class) findBindClass = findClass isBindClass -- | Checks if a type class matching the shape and name of the -- 'Control.Supermonad.Return' type class is in scope. findReturnClass :: TcPluginM (Maybe Class) findReturnClass = findClass isReturnClass -- ----------------------------------------------------------------------------- -- Identity Type Detection -- ----------------------------------------------------------------------------- -- | Checks if the module "Data.Functor.Identity" -- is imported and, if so, returns the module. findIdentityModule :: TcPluginM (Either SDoc Module) findIdentityModule = do mdls <- findModules [getModule (Just baseUnitId) identityModuleName, findSupermonadModule] case mdls of [] -> return $ Left $ text "Could not find module 'Identity' module." (mdl:_) -> return $ Right mdl -- | Tries to find the 'Data.Functor.Identity.Identity' type constructor in the imported -- modules. Only looks for imports through specific modules. findIdentityTyCon :: TcPluginM (Maybe TyCon) findIdentityTyCon = do mdls <- findModules [findIdentityModule, findSupermonadModule] case mdls of [] -> return Nothing _ -> findTyConByNameAndModule (mkTcOcc identityTyConName) mdls -- ----------------------------------------------------------------------------- -- Local Utility Functions -- ----------------------------------------------------------------------------- -- | Tries to find all of the given modules using the given search functions. -- Returns the list of all found modules. findModules :: [TcPluginM (Either SDoc Module)] -> TcPluginM [Module] findModules findMdls = do eitherMdls <- sequence findMdls return $ catMaybes $ fmap (either (const Nothing) Just) eitherMdls -- | Checks if the module with the given name is imported and, -- if so, returns that module. getModule :: Maybe UnitId -> String -> TcPluginM (Either SDoc Module) getModule pkgKeyToFind mdlNameToFind = do (gblEnv, _lclEnv) <- getEnvs let mdls = moduleEnvKeys $ imp_mods $ tcg_imports $ gblEnv case find (isModule . splitModule) mdls of Just mdl -> return $ Right mdl Nothing -> return $ Left $ text $ "Could not find module '" ++ mdlNameToFind ++ "'" where isModule :: (UnitId, ModuleName) -> Bool isModule (pkgKey, mdlName) = maybe True (pkgKey ==) pkgKeyToFind && mdlName == mkModuleName mdlNameToFind splitModule :: Module -> (UnitId, ModuleName) splitModule mdl = (moduleUnitId mdl, moduleName mdl) -- For some reason this version also found modules that are not in the -- imports. {- -- | Checks if the module with the given name is imported and, -- if so, returns that module. getModule :: Maybe UnitId -> String -> TcPluginM (Either SDoc Module) getModule pkgKeyToFind mdlNameToFind = do mdlResult <- findImportedModule (mkModuleName mdlNameToFind) Nothing -- From "TcPluginM" case mdlResult of Found _mdlLoc mdl -> if maybe True (moduleUnitId mdl ==) pkgKeyToFind then return $ Right mdl else return $ Left $ text "Package key of found module does not match the requested key:" $$ text "Found: " <> ppr (moduleUnitId mdl) $$ text "Requested: " <> ppr pkgKeyToFind NoPackage pkgKey -> return $ Left $ text "Found module, but missing its package: " <> ppr pkgKey FoundMultiple mdls -> return $ Left $ text ("Module '" ++ mdlNameToFind ++ "' appears in several packages:") $$ ppr (fmap snd mdls) NotFound {} -> return $ Left $ text $ "Module was not found: " ++ mdlNameToFind -} -- | Checks if a type class matching the shape of the given -- predicate is in scope. findClass :: (Class -> Bool) -> TcPluginM (Maybe Class) findClass isClass' = do let isCls = isClass' . is_cls envs <- fst <$> getEnvs -- This is needed while compiling the package itself... let foundInstsLcl = (filter isCls . instEnvElts . tcg_inst_env $ envs) ++ (filter isCls . tcg_insts $ envs) -- This is needed while compiling an external package depending on it... foundInstsGbl <- filter isCls . instEnvElts . ie_global <$> getInstEnvs return $ case foundInstsLcl ++ foundInstsGbl of (inst : _) -> Just $ is_cls inst [] -> Nothing -- | Check if the given class has the given name, arity and if the classes -- module fulfills the given predicate. isClass :: Class -> (Module -> Bool) -> String -> Arity -> Bool isClass cls isModule targetClassName targetArity = let clsName = className cls clsMdl = nameModule clsName clsNameStr = occNameString $ getOccName clsName clsArity = classArity cls in isModule clsMdl && clsNameStr == targetClassName && clsArity == targetArity -- | Try to find a type constructor given its name and the modules it -- is exported from. The type constructor needs to be imported from -- one of these modules. findTyConByNameAndModule :: OccName -> [Module] -> TcPluginM (Maybe TyCon) findTyConByNameAndModule occName mdls = do -- Look at the global environment of names that are in scope. rdrEnv <- tcg_rdr_env . fst <$> getEnvs -- Search for things that have the same name as what we are looking for. let envResultElem = lookupGlobalRdrEnv rdrEnv occName -- Only keep things that are originally from our module and have no parents, -- because type constructors are declared on top-level. let relResults = filter (\e -> any (e `isImportedFrom`) mdls && hasNoParent e) envResultElem -- Find all the typed things that have the same name as the stuff we found. -- Also directly convert them into type constructors if possible mTyCons <- forM relResults $ liftM tcTyThingToTyCon . tcLookup . gre_name -- Only keep those things that actually were type constructors. let tyCons = catMaybes mTyCons -- In theory, we should not find more then one type constructor, -- because that would lead to a name clash in the source module -- and we made sure to only look at one module. return $ listToMaybe tyCons -- | Try to convert the given typed thing into a type constructor. tcTyThingToTyCon :: TcTyThing -> Maybe TyCon tcTyThingToTyCon (AGlobal (ATyCon tc)) = Just tc tcTyThingToTyCon _ = Nothing -- | Check if the given element has no parents. hasNoParent :: GlobalRdrElt -> Bool hasNoParent rdrElt = case gre_par rdrElt of NoParent -> True _ -> False -- | Returns a list of all instances for the given class that are currently in scope. findInstancesInScope :: Class -> TcPluginM [ClsInst] findInstancesInScope cls = do instEnvs <- TcPluginM.getInstEnvs return $ classInstances instEnvs cls -- | Check if there are any supermonad instances that clearly -- do not belong to a specific supermonad. checkSupermonadInstances :: Class -- ^ 'Control.Supermonad.Bind' type class. -> Class -- ^ 'Control.Supermonad.Return' type class. -> TcPluginM [(ClsInst, SDoc)] checkSupermonadInstances bindCls returnCls = do bindInsts <- findInstancesInScope bindCls returnInsts <- findInstancesInScope returnCls let polyBindInsts = filter (isPolyTyConInstance bindCls ) bindInsts let polyReturnInsts = filter (isPolyTyConInstance returnCls) returnInsts return $ fmap (\inst -> (inst, text "Not a valid supermonad instance: " $$ ppr inst)) polyBindInsts ++ fmap (\inst -> (inst, text "Not a valid supermonad instance: " $$ ppr inst)) polyReturnInsts -- | Constructs the map between type constructors and their supermonad instance. findSupermonads :: Class -- ^ 'Control.Supermonad.Bind' type class. -> Class -- ^ 'Control.Supermonad.Return' type class. -> TcPluginM (M.Map TyCon (ClsInst, ClsInst), [(TyCon, SDoc)]) -- ^ Association between type constructor and its -- 'Control.Supermonad.Bind' and 'Control.Supermonad.Return' -- instance in that order. findSupermonads bindCls returnCls = do bindInsts <- findInstancesInScope bindCls returnInsts <- findInstancesInScope returnCls -- Collect all type constructors that are used for supermonads let supermonadTyCons = S.unions $ fmap instTopTyCons $ bindInsts ++ returnInsts -- Find the supermonad instances of each type constructor return $ mconcat $ fmap (findSupermonad bindInsts returnInsts) $ S.toList supermonadTyCons where findSupermonad :: [ClsInst] -> [ClsInst] -> TyCon -> (M.Map TyCon (ClsInst, ClsInst), [(TyCon, SDoc)]) findSupermonad bindInsts returnInsts tc = case ( filter (isMonoTyConInstance tc bindCls) bindInsts , filter (isMonoTyConInstance tc returnCls) returnInsts ) of ([bindInst], [returnInst]) -> (M.singleton tc (bindInst, returnInst), []) ([], _) -> findError tc $ text "Missing 'Bind' instance for supermonad '" <> ppr tc <> text "'." (_, []) -> findError tc $ text "Missing 'Return' instance for supermonad '" <> ppr tc <> text "'." (bindInsts', returnInsts') -> findError tc $ text "Multiple 'Bind' instances for supermonad '" <> ppr tc <> text "':" $$ vcat (fmap ppr bindInsts') $$ text "Multiple 'Return' instances for supermonad '" <> ppr tc <> text "':" $$ vcat (fmap ppr returnInsts') findError :: TyCon -> SDoc -> (M.Map TyCon (ClsInst, ClsInst), [(TyCon, SDoc)]) findError tc msg = (M.empty, [(tc, msg)]) -- | Collect the top-level type constructors in the arguments -- of the given instance. instTopTyCons :: ClsInst -> S.Set TyCon instTopTyCons = collectTopTyCons . instanceTyArgs