{-# LANGUAGE OverloadedStrings, CPP #-}
module Yesod.Core.Internal.Request
( parseWaiRequest
, RequestBodyContents
, FileInfo
, fileName
, fileContentType
, fileMove
, mkFileInfoLBS
, mkFileInfoFile
, mkFileInfoSource
, FileUpload (..)
, tooLargeResponse
, tokenKey
, langKey
, textQueryString
, randomString
) where
import Data.String (IsString)
import Control.Arrow (second)
import qualified Network.Wai.Parse as NWP
import qualified Network.Wai as W
import Web.Cookie (parseCookiesText)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy.Char8 as LS8
import Data.Text (Text, pack)
import Network.HTTP.Types (queryToQueryText, Status (Status))
import Data.Maybe (fromMaybe, catMaybes)
import qualified Data.ByteString.Lazy as L
import qualified Data.Set as Set
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8With, decodeUtf8)
import Data.Text.Encoding.Error (lenientDecode)
import Conduit
import Data.Word (Word8, Word64)
import Control.Exception (throwIO)
import Control.Monad ((<=<), liftM)
import Yesod.Core.Types
import qualified Data.Map as Map
import Data.IORef
import qualified Data.Vector.Storable as V
import Data.ByteString.Internal (ByteString (PS))
import qualified Data.Word8 as Word8
limitRequestBody :: Word64 -> W.Request -> IO W.Request
limitRequestBody maxLen req = do
ref <- newIORef maxLen
return req
{ W.requestBody = do
bs <- W.requestBody req
remaining <- readIORef ref
let len = fromIntegral $ S8.length bs
remaining' = remaining - len
if remaining < len
then throwIO $ HCWai $ tooLargeResponse maxLen len
else do
writeIORef ref remaining'
return bs
}
tooLargeResponse :: Word64 -> Word64 -> W.Response
tooLargeResponse maxLen bodyLen = W.responseLBS
(Status 413 "Too Large")
[("Content-Type", "text/plain")]
(L.concat
[ "Request body too large to be processed. The maximum size is "
, (LS8.pack (show maxLen))
, " bytes; your request body was "
, (LS8.pack (show bodyLen))
, " bytes. If you're the developer of this site, you can configure the maximum length with the `maximumContentLength` or `maximumContentLengthIO` function on the Yesod typeclass."
])
parseWaiRequest :: W.Request
-> SessionMap
-> Bool
-> Maybe Word64
-> Either (IO YesodRequest) (IO Int -> IO YesodRequest)
parseWaiRequest env session useToken mmaxBodySize =
case etoken of
Left token -> Left $ mkRequest token
Right mkToken -> Right $ mkRequest <=< mkToken
where
mkRequest token' = do
envLimited <- maybe return limitRequestBody mmaxBodySize env
return YesodRequest
{ reqGetParams = gets
, reqCookies = cookies
, reqWaiRequest = envLimited
, reqLangs = langs''
, reqToken = token'
, reqSession = if useToken
then Map.delete tokenKey session
else session
, reqAccept = httpAccept env
}
gets = textQueryString env
reqCookie = lookup "Cookie" $ W.requestHeaders env
cookies = maybe [] parseCookiesText reqCookie
acceptLang = lookup "Accept-Language" $ W.requestHeaders env
langs = map (pack . S8.unpack) $ maybe [] NWP.parseHttpAccept acceptLang
lookupText k = fmap (decodeUtf8With lenientDecode) . Map.lookup k
langs' = catMaybes [ lookup langKey gets
, lookup langKey cookies
, lookupText langKey session
] ++ langs
langs'' = addTwoLetters (id, Set.empty) langs'
etoken
| useToken =
case Map.lookup tokenKey session of
Just bs -> Left $ Just $ decodeUtf8With lenientDecode bs
Nothing -> Right $ fmap Just . randomString 40
| otherwise = Left Nothing
textQueryString :: W.Request -> [(Text, Text)]
textQueryString = map (second $ fromMaybe "") . queryToQueryText . W.queryString
httpAccept :: W.Request -> [ContentType]
httpAccept = NWP.parseHttpAccept
. fromMaybe S8.empty
. lookup "Accept"
. W.requestHeaders
addTwoLetters :: ([Text] -> [Text], Set.Set Text) -> [Text] -> [Text]
addTwoLetters (toAdd, exist) [] =
filter (`Set.notMember` exist) $ toAdd []
addTwoLetters (toAdd, exist) (l:ls) =
l : addTwoLetters (toAdd', exist') ls
where
(toAdd', exist')
| T.length l > 2 = (toAdd . (T.take 2 l:), exist)
| otherwise = (toAdd, Set.insert l exist)
randomString :: Monad m => Int -> m Int -> m Text
randomString len gen =
liftM (decodeUtf8 . fromByteVector) $ V.replicateM len asciiChar
where
asciiChar =
let loop = do
x <- gen
let y = fromIntegral $ x `mod` 64
case () of
()
| y < 26 -> return $ y + Word8._A
| y < 52 -> return $ y + Word8._a - 26
| y < 62 -> return $ y + Word8._0 - 52
| otherwise -> loop
in loop
fromByteVector :: V.Vector Word8 -> ByteString
fromByteVector v =
PS fptr offset idx
where
(fptr, offset, idx) = V.unsafeToForeignPtr v
{-# INLINE fromByteVector #-}
mkFileInfoLBS :: Text -> Text -> L.ByteString -> FileInfo
mkFileInfoLBS name ct lbs =
FileInfo name ct (sourceLazy lbs) (`L.writeFile` lbs)
mkFileInfoFile :: Text -> Text -> FilePath -> FileInfo
mkFileInfoFile name ct fp = FileInfo name ct (sourceFile fp) (\dst -> runConduitRes $ sourceFile fp .| sinkFile dst)
mkFileInfoSource :: Text -> Text -> ConduitT () ByteString (ResourceT IO) () -> FileInfo
mkFileInfoSource name ct src = FileInfo name ct src (\dst -> runConduitRes $ src .| sinkFile dst)
tokenKey :: IsString a => a
tokenKey = "_TOKEN"
langKey :: IsString a => a
langKey = "_LANG"