module Network.Wai.Middleware.Logging
  ( addThreadContext
  , requestLogger
  ) where

import Prelude

import Blammo.Logging
import Control.Arrow ((***))
import Control.Monad.IO.Unlift (withRunInIO)
import Data.Aeson
import qualified Data.Aeson.Compat as Key
import qualified Data.Aeson.Compat as KeyMap
import qualified Data.CaseInsensitive as CI
import Data.Text (pack)
import Data.Text.Encoding (decodeUtf8)
import Network.HTTP.Types.Header (Header, HeaderName)
import Network.HTTP.Types.Status (Status(..))
import Network.Wai
  ( Middleware
  , Request
  , Response
  , rawPathInfo
  , rawQueryString
  , requestHeaders
  , requestMethod
  , responseHeaders
  , responseStatus
  )
import qualified System.Clock as Clock

-- | Add context to any logging done from the request-handling thread
addThreadContext :: [Pair] -> Middleware
addThreadContext :: [Pair] -> Middleware
addThreadContext [Pair]
context Application
app Request
request Response -> IO ResponseReceived
respond = do
  [Pair] -> IO ResponseReceived -> IO ResponseReceived
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Pair] -> m a -> m a
withThreadContext [Pair]
context (IO ResponseReceived -> IO ResponseReceived)
-> IO ResponseReceived -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ do
    Application
app Request
request Response -> IO ResponseReceived
respond

-- | Log requests (more accurately, responses) as they happen
--
-- In JSON format, logged messages look like:
--
-- @
-- {
--   ...
--   message: {
--     text: "GET /foo/bar => 200 OK",
--     meta: {
--       method: "GET",
--       path: "/foo/bar",
--       query: "?baz=bat&quix=quo",
--       status: {
--         code: 200,
--         message: "OK"
--       },
--       durationMs: 1322.2,
--       requestHeaders: {
--         Authorization: "***",
--         Accept: "text/html",
--         Cookie: "***"
--       },
--       responseHeaders: {
--         Set-Cookie: "***",
--         Expires: "never"
--       }
--     }
--   }
-- }
-- @
--
requestLogger :: HasLogger env => env -> Middleware
requestLogger :: env -> Middleware
requestLogger env
env Application
app Request
req Response -> IO ResponseReceived
respond =
  env -> LoggingT IO ResponseReceived -> IO ResponseReceived
forall (m :: * -> *) env a.
(MonadIO m, HasLogger env) =>
env -> LoggingT m a -> m a
runLoggerLoggingT env
env (LoggingT IO ResponseReceived -> IO ResponseReceived)
-> LoggingT IO ResponseReceived -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ ((forall a. LoggingT IO a -> IO a) -> IO ResponseReceived)
-> LoggingT IO ResponseReceived
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO (((forall a. LoggingT IO a -> IO a) -> IO ResponseReceived)
 -> LoggingT IO ResponseReceived)
-> ((forall a. LoggingT IO a -> IO a) -> IO ResponseReceived)
-> LoggingT IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \forall a. LoggingT IO a -> IO a
runInIO -> do
    TimeSpec
begin <- IO TimeSpec
getTime
    Application
app Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response
resp -> do
      ResponseReceived
recvd <- Response -> IO ResponseReceived
respond Response
resp
      Double
duration <- TimeSpec -> Double
toMillis (TimeSpec -> Double)
-> (TimeSpec -> TimeSpec) -> TimeSpec -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeSpec -> TimeSpec -> TimeSpec
forall a. Num a => a -> a -> a
subtract TimeSpec
begin (TimeSpec -> Double) -> IO TimeSpec -> IO Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO TimeSpec
getTime
      ResponseReceived
recvd ResponseReceived -> IO () -> IO ResponseReceived
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ LoggingT IO () -> IO ()
forall a. LoggingT IO a -> IO a
runInIO (Double -> Request -> Response -> LoggingT IO ()
forall (m :: * -> *).
MonadLogger m =>
Double -> Request -> Response -> m ()
logResponse Double
duration Request
req Response
resp)
 where
  getTime :: IO TimeSpec
getTime = Clock -> IO TimeSpec
Clock.getTime Clock
Clock.Monotonic

  toMillis :: TimeSpec -> Double
toMillis TimeSpec
x = Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (TimeSpec -> Integer
Clock.toNanoSecs TimeSpec
x) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
nsPerMs

