{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Snap.Internal.Util.FileUploads
(
handleFileUploads
, handleMultipart
, PartProcessor
, PartInfo(..)
, PartDisposition(..)
, toPartDisposition
, UploadPolicy(..)
, defaultUploadPolicy
, doProcessFormInputs
, setProcessFormInputs
, getMaximumFormInputSize
, setMaximumFormInputSize
, getMaximumNumberOfFormInputs
, setMaximumNumberOfFormInputs
, getMinimumUploadRate
, setMinimumUploadRate
, getMinimumUploadSeconds
, setMinimumUploadSeconds
, getUploadTimeout
, setUploadTimeout
, PartUploadPolicy(..)
, disallow
, allowWithMaximumSize
, FileUploadException(..)
, fileUploadExceptionReason
, BadPartException(..)
, PolicyViolationException(..)
) where
import Control.Applicative (Alternative ((<|>)), Applicative ((*>), (<*), pure))
import Control.Arrow (Arrow (first))
import Control.Exception.Lifted (Exception, SomeException (..), bracket, catch, fromException, mask, throwIO, toException)
import qualified Control.Exception.Lifted as E (try)
import Control.Monad (Functor (fmap), Monad ((>>=), return), MonadPlus (mzero), guard, liftM, sequence, void, when, (>=>))
import Data.Attoparsec.ByteString.Char8 (Parser, isEndOfLine, string, takeWhile)
import qualified Data.Attoparsec.ByteString.Char8 as Atto (try)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as S
import Data.ByteString.Internal (c2w)
import qualified Data.CaseInsensitive as CI (mk)
import Data.Int (Int, Int64)
import Data.List (concat, find, map, (++))
import qualified Data.Map as Map (insertWith', size)
import Data.Maybe (Maybe (..), fromMaybe, isJust, maybe)
import Data.Text (Text)
import qualified Data.Text as T (concat, pack, unpack)
import qualified Data.Text.Encoding as TE (decodeUtf8)
import Data.Typeable (Typeable, cast)
import Prelude (Bool (..), Double, Either (..), Eq (..), FilePath, IO, Ord (..), Show (..), String, const, either, flip, fst, id, max, not, otherwise, snd, ($), ($!), (.), (^), (||))
import Snap.Core (HasHeaders (headers), Headers, MonadSnap, Request (rqParams, rqPostParams), getHeader, getRequest, getTimeoutModifier, putRequest, runRequestBody)
import Snap.Internal.Parsing (crlf, fullyParse, pContentTypeWithParameters, pHeaders, pValueWithParameters)
import qualified Snap.Types.Headers as H (fromList)
import System.Directory (removeFile)
import System.FilePath ((</>))
import System.IO (BufferMode (NoBuffering), Handle, hClose, hSetBuffering)
import System.IO.Streams (InputStream, MatchInfo (..), TooManyBytesReadException, search)
import qualified System.IO.Streams as Streams
import System.IO.Streams.Attoparsec (parseFromStream)
import System.PosixCompat.Temp (mkstemp)
handleFileUploads ::
(MonadSnap m) =>
FilePath
-> UploadPolicy
-> (PartInfo -> PartUploadPolicy)
-> (PartInfo -> Either PolicyViolationException FilePath -> IO a)
-> m [a]
handleFileUploads tmpdir uploadPolicy partPolicy partHandler =
handleMultipart uploadPolicy go
where
go partInfo stream = maybe disallowed takeIt mbFs
where
ctText = partContentType partInfo
fnText = fromMaybe "" $ partFileName partInfo
ct = TE.decodeUtf8 ctText
fn = TE.decodeUtf8 fnText
(PartUploadPolicy mbFs) = partPolicy partInfo
takeIt maxSize = do
str' <- Streams.throwIfProducesMoreThan maxSize stream
fileReader tmpdir partHandler partInfo str' `catch` tooMany maxSize
tooMany maxSize (_ :: TooManyBytesReadException) = do
partHandler partInfo
(Left $
PolicyViolationException $
T.concat [ "File \""
, fn
, "\" exceeded maximum allowable size "
, T.pack $ show maxSize ])
disallowed =
partHandler partInfo
(Left $
PolicyViolationException $
T.concat [ "Policy disallowed upload of file \""
, fn
, "\" with content-type \""
, ct
, "\"" ] )
type PartProcessor a = PartInfo -> InputStream ByteString -> IO a
handleMultipart ::
(MonadSnap m) =>
UploadPolicy
-> PartProcessor a
-> m [a]
handleMultipart uploadPolicy origPartHandler = do
hdrs <- liftM headers getRequest
let (ct, mbBoundary) = getContentType hdrs
tickleTimeout <- liftM (. max) getTimeoutModifier
let bumpTimeout = tickleTimeout $ uploadTimeout uploadPolicy
let partHandler = if doProcessFormInputs uploadPolicy
then captureVariableOrReadFile
(getMaximumFormInputSize uploadPolicy)
origPartHandler
else \x y -> liftM File $ origPartHandler x y
guard (ct == "multipart/form-data")
boundary <- maybe (throwIO $ BadPartException
"got multipart/form-data without boundary")
return
mbBoundary
captures <- runRequestBody (proc bumpTimeout boundary partHandler)
procCaptures captures id
where
uploadRate = minimumUploadRate uploadPolicy
uploadSecs = minimumUploadSeconds uploadPolicy
maxFormVars = maximumNumberOfFormInputs uploadPolicy
proc bumpTimeout boundary partHandler =
Streams.throwIfTooSlow bumpTimeout uploadRate uploadSecs >=>
internalHandleMultipart boundary partHandler
procCaptures [] dl = return $! dl []
procCaptures ((File x):xs) dl = procCaptures xs (dl . (x:))
procCaptures ((Capture k v):xs) dl = do
rq <- getRequest
when (Map.size (rqPostParams rq) >= maxFormVars)
$ throwIO . PolicyViolationException
$ T.concat [ "number of form inputs exceeded maximum of "
, T.pack $ show maxFormVars ]
putRequest $ modifyParams (ins k v) rq
procCaptures xs dl
ins k v = Map.insertWith' (flip (++)) k [v]
modifyParams f r = r { rqPostParams = f $ rqPostParams r
, rqParams = f $ rqParams r
}
data PartDisposition =
DispositionAttachment
| DispositionFile
| DispositionFormData
| DispositionOther ByteString
deriving (Eq, Show)
data PartInfo =
PartInfo
{ partFieldName :: !ByteString
, partFileName :: !(Maybe ByteString)
, partContentType :: !ByteString
, partDisposition :: !PartDisposition
, partHeaders :: !(Headers)
}
deriving (Show)
toPartDisposition :: ByteString -> PartDisposition
toPartDisposition s | s == "attachment" = DispositionAttachment
| s == "file" = DispositionFile
| s == "form-data" = DispositionFormData
| otherwise = DispositionOther s
data FileUploadException = forall e . (ExceptionWithReason e, Show e) =>
WrappedFileUploadException e
deriving (Typeable)
class Exception e => ExceptionWithReason e where
exceptionReason :: e -> Text
instance Show FileUploadException where
show (WrappedFileUploadException e) = show e
instance Exception FileUploadException
fileUploadExceptionReason :: FileUploadException -> Text
fileUploadExceptionReason (WrappedFileUploadException e) = exceptionReason e
uploadExceptionToException :: ExceptionWithReason e => e -> SomeException
uploadExceptionToException = toException . WrappedFileUploadException
uploadExceptionFromException :: ExceptionWithReason e => SomeException -> Maybe e
uploadExceptionFromException x = do
WrappedFileUploadException e <- fromException x
cast e
data BadPartException = BadPartException {
badPartExceptionReason :: Text
}
deriving (Typeable)
instance Exception BadPartException where
toException = uploadExceptionToException
fromException = uploadExceptionFromException
instance ExceptionWithReason BadPartException where
exceptionReason (BadPartException e) = T.concat ["Bad part: ", e]
instance Show BadPartException where
show = T.unpack . exceptionReason
data PolicyViolationException = PolicyViolationException {
policyViolationExceptionReason :: Text
} deriving (Typeable)
instance Exception PolicyViolationException where
toException e@(PolicyViolationException _) =
uploadExceptionToException e
fromException = uploadExceptionFromException
instance ExceptionWithReason PolicyViolationException where
exceptionReason (PolicyViolationException r) =
T.concat ["File upload policy violation: ", r]
instance Show PolicyViolationException where
show (PolicyViolationException s) = "File upload policy violation: "
++ T.unpack s
data UploadPolicy = UploadPolicy {
processFormInputs :: Bool
, maximumFormInputSize :: Int64
, maximumNumberOfFormInputs :: Int
, minimumUploadRate :: Double
, minimumUploadSeconds :: Int
, uploadTimeout :: Int
}
defaultUploadPolicy :: UploadPolicy
defaultUploadPolicy = UploadPolicy True maxSize maxNum minRate minSeconds tout
where
maxSize = 2^(17::Int)
maxNum = 10
minRate = 1000
minSeconds = 10
tout = 20
doProcessFormInputs :: UploadPolicy -> Bool
doProcessFormInputs = processFormInputs
setProcessFormInputs :: Bool -> UploadPolicy -> UploadPolicy
setProcessFormInputs b u = u { processFormInputs = b }
getMaximumFormInputSize :: UploadPolicy -> Int64
getMaximumFormInputSize = maximumFormInputSize
setMaximumFormInputSize :: Int64 -> UploadPolicy -> UploadPolicy
setMaximumFormInputSize s u = u { maximumFormInputSize = s }
getMaximumNumberOfFormInputs :: UploadPolicy -> Int
getMaximumNumberOfFormInputs = maximumNumberOfFormInputs
setMaximumNumberOfFormInputs :: Int -> UploadPolicy -> UploadPolicy
setMaximumNumberOfFormInputs s u = u { maximumNumberOfFormInputs = s }
getMinimumUploadRate :: UploadPolicy -> Double
getMinimumUploadRate = minimumUploadRate
setMinimumUploadRate :: Double -> UploadPolicy -> UploadPolicy
setMinimumUploadRate s u = u { minimumUploadRate = s }
getMinimumUploadSeconds :: UploadPolicy -> Int
getMinimumUploadSeconds = minimumUploadSeconds
setMinimumUploadSeconds :: Int -> UploadPolicy -> UploadPolicy
setMinimumUploadSeconds s u = u { minimumUploadSeconds = s }
getUploadTimeout :: UploadPolicy -> Int
getUploadTimeout = uploadTimeout
setUploadTimeout :: Int -> UploadPolicy -> UploadPolicy
setUploadTimeout s u = u { uploadTimeout = s }
data PartUploadPolicy = PartUploadPolicy (Maybe Int64)
disallow :: PartUploadPolicy
disallow = PartUploadPolicy Nothing
allowWithMaximumSize :: Int64 -> PartUploadPolicy
allowWithMaximumSize = PartUploadPolicy . Just
captureVariableOrReadFile ::
Int64
-> PartProcessor a
-> PartProcessor (Capture a)
captureVariableOrReadFile maxSize fileHandler partInfo stream =
if isFile
then liftM File $ fileHandler partInfo stream
else variable `catch` handler
where
isFile = isJust (partFileName partInfo) ||
partDisposition partInfo == DispositionFile
variable = do
x <- liftM S.concat $
Streams.throwIfProducesMoreThan maxSize stream >>= Streams.toList
return $! Capture fieldName x
fieldName = partFieldName partInfo
handler (_ :: TooManyBytesReadException) =
throwIO $ PolicyViolationException $
T.concat [ "form input '"
, TE.decodeUtf8 fieldName
, "' exceeded maximum permissible size ("
, T.pack $ show maxSize
, " bytes)" ]
data Capture a = Capture ByteString ByteString
| File a
fileReader :: FilePath
-> (PartInfo -> Either PolicyViolationException FilePath -> IO a)
-> PartProcessor a
fileReader tmpdir partProc partInfo input =
withTempFile tmpdir "snap-upload-" $ \(fn, h) -> do
hSetBuffering h NoBuffering
output <- Streams.handleToOutputStream h
Streams.connect input output
hClose h
partProc partInfo $ Right fn
internalHandleMultipart ::
ByteString
-> (PartInfo -> InputStream ByteString -> IO a)
-> InputStream ByteString
-> IO [a]
internalHandleMultipart !boundary clientHandler !stream = go
where
go = do
_ <- parseFromStream (parseFirstBoundary boundary) stream
bmstream <- search (fullBoundary boundary) stream
liftM concat $ processParts goPart bmstream
pBoundary !b = Atto.try $ do
_ <- string "--"
string b
fullBoundary !b = S.concat ["\r\n", "--", b]
pLine = takeWhile (not . isEndOfLine . c2w) <* eol
parseFirstBoundary !b = pBoundary b <|> (pLine *> parseFirstBoundary b)
takeHeaders !str = hdrs `catch` handler
where
hdrs = do
str' <- Streams.throwIfProducesMoreThan mAX_HDRS_SIZE str
liftM toHeaders $ parseFromStream pHeadersWithSeparator str'
handler (_ :: TooManyBytesReadException) =
throwIO $ BadPartException "headers exceeded maximum size"
goPart !str = do
hdrs <- takeHeaders str
let (contentType, mboundary) = getContentType hdrs
let (fieldName, fileName, disposition) = getFieldHeaderInfo hdrs
if contentType == "multipart/mixed"
then maybe (throwIO $ BadPartException $
"got multipart/mixed without boundary")
(processMixed fieldName str)
mboundary
else do
let info = PartInfo fieldName fileName contentType disposition hdrs
liftM (:[]) $ clientHandler info str
processMixed !fieldName !str !mixedBoundary = do
_ <- parseFromStream (parseFirstBoundary mixedBoundary) str
bm <- search (fullBoundary mixedBoundary) str
processParts (mixedStream fieldName) bm
mixedStream !fieldName !str = do
hdrs <- takeHeaders str
let (contentType, _) = getContentType hdrs
let (_, fileName, disposition) = getFieldHeaderInfo hdrs
let info = PartInfo fieldName fileName contentType disposition hdrs
clientHandler info str
getContentType :: Headers
-> (ByteString, Maybe ByteString)
getContentType hdrs = (contentType, boundary)
where
contentTypeValue = fromMaybe "text/plain" $
getHeader "content-type" hdrs
eCT = fullyParse contentTypeValue pContentTypeWithParameters
(!contentType, !params) = either (const ("text/plain", [])) id eCT
boundary = findParam "boundary" params
getFieldHeaderInfo :: Headers -> (ByteString, Maybe ByteString, PartDisposition)
getFieldHeaderInfo hdrs = (fieldName, fileName, disposition)
where
contentDispositionValue = fromMaybe "unknown" $
getHeader "content-disposition" hdrs
eDisposition = fullyParse contentDispositionValue pValueWithParameters
(!dispositionType, dispositionParameters) =
either (const ("unknown", [])) id eDisposition
disposition = toPartDisposition dispositionType
fieldName = fromMaybe "" $ findParam "name" dispositionParameters
fileName = findParam "filename" dispositionParameters
findParam :: (Eq a) => a -> [(a, b)] -> Maybe b
findParam p = fmap snd . find ((== p) . fst)
partStream :: InputStream MatchInfo -> IO (InputStream ByteString)
partStream st = Streams.makeInputStream go
where
go = do
s <- Streams.read st
return $! s >>= f
f (NoMatch s) = return s
f _ = mzero
processParts :: (InputStream ByteString -> IO a)
-> InputStream MatchInfo
-> IO [a]
processParts partFunc stream = go id
where
part pStream = do
isLast <- parseFromStream pBoundaryEnd pStream
if isLast
then return Nothing
else do
!x <- partFunc pStream
Streams.skipToEof pStream
return $! Just x
go !soFar = partStream stream >>=
part >>=
maybe (return $ soFar []) (\x -> go (soFar . (x:)))
pBoundaryEnd = (eol *> pure False) <|> (string "--" *> pure True)
eol :: Parser ByteString
eol = (string "\n") <|> (string "\r\n")
pHeadersWithSeparator :: Parser [(ByteString,ByteString)]
pHeadersWithSeparator = pHeaders <* crlf
toHeaders :: [(ByteString,ByteString)] -> Headers
toHeaders kvps = H.fromList kvps'
where
kvps' = map (first CI.mk) kvps
mAX_HDRS_SIZE :: Int64
mAX_HDRS_SIZE = 32768
withTempFile :: FilePath
-> String
-> ((FilePath, Handle) -> IO a)
-> IO a
withTempFile tmpl temp handler =
mask $ \restore -> bracket make cleanup (restore . handler)
where
make = mkstemp $ tmpl </> (temp ++ "XXXXXXX")
cleanup (fp,h) = sequence $ map gobble [hClose h, removeFile fp]
t :: IO z -> IO (Either SomeException z)
t = E.try
gobble = void . t