{-# LANGUAGE TypeApplications #-}

module Imp where

import qualified Control.Monad.Catch as Exception
import qualified Data.Data as Data
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import qualified Data.Set as Set
import qualified GHC.Hs as Hs
import qualified GHC.Plugins as Plugin
import qualified Imp.Exception.ShowHelp as ShowHelp
import qualified Imp.Exception.ShowVersion as ShowVersion
import qualified Imp.Extra.Exception as Exception
import qualified Imp.Extra.HsModule as HsModule
import qualified Imp.Extra.HsParsedModule as HsParsedModule
import qualified Imp.Extra.ImportDecl as ImportDecl
import qualified Imp.Extra.ParsedResult as ParsedResult
import qualified Imp.Extra.SrcSpanAnnN as SrcSpanAnnN
import qualified Imp.Ghc as Ghc
import qualified Imp.Type.Config as Config
import qualified Imp.Type.Context as Context
import qualified Imp.Type.Flag as Flag
import qualified System.Exit as Exit
import qualified System.IO as IO

plugin :: Plugin.Plugin
plugin :: Plugin
plugin =
  Plugin
Plugin.defaultPlugin
    { Plugin.parsedResultAction = parsedResultAction,
      Plugin.pluginRecompile = Plugin.purePlugin
    }

parsedResultAction ::
  [Plugin.CommandLineOption] ->
  modSummary ->
  Plugin.ParsedResult ->
  Plugin.Hsc Plugin.ParsedResult
parsedResultAction :: forall modSummary.
[CommandLineOption]
-> modSummary -> ParsedResult -> Hsc ParsedResult
parsedResultAction [CommandLineOption]
commandLineOptions modSummary
_ =
  IO ParsedResult -> Hsc ParsedResult
forall a. IO a -> Hsc a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
Plugin.liftIO
    (IO ParsedResult -> Hsc ParsedResult)
-> (ParsedResult -> IO ParsedResult)
-> ParsedResult
-> Hsc ParsedResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SomeException -> IO ParsedResult)
-> IO ParsedResult -> IO ParsedResult
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
(e -> m a) -> m a -> m a
Exception.handle SomeException -> IO ParsedResult
forall a. SomeException -> IO a
handleException
    (IO ParsedResult -> IO ParsedResult)
-> (ParsedResult -> IO ParsedResult)
-> ParsedResult
-> IO ParsedResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HsParsedModule -> IO HsParsedModule)
-> ParsedResult -> IO ParsedResult
forall (f :: * -> *).
Functor f =>
(HsParsedModule -> f HsParsedModule)
-> ParsedResult -> f ParsedResult
ParsedResult.overModule ((Located HsModulePs -> IO (Located HsModulePs))
-> HsParsedModule -> IO HsParsedModule
forall (f :: * -> *).
Functor f =>
(Located HsModulePs -> f (Located HsModulePs))
-> HsParsedModule -> f HsParsedModule
HsParsedModule.overModule ((Located HsModulePs -> IO (Located HsModulePs))
 -> HsParsedModule -> IO HsParsedModule)
-> (Located HsModulePs -> IO (Located HsModulePs))
-> HsParsedModule
-> IO HsParsedModule
forall a b. (a -> b) -> a -> b
$ [CommandLineOption]
-> Located HsModulePs -> IO (Located HsModulePs)
forall (m :: * -> *).
MonadThrow m =>
[CommandLineOption] -> Located HsModulePs -> m (Located HsModulePs)
imp [CommandLineOption]
commandLineOptions)

handleException :: Exception.SomeException -> IO a
handleException :: forall a. SomeException -> IO a
handleException SomeException
e = do
  Handle -> CommandLineOption -> IO ()
IO.hPutStrLn Handle
IO.stderr (CommandLineOption -> IO ()) -> CommandLineOption -> IO ()
forall a b. (a -> b) -> a -> b
$ SomeException -> CommandLineOption
forall e. Exception e => e -> CommandLineOption
Exception.displayException SomeException
e
  ExitCode -> IO a
forall a. ExitCode -> IO a
Exit.exitWith (ExitCode -> IO a) -> ExitCode -> IO a
forall a b. (a -> b) -> a -> b
$ SomeException -> ExitCode
exceptionToExitCode SomeException
e

exceptionToExitCode :: Exception.SomeException -> Exit.ExitCode
exceptionToExitCode :: SomeException -> ExitCode
exceptionToExitCode SomeException
e
  | forall e. Exception e => SomeException -> Bool
Exception.isType @ShowHelp.ShowHelp SomeException
e = ExitCode
Exit.ExitSuccess
  | forall e. Exception e => SomeException -> Bool
Exception.isType @ShowVersion.ShowVersion SomeException
e = ExitCode
Exit.ExitSuccess
  | Bool
otherwise = Int -> ExitCode
Exit.ExitFailure Int
1

imp ::
  (Exception.MonadThrow m) =>
  [String] ->
  Plugin.Located Ghc.HsModulePs ->
  m (Plugin.Located Ghc.HsModulePs)
