{-# LANGUAGE OverloadedStrings, CPP, TupleSections #-}

module Network.Wai.Logger.Apache (
    IPAddrSource(..)
  , apacheLogStr
  , serverpushLogStr
  ) where

#ifndef MIN_VERSION_base
#define MIN_VERSION_base(x,y,z) 1
#endif
#ifndef MIN_VERSION_wai
#define MIN_VERSION_wai(x,y,z) 1
#endif

import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import Data.List (find)
import Data.Maybe (fromMaybe)
#if MIN_VERSION_base(4,5,0)
import Data.Monoid ((<>), First (..))
#else
import Data.Monoid (mappend)
#endif
import Network.HTTP.Types (Status, statusCode)
import Network.HTTP.Types.Header (HeaderName)
import Network.Wai (Request(..))
import Network.Wai.Logger.IP
import System.Log.FastLogger

-- $setup
-- >>> :set -XOverloadedStrings
-- >>> import Network.Wai (defaultRequest)

-- | Source from which the IP source address of the client is obtained.
data IPAddrSource =
  -- | From the peer address of the HTTP connection.
    FromSocket
  -- | From @X-Real-IP@ or @X-Forwarded-For@ in the HTTP header.
  --
  -- This picks either @X-Real-IP@ or @X-Forwarded-For@ depending on which of these
  -- headers comes first in the ordered list of request headers.
  --
  -- If the @X-Forwarded-For@ header is picked, the value will be assumed to be a
  -- comma-separated list of IP addresses.  The value will be parsed, and the
  -- left-most IP address will be used (which is mostly likely to be the actual
  -- client IP address).
  | FromHeader
  -- | From a custom HTTP header, useful in proxied environment.
  --
  -- The header value will be assumed to be a comma-separated list of IP
  -- addresses.  The value will be parsed, and the left-most IP address will be
  -- used (which is mostly likely to be the actual client IP address).
  --
  -- Note that this still works as expected for a single IP address.
  | FromHeaderCustom [HeaderName]
  -- | Just like 'FromHeader', but falls back on the peer address if header is not found.
  | FromFallback
  -- | This gives you the most flexibility to figure out the IP source address
  -- from the 'Request'.  The returned 'ByteString' is used as the IP source
  -- address.
  | FromRequest (Request -> ByteString)

-- | Apache style log format.
apacheLogStr :: ToLogStr user => IPAddrSource -> (Request -> Maybe user) -> FormattedTime -> Request -> Status -> Maybe Integer -> LogStr
apacheLogStr :: forall user.
ToLogStr user =>
IPAddrSource
-> (Request -> Maybe user)
-> ByteString
-> Request
-> Status
-> Maybe Integer
-> LogStr
apacheLogStr IPAddrSource
ipsrc Request -> Maybe user
userget ByteString
tmstr Request
req Status
status Maybe Integer
msize =
      ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (IPAddrSource -> Request -> ByteString
getSourceIP IPAddrSource
ipsrc Request
req)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" - "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr -> (user -> LogStr) -> Maybe user -> LogStr
forall b a. b -> (a -> b) -> Maybe a -> b
maybe LogStr
"-" user -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Request -> Maybe user
userget Request
req)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" ["
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr ByteString
tmstr
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"] \""
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Request -> ByteString
requestMethod Request
req)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr ByteString
path
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> String -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (HttpVersion -> String
forall a. Show a => a -> String
show (Request -> HttpVersion
httpVersion Request
req))
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"\" "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> String -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Int -> String
forall a. Show a => a -> String
show (Status -> Int
statusCode Status
status))
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> String -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (String -> (Integer -> String) -> Maybe Integer -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"-" Integer -> String
forall a. Show a => a -> String
show Maybe Integer
msize)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" \""
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" Maybe ByteString
mr)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"\" \""
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" Maybe ByteString
mua)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"\"\n"
  where
    path :: ByteString
path = Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
#if !MIN_VERSION_base(4,5,0)
    (<>) = mappend
#endif
#if MIN_VERSION_wai(3,2,0)
    mr :: Maybe ByteString
mr  = Request -> Maybe ByteString
requestHeaderReferer Request
req
    mua :: Maybe ByteString
mua = Request -> Maybe ByteString
requestHeaderUserAgent Request
req
#else
    mr  = lookup "referer" $ requestHeaders req
    mua = lookup "user-agent" $ requestHeaders req
