module Network.Wai.Parse
( parseHttpAccept
, parseRequestBody
, RequestBodyType (..)
, getRequestBodyType
, sinkRequestBody
, conduitRequestBody
, BackEnd
, lbsBackEnd
, tempFileBackEnd
, tempFileBackEndOpts
, Param
, File
, FileInfo (..)
, parseContentType
#if TEST
, Bound (..)
, findBound
, sinkTillBound
, killCR
, killCRLF
, takeLine
#endif
) where
import qualified Data.ByteString.Search as Search
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Char8 as S8
import Data.Word (Word8)
import Data.Maybe (fromMaybe)
import Data.List (sortBy)
import Data.Function (on)
import System.Directory (removeFile, getTemporaryDirectory)
import System.IO (hClose, openBinaryTempFile)
import Network.Wai
import Data.Conduit
import Data.Conduit.Internal (sinkToPipe)
import qualified Data.Conduit.List as CL
import qualified Data.Conduit.Binary as CB
import Control.Monad.IO.Class (liftIO)
import qualified Network.HTTP.Types as H
import Data.Either (partitionEithers)
import Control.Monad (when, unless)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Resource (allocate, release, register)
#if MIN_VERSION_conduit(1, 0, 0)
import Data.Conduit.Internal (Pipe (NeedInput, HaveOutput), (>+>), withUpstream, Sink (..), injectLeftovers, ConduitM (..))
import Data.Void (Void)
#endif
breakDiscard :: Word8 -> S.ByteString -> (S.ByteString, S.ByteString)
breakDiscard w s =
let (x, y) = S.break (== w) s
in (x, S.drop 1 y)
parseHttpAccept :: S.ByteString -> [S.ByteString]
parseHttpAccept = map fst
. sortBy (rcompare `on` snd)
. map (addSpecificity . grabQ)
. S.split 44
where
rcompare :: (Double,Int) -> (Double,Int) -> Ordering
rcompare = flip compare
addSpecificity (s, q) =
let semicolons = S.count 0x3B s
stars = S.count 0x2A s
in (s, (q, semicolons stars))
grabQ s =
let (s', q) = S.breakSubstring ";q=" (S.filter (/=0x20) s)
q' = S.takeWhile (/=0x3B) (S.drop 3 q)
in (s', readQ q')
readQ s = case reads $ S8.unpack s of
(x, _):_ -> x
_ -> 1.0
lbsBackEnd :: Monad m => ignored1 -> ignored2 -> Sink S.ByteString m L.ByteString
lbsBackEnd _ _ = fmap L.fromChunks CL.consume
tempFileBackEnd :: MonadResource m => ignored1 -> ignored2 -> Sink S.ByteString m FilePath
tempFileBackEnd = tempFileBackEndOpts getTemporaryDirectory "webenc.buf"
tempFileBackEndOpts :: MonadResource m
=> IO FilePath
-> String
-> ignored1
-> ignored2
-> Sink S.ByteString m FilePath
tempFileBackEndOpts getTmpDir pattern _ _ = do
(key, (fp, h)) <- lift $ allocate (do
tempDir <- getTmpDir
openBinaryTempFile tempDir pattern) (\(_, h) -> hClose h)
_ <- lift $ register $ removeFile fp
CB.sinkHandle h
lift $ release key
return fp
data FileInfo c = FileInfo
{ fileName :: S.ByteString
, fileContentType :: S.ByteString
, fileContent :: c
}
deriving (Eq, Show)
type Param = (S.ByteString, S.ByteString)
type File y = (S.ByteString, FileInfo y)
type BackEnd a = S.ByteString
-> FileInfo ()
-> Sink S.ByteString (ResourceT IO) a
data RequestBodyType = UrlEncoded | Multipart S.ByteString
getRequestBodyType :: Request -> Maybe RequestBodyType
getRequestBodyType req = do
ctype' <- lookup "Content-Type" $ requestHeaders req
let (ctype, attrs) = parseContentType ctype'
case ctype of
"application/x-www-form-urlencoded" -> return UrlEncoded
"multipart/form-data" | Just bound <- lookup "boundary" attrs -> return $ Multipart bound
_ -> Nothing
parseContentType :: S.ByteString -> (S.ByteString, [(S.ByteString, S.ByteString)])
parseContentType a = do
let (ctype, b) = S.break (== semicolon) a
attrs = goAttrs id $ S.drop 1 b
in (ctype, attrs)
where
semicolon = 59
equals = 61
space = 32
goAttrs front bs
| S.null bs = front []
| otherwise =
let (x, rest) = S.break (== semicolon) bs
in goAttrs (front . (goAttr x:)) $ S.drop 1 rest
goAttr bs =
let (k, v') = S.break (== equals) bs
v = S.drop 1 v'
in (strip k, strip v)
strip = S.dropWhile (== space) . fst . S.breakEnd (/= space)
parseRequestBody :: BackEnd y
-> Request
-> ResourceT IO ([Param], [File y])
parseRequestBody s r =
case getRequestBodyType r of
Nothing -> return ([], [])
Just rbt -> fmap partitionEithers $ requestBody r $$ conduitRequestBody s rbt =$ CL.consume
sinkRequestBody :: BackEnd y
-> RequestBodyType
-> Sink S.ByteString (ResourceT IO) ([Param], [File y])
sinkRequestBody s r = fmap partitionEithers $ conduitRequestBody s r =$ CL.consume
conduitRequestBody :: BackEnd y
-> RequestBodyType
-> Conduit S.ByteString (ResourceT IO) (Either Param (File y))
conduitRequestBody _ UrlEncoded = do
bs <- CL.consume
mapM_ yield $ map Left $ H.parseSimpleQuery $ S.concat bs
conduitRequestBody backend (Multipart bound) =
parsePieces backend $ S8.pack "--" `S.append` bound
#if MIN_VERSION_conduit(1, 0, 0)
takeLine :: Monad m => Consumer S.ByteString m (Maybe S.ByteString)
#else
takeLine :: Monad m => Pipe S.ByteString S.ByteString o u m (Maybe S.ByteString)
#endif
takeLine =
go id
where
go front = await >>= maybe (close front) (push front)
close front = leftover (front S.empty) >> return Nothing
push front bs = do
let (x, y) = S.break (== 10) $ front bs
in if S.null y
then go $ S.append x
else do
when (S.length y > 1) $ leftover $ S.drop 1 y
return $ Just $ killCR x
#if MIN_VERSION_conduit(1, 0, 0)
takeLines :: Consumer S.ByteString (ResourceT IO) [S.ByteString]
#else
takeLines :: Pipe S.ByteString S.ByteString o u (ResourceT IO) [S.ByteString]
#endif
takeLines = do
res <- takeLine
case res of
Nothing -> return []
Just l
| S.null l -> return []
| otherwise -> do
ls <- takeLines
return $ l : ls
parsePieces :: BackEnd y
-> S.ByteString
#if MIN_VERSION_conduit(1, 0, 0)
-> ConduitM S.ByteString (Either Param (File y)) (ResourceT IO) ()
#else
-> Pipe S.ByteString S.ByteString (Either Param (File y)) u (ResourceT IO) ()
#endif
parsePieces sink bound =
loop
where
loop = do
_boundLine <- takeLine
res' <- takeLines
unless (null res') $ do
let ls' = map parsePair res'
let x = do
cd <- lookup contDisp ls'
let ct = lookup contType ls'
let attrs = parseAttrs cd
name <- lookup "name" attrs
return (ct, name, lookup "filename" attrs)
case x of
Just (mct, name, Just filename) -> do
let ct = fromMaybe "application/octet-stream" mct
fi0 = FileInfo filename ct ()
(wasFound, y) <- sinkTillBound' bound name fi0 sink
yield $ Right (name, fi0 { fileContent = y })
when wasFound loop
Just (_ct, name, Nothing) -> do
let seed = id
let iter front bs = return $ front . (:) bs
(wasFound, front) <- sinkTillBound bound iter seed
let bs = S.concat $ front []
let x' = (name, bs)
yield $ Left x'
when wasFound loop
_ -> do
let seed = ()
iter () _ = return ()
(wasFound, ()) <- sinkTillBound bound iter seed
when wasFound loop
where
contDisp = S8.pack "Content-Disposition"
contType = S8.pack "Content-Type"
parsePair s =
let (x, y) = breakDiscard 58 s
in (x, S.dropWhile (== 32) y)
data Bound = FoundBound S.ByteString S.ByteString
| NoBound
| PartialBound
deriving (Eq, Show)
findBound :: S.ByteString -> S.ByteString -> Bound
findBound b bs = handleBreak $ Search.breakOn b bs
where
handleBreak (h, t)
| S.null t = go [lowBound..S.length bs 1]
| otherwise = FoundBound h $ S.drop (S.length b) t
lowBound = max 0 $ S.length bs S.length b
go [] = NoBound
go (i:is)
| mismatch [0..S.length b 1] [i..S.length bs 1] = go is
| otherwise =
let endI = i + S.length b
in if endI > S.length bs
then PartialBound
else FoundBound (S.take i bs) (S.drop endI bs)
mismatch [] _ = False
mismatch _ [] = False
mismatch (x:xs) (y:ys)
| S.index b x == S.index bs y = mismatch xs ys
| otherwise = True
sinkTillBound' :: S.ByteString
-> S.ByteString
-> FileInfo ()
-> BackEnd y
#if MIN_VERSION_conduit(1, 0, 0)
-> ConduitM S.ByteString o (ResourceT IO) (Bool, y)
#else
-> Pipe S.ByteString S.ByteString o u (ResourceT IO) (Bool, y)
#endif
sinkTillBound' bound name fi sink =
#if MIN_VERSION_conduit(1, 0, 0)
ConduitM $ anyOutput $
#endif
conduitTillBound bound >+> withUpstream (fix $ sink name fi)
where
#if MIN_VERSION_conduit(1, 0, 0)
fix :: Sink S8.ByteString (ResourceT IO) y -> Pipe Void S8.ByteString Void Bool (ResourceT IO) y
fix (ConduitM p) = ignoreTerm >+> injectLeftovers p
ignoreTerm = await' >>= maybe (return ()) (\x -> yield' x >> ignoreTerm)
await' = NeedInput (return . Just) (const $ return Nothing)
yield' = HaveOutput (return ()) (return ())
anyOutput p = p >+> dropInput
dropInput = NeedInput (const dropInput) return
#else
fix = sinkToPipe
#endif
conduitTillBound :: Monad m
=> S.ByteString
#if MIN_VERSION_conduit(1, 0, 0)
-> Pipe S.ByteString S.ByteString S.ByteString () m Bool
#else
-> Pipe S.ByteString S.ByteString S.ByteString u m Bool
#endif
conduitTillBound bound =
#if MIN_VERSION_conduit(1, 0, 0)
unConduitM $
#endif
go id
where
go front = await >>= maybe (close front) (push front)
close front = do
let bs = front S.empty
unless (S.null bs) $ yield bs
return False
push front bs' = do
let bs = front bs'
case findBound bound bs of
FoundBound before after -> do
let before' = killCRLF before
yield before'
leftover after
return True
NoBound -> do
let (toEmit, front') =
if not (S8.null bs) && S8.last bs `elem` "\r\n"
then let (x, y) = S.splitAt (S.length bs 2) bs
in (x, S.append y)
else (bs, id)
yield toEmit
go front'
PartialBound -> go $ S.append bs
sinkTillBound :: S.ByteString
-> (x -> S.ByteString -> IO x)
-> x
#if MIN_VERSION_conduit(1, 0, 0)
-> Consumer S.ByteString (ResourceT IO) (Bool, x)
#else
-> Pipe S.ByteString S.ByteString o u (ResourceT IO) (Bool, x)
#endif
sinkTillBound bound iter seed0 =
#if MIN_VERSION_conduit(1, 0, 0)
ConduitM $
#endif
(conduitTillBound bound >+> (withUpstream $ ij $ CL.foldM iter' seed0))
where
iter' a b = liftIO $ iter a b
#if MIN_VERSION_conduit(1, 0, 0)
ij (ConduitM p) = ignoreTerm >+> injectLeftovers p
ignoreTerm = await' >>= maybe (return ()) (\x -> yield' x >> ignoreTerm)
await' = NeedInput (return . Just) (const $ return Nothing)
yield' = HaveOutput (return ()) (return ())
#else
ij = id
#endif
parseAttrs :: S.ByteString -> [(S.ByteString, S.ByteString)]
parseAttrs = map go . S.split 59
where
tw = S.dropWhile (== 32)
dq s = if S.length s > 2 && S.head s == 34 && S.last s == 34
then S.tail $ S.init s
else s
go s =
let (x, y) = breakDiscard 61 s
in (tw x, dq $ tw y)
killCRLF :: S.ByteString -> S.ByteString
killCRLF bs
| S.null bs || S.last bs /= 10 = bs
| otherwise = killCR $ S.init bs
killCR :: S.ByteString -> S.ByteString
killCR bs
| S.null bs || S.last bs /= 13 = bs
| otherwise = S.init bs