imp :: forall (m :: * -> *).
MonadThrow m =>
[CommandLineOption] -> Located HsModulePs -> m (Located HsModulePs)
imp [CommandLineOption]
arguments Located HsModulePs
lHsModule = do
  [Flag]
flags <- [CommandLineOption] -> m [Flag]
forall (m :: * -> *).
MonadThrow m =>
[CommandLineOption] -> m [Flag]
Flag.fromArguments [CommandLineOption]
arguments
  Config
config <- [Flag] -> m Config
forall (m :: * -> *). MonadThrow m => [Flag] -> m Config
Config.fromFlags [Flag]
flags
  Context
context <- Config -> m Context
forall (m :: * -> *). MonadThrow m => Config -> m Context
Context.fromConfig Config
config
  let aliases :: Map ModuleName ModuleName
aliases = Context -> Map ModuleName ModuleName
Context.aliases Context
context
      moduleNames :: Map ModuleName SrcSpanAnnN
moduleNames =
        (SrcSpanAnnN -> SrcSpanAnnN -> SrcSpanAnnN)
-> [(ModuleName, SrcSpanAnnN)] -> Map ModuleName SrcSpanAnnN
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith SrcSpanAnnN -> SrcSpanAnnN -> SrcSpanAnnN
SrcSpanAnnN.leftmostSmallest
          ([(ModuleName, SrcSpanAnnN)] -> Map ModuleName SrcSpanAnnN)
-> (HsModulePs -> [(ModuleName, SrcSpanAnnN)])
-> HsModulePs
-> Map ModuleName SrcSpanAnnN
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GenLocated SrcSpanAnnN RdrName -> Maybe (ModuleName, SrcSpanAnnN))
-> [GenLocated SrcSpanAnnN RdrName] -> [(ModuleName, SrcSpanAnnN)]
forall a b. (a -> Maybe b) -> [a] -> [b]
Maybe.mapMaybe
            ( \GenLocated SrcSpanAnnN RdrName
lRdrName -> case GenLocated SrcSpanAnnN RdrName -> RdrName
forall l e. GenLocated l e -> e
Plugin.unLoc GenLocated SrcSpanAnnN RdrName
lRdrName of
                Plugin.Qual ModuleName
moduleName OccName
_ -> (ModuleName, SrcSpanAnnN) -> Maybe (ModuleName, SrcSpanAnnN)
forall a. a -> Maybe a
Just (ModuleName
moduleName, GenLocated SrcSpanAnnN RdrName -> SrcSpanAnnN
forall l e. GenLocated l e -> l
Plugin.getLoc GenLocated SrcSpanAnnN RdrName
lRdrName)
                RdrName
_ -> Maybe (ModuleName, SrcSpanAnnN)
forall a. Maybe a
Nothing
            )
          ([GenLocated SrcSpanAnnN RdrName] -> [(ModuleName, SrcSpanAnnN)])
-> (HsModulePs -> [GenLocated SrcSpanAnnN RdrName])
-> HsModulePs
-> [(ModuleName, SrcSpanAnnN)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [GenLocated (SrcAnn AnnListItem) (HsDecl GhcPs)]
-> [GenLocated SrcSpanAnnN RdrName]
forall a b. (Data a, Data b) => a -> [b]
biplate
          ([GenLocated (SrcAnn AnnListItem) (HsDecl GhcPs)]
 -> [GenLocated SrcSpanAnnN RdrName])
-> (HsModulePs -> [GenLocated (SrcAnn AnnListItem) (HsDecl GhcPs)])
-> HsModulePs
-> [GenLocated SrcSpanAnnN RdrName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsModulePs -> [LHsDecl GhcPs]
HsModulePs -> [GenLocated (SrcAnn AnnListItem) (HsDecl GhcPs)]
forall p. HsModule p -> [LHsDecl p]
Hs.hsmodDecls
          (HsModulePs -> Map ModuleName SrcSpanAnnN)
-> HsModulePs -> Map ModuleName SrcSpanAnnN
forall a b. (a -> b) -> a -> b
$ Located HsModulePs -> HsModulePs
forall l e. GenLocated l e -> e
Plugin.unLoc Located HsModulePs
lHsModule
  Located HsModulePs -> m (Located HsModulePs)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Located HsModulePs -> m (Located HsModulePs))
-> Located HsModulePs -> m (Located HsModulePs)
forall a b. (a -> b) -> a -> b
$ (HsModulePs -> HsModulePs)
-> Located HsModulePs -> Located HsModulePs
forall a b.
(a -> b) -> GenLocated SrcSpan a -> GenLocated SrcSpan b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([LImportDecl GhcPs] -> [LImportDecl GhcPs])
-> HsModulePs -> HsModulePs
HsModule.overImports (([LImportDecl GhcPs] -> [LImportDecl GhcPs])
 -> HsModulePs -> HsModulePs)
-> ([LImportDecl GhcPs] -> [LImportDecl GhcPs])
-> HsModulePs
-> HsModulePs
forall a b. (a -> b) -> a -> b
$ Map ModuleName ModuleName
-> Map ModuleName SrcSpanAnnN
-> [LImportDecl GhcPs]
-> [LImportDecl GhcPs]
updateImports Map ModuleName ModuleName
aliases Map ModuleName SrcSpanAnnN
moduleNames) Located HsModulePs
lHsModule