#endif

-- | HTTP/2 Push log format in the Apache style.
serverpushLogStr :: ToLogStr user => IPAddrSource -> (Request -> Maybe user) -> FormattedTime -> Request -> ByteString -> Integer -> LogStr
serverpushLogStr :: forall user.
ToLogStr user =>
IPAddrSource
-> (Request -> Maybe user)
-> ByteString
-> Request
-> ByteString
-> Integer
-> LogStr
serverpushLogStr IPAddrSource
ipsrc Request -> Maybe user
userget ByteString
tmstr Request
req ByteString
path Integer
size =
      ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (IPAddrSource -> Request -> ByteString
getSourceIP IPAddrSource
ipsrc Request
req)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" - "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr -> (user -> LogStr) -> Maybe user -> LogStr
forall b a. b -> (a -> b) -> Maybe a -> b
maybe LogStr
"-" user -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Request -> Maybe user
userget Request
req)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" ["
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr ByteString
tmstr
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"] \"PUSH "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr ByteString
path
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" HTTP/2\" 200 "
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> String -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Integer -> String
forall a. Show a => a -> String
show Integer
size)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
" \""
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr ByteString
ref
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"\" \""
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> ByteString -> LogStr
forall msg. ToLogStr msg => msg -> LogStr
toLogStr (ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"" Maybe ByteString
mua)
  LogStr -> LogStr -> LogStr
forall a. Semigroup a => a -> a -> a
<> LogStr
"\"\n"
  where
    ref :: ByteString
ref  = Request -> ByteString
rawPathInfo Request
req
#if !MIN_VERSION_base(4,5,0)
    (<>) = mappend
#endif
#if MIN_VERSION_wai(3,2,0)
    mua :: Maybe ByteString
mua = Request -> Maybe ByteString
requestHeaderUserAgent Request
req
#else
    mua = lookup "user-agent" $ requestHeaders req
#endif

getSourceIP :: IPAddrSource -> Request -> ByteString
getSourceIP :: IPAddrSource -> Request -> ByteString
getSourceIP IPAddrSource
FromSocket = Request -> ByteString
getSourceFromSocket
getSourceIP IPAddrSource
FromHeader = Request -> ByteString
getSourceFromHeader
getSourceIP IPAddrSource
FromFallback = Request -> ByteString
getSourceFromFallback
getSourceIP (FromHeaderCustom [HeaderName]
hs) = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"-" (Maybe ByteString -> ByteString)
-> (Request -> Maybe ByteString) -> Request -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [HeaderName] -> Request -> Maybe ByteString
getSourceFromHeaderCustom [HeaderName]
hs
getSourceIP (FromRequest Request -> ByteString
fromReq) = Request -> ByteString
fromReq

-- |
-- >>> getSourceFromSocket defaultRequest
-- "0.0.0.0"
getSourceFromSocket :: Request -> ByteString
getSourceFromSocket :: Request -> ByteString
getSourceFromSocket = String -> ByteString
BS.pack (String -> ByteString)
-> (Request -> String) -> Request -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SockAddr -> String
showSockAddr (SockAddr -> String) -> (Request -> SockAddr) -> Request -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> SockAddr
remoteHost

-- |
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [ ("X-Real-IP", "127.0.0.1") ] }
-- "127.0.0.1"
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [ ("X-Forwarded-For", "127.0.0.1") ] }
-- "127.0.0.1"
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [ ("Something", "127.0.0.1") ] }
-- "-"
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [] }
-- "-"
--
-- 'getSourceFromHeader' uses the first instance of either @"X-Real-IP"@ or
-- @"X-Forwarded-For"@ that it finds in the ordered header list:
--
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [ ("X-Real-IP", "1.2.3.4"), ("X-Forwarded-For", "5.6.7.8") ] }
-- "1.2.3.4"
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [ ("X-Forwarded-For", "5.6.7.8"), ("X-Real-IP", "1.2.3.4") ] }
-- "5.6.7.8"
--
-- 'getSourceFromHeader' handles pulling out the first IP in the
-- comma-separated IP list in X-Forwarded-For:
--
-- >>> getSourceFromHeader defaultRequest { requestHeaders = [ ("X-Forwarded-For", "5.6.7.8, 10.11.12.13, 1.2.3.4") ] }
-- "5.6.7.8"
getSourceFromHeader :: Request -> ByteString
getSourceFromHeader :: Request -> ByteString
getSourceFromHeader = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"-" (Maybe ByteString -> ByteString)
-> (Request -> Maybe ByteString) -> Request -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe ByteString
getSource

