module Network.Wai.Middleware.Logging
  ( addThreadContext
  , addThreadContextFromRequest
  , requestLogger
  , requestLoggerWith

  -- * Configuration
  , Config
  , defaultConfig
  , setConfigLogSource
  , setConfigGetClientIp
  , setConfigGetDestinationIp
  ) where

import Prelude

import Blammo.Logging
import Control.Applicative ((<|>))
import Control.Arrow ((***))
import Control.Monad.IO.Unlift (withRunInIO)
import Data.Aeson
import qualified Data.Aeson.Compat as Key
import qualified Data.Aeson.Compat as KeyMap
import Data.ByteString (ByteString)
import qualified Data.CaseInsensitive as CI
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Text (Text, pack)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With)
import Data.Text.Encoding.Error (lenientDecode)
import Network.HTTP.Types.Header (Header, HeaderName)
import Network.HTTP.Types.Status (Status(..))
import Network.Wai
  ( Middleware
  , Request
  , Response
  , rawPathInfo
  , rawQueryString
  , remoteHost
  , requestHeaders
  , requestMethod
  , responseHeaders
  , responseStatus
  )
import qualified System.Clock as Clock

-- | Add context to any logging done from the request-handling thread
addThreadContext :: [Pair] -> Middleware
addThreadContext :: [Pair] -> Middleware
addThreadContext = (Request -> [Pair]) -> Middleware
addThreadContextFromRequest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const

-- | 'addThreadContext', but have the 'Request' available
addThreadContextFromRequest :: (Request -> [Pair]) -> Middleware
addThreadContextFromRequest :: (Request -> [Pair]) -> Middleware
addThreadContextFromRequest Request -> [Pair]
toContext Application
app Request
request Response -> IO ResponseReceived
respond = do
  forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Pair] -> m a -> m a
withThreadContext (Request -> [Pair]
toContext Request
request) forall a b. (a -> b) -> a -> b
$ do
    Application
app Request
request Response -> IO ResponseReceived
respond

-- | Log requests (more accurately, responses) as they happen
--
-- In JSON format, logged messages look like:
--
-- @
-- {
--   ...
--   message: {
--     text: "GET /foo/bar => 200 OK",
--     meta: {
--       method: "GET",
--       path: "/foo/bar",
--       query: "?baz=bat&quix=quo",
--       status: {
--         code: 200,
--         message: "OK"
--       },
--       durationMs: 1322.2,
--       requestHeaders: {
--         Authorization: "***",
--         Accept: "text/html",
--         Cookie: "***"
--       },
--       responseHeaders: {
--         Set-Cookie: "***",
--         Expires: "never"
--       }
--     }
--   }
-- }
-- @
--
requestLogger :: HasLogger env => env -> Middleware
requestLogger :: forall env. HasLogger env => env -> Middleware
requestLogger = forall env. HasLogger env => Config -> env -> Middleware
requestLoggerWith Config
defaultConfig

data Config = Config
  { Config -> LogSource
cLogSource :: LogSource
  , Config -> Request -> LogSource
cGetClientIp :: Request -> Text
  , Config -> Request -> Maybe LogSource
cGetDestinationIp :: Request -> Maybe Text
  }

defaultConfig :: Config
defaultConfig :: Config
defaultConfig = Config
  { cLogSource :: LogSource
cLogSource = LogSource
"requestLogger"
  , cGetClientIp :: Request -> LogSource
cGetClientIp = \Request
req ->
    forall a. a -> Maybe a -> a
fromMaybe (String -> LogSource
pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ Request -> SockAddr
remoteHost Request
req)
      forall a b. (a -> b) -> a -> b
$ (LogSource -> Maybe LogSource
firstValue forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
"x-forwarded-for" Request
req)
      forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
"x-real-ip" Request
req
  , cGetDestinationIp :: Request -> Maybe LogSource
cGetDestinationIp = HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
"x-real-ip"
  }
  where firstValue :: LogSource -> Maybe LogSource
firstValue = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. LogSource -> Bool
T.null) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map LogSource -> LogSource
T.strip forall b c a. (b -> c) -> (a -> b) -> a -> c
. LogSource -> LogSource -> [LogSource]
T.splitOn LogSource
","