biplate :: (Data.Data a, Data.Data b) => a -> [b]
biplate :: forall a b. (Data a, Data b) => a -> [b]
biplate = [[b]] -> [b]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[b]] -> [b]) -> (a -> [[b]]) -> a -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall d. Data d => d -> [b]) -> a -> [[b]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> a -> [u]
Data.gmapQ (\d
d -> [b] -> (b -> [b]) -> Maybe b -> [b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (d -> [b]
forall a b. (Data a, Data b) => a -> [b]
biplate d
d) b -> [b]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe b -> [b]) -> Maybe b -> [b]
forall a b. (a -> b) -> a -> b
$ d -> Maybe b
forall a b. (Typeable a, Typeable b) => a -> Maybe b
Data.cast d
d)

updateImports ::
  Map.Map Plugin.ModuleName Plugin.ModuleName ->
  Map.Map Plugin.ModuleName Hs.SrcSpanAnnN ->
  [Hs.LImportDecl Hs.GhcPs] ->
  [Hs.LImportDecl Hs.GhcPs]
updateImports :: Map ModuleName ModuleName
-> Map ModuleName SrcSpanAnnN
-> [LImportDecl GhcPs]
-> [LImportDecl GhcPs]
updateImports Map ModuleName ModuleName
aliases Map ModuleName SrcSpanAnnN
want [LImportDecl GhcPs]
imports =
  let have :: Set ModuleName
have = [ModuleName] -> Set ModuleName
forall a. Ord a => [a] -> Set a
Set.fromList ([ModuleName] -> Set ModuleName) -> [ModuleName] -> Set ModuleName
forall a b. (a -> b) -> a -> b
$ (GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs) -> ModuleName)
-> [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
-> [ModuleName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ImportDecl GhcPs -> ModuleName
ImportDecl.toModuleName (ImportDecl GhcPs -> ModuleName)
-> (GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)
    -> ImportDecl GhcPs)
-> GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)
-> ModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)
-> ImportDecl GhcPs
forall l e. GenLocated l e -> e
Plugin.unLoc) [LImportDecl GhcPs]
[GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
imports
      need :: [(ModuleName, SrcSpanAnnN)]
need = Map ModuleName SrcSpanAnnN -> [(ModuleName, SrcSpanAnnN)]
forall k a. Map k a -> [(k, a)]
Map.toList (Map ModuleName SrcSpanAnnN -> [(ModuleName, SrcSpanAnnN)])
-> Map ModuleName SrcSpanAnnN -> [(ModuleName, SrcSpanAnnN)]
forall a b. (a -> b) -> a -> b
$ Map ModuleName SrcSpanAnnN
-> Set ModuleName -> Map ModuleName SrcSpanAnnN
forall k a. Ord k => Map k a -> Set k -> Map k a
Map.withoutKeys Map ModuleName SrcSpanAnnN
want Set ModuleName
have
   in [LImportDecl GhcPs]
[GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
imports [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
-> [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
-> [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
forall a. Semigroup a => a -> a -> a
<> ((ModuleName, SrcSpanAnnN)
 -> GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs))
-> [(ModuleName, SrcSpanAnnN)]
-> [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(ModuleName
m, SrcSpanAnnN
l) -> SrcAnn AnnListItem
-> ImportDecl GhcPs
-> GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)
forall l e. l -> e -> GenLocated l e
Plugin.L (SrcSpanAnnN -> SrcAnn AnnListItem
forall a ann. SrcSpanAnn' a -> SrcAnn ann
Hs.na2la SrcSpanAnnN
l) (ImportDecl GhcPs
 -> GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs))
-> ImportDecl GhcPs
-> GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)
forall a b. (a -> b) -> a -> b
$ Map ModuleName ModuleName -> ModuleName -> ImportDecl GhcPs
createImport Map ModuleName ModuleName
aliases ModuleName
m) [(ModuleName, SrcSpanAnnN)]
need

createImport ::
  Map.Map Plugin.ModuleName Plugin.ModuleName ->
  Plugin.ModuleName ->
  Hs.ImportDecl Hs.GhcPs
createImport :: Map ModuleName ModuleName -> ModuleName -> ImportDecl GhcPs
createImport Map ModuleName ModuleName
aliases ModuleName
target =
  let source :: ModuleName
source = ModuleName -> ModuleName -> Map ModuleName ModuleName -> ModuleName
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault ModuleName
target ModuleName
target Map ModuleName ModuleName
aliases
   in (ModuleName -> ImportDecl GhcPs
Ghc.newImportDecl ModuleName
source)
        { Hs.ideclAs =
            if source == target
              then Nothing
              else Just $ Hs.noLocA target
        }