-- |
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [ ("X-Real-IP", "127.0.0.1") ] }
-- "127.0.0.1"
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [ ("X-Forwarded-For", "127.0.0.1") ] }
-- "127.0.0.1"
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [ ("Something", "127.0.0.1") ] }
-- "0.0.0.0"
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [] }
-- "0.0.0.0"
--
-- 'getSourceFromFallback' uses the first instance of either @"X-Real-IP"@ or
-- @"X-Forwarded-For"@ that it finds in the ordered header list:
--
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [ ("X-Real-IP", "1.2.3.4"), ("X-Forwarded-For", "5.6.7.8") ] }
-- "1.2.3.4"
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [ ("X-Forwarded-For", "5.6.7.8"), ("X-Real-IP", "1.2.3.4") ] }
-- "5.6.7.8"
--
-- 'getSourceFromFallback' handles pulling out the first IP in the
-- comma-separated IP list in X-Forwarded-For:
--
-- >>> getSourceFromFallback defaultRequest { requestHeaders = [ ("X-Forwarded-For", "5.6.7.8, 10.11.12.13, 1.2.3.4") ] }
-- "5.6.7.8"
getSourceFromFallback :: Request -> ByteString
getSourceFromFallback :: Request -> ByteString
getSourceFromFallback Request
req = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe (Request -> ByteString
getSourceFromSocket Request
req) (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Request -> Maybe ByteString
getSource Request
req

-- |
-- >>> getSource defaultRequest { requestHeaders = [ ("X-Real-IP", "127.0.0.1") ] }
-- Just "127.0.0.1"
-- >>> getSource defaultRequest { requestHeaders = [ ("X-Forwarded-For", "127.0.0.1") ] }
-- Just "127.0.0.1"
-- >>> getSource defaultRequest { requestHeaders = [ ("Something", "127.0.0.1") ] }
-- Nothing
-- >>> getSource defaultRequest
-- Nothing
--
-- 'getSource' uses the first instance of either @"X-Real-IP"@ or
-- @"X-Forwarded-For"@ that it finds in the ordered header list:
--
-- >>> getSource defaultRequest { requestHeaders = [ ("X-Real-IP", "1.2.3.4"), ("X-Forwarded-For", "5.6.7.8") ] }
-- Just "1.2.3.4"
-- >>> getSource defaultRequest { requestHeaders = [ ("X-Forwarded-For", "5.6.7.8"), ("X-Real-IP", "1.2.3.4") ] }
-- Just "5.6.7.8"
--
-- 'getSource' handles pulling out the first IP in the comma-separated IP list
-- in X-Forwarded-For:
--
-- >>> getSource defaultRequest { requestHeaders = [ ("X-Forwarded-For", "5.6.7.8, 10.11.12.13, 1.2.3.4") ] }
-- Just "5.6.7.8"
getSource :: Request -> Maybe ByteString
getSource :: Request -> Maybe ByteString
getSource = [(HeaderName, ByteString -> ByteString)]
-> Request -> Maybe ByteString
getSourceFromHeaders [(HeaderName
"x-real-ip", ByteString -> ByteString
forall a. a -> a
id), (HeaderName
"x-forwarded-for", ByteString -> ByteString
firstIpInXFF)]

-- | Pull out the first IP in a comma-separated list of X-Forwarded-For IPs.
--
-- >>> firstIpInXFF "1.2.3.4, 5.6.7.8, 10.11.12.13"
-- "1.2.3.4"
--
-- If there are no commas, just return the whole input ByteString:
--
-- >>> firstIpInXFF "5.6.7.8"
-- "5.6.7.8"
--
-- Note that this function doesn't make sure the input is actually an IP address:
--
-- >>> firstIpInXFF "hello, world"
-- "hello"
firstIpInXFF :: ByteString -> ByteString
firstIpInXFF :: ByteString -> ByteString
firstIpInXFF = (Char -> Bool) -> ByteString -> ByteString
BS.takeWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
',')

