{-# LANGUAGE TypeApplications #-}

module Imp where

import qualified Control.Monad.Catch as Exception
import qualified Data.Data as Data
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified GHC.Hs as Hs
import qualified GHC.Plugins as Plugin
import qualified Imp.Exception.ShowHelp as ShowHelp
import qualified Imp.Exception.ShowVersion as ShowVersion
import qualified Imp.Extra.Exception as Exception
import qualified Imp.Extra.HsModule as HsModule
import qualified Imp.Extra.HsParsedModule as HsParsedModule
import qualified Imp.Extra.ImportDecl as ImportDecl
import qualified Imp.Extra.ParsedResult as ParsedResult
import qualified Imp.Ghc as Ghc
import qualified Imp.Type.Config as Config
import qualified Imp.Type.Context as Context
import qualified Imp.Type.Flag as Flag
import qualified System.Exit as Exit
import qualified System.IO as IO

plugin :: Plugin.Plugin
plugin :: Plugin
plugin =
  Plugin
Plugin.defaultPlugin
    { Plugin.parsedResultAction = parsedResultAction,
      Plugin.pluginRecompile = Plugin.purePlugin
    }

parsedResultAction ::
  [Plugin.CommandLineOption] ->
  modSummary ->
  Plugin.ParsedResult ->
  Plugin.Hsc Plugin.ParsedResult
parsedResultAction :: forall modSummary.
[CommandLineOption]
-> modSummary -> ParsedResult -> Hsc ParsedResult
parsedResultAction [CommandLineOption]
commandLineOptions modSummary
_ =
  IO ParsedResult -> Hsc ParsedResult
forall a. IO a -> Hsc a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
Plugin.liftIO
    (IO ParsedResult -> Hsc ParsedResult)
-> (ParsedResult -> IO ParsedResult)
-> ParsedResult
-> Hsc ParsedResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SomeException -> IO ParsedResult)
-> IO ParsedResult -> IO ParsedResult
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
(e -> m a) -> m a -> m a
Exception.handle SomeException -> IO ParsedResult
forall a. SomeException -> IO a
handleException
    (IO ParsedResult -> IO ParsedResult)
-> (ParsedResult -> IO ParsedResult)
-> ParsedResult
-> IO ParsedResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HsParsedModule -> IO HsParsedModule)
-> ParsedResult -> IO ParsedResult
forall (f :: * -> *).
Functor f =>
(HsParsedModule -> f HsParsedModule)
-> ParsedResult -> f ParsedResult
ParsedResult.overModule ((Located HsModulePs -> IO (Located HsModulePs))
-> HsParsedModule -> IO HsParsedModule
forall (f :: * -> *).
Functor f =>
(Located HsModulePs -> f (Located HsModulePs))
-> HsParsedModule -> f HsParsedModule
HsParsedModule.overModule ((Located HsModulePs -> IO (Located HsModulePs))
 -> HsParsedModule -> IO HsParsedModule)
-> (Located HsModulePs -> IO (Located HsModulePs))
-> HsParsedModule
-> IO HsParsedModule
forall a b. (a -> b) -> a -> b
$ [CommandLineOption]
-> Located HsModulePs -> IO (Located HsModulePs)
forall (m :: * -> *).
MonadThrow m =>
[CommandLineOption] -> Located HsModulePs -> m (Located HsModulePs)
imp [CommandLineOption]
commandLineOptions)

handleException :: Exception.SomeException -> IO a
handleException :: forall a. SomeException -> IO a
handleException SomeException
e = do
  Handle -> CommandLineOption -> IO ()
IO.hPutStrLn Handle
IO.stderr (CommandLineOption -> IO ()) -> CommandLineOption -> IO ()
forall a b. (a -> b) -> a -> b
$ SomeException -> CommandLineOption
forall e. Exception e => e -> CommandLineOption
Exception.displayException SomeException
e
  ExitCode -> IO a
forall a. ExitCode -> IO a
Exit.exitWith (ExitCode -> IO a) -> ExitCode -> IO a
forall a b. (a -> b) -> a -> b
$ SomeException -> ExitCode
exceptionToExitCode SomeException
e

exceptionToExitCode :: Exception.SomeException -> Exit.ExitCode
exceptionToExitCode :: SomeException -> ExitCode
exceptionToExitCode SomeException
e
  | forall e. Exception e => SomeException -> Bool