logResponse :: MonadLogger m => Double -> Request -> Response -> m ()
logResponse :: Double -> Request -> Response -> m ()
logResponse Double
duration Request
req Response
resp
  | Status -> Int
statusCode Status
status Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
500 = Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logError (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ Text
message Text -> [Series] -> Message
:# [Series]
details
  | Status -> Int
statusCode Status
status Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
404 = Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logDebug (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ Text
message Text -> [Series] -> Message
:# [Series]
details
  | Status -> Int
statusCode Status
status Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
400 = Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logWarn (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ Text
message Text -> [Series] -> Message
:# [Series]
details
  | Bool
otherwise = Message -> m ()
forall (m :: * -> *).
(HasCallStack, MonadLogger m) =>
Message -> m ()
logDebug (Message -> m ()) -> Message -> m ()
forall a b. (a -> b) -> a -> b
$ Text
message Text -> [Series] -> Message
:# [Series]
details
 where
  message :: Text
message =
    ByteString -> Text
decodeUtf8 (Request -> ByteString
requestMethod Request
req)
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" "
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ByteString -> Text
decodeUtf8 (Request -> ByteString
rawPathInfo Request
req)
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" => "
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
pack (Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ Status -> Int
statusCode Status
status)
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" "
      Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ByteString -> Text
decodeUtf8 (Status -> ByteString
statusMessage Status
status)

  details :: [Series]
details =
    [ Key
"method" Key -> Text -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> Text
decodeUtf8 (Request -> ByteString
requestMethod Request
req)
    , Key
"path" Key -> Text -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> Text
decodeUtf8 (Request -> ByteString
rawPathInfo Request
req)
    , Key
"query" Key -> Text -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> Text
decodeUtf8 (Request -> ByteString
rawQueryString Request
req)
    , Key
"status" Key -> Value -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= [Pair] -> Value
object
      [ Key
"code" Key -> Int -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= Status -> Int
statusCode Status
status
      , Key
"message" Key -> Text -> Pair
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ByteString -> Text
decodeUtf8 (Status -> ByteString
statusMessage Status
status)
      ]
    , Key
"durationMs" Key -> Double -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= Double
duration
    , Key
"requestHeaders"
      Key -> Value -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= [HeaderName] -> [Header] -> Value
headerObject [HeaderName
"authorization", HeaderName
"cookie"] (Request -> [Header]
requestHeaders Request
req)
    , Key
"responseHeaders" Key -> Value -> Series
forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= [HeaderName] -> [Header] -> Value
headerObject [HeaderName
"set-cookie"] (Response -> [Header]
responseHeaders Response
resp)
    ]

  status :: Status
status = Response -> Status
responseStatus Response
resp

headerObject :: [HeaderName] -> [Header] -> Value
headerObject :: [HeaderName] -> [Header] -> Value
headerObject [HeaderName]
redact = Object -> Value
Object (Object -> Value) -> ([Header] -> Object) -> [Header] -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Pair] -> Object
forall v. [(Key, v)] -> KeyMap v
KeyMap.fromList ([Pair] -> Object) -> ([Header] -> [Pair]) -> [Header] -> Object
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header -> Pair) -> [Header] -> [Pair]
forall a b. (a -> b) -> [a] -> [b]
map (Header -> Pair
mung (Header -> Pair) -> (Header -> Header) -> Header -> Pair
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> Header
forall b. IsString b => (HeaderName, b) -> (HeaderName, b)
hide)
 where
  mung :: Header -> Pair
mung = Text -> Key
Key.fromText (Text -> Key) -> (HeaderName -> Text) -> HeaderName -> Key
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8 (ByteString -> Text)
-> (HeaderName -> ByteString) -> HeaderName -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> ByteString
forall s. CI s -> s
CI.foldedCase (HeaderName -> Key) -> (ByteString -> Value) -> Header -> Pair
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Text -> Value
String (Text -> Value) -> (ByteString -> Text) -> ByteString -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
decodeUtf8
  hide :: (HeaderName, b) -> (HeaderName, b)
hide (HeaderName
k, b
v)
    | HeaderName
k HeaderName -> [HeaderName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName]
redact = (HeaderName
k, b
"***")
    | Bool
otherwise = (HeaderName
k, b
v)

nsPerMs :: Double
nsPerMs :: Double
nsPerMs = Double
1000000