lookupRequestHeader :: HeaderName -> Request -> Maybe Text
lookupRequestHeader :: HeaderName -> Request -> Maybe LogSource
lookupRequestHeader HeaderName
h = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> LogSource
decodeUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
h forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> RequestHeaders
requestHeaders

-- | Change the source used for log messages
--
-- Default is @requestLogger@.
--
setConfigLogSource :: LogSource -> Config -> Config
setConfigLogSource :: LogSource -> Config -> Config
setConfigLogSource LogSource
x Config
c = Config
c { cLogSource :: LogSource
cLogSource = LogSource
x }

-- | Change how the @clientIp@ field is determined
--
-- Default is looking up the first value in @x-forwarded-for@, then the
-- @x-real-ip@ header, then finally falling back to 'Network.Wai.remoteHost'.
--
setConfigGetClientIp :: (Request -> Text) -> Config -> Config
setConfigGetClientIp :: (Request -> LogSource) -> Config -> Config
setConfigGetClientIp Request -> LogSource
x Config
c = Config
c { cGetClientIp :: Request -> LogSource
cGetClientIp = Request -> LogSource
x }

-- | Change how the @destinationIp@ field is determined
--
-- Default is looking up the @x-real-ip@ header.
--
-- __NOTE__: Our default uses a somewhat loose definition of /destination/. It
-- would be more accurate to report the resolved IP address of the @Host@
-- header, but we don't have that available. Our default of @x-real-ip@ favors
-- containerized Warp on AWS/ECS, where this value holds the ECS target
-- container's IP address. This is valuable debugging information and could, if
-- you squint, be considered a /destination/.
--
setConfigGetDestinationIp :: (Request -> Maybe Text) -> Config -> Config
setConfigGetDestinationIp :: (Request -> Maybe LogSource) -> Config -> Config
setConfigGetDestinationIp Request -> Maybe LogSource
x Config
c = Config
c { cGetDestinationIp :: Request -> Maybe LogSource
cGetDestinationIp = Request -> Maybe LogSource
x }

requestLoggerWith :: HasLogger env => Config -> env -> Middleware
requestLoggerWith :: forall env. HasLogger env => Config -> env -> Middleware
requestLoggerWith Config
config env
env Application
app Request
req Response -> IO ResponseReceived
respond =
  forall (m :: * -> *) env a.
(MonadUnliftIO m, HasLogger env) =>
env -> LoggingT m a -> m a
runLoggerLoggingT env
env forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO forall a b. (a -> b) -> a -> b
$ \forall a. LoggingT IO a -> IO a
runInIO -> do
    TimeSpec
begin <- IO TimeSpec
getTime
    Application
app Request
req forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
      ResponseReceived
recvd <- Response -> IO ResponseReceived
respond Response
resp
      Double
duration <- TimeSpec -> Double
toMillis forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
subtract TimeSpec
begin forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO TimeSpec
getTime
      ResponseReceived
recvd forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall a. LoggingT IO a -> IO a
runInIO (forall (m :: * -> *).
MonadLogger m =>
Config -> Double -> Request -> Response -> m ()
logResponse Config
config Double
duration Request
req Response
resp)
 where
  getTime :: IO TimeSpec
getTime = Clock -> IO TimeSpec
Clock.getTime Clock
Clock.Monotonic

  toMillis :: TimeSpec -> Double
toMillis TimeSpec
x = forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
Clock.toNanoSecs TimeSpec
x) forall a. Fractional a => a -> a -> a
/ Double
nsPerMs

logResponse :: MonadLogger m => Config -> Double -> Request -> Response -> m ()
logResponse :: forall (m :: * -> *).
MonadLogger m =>
Config -> Double -> Request -> Response -> m ()
logResponse Config {LogSource
Request -> Maybe LogSource
Request -> LogSource
cGetDestinationIp :: Request -> Maybe LogSource
cGetClientIp :: Request -> LogSource
cLogSource :: LogSource
cGetDestinationIp :: Config -> Request -> Maybe LogSource
cGetClientIp :: Config -> Request -> LogSource
cLogSource :: Config -> LogSource
..} Double
duration Request
req Response
resp
  | Status -> Int