Exception.isType @ShowHelp.ShowHelp SomeException
e = ExitCode
Exit.ExitSuccess
  | forall e. Exception e => SomeException -> Bool
Exception.isType @ShowVersion.ShowVersion SomeException
e = ExitCode
Exit.ExitSuccess
  | Bool
otherwise = Int -> ExitCode
Exit.ExitFailure Int
1

imp ::
  (Exception.MonadThrow m) =>
  [String] ->
  Plugin.Located Ghc.HsModulePs ->
  m (Plugin.Located Ghc.HsModulePs)
imp :: forall (m :: * -> *).
MonadThrow m =>
[CommandLineOption] -> Located HsModulePs -> m (Located HsModulePs)
imp [CommandLineOption]
arguments Located HsModulePs
lHsModule = do
  [Flag]
flags <- [CommandLineOption] -> m [Flag]
forall (m :: * -> *).
MonadThrow m =>
[CommandLineOption] -> m [Flag]
Flag.fromArguments [CommandLineOption]
arguments
  Config
config <- [Flag] -> m Config
forall (m :: * -> *). MonadThrow m => [Flag] -> m Config
Config.fromFlags [Flag]
flags
  Context
context <- Config -> m Context
forall (m :: * -> *). MonadThrow m => Config -> m Context
Context.fromConfig Config
config
  let aliases :: Map ModuleName ModuleName
aliases = Context -> Map ModuleName ModuleName
Context.aliases Context
context
      moduleNames :: Set ModuleName
moduleNames =
        forall a. Ord a => [a] -> Set a
Set.fromList @Plugin.ModuleName
          ([ModuleName] -> Set ModuleName)
-> (HsModulePs -> [ModuleName]) -> HsModulePs -> Set ModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [GenLocated SrcSpanAnnA (HsDecl GhcPs)] -> [ModuleName]
forall a b. (Data a, Data b) => a -> [b]
biplate
          ([GenLocated SrcSpanAnnA (HsDecl GhcPs)] -> [ModuleName])
-> (HsModulePs -> [GenLocated SrcSpanAnnA (HsDecl GhcPs)])
-> HsModulePs
-> [ModuleName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsModulePs -> [LHsDecl GhcPs]
HsModulePs -> [GenLocated SrcSpanAnnA (HsDecl GhcPs)]
forall p. HsModule p -> [LHsDecl p]
Hs.hsmodDecls
          (HsModulePs -> Set ModuleName) -> HsModulePs -> Set ModuleName
forall a b. (a -> b) -> a -> b
$ Located HsModulePs -> HsModulePs
forall l e. GenLocated l e -> e
Plugin.unLoc Located HsModulePs
lHsModule
  Located HsModulePs -> m (Located HsModulePs)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Located HsModulePs -> m (Located HsModulePs))
-> Located HsModulePs -> m (Located HsModulePs)
forall a b. (a -> b) -> a -> b
$ (HsModulePs -> HsModulePs)
-> Located HsModulePs -> Located HsModulePs
forall a b.
(a -> b) -> GenLocated SrcSpan a -> GenLocated SrcSpan b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([LImportDecl GhcPs] -> [LImportDecl GhcPs])
-> HsModulePs -> HsModulePs
HsModule.overImports (([LImportDecl GhcPs] -> [LImportDecl GhcPs])
 -> HsModulePs -> HsModulePs)
-> ([LImportDecl GhcPs] -> [LImportDecl GhcPs])
-> HsModulePs
-> HsModulePs
forall a b. (a -> b) -> a -> b
$ Map ModuleName ModuleName
-> Set ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
updateImports Map ModuleName ModuleName
aliases Set ModuleName
moduleNames) Located HsModulePs
lHsModule

biplate :: (Data.Data a, Data.Data b) => a -> [b]
biplate :: forall a b. (Data a, Data b) => a -> [b]
biplate = [[b]] -> [b]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[b]] -> [b]) -> (a -> [[b]]) -> a -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall d. Data d => d -> [b]) -> a -> [[b]]
forall a u. Data a => (forall d. Data d => d -> u) -> a -> [u]
forall u. (forall d. Data d => d -> u) -> a -> [u]
Data.gmapQ (\d
d -> [b] -> (b -> [b]) -> Maybe b -> [b]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (d -> [b]
forall a b. (Data a, Data b) => a -> [b]
biplate d
d) b -> [b]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe b -> [b]) -> Maybe b -> [b]
forall a b. (a -> b) -> a -> b
$ d -> Maybe b
forall a b. (Typeable a, Typeable b) => a -> Maybe b
Data.cast d
d)

