module Hint.Annotations (
    getModuleAnnotations,
    getValAnnotations
) where

import Data.Data
import GHC.Serialized

import Hint.Base
import qualified Hint.GHC as GHC

#if MIN_VERSION_ghc(9,2,0)
import GHC (ms_mod)
import GHC.Driver.Env (hsc_mod_graph)
#elif MIN_VERSION_ghc(9,0,0)
import GHC.Driver.Types (hsc_mod_graph, ms_mod)
#else
import HscTypes (hsc_mod_graph, ms_mod)
#endif

#if MIN_VERSION_ghc(9,0,0)
import GHC.Types.Annotations
import GHC.Utils.Monad (concatMapM)
#else
import Annotations
import MonadUtils (concatMapM)
#endif

-- Get the annotations associated with a particular module.
getModuleAnnotations :: (Data a, MonadInterpreter m) => a -> String -> m [a]
getModuleAnnotations :: a -> String -> m [a]
getModuleAnnotations _ x :: String
x = do
    [ModSummary]
mods <- ModuleGraph -> [ModSummary]
GHC.mgModSummaries (ModuleGraph -> [ModSummary])
-> (HscEnv -> ModuleGraph) -> HscEnv -> [ModSummary]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HscEnv -> ModuleGraph
hsc_mod_graph (HscEnv -> [ModSummary]) -> m HscEnv -> m [ModSummary]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RunGhc m HscEnv
forall (m :: * -> *) a. MonadInterpreter m => RunGhc m a
runGhc forall (n :: * -> *). (MonadIO n, MonadMask n) => GhcT n HscEnv
forall (m :: * -> *). GhcMonad m => m HscEnv
GHC.getSession
    let x' :: [ModSummary]
x' = (ModSummary -> Bool) -> [ModSummary] -> [ModSummary]
forall a. (a -> Bool) -> [a] -> [a]
filter (String -> String -> Bool
forall a. Eq a => a -> a -> Bool
(==) String
x (String -> Bool) -> (ModSummary -> String) -> ModSummary -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleName -> String
GHC.moduleNameString (ModuleName -> String)
-> (ModSummary -> ModuleName) -> ModSummary -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Module -> ModuleName
GHC.moduleName (Module -> ModuleName)
-> (ModSummary -> Module) -> ModSummary -> ModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> Module
ms_mod) [ModSummary]
mods
    (ModSummary -> m [a]) -> [ModSummary] -> m [a]
forall (m :: * -> *) a b. Monad m => (a -> m [b]) -> [a] -> m [b]
concatMapM (AnnTarget Name -> m [a]
forall (m :: * -> *) a.
(MonadInterpreter m, Data a) =>
AnnTarget Name -> m [a]
anns (AnnTarget Name -> m [a])
-> (ModSummary -> AnnTarget Name) -> ModSummary -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Module -> AnnTarget Name
forall name. Module -> AnnTarget name
ModuleTarget (Module -> AnnTarget Name)
-> (ModSummary -> Module) -> ModSummary -> AnnTarget Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> Module
ms_mod) [ModSummary]
x'

-- Get the annotations associated with a particular function.
getValAnnotations :: (Data a, MonadInterpreter m) => a -> String -> m [a]
getValAnnotations :: a -> String -> m [a]
getValAnnotations _ s :: String
s = do
    [Name]
names <- RunGhc m [Name]
forall (m :: * -> *) a. MonadInterpreter m => RunGhc m a
runGhc RunGhc m [Name] -> RunGhc m [Name]
forall a b. (a -> b) -> a -> b
$ String -> GhcT n [Name]
forall (m :: * -> *). GhcMonad m => String -> m [Name]
GHC.parseName String
s
    (Name -> m [a]) -> [Name] -> m [a]
forall (m :: * -> *) a b. Monad m => (a -> m [b]) -> [a] -> m [b]
concatMapM (AnnTarget Name -> m [a]
forall (m :: * -> *) a.
(MonadInterpreter m, Data a) =>
AnnTarget Name -> m [a]
anns (AnnTarget Name -> m [a])
-> (Name -> AnnTarget Name) -> Name -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> AnnTarget Name
forall name. name -> AnnTarget name
NamedTarget) [Name]
names

anns :: (MonadInterpreter m, Data a) => AnnTarget GHC.Name -> m [a]
anns :: AnnTarget Name -> m [a]
anns target :: AnnTarget Name
target = RunGhc m [a]
forall (m :: * -> *) a. MonadInterpreter m => RunGhc m a
runGhc RunGhc m [a] -> RunGhc m [a]
forall a b. (a -> b) -> a -> b
$ ([Word8] -> a) -> AnnTarget Name -> GhcT n [a]
forall (m :: * -> *) a.
(GhcMonad m, Typeable a) =>
([Word8] -> a) -> AnnTarget Name -> m [a]
GHC.findGlobalAnns [Word8] -> a
forall a. Data a => [Word8] -> a
deserializeWithData AnnTarget Name
target