statusCode Status
status forall a. Ord a => a -> a -> Bool
>= Int
500 = forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logErrorNS LogSource
cLogSource forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
  | Status -> Int
statusCode Status
status forall a. Eq a => a -> a -> Bool
== Int
404 = forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logDebugNS LogSource
cLogSource forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
  | Status -> Int
statusCode Status
status forall a. Ord a => a -> a -> Bool
>= Int
400 = forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logWarnNS LogSource
cLogSource forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
  | Bool
otherwise = forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
LogSource -> Message -> m ()
logDebugNS LogSource
cLogSource forall a b. (a -> b) -> a -> b
$ LogSource
message LogSource -> [SeriesElem] -> Message
:# [SeriesElem]
details
 where
  message :: LogSource
message =
    ByteString -> LogSource
decodeUtf8 (Request -> ByteString
requestMethod Request
req)
      forall a. Semigroup a => a -> a -> a
<> LogSource
" "
      forall a. Semigroup a => a -> a -> a
<> ByteString -> LogSource
decodeUtf8 (Request -> ByteString
rawPathInfo Request
req)
      forall a. Semigroup a => a -> a -> a
<> LogSource
" => "
      forall a. Semigroup a => a -> a -> a
<> String -> LogSource
pack (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode Status
status)
      forall a. Semigroup a => a -> a -> a
<> LogSource
" "
      forall a. Semigroup a => a -> a -> a
<> ByteString -> LogSource
decodeUtf8 (Status -> ByteString
statusMessage Status
status)

  details :: [SeriesElem]
details =
    [ Key
"method" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Request -> ByteString
requestMethod Request
req)
    , Key
"path" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Request -> ByteString
rawPathInfo Request
req)
    , Key
"query" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Request -> ByteString
rawQueryString Request
req)
    , Key
"status" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= [Pair] -> Value
object
      [ Key
"code" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= Status -> Int
statusCode Status
status
      , Key
"message" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> LogSource
decodeUtf8 (Status -> ByteString
statusMessage Status
status)
      ]
    , Key
"clientIp" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= Request -> LogSource
cGetClientIp Request
req
    , Key
"destinationIp" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= Request -> Maybe LogSource
cGetDestinationIp Request
req
    , Key
"durationMs" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= Double
duration
    , Key
"requestHeaders"
      forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= [HeaderName] -> RequestHeaders -> Value
headerObject [HeaderName
"authorization", HeaderName
"cookie"] (Request -> RequestHeaders
requestHeaders Request
req)
    , Key
"responseHeaders" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= [HeaderName] -> RequestHeaders -> Value
headerObject [HeaderName
"set-cookie"] (Response -> RequestHeaders
responseHeaders Response
resp)
    ]

  status :: Status
status = Response -> Status
responseStatus Response
resp

headerObject :: [HeaderName] -> [Header] -> Value
headerObject :: [HeaderName] -> RequestHeaders -> Value
headerObject [HeaderName]
redact = Object -> Value
Object forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. [(Key, v)] -> KeyMap v
KeyMap.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (Header -> Pair
mung forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {b}. IsString b => (HeaderName, b) -> (HeaderName, b)
hide)
 where
  mung :: Header -> Pair
mung = LogSource -> Key
Key.fromText forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> LogSource
decodeUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s. CI s -> s
CI.foldedCase forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** LogSource -> Value
String forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> LogSource
decodeUtf8
  hide :: (HeaderName, b) -> (HeaderName, b)
hide (HeaderName
k, b
v)
    | HeaderName
k forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName]
redact = (HeaderName
k, b
"***")
    | Bool
otherwise = (HeaderName
k, b
v)

nsPerMs :: Double
nsPerMs :: Double
nsPerMs = Double
1000000

decodeUtf8 :: ByteString -> Text
decodeUtf8 :: ByteString -> LogSource
decodeUtf8 = OnDecodeError -> ByteString -> LogSource
decodeUtf8With OnDecodeError
lenientDecode