module Hint.Annotations (
    getModuleAnnotations,
    getValAnnotations
) where

import Data.Data
import GHC.Serialized

import Hint.Base
import qualified Hint.GHC as GHC

#if MIN_VERSION_ghc(9,0,0)
import GHC.Driver.Types (hsc_mod_graph, ms_mod)
import GHC.Types.Annotations
import GHC.Utils.Monad (concatMapM)
#else
import Annotations
import HscTypes (hsc_mod_graph, ms_mod)
import MonadUtils (concatMapM)
#endif

-- Get the annotations associated with a particular module.
getModuleAnnotations :: (Data a, MonadInterpreter m) => a -> String -> m [a]
getModuleAnnotations :: forall a (m :: * -> *).
(Data a, MonadInterpreter m) =>
a -> String -> m [a]
getModuleAnnotations a
_ String
x = do
    [ModSummary]
mods <- ModuleGraph -> [ModSummary]
GHC.mgModSummaries (ModuleGraph -> [ModSummary])
-> (HscEnv -> ModuleGraph) -> HscEnv -> [ModSummary]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HscEnv -> ModuleGraph
hsc_mod_graph (HscEnv -> [ModSummary]) -> m HscEnv -> m [ModSummary]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RunGhc m HscEnv
forall (m :: * -> *) a. MonadInterpreter m => RunGhc m a
runGhc forall {n :: * -> *}. (MonadIO n, MonadMask n) => GhcT n HscEnv
forall (m :: * -> *). GhcMonad m => m HscEnv
GHC.getSession
    let x' :: [ModSummary]
x' = (ModSummary -> Bool) -> [ModSummary] -> [ModSummary]
forall a. (a -> Bool) -> [a] -> [a]
filter (String -> String -> Bool
forall a. Eq a => a -> a -> Bool
(==) String
x (String -> Bool) -> (ModSummary -> String) -> ModSummary -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModuleName -> String
GHC.moduleNameString (ModuleName -> String)
-> (ModSummary -> ModuleName) -> ModSummary -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenModule Unit -> ModuleName
forall unit. GenModule unit -> ModuleName
GHC.moduleName (GenModule Unit -> ModuleName)
-> (ModSummary -> GenModule Unit) -> ModSummary -> ModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> GenModule Unit
ms_mod) [ModSummary]
mods
    (ModSummary -> m [a]) -> [ModSummary] -> m [a]
forall (m :: * -> *) a b. Monad m => (a -> m [b]) -> [a] -> m [b]
concatMapM (AnnTarget Name -> m [a]
forall (m :: * -> *) a.
(MonadInterpreter m, Data a) =>
AnnTarget Name -> m [a]
anns (AnnTarget Name -> m [a])
-> (ModSummary -> AnnTarget Name) -> ModSummary -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenModule Unit -> AnnTarget Name
forall name. GenModule Unit -> AnnTarget name
ModuleTarget (GenModule Unit -> AnnTarget Name)
-> (ModSummary -> GenModule Unit) -> ModSummary -> AnnTarget Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ModSummary -> GenModule Unit
ms_mod) [ModSummary]
x'

-- Get the annotations associated with a particular function.
getValAnnotations :: (Data a, MonadInterpreter m) => a -> String -> m [a]
getValAnnotations :: forall a (m :: * -> *).
(Data a, MonadInterpreter m) =>
a -> String -> m [a]
getValAnnotations a
_ String
s = do
    [Name]
names <- RunGhc1 m String [Name]
forall (m :: * -> *) a b. MonadInterpreter m => RunGhc1 m a b
runGhc1 forall {n :: * -> *}.
(MonadIO n, MonadMask n) =>
String -> GhcT n [Name]
forall (m :: * -> *). GhcMonad m => String -> m [Name]
GHC.parseName String
s
    (Name -> m [a]) -> [Name] -> m [a]
forall (m :: * -> *) a b. Monad m => (a -> m [b]) -> [a] -> m [b]
concatMapM (AnnTarget Name -> m [a]
forall (m :: * -> *) a.
(MonadInterpreter m, Data a) =>
AnnTarget Name -> m [a]
anns (AnnTarget Name -> m [a])
-> (Name -> AnnTarget Name) -> Name -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> AnnTarget Name
forall name. name -> AnnTarget name
NamedTarget) [Name]
names

anns :: (MonadInterpreter m, Data a) => AnnTarget GHC.Name -> m [a]
anns :: forall (m :: * -> *) a.
(MonadInterpreter m, Data a) =>
AnnTarget Name -> m [a]
anns = RunGhc1 m (AnnTarget Name) [a]
forall (m :: * -> *) a b. MonadInterpreter m => RunGhc1 m a b
runGhc1 (([Word8] -> a) -> AnnTarget Name -> GhcT n [a]
forall (m :: * -> *) a.
(GhcMonad m, Typeable a) =>
([Word8] -> a) -> AnnTarget Name -> m [a]
GHC.findGlobalAnns [Word8] -> a
forall a. Data a => [Word8] -> a
deserializeWithData)