{-# LANGUAGE CPP #-}
module GhcUtil (withGhc) where

import           GHC.Paths (libdir)
#if __GLASGOW_HASKELL__ < 707
import           Control.Exception
import           GHC hiding (flags)
import           DynFlags (dopt_set)
#else
import           GHC
#if __GLASGOW_HASKELL__ < 900
import           DynFlags (gopt_set)
#else
import           GHC.Driver.Session (gopt_set)
#endif
#endif

#if __GLASGOW_HASKELL__ < 900
import           Panic (throwGhcException)
#else
import           GHC.Utils.Panic (throwGhcException)
#endif

#if __GLASGOW_HASKELL__ < 900
import           MonadUtils (liftIO)
#else
import           GHC.Utils.Monad (liftIO)
#endif

import           System.Exit (exitFailure)

#if __GLASGOW_HASKELL__ < 702
import           StaticFlags (v_opt_C_ready)
import           Data.IORef (writeIORef)
#elif __GLASGOW_HASKELL__ < 707
import           StaticFlags (saveStaticFlagGlobals, restoreStaticFlagGlobals)
#elif __GLASGOW_HASKELL__ < 801
import           StaticFlags (discardStaticFlags)
#endif


-- | Save static flag globals, run action, and restore them.
bracketStaticFlags :: IO a -> IO a
#if __GLASGOW_HASKELL__ < 702
-- GHC < 7.2 does not provide saveStaticFlagGlobals/restoreStaticFlagGlobals,
-- so we need to modifying v_opt_C_ready directly
bracketStaticFlags action = action `finally` writeIORef v_opt_C_ready False
#elif __GLASGOW_HASKELL__ < 707
bracketStaticFlags action = bracket saveStaticFlagGlobals restoreStaticFlagGlobals (const action)
#else
bracketStaticFlags :: forall a. IO a -> IO a
bracketStaticFlags IO a
action = IO a
action
#endif

-- Catch GHC source errors, print them and exit.
handleSrcErrors :: Ghc a -> Ghc a
handleSrcErrors :: forall a. Ghc a -> Ghc a
handleSrcErrors Ghc a
action' = ((SourceError -> Ghc a) -> Ghc a -> Ghc a)
-> Ghc a -> (SourceError -> Ghc a) -> Ghc a
forall a b c. (a -> b -> c) -> b -> a -> c
flip (SourceError -> Ghc a) -> Ghc a -> Ghc a
forall (m :: * -> *) a.
MonadCatch m =>
(SourceError -> m a) -> m a -> m a
handleSourceError Ghc a
action' ((SourceError -> Ghc a) -> Ghc a)
-> (SourceError -> Ghc a) -> Ghc a
forall a b. (a -> b) -> a -> b
$ \SourceError
err -> do
#if __GLASGOW_HASKELL__ < 702
  printExceptionAndWarnings err
#else
  SourceError -> Ghc ()
forall (m :: * -> *). GhcMonad m => SourceError -> m ()
printException SourceError
err
#endif
  IO a -> Ghc a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO a
forall a. IO a
exitFailure

-- | Run a GHC action in Haddock mode
withGhc :: [String] -> ([String] -> Ghc a) -> IO a
withGhc :: forall a. [String] -> ([String] -> Ghc a) -> IO a
withGhc [String]
flags [String] -> Ghc a
action = IO a -> IO a
forall a. IO a -> IO a
bracketStaticFlags (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
  [Located String]
flags_ <- [String] -> IO [Located String]
handleStaticFlags [String]
flags

  Maybe String -> Ghc a -> IO a
forall a. Maybe String -> Ghc a -> IO a
runGhc (String -> Maybe String
forall a. a -> Maybe a
Just String
libdir) (Ghc a -> IO a) -> Ghc a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    [Located String] -> Ghc [String]
forall (m :: * -> *). GhcMonad m => [Located String] -> m [String]
handleDynamicFlags [Located String]
flags_ Ghc [String] -> ([String] -> Ghc a) -> Ghc a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ghc a -> Ghc a
forall a. Ghc a -> Ghc a
handleSrcErrors (Ghc a -> Ghc a) -> ([String] -> Ghc a) -> [String] -> Ghc a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> Ghc a
action

handleStaticFlags :: [String] -> IO [Located String]
#if __GLASGOW_HASKELL__ < 707
handleStaticFlags flags = fst `fmap` parseStaticFlags (map noLoc flags)
#elif __GLASGOW_HASKELL__ < 801
handleStaticFlags flags = return $ map noLoc $ discardStaticFlags flags
#else
handleStaticFlags :: [String] -> IO [Located String]
handleStaticFlags [String]
flags = [Located String] -> IO [Located String]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Located String] -> IO [Located String])
-> [Located String] -> IO [Located String]
forall a b. (a -> b) -> a -> b
$ (String -> Located String) -> [String] -> [Located String]
forall a b. (a -> b) -> [a] -> [b]
map String -> Located String
forall e. e -> Located e
noLoc ([String] -> [Located String]) -> [String] -> [Located String]
forall a b. (a -> b) -> a -> b
$ [String]
flags
#endif

