{-# LANGUAGE DeriveDataTypeable #-}
module Network.Wai.Request
( appearsSecure
, guessApproot
, RequestSizeException(..)
, requestSizeCheck
) where
import Control.Exception (Exception, throwIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C
import Data.IORef (atomicModifyIORef', newIORef)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word64)
import Network.HTTP.Types (HeaderName)
import Network.Wai
appearsSecure :: Request -> Bool
Request
request = Request -> Bool
isSecure Request
request Bool -> Bool -> Bool
|| ((HeaderName, ByteString -> Bool) -> Bool)
-> [(HeaderName, ByteString -> Bool)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((HeaderName -> (ByteString -> Bool) -> Bool)
-> (HeaderName, ByteString -> Bool) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry HeaderName -> (ByteString -> Bool) -> Bool
matchHeader)
[ (HeaderName
"HTTPS" , (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"on"))
, (HeaderName
"HTTP_X_FORWARDED_SSL" , (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"on"))
, (HeaderName
"HTTP_X_FORWARDED_SCHEME", (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https"))
, (HeaderName
"HTTP_X_FORWARDED_PROTO" , ([ByteString] -> [ByteString] -> Bool
forall a. Eq a => a -> a -> Bool
== [ByteString
"https"]) ([ByteString] -> Bool)
-> (ByteString -> [ByteString]) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
1 ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
C.split Char
',')
, (HeaderName
"X-Forwarded-Proto" , (ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"https"))
]
where
matchHeader :: HeaderName -> (ByteString -> Bool) -> Bool
matchHeader :: HeaderName -> (ByteString -> Bool) -> Bool
matchHeader HeaderName
h ByteString -> Bool
f = Bool -> (ByteString -> Bool) -> Maybe ByteString -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False ByteString -> Bool
f (Maybe ByteString -> Bool) -> Maybe ByteString -> Bool
forall a b. (a -> b) -> a -> b
$ HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
h ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
request
guessApproot :: Request -> ByteString
guessApproot :: Request -> ByteString
guessApproot Request
req =
(if Request -> Bool
appearsSecure Request
req then ByteString
"https://" else ByteString
"http://") ByteString -> ByteString -> ByteString
`S.append`
ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
"localhost" (Request -> Maybe ByteString
requestHeaderHost Request
req)
newtype RequestSizeException
= RequestSizeException Word64
deriving (RequestSizeException -> RequestSizeException -> Bool
(RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> Eq RequestSizeException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RequestSizeException -> RequestSizeException -> Bool
$c/= :: RequestSizeException -> RequestSizeException -> Bool
== :: RequestSizeException -> RequestSizeException -> Bool
$c== :: RequestSizeException -> RequestSizeException -> Bool
Eq, Eq RequestSizeException
Eq RequestSizeException
-> (RequestSizeException -> RequestSizeException -> Ordering)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException -> RequestSizeException -> Bool)
-> (RequestSizeException
-> RequestSizeException -> RequestSizeException)
-> (RequestSizeException
-> RequestSizeException -> RequestSizeException)
-> Ord RequestSizeException
RequestSizeException -> RequestSizeException -> Bool
RequestSizeException -> RequestSizeException -> Ordering
RequestSizeException
-> RequestSizeException -> RequestSizeException
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RequestSizeException
-> RequestSizeException -> RequestSizeException
$cmin :: RequestSizeException
-> RequestSizeException -> RequestSizeException
max :: RequestSizeException
-> RequestSizeException -> RequestSizeException
$cmax :: RequestSizeException
-> RequestSizeException -> RequestSizeException
>= :: RequestSizeException -> RequestSizeException -> Bool
$c>= :: RequestSizeException -> RequestSizeException -> Bool
> :: RequestSizeException -> RequestSizeException -> Bool
$c> :: RequestSizeException -> RequestSizeException -> Bool
<= :: RequestSizeException -> RequestSizeException -> Bool
$c<= :: RequestSizeException -> RequestSizeException -> Bool
< :: RequestSizeException -> RequestSizeException -> Bool
$c< :: RequestSizeException -> RequestSizeException -> Bool
compare :: RequestSizeException -> RequestSizeException -> Ordering
$ccompare :: RequestSizeException -> RequestSizeException -> Ordering
$cp1Ord :: Eq RequestSizeException
Ord, Typeable)
instance Exception RequestSizeException
instance Show RequestSizeException where
showsPrec :: Int -> RequestSizeException -> ShowS
showsPrec Int
p (RequestSizeException Word64
limit) =
String -> ShowS
showString String
"Request Body is larger than " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word64 -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
p Word64
limit ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" bytes."
requestSizeCheck :: Word64 -> Request -> IO Request
requestSizeCheck :: Word64 -> Request -> IO Request
requestSizeCheck Word64
maxSize Request
req =
case Request -> RequestBodyLength
requestBodyLength Request
req of
KnownLength Word64
len ->
if Word64
len Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Word64
maxSize
then Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ Request
req { requestBody :: IO ByteString
requestBody = RequestSizeException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (Word64 -> RequestSizeException
RequestSizeException Word64
maxSize) }
else Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return Request
req
RequestBodyLength
ChunkedBody -> do
IORef Word64
currentSize <- Word64 -> IO (IORef Word64)
forall a. a -> IO (IORef a)
newIORef Word64
0
Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return (Request -> IO Request) -> Request -> IO Request
forall a b. (a -> b) -> a -> b
$ Request
req
{ requestBody :: IO ByteString
requestBody = do
ByteString
bs <- Request -> IO ByteString
requestBody Request
req
Word64
total <-
IORef Word64 -> (Word64 -> (Word64, Word64)) -> IO Word64
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Word64
currentSize ((Word64 -> (Word64, Word64)) -> IO Word64)
-> (Word64 -> (Word64, Word64)) -> IO Word64
forall a b. (a -> b) -> a -> b
$ \Word64
sz ->
let nextSize :: Word64
nextSize = Word64
sz Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
S.length ByteString
bs)
in (Word64
nextSize, Word64
nextSize)
if Word64
total Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Word64
maxSize
then RequestSizeException -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (Word64 -> RequestSizeException
RequestSizeException Word64
maxSize)
else ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
}