{-# LANGUAGE DeriveDataTypeable #-}
-- | Some helpers for interrogating a WAI 'Request'.

module Network.Wai.Request
    ( appearsSecure
    , guessApproot
    , RequestSizeException(..)
    , requestSizeCheck
    ) where

import Data.ByteString (ByteString)
import Data.Maybe (fromMaybe)
import Network.HTTP.Types (HeaderName)
import Network.Wai

import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C
import Control.Exception (Exception, throwIO)
import Data.Typeable (Typeable)
import Data.Word (Word64)
import Data.IORef (atomicModifyIORef', newIORef)


-- | Does this request appear to have been made over an SSL connection?
--
-- This function first checks @'isSecure'@, but also checks for headers that may
-- indicate a secure connection even in the presence of reverse proxies.
--
-- Note: these headers can be easily spoofed, so decisions which require a true
-- SSL connection (i.e. sending sensitive information) should only use
-- @'isSecure'@. This is not always the case though: for example, deciding to
-- force a non-SSL request to SSL by redirect. One can safely choose not to
-- redirect when the request /appears/ secure, even if it's actually not.
--
-- @since 3.0.7
appearsSecure :: Request -> Bool
appearsSecure :: Request -> Bool
appearsSecure Request
request = Request -> Bool
isSecure Request
request Bool -> Bool -> Bool
|| ((HeaderName, ByteString -> Bool) -> Bool)
-> [(HeaderName, ByteString -> Bool)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((HeaderName -> (ByteString -> Bool) -> Bool)
-> (HeaderName, ByteString -> Bool) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry HeaderName -> (ByteString -> Bool) -> Bool
matchHeader)
    [ (HeaderName
"HTTPS"                  , (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"on"))
    , (HeaderName
"HTTP_X_FORWARDED_SSL"   , (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"on"))
    , (HeaderName
"HTTP_X_FORWARDED_SCHEME", (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https"))
    , (HeaderName
"HTTP_X_FORWARDED_PROTO" , (([ByteString] -> [ByteString] -> Bool
forall a. Eq a => a -> a -> Bool
== [ByteString
"https"]) ([ByteString] -> Bool)
-> (ByteString -> [ByteString]) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
1 ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
C.split Char
','))
    , (HeaderName
"X-Forwarded-Proto"      , (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https")) -- Used by Nginx and AWS ELB.
    ]

  where
    matchHeader :: HeaderName -> (ByteString -> Bool) -> Bool
    matchHeader :: HeaderName -> (ByteString -> Bool) -> Bool
matchHeader HeaderName
h ByteString -> Bool
f = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ByteString -> Bool
f (Maybe ByteString -> Bool) -> Maybe ByteString -> Bool
forall a b. (a -> b) -> a -> b
$ HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
h ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
request

-- | Guess the \"application root\" based on the given request.
--
-- The application root is the basis for forming URLs pointing at the current
-- application. For more information and relevant caveats, please see
-- "Network.Wai.Middleware.Approot".
--
-- @since 3.0.7
guessApproot :: Request -> ByteString
guessApproot :: Request -> ByteString
guessApproot Request
req =
    (if Request -> Bool
appearsSecure Request
req then ByteString
"https://" else ByteString
"http://") ByteString -> ByteString -> ByteString
`S.append`
    (ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"localhost" (Maybe ByteString -> ByteString) -> Maybe ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Request -> Maybe ByteString
requestHeaderHost Request
req)

-- | see 'requestSizeCheck'
--
-- @since 3.0.15
data RequestSizeException
    = RequestSizeException Word64
    deriving (RequestSizeException -> RequestSizeException -> Bool
(RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> Eq RequestSizeException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RequestSizeException -> RequestSizeException -> Bool
$c/= :: RequestSizeException -> RequestSizeException -> Bool
== :: RequestSizeException -> RequestSizeException -> Bool
$c== :: RequestSizeException -> RequestSizeException -> Bool
Eq, Eq RequestSizeException
Eq RequestSizeException
-> (RequestSizeException -> RequestSizeException -> Ordering)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException
    -> RequestSizeException -> RequestSizeException)
-> (RequestSizeException
    -> RequestSizeException -> RequestSizeException)
-> Ord RequestSizeException
RequestSizeException -> RequestSizeException -> Bool
RequestSizeException -> RequestSizeException -> Ordering
RequestSizeException
-> RequestSizeException -> RequestSizeException
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RequestSizeException
-> RequestSizeException -> RequestSizeException
$cmin :: RequestSizeException
-> RequestSizeException -> RequestSizeException
max :: RequestSizeException
-> RequestSizeException -> RequestSizeException
$cmax :: RequestSizeException
-> RequestSizeException -> RequestSizeException
>= :: RequestSizeException -> RequestSizeException -> Bool
$c>= :: RequestSizeException -> RequestSizeException -> Bool
> :: RequestSizeException -> RequestSizeException -> Bool
$c> :: RequestSizeException -> RequestSizeException -> Bool
<= :: RequestSizeException -> RequestSizeException -> Bool
$c<= :: RequestSizeException -> RequestSizeException -> Bool
< :: RequestSizeException -> RequestSizeException -> Bool
$c< :: RequestSizeException -> RequestSizeException -> Bool
compare :: RequestSizeException -> RequestSizeException -> Ordering
$ccompare :: RequestSizeException -> RequestSizeException -> Ordering
$cp1Ord :: Eq RequestSizeException
Ord, Typeable)

instance Exception RequestSizeException

instance Show RequestSizeException where
    showsPrec :: Int -> RequestSizeException -> ShowS
showsPrec Int
p (RequestSizeException Word64
limit) =
        String -> ShowS
showString (String
"Request Body is larger than ") ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word64 -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
p Word64
limit ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" bytes."

-- | Check request body size to avoid server crash when request is too large.
--
-- This function first checks @'requestBodyLength'@, if content-length is known
-- but larger than limit, or it's unknown but we have received too many chunks,
-- a 'RequestSizeException' are thrown when user use @'requestBody'@ to extract
-- request body inside IO.
--
-- @since 3.0.15
requestSizeCheck :: Word64 -> Request -> IO Request
requestSizeCheck :: Word64 -> Request -> IO Request
requestSizeCheck Word64
maxSize Request
req =
    case Request -> RequestBodyLength
requestBodyLength Request
req of
        KnownLength Word64
len  ->
            if Word64
len Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Word64
maxSize
                then Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ Request
req { requestBody :: IO ByteString
requestBody = RequestSizeException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (Word64 -> RequestSizeException
RequestSizeException Word64
maxSize) }
                else Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return Request
req
        RequestBodyLength
ChunkedBody      -> do
            IORef Word64
currentSize <- Word64 -> IO (IORef Word64)
forall a. a -> IO (IORef a)
newIORef Word64
0
            Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ Request
req
                { requestBody :: IO ByteString
requestBody = do
                    ByteString
bs <- Request -> IO ByteString
requestBody Request
req
                    Word64
total <-
                        IORef Word64 -> (Word64 -> (Word64, Word64)) -> IO Word64
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Word64
currentSize ((Word64 -> (Word64, Word64)) -> IO Word64)
-> (Word64 -> (Word64, Word64)) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Word64
sz ->
                            let nextSize :: Word64
nextSize = Word64
sz Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
S.length ByteString
bs)
                            in (Word64
nextSize, Word64
nextSize)
                    if Word64
total Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Word64
maxSize
                    then RequestSizeException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (Word64 -> RequestSizeException
RequestSizeException Word64
maxSize)
                    else ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
                }