getSourceFromHeaders :: [(HeaderName, ByteString -> ByteString)] -> Request -> Maybe ByteString
getSourceFromHeaders :: [(HeaderName, ByteString -> ByteString)]
-> Request -> Maybe ByteString
getSourceFromHeaders [(HeaderName, ByteString -> ByteString)]
headerNamesAndPostProc Request
req = First ByteString -> Maybe ByteString
forall a. First a -> Maybe a
getFirst (First ByteString -> Maybe ByteString)
-> First ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ (Header -> First ByteString) -> [Header] -> First ByteString
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Header -> First ByteString
f ([Header] -> First ByteString) -> [Header] -> First ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [Header]
requestHeaders Request
req
  where
    -- Take a header name and value from the request, and try match it against
    -- the list of headers and post-processing functions.  If it matches,
    -- return the ByteString resulting from applying the post-processing function
    -- to the header value.
    f :: (HeaderName, ByteString) -> First ByteString
    f :: Header -> First ByteString
f (HeaderName
headerNameFromReq, ByteString
headerValFromReq) =
      let maybePostProc :: Maybe (HeaderName, ByteString -> ByteString)
maybePostProc = ((HeaderName, ByteString -> ByteString) -> Bool)
-> [(HeaderName, ByteString -> ByteString)]
-> Maybe (HeaderName, ByteString -> ByteString)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(HeaderName
headerNameFromPostProc, ByteString -> ByteString
_) -> HeaderName
headerNameFromReq HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
headerNameFromPostProc) [(HeaderName, ByteString -> ByteString)]
headerNamesAndPostProc
      in Maybe ByteString -> First ByteString
forall a. Maybe a -> First a
First (Maybe ByteString -> First ByteString)
-> Maybe ByteString -> First ByteString
forall a b. (a -> b) -> a -> b
$ ((HeaderName, ByteString -> ByteString) -> ByteString)
-> Maybe (HeaderName, ByteString -> ByteString) -> Maybe ByteString
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(HeaderName
_, ByteString -> ByteString
postProc) -> ByteString -> ByteString
postProc ByteString
headerValFromReq) Maybe (HeaderName, ByteString -> ByteString)
maybePostProc

-- |
-- >>> getSourceFromHeaderCustom ["x-foobar"] defaultRequest { requestHeaders = [ ("X-catdog", "1.2.3.4"), ("X-Foobar", "5.6.7.8"), ("Other", "1.1.1.1") ] }
-- Just "5.6.7.8"
--
-- If none of the headers in the passed-in list are in the 'Request', then return 'Nothing':
--
-- >>> getSourceFromHeaderCustom ["x-foobar", "baz"] defaultRequest { requestHeaders = [ ("abb", "1.2.3.4"), ("xyz", "5.6.7.8") ] }
-- Nothing
--
-- 'getSourceFromHeaderCustom' uses the first instance of any header in the
-- passed in list that it finds in the ordered header list from the request:
--
-- >>> getSourceFromHeaderCustom ["x-foobar", "baz"] defaultRequest { requestHeaders = [ ("baz", "1.2.3.4"), ("x-foobar", "5.6.7.8") ] }
-- Just "1.2.3.4"
--
-- 'getSourceFromHeaderCustom' splits the value of the header it finds by @,@
-- and uses the first item. This makes it easy to use with headers like
-- @X-Forwarded-For@, which are expected to have a comma-separated list of IP
-- addresses:
--
-- >>> getSourceFromHeaderCustom ["x-foobar"] defaultRequest { requestHeaders = [ ("X-Foobar", "5.6.7.8, 10.11.12.13, 1.2.3.4") ] }
-- Just "5.6.7.8"
getSourceFromHeaderCustom :: [HeaderName] -> Request -> Maybe ByteString
getSourceFromHeaderCustom :: [HeaderName] -> Request -> Maybe ByteString
getSourceFromHeaderCustom [HeaderName]
hs = [(HeaderName, ByteString -> ByteString)]
-> Request -> Maybe ByteString
getSourceFromHeaders ((HeaderName -> (HeaderName, ByteString -> ByteString))
-> [HeaderName] -> [(HeaderName, ByteString -> ByteString)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,ByteString -> ByteString
firstIpInXFF) [HeaderName]
hs)