module Snap.Internal.Http.Server.Config
  
  ( ConfigLog(..)
  , Config(..)
  , ProxyType(..)
  , emptyConfig
  , defaultConfig
  , commandLineConfig
  , extendedCommandLineConfig
  , completeConfig
  , optDescrs
  , fmapOpt
  , getAccessLog
  , getBind
  , getCompression
  , getDefaultTimeout
  , getErrorHandler
  , getErrorLog
  , getHostname
  , getLocale
  , getOther
  , getPort
  , getProxyType
  , getSSLBind
  , getSSLCert
  , getSSLChainCert
  , getSSLKey
  , getSSLPort
  , getVerbose
  , getStartupHook
  , getUnixSocket
  , getUnixSocketAccessMode
  , setAccessLog
  , setBind
  , setCompression
  , setDefaultTimeout
  , setErrorHandler
  , setErrorLog
  , setHostname
  , setLocale
  , setOther
  , setPort
  , setProxyType
  , setSSLBind
  , setSSLCert
  , setSSLChainCert
  , setSSLKey
  , setSSLPort
  , setVerbose
  , setUnixSocket
  , setUnixSocketAccessMode
  , setStartupHook
  , StartupInfo(..)
  , getStartupSockets
  , getStartupConfig
  
  , emptyStartupInfo
  , setStartupSockets
  , setStartupConfig
  ) where
