{-# LANGUAGE UndecidableInstances, TypeFamilies #-}
module Control.Monad.Log
    ( MonadLog(..)
    , Level(..)
    , LoggingConf(..)
    , Logged(..)
    , LIO
    , withLogging
    , withLogging_
    , logOptions
    , execWithParser
    , execWithParser_
    , PanicCall(..)
    , panic
    ) where

import BasePrelude           hiding ( try, catchIOError )
import Control.Monad.Base           ( MonadBase(..) )
import Control.Monad.Catch
import Control.Monad.Primitive
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import Control.Monad.Trans.Reader
import Control.Monad.Trans.RWS.Strict ( RWST )
import GitVersion                   ( gitFullVersion )
import Options.Applicative
import Paths_biohazard              ( version )
import Streaming
import System.IO                    ( hPutStr, hPutStrLn, hFlush, stderr, openFile, IOMode(..) )

import qualified Data.Vector                 as V

-- | Severity levels for logging.
data Level = Debug      -- ^ Message only useful for debugging.  Typically ignored.
           | Info       -- ^ Purely informative message, e.g. progress reports.  Sometimes printed.
           | Notice     -- ^ Something remarkable, but harmless.  Sometimes printed, but not collected.
           | Warning    -- ^ Something unexpected, but usually not a problem.  Typically printed, but not collected.
           | Error      -- ^ Recoverable error, will normally result in `ExitFailure 1`.  Printed and collected.
    deriving ( Show, Eq, Ord, Enum, Bounded, Ix )

color_coded :: Level -> String -> String
color_coded Debug   s = "\27[90m"   ++ s ++ "\27[0m"        -- gray
color_coded Info    s = "\27[34m"   ++ s ++ "\27[0m"        -- blue
color_coded Notice  s = "\27[32;1m" ++ s ++ "\27[0m"        -- bold green
color_coded Warning s = "\27[33m"   ++ s ++ "\27[0m"        -- yellow
color_coded Error   s = "\27[31;1m" ++ s ++ "\27[0m"        -- bold red

-- | Monads in which messages can be logged.  Any 'Exception' can be
-- logged; it is reported and/or collected, but does not abort any
-- computation.
class Monad m => MonadLog m where
    -- | Logs a message at a given level.  Depending on settings, the
    -- message may be printed and/or stored.
    logMsg :: Exception e => Level -> e -> m ()

    -- | Updates the progress indicator.  The message should not contain
    -- line feeds, as it is intended to fit on one line and be
    -- overwritten repeatedly.
    logString_ :: String -> m ()

    -- | Prints a progress indication.  The message should persist on
    -- the user's terminal.
    logStringLn :: String -> m ()

instance (MonadLog m, Monoid w) => MonadLog (RWST r w s m) where
    logMsg    l e = lift (logMsg    l e)
    logString_  e = lift (logString_  e)
    logStringLn e = lift (logStringLn e)


-- | Adds logging to any 'MonadIO' type.  Warnings are printed
-- to stderr immediately, but we remember whether any were emitted.  If
-- so, we exit with an error code.  The advantage over @WarningT IO@ is
-- that the warnings are tracked even if the computation exits with an
-- exception.  Progress indicators are sent to the controlling terminal,
-- and dicarded if none exists.
newtype Logged m a = Logged { runLogged :: ReaderT (LoggingConf, Journal) m a }
  deriving ( Functor, Applicative, Alternative, Monad, MonadTrans, MonadIO, MonadThrow, MonadCatch, MonadMask, MFunctor )

instance MonadTransControl Logged where
    type StT Logged a = StT (ReaderT (LoggingConf, Journal)) a
    liftWith = defaultLiftWith Logged runLogged
    restoreT = defaultRestoreT Logged

instance MonadBase b m => MonadBase b (Logged m) where
    liftBase = lift . liftBase

instance MonadBaseControl b m => MonadBaseControl b (Logged m) where
    type StM (Logged m) a = StM (ReaderT (LoggingConf, Journal) m) a
    liftBaseWith f        = defaultLiftBaseWith f
    restoreM              = defaultRestoreM

instance PrimMonad m => PrimMonad (Logged m) where
    type PrimState (Logged m) = PrimState m
    primitive                 = lift . primitive

type LIO = Logged IO

data LoggingConf = LoggingConf
    { reporting_level :: Level      -- ^ minimum 'Level' to print a message
    , logging_level   :: Level      -- ^ minimum 'Level' to remember a message
    , error_level     :: Level      -- ^ minimum 'Level' that results in a call to 'exitFailure'
    , max_log_size    :: Int        -- ^ number of messages to keep at any given level
    , want_progress   :: Bool }
  deriving Show

data Journal = Journal
    { logged_messages :: V.Vector (IORef [SomeException])     -- ^ collected messages per level
    , num_messages    :: V.Vector (IORef Int)                 -- ^ number of collected messages per level
    , error_exit      :: IORef Bool
    , cterminal       :: Maybe Handle
    , spinner         :: IORef String }

instance MonadIO m => MonadLog (Logged m) where
    logMsg lv e = Logged $ ReaderT $ \(LoggingConf{..},Journal{..}) -> do
        when (lv >= reporting_level) $ liftIO $ do
            -- clear spinner
            forM_ cterminal $ \h -> tryIO $ hPutStr h "\r\27[K" >> hFlush h
            pn <- getProgName
            hPutStrLn stderr $ color_coded lv $ printf "%s: [%s] %s" pn (show lv) (displayException e)
            hFlush stderr
            -- restore spinner
            forM_ cterminal $ \h -> readIORef spinner >>= \s ->
                hPutStr h ("\27[?7l" ++ s ++ "\27[?7h") >> hFlush h
        when (lv >= logging_level) $ liftIO $
            atomicModifyIORef' (num_messages V.! fromEnum lv)
                (\num -> if num < max_log_size then (succ num, True) else (num, False)) >>=
            flip when (atomicModifyIORef (logged_messages V.! fromEnum lv)
                (\es -> (toException e : es, ())))
        when (lv >= error_level) $ liftIO $
            atomicWriteIORef error_exit True

    logString_ m = Logged $ ReaderT $ \(LoggingConf{..},Journal{..}) ->
        liftIO $ forM_ cterminal $ \h -> do
            pn <- getProgName
            let s = if null m then m else pn ++ ": " ++ m
            writeIORef spinner s
            tryIO $ hPutStr h ("\r\27[K\27[?7l" ++ s ++ "\27[?7h") >> hFlush h

    logStringLn m = Logged $ ReaderT $ \(LoggingConf{..},Journal{..}) ->
        liftIO $ forM_ cterminal $ \h -> do
            s <- readIORef spinner
            tryIO $ hPutStr h ("\r\27[K" ++ m ++ "\n\27[?7l" ++ s ++ "\27[?7h") >> hFlush h


withLogging_ :: (MonadIO m, MonadMask m) => LoggingConf -> Logged m a -> m a
withLogging_ conf = withLogging conf >=> either (liftIO . exitWith) pure

withLogging :: (MonadIO m, MonadMask m) => LoggingConf -> Logged m a -> m (Either ExitCode a)
withLogging conf (Logged k) = do
    journal <- let n = fromEnum (maxBound :: Level) - fromEnum (minBound :: Level) + 1
               in liftIO $ Journal <$> V.replicateM n (newIORef [])
                                   <*> V.replicateM n (newIORef 0)
                                   <*> newIORef False
                                   <*> bool (pure Nothing) (tryIO $ openFile "/dev/tty" WriteMode) (want_progress conf)
                                   <*> newIORef []

    r  <- try $ runReaderT k (conf,journal)
    liftIO $ do
        ws  <- V.mapM readIORef (logged_messages journal)
        nws <- V.mapM readIORef (num_messages journal)
        pn  <- getProgName
        forM_ (cterminal journal) $ \h -> do
            s <- readIORef (spinner journal)
            tryIO $ unless (null s) (hPutStrLn h []) >> hClose h

        do let eff_warnings  =     [ (l,e) | l <- [minBound ..], l < error_level conf,     e <- ws V.! fromEnum l ]
               neff_warnings = sum [   n   | l <- [minBound ..], l < error_level conf, let n = nws V.! fromEnum l ]
           unless (neff_warnings == 0) $ do
               hPrintf stderr "%s: there were %d warnings\n" pn neff_warnings
               forM_ eff_warnings $ \(l,e) -> hPutStrLn stderr . color_coded l $ displayException e
               unless (neff_warnings - length eff_warnings <= 0 || null eff_warnings) $
                   hPrintf stderr "(and %d more)\n" (neff_warnings - length eff_warnings)

        do let eff_errors    =     [ (l,e) | l <- [error_level conf ..],                   e <- ws V.! fromEnum l ]
               neff_errors   = sum [     n | l <- [error_level conf ..],               let n = nws V.! fromEnum l ]
           unless (null eff_errors) $ do
               hPrintf stderr "%s: there were %d (non-catastrophic) errors\n" pn neff_errors
               forM_ eff_errors $ \(l,e) -> hPutStrLn stderr . color_coded l $ displayException e
               unless (neff_errors - length eff_errors <= 0 || null eff_errors) $
                   hPrintf stderr "(and %d more)\n" (neff_errors - length eff_errors)

        case r of
          Left  e -> do case fromException e of
                            Just UserInterrupt -> hPutStrLn stderr $ pn ++ ": Interrupted"
                            _                  -> hPutStrLn stderr $ pn ++ ": catastrophic error: " ++ displayException e
                        return . Left $ ExitFailure 2

          Right x -> bool (Right x) (Left $ ExitFailure 1) <$> readIORef (error_exit journal)


-- | General wrapper around main.  Runs a command line parser with added
-- standard options (logging and usage related), runs the actual main
-- function, prints collected warnings and caught exceptions, and exits
-- appropriately:  `exitWith (ExitFailure 2)` if an exception was
-- caught, `exitFailure` if there were warnings of sufficient severity,
-- and `exitSuccess` otherwise.

execWithParser_ :: Parser a -> Maybe Version -> Maybe String -> InfoMod (a,LoggingConf) -> (a -> LIO b) -> IO b
execWithParser_ opts prog_ver prog_git_ver inf =
    execWithParser opts prog_ver prog_git_ver inf >=> either exitWith pure

execWithParser :: Parser a -> Maybe Version -> Maybe String -> InfoMod (a,LoggingConf)
               -> (a -> LIO b) -> IO (Either ExitCode b)
execWithParser opts prog_ver prog_git_ver inf k = do
    pn <- getProgName
    let verStr = printf "%s%s (%s) using biohazard-%s (%s)" pn
                        (maybe "" (('-':) . showVersion) prog_ver) (fromMaybe "release" prog_git_ver)
                        (showVersion version) (fromMaybe "release" gitFullVersion)
        verOpt = infoOption verStr (short 'V' <> long "version" <> help "Print version number and exit")
    (a,cf) <- execParser $ info ((,) <$> opts <*> logOptions <* verOpt <* helper) inf
    withLogging cf (k a)

logOptions :: Parser LoggingConf
logOptions =
    LoggingConf
    <$> (foldl (&) Notice <$> many
            (flag' more (long "quiet" <> help "Print only important messages") <|>
             flag' less (long "verbose" <> help "Print also trivial messages")))

    <*> (foldl (&) Warning <$> many
            (flag' more (long "drop-errors" <> help "Remember only critical messages") <|>
             flag' less (long "keep-warnings" <> help "Remember also minor messages")))

    <*> (foldl (&) Error <$> many
            (flag' more (long "warn-ignore" <> help "Fail only after critical errors") <|>
             flag' less (long "warn-error" <> help "Fail also after warnings")))

    <*> option auto (long "journal-size" <> metavar "NUM" <> help "Hold up to NUM errors in memory" <> value 20)
    <*> switch (long "progress" <> help "Print progress reports to the terminal")
  where
    more, less :: (Enum a, Bounded a, Eq a) => a -> a
    more a = if a == maxBound then a else succ a
    less a = if a == minBound then a else pred a


-- | An exception than can be thrown when it doesn't seem warranted to
-- define a custom exception.  Transports a message.
data PanicCall = PanicCall String deriving (Typeable, Show)
instance Exception PanicCall where displayException (PanicCall msg) = msg

panic :: MonadIO m => String -> m a
panic = liftIO . throwIO . PanicCall

tryIO :: IO k -> IO (Maybe k)
tryIO k = catchIOError (Just <$> k) (\_ -> pure Nothing)