updateImports ::
  Map.Map Plugin.ModuleName Plugin.ModuleName ->
  Set.Set Plugin.ModuleName ->
  [Hs.LImportDecl Hs.GhcPs] ->
  [Hs.LImportDecl Hs.GhcPs]
updateImports :: Map ModuleName ModuleName
-> Set ModuleName -> [LImportDecl GhcPs] -> [LImportDecl GhcPs]
updateImports Map ModuleName ModuleName
aliases Set ModuleName
want [LImportDecl GhcPs]
imports =
  let have :: Set ModuleName
have = [ModuleName] -> Set ModuleName
forall a. Ord a => [a] -> Set a
Set.fromList ([ModuleName] -> Set ModuleName) -> [ModuleName] -> Set ModuleName
forall a b. (a -> b) -> a -> b
$ (LocatedAn AnnListItem (ImportDecl GhcPs) -> ModuleName)
-> [LocatedAn AnnListItem (ImportDecl GhcPs)] -> [ModuleName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ImportDecl GhcPs -> ModuleName
ImportDecl.toModuleName (ImportDecl GhcPs -> ModuleName)
-> (LocatedAn AnnListItem (ImportDecl GhcPs) -> ImportDecl GhcPs)
-> LocatedAn AnnListItem (ImportDecl GhcPs)
-> ModuleName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LocatedAn AnnListItem (ImportDecl GhcPs) -> ImportDecl GhcPs
forall l e. GenLocated l e -> e
Plugin.unLoc) [LImportDecl GhcPs]
[LocatedAn AnnListItem (ImportDecl GhcPs)]
imports
      need :: [ModuleName]
need = Set ModuleName -> [ModuleName]
forall a. Set a -> [a]
Set.toList (Set ModuleName -> [ModuleName]) -> Set ModuleName -> [ModuleName]
forall a b. (a -> b) -> a -> b
$ Set ModuleName -> Set ModuleName -> Set ModuleName
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set ModuleName
want Set ModuleName
have
   in [LImportDecl GhcPs]
[LocatedAn AnnListItem (ImportDecl GhcPs)]
imports [LocatedAn AnnListItem (ImportDecl GhcPs)]
-> [LocatedAn AnnListItem (ImportDecl GhcPs)]
-> [LocatedAn AnnListItem (ImportDecl GhcPs)]
forall a. Semigroup a => a -> a -> a
<> (ModuleName -> LocatedAn AnnListItem (ImportDecl GhcPs))
-> [ModuleName] -> [LocatedAn AnnListItem (ImportDecl GhcPs)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ImportDecl GhcPs -> LocatedAn AnnListItem (ImportDecl GhcPs)
forall a an. a -> LocatedAn an a
Hs.noLocA (ImportDecl GhcPs -> LocatedAn AnnListItem (ImportDecl GhcPs))
-> (ModuleName -> ImportDecl GhcPs)
-> ModuleName
-> LocatedAn AnnListItem (ImportDecl GhcPs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ModuleName ModuleName -> ModuleName -> ImportDecl GhcPs
createImport Map ModuleName ModuleName
aliases) [ModuleName]
need

createImport ::
  Map.Map Plugin.ModuleName Plugin.ModuleName ->
  Plugin.ModuleName ->
  Hs.ImportDecl Hs.GhcPs
createImport :: Map ModuleName ModuleName -> ModuleName -> ImportDecl GhcPs
createImport Map ModuleName ModuleName
aliases ModuleName
target =
  let source :: ModuleName
source = ModuleName -> ModuleName -> Map ModuleName ModuleName -> ModuleName
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault ModuleName
target ModuleName
target Map ModuleName ModuleName
aliases
   in (ModuleName -> ImportDecl GhcPs
Ghc.newImportDecl ModuleName
source)
        { Hs.ideclAs =
            if source == target
              then Nothing
              else Just $ Hs.noLocA target
        }