import           Control.Exception          (SomeException)
import           Control.Monad              (when)
import           Data.Bits                  ((.&.))
import           Data.ByteString            (ByteString)
import qualified Data.ByteString.Char8      as S
import qualified Data.ByteString.Lazy.Char8 as L
import qualified Data.CaseInsensitive       as CI
import           Data.Function              (on)
import           Data.List                  (foldl')
import qualified Data.Map                   as Map
import           Data.Maybe                 (isJust, isNothing)
#if !MIN_VERSION_base(4,8,0)
import           Data.Monoid                (Monoid (..))
#endif
import           Data.Monoid                (Last (Last, getLast))
import qualified Data.Text                  as T
import qualified Data.Text.Encoding         as T
#if MIN_VERSION_base(4,7,0)
import           Data.Typeable              (Typeable)
#else
import           Data.Typeable              (TyCon, Typeable, Typeable1 (..), mkTyCon3, mkTyConApp)
#endif
import           Network                    (Socket)
import           Numeric                    (readOct, showOct)
#if !MIN_VERSION_base(4,6,0)
import           Prelude                    hiding (catch)
#endif
import           System.Console.GetOpt      (ArgDescr (..), ArgOrder (Permute), OptDescr (..), getOpt, usageInfo)
import           System.Environment         hiding (getEnv)
#ifndef PORTABLE
import           Data.Char                  (isAlpha)
import           System.Posix.Env           (getEnv)
#endif
import           System.Exit                (exitFailure)
import           System.IO                  (hPutStrLn, stderr)
import           Data.ByteString.Builder    (Builder, byteString, stringUtf8, toLazyByteString)
import qualified System.IO.Streams          as Streams
import           Snap.Core                  (MonadSnap, Request (rqClientAddr, rqClientPort, rqParams, rqPostParams), emptyResponse, finishWith, getsRequest, logError, setContentLength, setContentType, setResponseBody, setResponseStatus)
import           Snap.Internal.Debug        (debug)
data ProxyType = NoProxy
               | HaProxy
               | X_Forwarded_For
  deriving (Show, Eq, Typeable)
data ConfigLog = ConfigNoLog                        
               | ConfigFileLog FilePath             
               | ConfigIoLog (ByteString -> IO ())  
instance Show ConfigLog where
    show ConfigNoLog       = "no log"
    show (ConfigFileLog f) = "log to file " ++ show f
    show (ConfigIoLog _)   = "custom logging handler"
data Config m a = Config
    { hostname       :: Maybe ByteString
    , accessLog      :: Maybe ConfigLog
    , errorLog       :: Maybe ConfigLog
    , locale         :: Maybe String
    , port           :: Maybe Int
    , bind           :: Maybe ByteString
    , sslport        :: Maybe Int
    , sslbind        :: Maybe ByteString
    , sslcert        :: Maybe FilePath
    , sslchaincert   :: Maybe Bool
    , sslkey         :: Maybe FilePath
    , unixsocket     :: Maybe FilePath
    , unixaccessmode :: Maybe Int
    , compression    :: Maybe Bool
    , verbose        :: Maybe Bool
    , errorHandler   :: Maybe (SomeException -> m ())
    , defaultTimeout :: Maybe Int
    , other          :: Maybe a
    , proxyType      :: Maybe ProxyType
    , startupHook    :: Maybe (StartupInfo m a -> IO ())
    }
#if MIN_VERSION_base(4,7,0)
  deriving Typeable
#else
configTyCon :: TyCon
configTyCon = mkTyCon3 "snap-server" "Snap.Http.Server.Config" "Config"
instance (Typeable1 m) => Typeable1 (Config m) where
    typeOf1 _ = mkTyConApp configTyCon [typeOf1 (undefined :: m ())]
#endif
instance Show (Config m a) where
    show c = unlines [ "Config:"
                     , "hostname: "       ++ _hostname
                     , "accessLog: "      ++ _accessLog
                     , "errorLog: "       ++ _errorLog
                     , "locale: "         ++ _locale
                     , "port: "           ++ _port
                     , "bind: "           ++ _bind
                     , "sslport: "        ++ _sslport
                     , "sslbind: "        ++ _sslbind
                     , "sslcert: "        ++ _sslcert
                     , "sslchaincert: "   ++ _sslchaincert
                     , "sslkey: "         ++ _sslkey
                     , "unixsocket: "     ++ _unixsocket
                     , "unixaccessmode: " ++ _unixaccessmode
                     , "compression: "    ++ _compression
                     , "verbose: "        ++ _verbose
                     , "defaultTimeout: " ++ _defaultTimeout
                     , "proxyType: "      ++ _proxyType
                     ]
      where
        _hostname       = show $ hostname       c
        _accessLog      = show $ accessLog      c
        _errorLog       = show $ errorLog       c
        _locale         = show $ locale         c
        _port           = show $ port           c
        _bind           = show $ bind           c
        _sslport        = show $ sslport        c
        _sslbind        = show $ sslbind        c
        _sslcert        = show $ sslcert        c
        _sslchaincert   = show $ sslchaincert   c
        _sslkey         = show $ sslkey         c
        _compression    = show $ compression    c
        _verbose        = show $ verbose        c
        _defaultTimeout = show $ defaultTimeout c
        _proxyType      = show $ proxyType      c
        _unixsocket     = show $ unixsocket     c
        _unixaccessmode = case unixaccessmode c of
                               Nothing -> "Nothing"
                               Just s -> ("Just 0" ++) . showOct s $ []
emptyConfig :: Config m a
emptyConfig = mempty
instance Monoid (Config m a) where
    mempty = Config
        { hostname       = Nothing
        , accessLog      = Nothing
        , errorLog       = Nothing
        , locale         = Nothing
        , port           = Nothing
        , bind           = Nothing
        , sslport        = Nothing
        , sslbind        = Nothing
        , sslcert        = Nothing
        , sslchaincert   = Nothing
        , sslkey         = Nothing
        , unixsocket     = Nothing
        , unixaccessmode = Nothing
        , compression    = Nothing
        , verbose        = Nothing
        , errorHandler   = Nothing
        , defaultTimeout = Nothing
        , other          = Nothing
        , proxyType      = Nothing
        , startupHook    = Nothing
        }
    a `mappend` b = Config
        { hostname       = ov hostname
        , accessLog      = ov accessLog
        , errorLog       = ov errorLog
        , locale         = ov locale
        , port           = ov port
        , bind           = ov bind
        , sslport        = ov sslport
        , sslbind        = ov sslbind
        , sslcert        = ov sslcert
        , sslchaincert   = ov sslchaincert
        , sslkey         = ov sslkey
        , unixsocket     = ov unixsocket
        , unixaccessmode = ov unixaccessmode
        , compression    = ov compression
        , verbose        = ov verbose
        , errorHandler   = ov errorHandler
        , defaultTimeout = ov defaultTimeout
        , other          = ov other
        , proxyType      = ov proxyType
        , startupHook    = ov startupHook
        }
      where
        ov :: (Config m a -> Maybe b) -> Maybe b
        ov f = getLast $! (mappend `on` (Last . f)) a b
defaultConfig :: MonadSnap m => Config m a
defaultConfig = mempty
    { hostname       = Just "localhost"
    , accessLog      = Just $ ConfigFileLog "log/access.log"
    , errorLog       = Just $ ConfigFileLog "log/error.log"
    , locale         = Just "en_US"
    , compression    = Just True
    , verbose        = Just True
    , errorHandler   = Just defaultErrorHandler
    , bind           = Just "0.0.0.0"
    , sslbind        = Nothing
    , sslcert        = Nothing
    , sslkey         = Nothing
    , sslchaincert   = Nothing
    , defaultTimeout = Just 60
    }
getHostname       :: Config m a -> Maybe ByteString
getHostname = hostname
getAccessLog      :: Config m a -> Maybe ConfigLog
getAccessLog = accessLog
getErrorLog       :: Config m a -> Maybe ConfigLog
getErrorLog = errorLog
getLocale         :: Config m a -> Maybe String
getLocale = locale
getPort           :: Config m a -> Maybe Int
getPort = port
getBind           :: Config m a -> Maybe ByteString
getBind = bind
getSSLPort        :: Config m a -> Maybe Int
getSSLPort = sslport
getSSLBind        :: Config m a -> Maybe ByteString
getSSLBind = sslbind
getSSLCert        :: Config m a -> Maybe FilePath
getSSLCert = sslcert
getSSLChainCert   :: Config m a -> Maybe Bool
getSSLChainCert = sslchaincert
getSSLKey         :: Config m a -> Maybe FilePath
getSSLKey = sslkey
getUnixSocket     :: Config m a -> Maybe FilePath
getUnixSocket = unixsocket
getUnixSocketAccessMode :: Config m a -> Maybe Int
getUnixSocketAccessMode = unixaccessmode
getCompression    :: Config m a -> Maybe Bool
getCompression = compression
getVerbose        :: Config m a -> Maybe Bool
getVerbose = verbose
getErrorHandler   :: Config m a -> Maybe (SomeException -> m ())
getErrorHandler = errorHandler
getDefaultTimeout :: Config m a -> Maybe Int
getDefaultTimeout = defaultTimeout
getOther :: Config m a -> Maybe a
getOther = other
getProxyType :: Config m a -> Maybe ProxyType
getProxyType = proxyType
getStartupHook :: Config m a -> Maybe (StartupInfo m a -> IO ())
getStartupHook = startupHook
setHostname       :: ByteString              -> Config m a -> Config m a
setHostname x c = c { hostname = Just x }
setAccessLog      :: ConfigLog               -> Config m a -> Config m a
setAccessLog x c = c { accessLog = Just x }
setErrorLog       :: ConfigLog               -> Config m a -> Config m a
setErrorLog x c = c { errorLog = Just x }
setLocale         :: String                  -> Config m a -> Config m a
setLocale x c = c { locale = Just x }
setPort           :: Int                     -> Config m a -> Config m a
setPort x c = c { port = Just x }
setBind           :: ByteString              -> Config m a -> Config m a
setBind x c = c { bind = Just x }
setSSLPort        :: Int                     -> Config m a -> Config m a
setSSLPort x c = c { sslport = Just x }
setSSLBind        :: ByteString              -> Config m a -> Config m a
setSSLBind x c = c { sslbind = Just x }
setSSLCert        :: FilePath                -> Config m a -> Config m a
setSSLCert x c = c { sslcert = Just x }
setSSLChainCert   :: Bool                    -> Config m a -> Config m a
setSSLChainCert x c = c { sslchaincert = Just x }
setSSLKey         :: FilePath                -> Config m a -> Config m a
setSSLKey x c = c { sslkey = Just x }
setUnixSocket     :: FilePath                -> Config m a -> Config m a
setUnixSocket x c = c { unixsocket = Just x }
setUnixSocketAccessMode :: Int               -> Config m a -> Config m a
setUnixSocketAccessMode p c = c { unixaccessmode = Just ( p .&. 0o777) }
setCompression    :: Bool                    -> Config m a -> Config m a
setCompression x c = c { compression = Just x }
setVerbose        :: Bool                    -> Config m a -> Config m a
setVerbose x c = c { verbose = Just x }
setErrorHandler   :: (SomeException -> m ()) -> Config m a -> Config m a
setErrorHandler x c = c { errorHandler = Just x }
setDefaultTimeout :: Int                     -> Config m a -> Config m a
setDefaultTimeout x c = c { defaultTimeout = Just x }
setOther          :: a                       -> Config m a -> Config m a
setOther x c = c { other = Just x }
setProxyType      :: ProxyType               -> Config m a -> Config m a
setProxyType x c = c { proxyType = Just x }
setStartupHook    :: (StartupInfo m a -> IO ()) -> Config m a -> Config m a
setStartupHook x c = c { startupHook = Just x }
data StartupInfo m a = StartupInfo
    { startupHookConfig  :: Config m a
    , startupHookSockets :: [Socket]
    }
emptyStartupInfo :: StartupInfo m a
emptyStartupInfo = StartupInfo emptyConfig []
getStartupSockets :: StartupInfo m a -> [Socket]
getStartupSockets = startupHookSockets
getStartupConfig :: StartupInfo m a -> Config m a
getStartupConfig = startupHookConfig
setStartupSockets :: [Socket] -> StartupInfo m a -> StartupInfo m a
setStartupSockets x c = c { startupHookSockets = x }
setStartupConfig :: Config m a -> StartupInfo m a -> StartupInfo m a
setStartupConfig x c = c { startupHookConfig = x }
completeConfig :: (MonadSnap m) => Config m a -> IO (Config m a)
completeConfig config = do
    when noPort $ hPutStrLn stderr
        "no port specified, defaulting to port 8000"
    return $! cfg `mappend` cfg'
  where
    cfg = defaultConfig `mappend` config
    sslVals = map ($ cfg) [ isJust . getSSLPort
                          , isJust . getSSLBind
                          , isJust . getSSLKey
                          , isJust . getSSLCert ]
    sslValid   = and sslVals
    unixValid  = isJust $ unixsocket cfg
    noPort = isNothing (getPort cfg) && not sslValid && not unixValid
    cfg' = emptyConfig { port = if noPort then Just 8000 else Nothing }
bsFromString :: String -> ByteString
bsFromString = T.encodeUtf8 . T.pack
toString :: ByteString -> String
toString = T.unpack . T.decodeUtf8
optDescrs :: forall m a . MonadSnap m =>
             Config m a         
          -> [OptDescr (Maybe (Config m a))]
optDescrs defaults =
    [ Option "" ["hostname"]
             (ReqArg (Just . setConfig setHostname . bsFromString) "NAME")
             $ "local hostname" ++ defaultC getHostname
    , Option "b" ["address"]
             (ReqArg (\s -> Just $ mempty { bind = Just $ bsFromString s })
                     "ADDRESS")
             $ "address to bind to" ++ defaultO bind
    , Option "p" ["port"]
             (ReqArg (\s -> Just $ mempty { port = Just $ read s}) "PORT")
             $ "port to listen on" ++ defaultO port
    , Option "" ["ssl-address"]
             (ReqArg (\s -> Just $ mempty { sslbind = Just $ bsFromString s })
                     "ADDRESS")
             $ "ssl address to bind to" ++ defaultO sslbind
    , Option "" ["ssl-port"]
             (ReqArg (\s -> Just $ mempty { sslport = Just $ read s}) "PORT")
             $ "ssl port to listen on" ++ defaultO sslport
    , Option "" ["ssl-cert"]
             (ReqArg (\s -> Just $ mempty { sslcert = Just s}) "PATH")
             $ "path to ssl certificate in PEM format" ++ defaultO sslcert
   , Option [] ["ssl-chain-cert"]
             (NoArg $ Just $ setConfig setSSLChainCert True)
             $ "certificate file contains complete certificate chain" ++ defaultB sslchaincert "site certificate only" "complete certificate chain"
    , Option [] ["no-ssl-chain-cert"]
             (NoArg $ Just $ setConfig setSSLChainCert False)
             $ "certificate file contains only the site certificate" ++ defaultB sslchaincert "site certificate only" "complete certificate chain"
    , Option [] ["ssl-key"]
             (ReqArg (\s -> Just $ mempty { sslkey = Just s}) "PATH")
             $ "path to ssl private key in PEM format" ++ defaultO sslkey
    , Option "" ["access-log"]
             (ReqArg (Just . setConfig setAccessLog . ConfigFileLog) "PATH")
             $ "access log" ++ defaultC getAccessLog
    , Option "" ["error-log"]
             (ReqArg (Just . setConfig setErrorLog . ConfigFileLog) "PATH")
             $ "error log" ++ defaultC getErrorLog
    , Option "" ["no-access-log"]
             (NoArg $ Just $ setConfig setAccessLog ConfigNoLog)
             "don't have an access log"
    , Option "" ["no-error-log"]
             (NoArg $ Just $ setConfig setErrorLog ConfigNoLog)
             "don't have an error log"
    , Option "c" ["compression"]
             (NoArg $ Just $ setConfig setCompression True)
             $ "use gzip compression on responses" ++
               defaultB getCompression "compressed" "uncompressed"
    , Option "t" ["timeout"]
             (ReqArg (\t -> Just $ mempty {
                              defaultTimeout = Just $ read t
                            }) "SECS")
             $ "set default timeout in seconds" ++ defaultC defaultTimeout
    , Option "" ["no-compression"]
             (NoArg $ Just $ setConfig setCompression False)
             $ "serve responses uncompressed" ++
               defaultB compression "compressed" "uncompressed"
    , Option "v" ["verbose"]
             (NoArg $ Just $ setConfig setVerbose True)
             $ "print server status updates to stderr" ++
               defaultC getVerbose
    , Option "q" ["quiet"]
             (NoArg $ Just $ setConfig setVerbose False)
             $ "do not print anything to stderr" ++
               defaultB getVerbose "verbose" "quiet"
    , Option "" ["proxy"]
             (ReqArg (Just . setConfig setProxyType . parseProxy . CI.mk)
                     "X_Forwarded_For")
             $ concat [ "Set --proxy=X_Forwarded_For if your snap application \n"
                      , "is behind an HTTP reverse proxy to ensure that \n"
                      , "rqClientAddr is set properly.\n"
                      , "Set --proxy=haproxy to use the haproxy protocol\n("
                      , "http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt)"
                      , defaultC getProxyType ]
    , Option "" ["unix-socket"]
             (ReqArg (Just . setConfig setUnixSocket) "PATH")
             $ concat ["Absolute path to unix socket file. "
                      , "File will be removed if already exists"]
    , Option "" ["unix-socket-mode"]
             (ReqArg (Just . setConfig setUnixSocketAccessMode . parseOctal)
                     "MODE")
             $ concat ["Access mode for unix socket in octal, for example 0760.\n"
                      ," Default is system specific."]
    , Option "h" ["help"]
             (NoArg Nothing)
             "display this help and exit"
    ]
  where
    parseProxy s | s == "NoProxy"         = NoProxy
                 | s == "X_Forwarded_For" = X_Forwarded_For
                 | s == "haproxy"         = HaProxy
                 | otherwise = error $ concat [
                         "Error (--proxy): expected one of 'NoProxy', "
                       , "'X_Forwarded_For', or 'haproxy'. Got '"
                       , CI.original s
                       , "'"
                       ]
    parseOctal s = case readOct s of
          ((v, _):_) | v >= 0 && v <= 0o777 -> v
          _ -> error $ "Error (--unix-socket-mode): expected octal access mode"
    setConfig f c  = f c mempty
    conf           = defaultConfig `mappend` defaults
    defaultB :: (Config m a -> Maybe Bool) -> String -> String -> String
    defaultB f y n = (maybe "" (\b -> ", default " ++ if b
                                                        then y
                                                        else n) $ f conf) :: String
    defaultC :: (Show b) => (Config m a -> Maybe b) -> String
    defaultC f     = maybe "" ((", default " ++) . show) $ f conf
    defaultO :: (Show b) => (Config m a -> Maybe b) -> String
    defaultO f     = maybe ", default off" ((", default " ++) . show) $ f conf
defaultErrorHandler :: MonadSnap m => SomeException -> m ()
defaultErrorHandler e = do
    debug "Snap.Http.Server.Config errorHandler:"
    req <- getsRequest blindParams
    let sm = smsg req
    debug $ toString sm
    logError sm
    finishWith $ setContentType "text/plain; charset=utf-8"
               . setContentLength (fromIntegral $ S.length msg)
               . setResponseStatus 500 "Internal Server Error"
               . setResponseBody errBody
               $ emptyResponse
  where
    blindParams r = r { rqPostParams = rmValues $ rqPostParams r
                      , rqParams     = rmValues $ rqParams r }
    rmValues = Map.map (const ["..."])
    errBody os = Streams.write (Just msgB) os >> return os
    toByteString = S.concat . L.toChunks . toLazyByteString
    smsg req = toByteString $ requestErrorMessage req e
    msg  = toByteString msgB
    msgB = mconcat [
             byteString "A web handler threw an exception. Details:\n"
           , stringUtf8 $ show e
           ]
commandLineConfig :: MonadSnap m
                  => Config m a
                      
                      
                      
                      
                  -> IO (Config m a)
commandLineConfig defaults = extendedCommandLineConfig (optDescrs defaults) f defaults
  where
    
    
    f = undefined
extendedCommandLineConfig :: MonadSnap m
                          => [OptDescr (Maybe (Config m a))]
                             
                          -> (a -> a -> a)
                             
                             
                          -> Config m a
                             
                             
                             
                             
                             
                          -> IO (Config m a)
extendedCommandLineConfig opts combiningFunction defaults = do
    args <- getArgs
    prog <- getProgName
    result <- either (usage prog)
                     return
                     (case getOpt Permute opts args of
                        (f, _, []  ) -> maybe (Left []) Right $
                                        fmap (foldl' combine mempty) $
                                        sequence f
                        (_, _, errs) -> Left errs)
#ifndef PORTABLE
    lang <- getEnv "LANG"
    completeConfig $ mconcat [defaults,
                              mempty {locale = fmap upToUtf8 lang},
                              result]
#else
    completeConfig $ mconcat [defaults, result]
#endif
  where
    usage prog errs = do
        let hdr = "Usage:\n  " ++ prog ++ " [OPTION...]\n\nOptions:"
        let msg = concat errs ++ usageInfo hdr opts
        hPutStrLn stderr msg
        exitFailure
#ifndef PORTABLE
    upToUtf8 = takeWhile $ \c -> isAlpha c || '_' == c
#endif
    combine !a !b = a `mappend` b `mappend` newOther
      where
        
        
        
        combined = do
            x <- getOther a
            y <- getOther b
            return $! combiningFunction x y
        newOther = mempty { other = combined }
fmapArg :: (a -> b) -> ArgDescr a -> ArgDescr b
fmapArg f (NoArg a) = NoArg (f a)
fmapArg f (ReqArg g s) = ReqArg (f . g) s
fmapArg f (OptArg g s) = OptArg (f . g) s
fmapOpt :: (a -> b) -> OptDescr a -> OptDescr b
fmapOpt f (Option s l d e) = Option s l (fmapArg f d) e
requestErrorMessage :: Request -> SomeException -> Builder
requestErrorMessage req e =
    mconcat [ byteString "During processing of request from "
            , byteString $ rqClientAddr req
            , byteString ":"
            , fromShow $ rqClientPort req
            , byteString "\nrequest:\n"
            , fromShow $ show req
            , byteString "\n"
            , msgB
            ]
  where
    msgB = mconcat [
             byteString "A web handler threw an exception. Details:\n"
           , fromShow e
           ]
fromShow :: Show a => a -> Builder
fromShow = stringUtf8 . show