handleDynamicFlags :: GhcMonad m => [Located String] -> m [String]
handleDynamicFlags :: forall (m :: * -> *). GhcMonad m => [Located String] -> m [String]
handleDynamicFlags [Located String]
flags = do
#if __GLASGOW_HASKELL__ >= 901
  Logger
logger <- m Logger
forall (m :: * -> *). HasLogger m => m Logger
getLogger
  let parseDynamicFlags' :: DynFlags
-> [Located String] -> m (DynFlags, [Located String], [Warn])
parseDynamicFlags' = Logger
-> DynFlags
-> [Located String]
-> m (DynFlags, [Located String], [Warn])
forall (m :: * -> *).
MonadIO m =>
Logger
-> DynFlags
-> [Located String]
-> m (DynFlags, [Located String], [Warn])
parseDynamicFlags Logger
logger
#else
  let parseDynamicFlags' = parseDynamicFlags
#endif
  (DynFlags
dynflags, [Located String]
locSrcs, [Warn]
_) <- (DynFlags -> DynFlags
setHaddockMode (DynFlags -> DynFlags) -> m DynFlags -> m DynFlags
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` m DynFlags
forall (m :: * -> *). GhcMonad m => m DynFlags
getSessionDynFlags) m DynFlags
-> (DynFlags -> m (DynFlags, [Located String], [Warn]))
-> m (DynFlags, [Located String], [Warn])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (DynFlags
-> [Located String] -> m (DynFlags, [Located String], [Warn])
`parseDynamicFlags'` [Located String]
flags)
  ()
_ <- DynFlags -> m ()
forall (m :: * -> *). GhcMonad m => DynFlags -> m ()
setSessionDynFlags DynFlags
dynflags

  -- We basically do the same thing as `ghc/Main.hs` to distinguish
  -- "unrecognised flags" from source files.
  let srcs :: [String]
srcs = (Located String -> String) -> [Located String] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Located String -> String
forall l e. GenLocated l e -> e
unLoc [Located String]
locSrcs
      unknown_opts :: [String]
unknown_opts = [ String
f | f :: String
f@(Char
'-':String
_) <- [String]
srcs ]
  case [String]
unknown_opts of
    String
opt : [String]
_ -> GhcException -> m [String]
forall a. GhcException -> a
throwGhcException (String -> GhcException
UsageError (String
"unrecognized option `"String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
opt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"))
    [String]
_       -> [String] -> m [String]
forall (m :: * -> *) a. Monad m => a -> m a
return [String]
srcs

setHaddockMode :: DynFlags -> DynFlags
#if __GLASGOW_HASKELL__ < 707
setHaddockMode dynflags = (dopt_set dynflags Opt_Haddock) {
#else
setHaddockMode :: DynFlags -> DynFlags
setHaddockMode DynFlags
dynflags = (DynFlags -> GeneralFlag -> DynFlags
gopt_set DynFlags
dynflags GeneralFlag
Opt_Haddock) {
#endif
#if __GLASGOW_HASKELL__ >= 901
      backend :: Backend
backend   = Backend
NoBackend
#else
      hscTarget = HscNothing
#endif
    , ghcMode :: GhcMode
ghcMode   = GhcMode
CompManager
    , ghcLink :: GhcLink
ghcLink   = GhcLink
NoLink
    }