module Network.HTTP.Conduit.Response
( Response (..)
, getRedirectedRequest
, getResponse
, lbsResponse
) where
import Control.Arrow (first)
import Control.Monad (liftM)
import Control.Exception (throwIO)
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Lazy as L
import qualified Data.CaseInsensitive as CI
import Data.Default (def)
import Data.Conduit
import Data.Conduit.Internal (ResumableSource (..), Pipe (..))
import qualified Data.Conduit.Zlib as CZ
import qualified Data.Conduit.List as CL
import qualified Network.HTTP.Types as W
import Network.URI (parseURIReference)
import Network.HTTP.Conduit.Types (Response (..), CookieJar)
import Network.HTTP.Conduit.Manager
import Network.HTTP.Conduit.Request
import Network.HTTP.Conduit.Util
import Network.HTTP.Conduit.Chunk
import Network.HTTP.Conduit.Parser (sinkHeaders)
import Data.Void (Void, absurd)
import System.Timeout.Lifted (timeout)
#if MIN_VERSION_conduit(1, 0, 0)
import Data.Conduit.Internal (ConduitM (..))
#endif
getRedirectedRequest :: Request m -> W.ResponseHeaders -> CookieJar -> Int -> Maybe (Request m)
getRedirectedRequest req hs cookie_jar code
| 300 <= code && code < 400 = do
l' <- lookup "location" hs
req' <- setUriRelative req =<< parseURIReference (S8.unpack l')
return $
if code == 302 || code == 303
then req'
{ method = "GET"
, requestBody = RequestBodyBS ""
, cookieJar = cookie_jar'
}
else req' {cookieJar = cookie_jar'}
| otherwise = Nothing
where
cookie_jar' = fmap (const cookie_jar) $ cookieJar req
lbsResponse :: Monad m
=> Response (ResumableSource m S8.ByteString)
-> m (Response L.ByteString)
lbsResponse res = do
bss <- responseBody res $$+- CL.consume
return res
{ responseBody = L.fromChunks bss
}
checkHeaderLength :: MonadResource m
=> Int
-> Pipe S8.ByteString S8.ByteString Void u m r
-> Pipe S8.ByteString S8.ByteString Void u m r
checkHeaderLength len NeedInput{}
| len <= 0 = liftIO $ throwIO OverlongHeaders
checkHeaderLength len (NeedInput pushI closeI) = NeedInput
(\bs -> checkHeaderLength
(len S8.length bs)
(pushI bs)) closeI
checkHeaderLength len (PipeM msink) = PipeM (liftM (checkHeaderLength len) msink)
checkHeaderLength _ s@Done{} = s
checkHeaderLength _ (HaveOutput _ _ o) = absurd o
checkHeaderLength len (Leftover p i) = Leftover (checkHeaderLength (len + S.length i) p) i
getResponse :: (MonadResource m, MonadBaseControl IO m)
=> ConnRelease m
-> Maybe Int
-> Request m
-> Source m S8.ByteString
-> m (Response (ResumableSource m S8.ByteString))
getResponse connRelease timeout'' req@(Request {..}) src1 = do
let timeout' =
case timeout'' of
Nothing -> id
Just useconds -> \ma -> do
x <- timeout useconds ma
case x of
Nothing -> liftIO $ throwIO ResponseTimeout
Just y -> return y
(src2, ((vbs, sc, sm), hs)) <- timeout' $ src1 $$+
#if MIN_VERSION_conduit(1, 0, 0)
ConduitM (checkHeaderLength 4096 $ unConduitM sinkHeaders)
#else
(checkHeaderLength 4096 sinkHeaders)
#endif
let version = if vbs == "1.1" then W.http11 else W.http10
let s = W.Status sc sm
let hs' = map (first CI.mk) hs
let mcl = lookup "content-length" hs' >>= readDec . S8.unpack
let toPut = Just "close" /= lookup "connection" hs' && vbs /= "1.0"
let cleanup bodyConsumed = connRelease $ if toPut && bodyConsumed then Reuse else DontReuse
body <-
if hasNoBody method sc || mcl == Just 0
then do
cleanup True
(rsrc, ()) <- return () $$+ return ()
return rsrc
else do
let isChunked = ("transfer-encoding", "chunked") `elem` hs'
src3 =
if isChunked
then fmapResume ($= chunkedConduit rawBody) src2
else
case mcl of
Just len -> fmapResume ($= requireLength len) src2
Nothing -> src2
src4 =
if needsGunzip req hs'
then fmapResume ($= (if isChunked then ungzipChunked else CZ.ungzip)) src3
else src3
return $ addCleanup' cleanup src4
return $ Response s version hs' body def
where
ungzipChunked =
CL.concatMap (\x -> [Chunk x, Flush])
=$= CZ.decompressFlush (CZ.WindowBits 31)
=$= awaitForever unChunk
where
unChunk Flush = return ()
unChunk (Chunk x) = yield x
fmapResume f (ResumableSource src m) = ResumableSource (f src) m
addCleanup' f (ResumableSource src m) = ResumableSource (addCleanup f src) (m >> f False)
requireLength :: MonadIO m => Int -> Conduit S.ByteString m S.ByteString
requireLength total =
loop total
where
loop 0 = return ()
loop i =
await >>= maybe
(liftIO $ throwIO $ ResponseBodyTooShort
(fromIntegral total)
(fromIntegral $ total i))
go
where
go bs =
case compare i l of
EQ -> yield bs
LT -> do
let (x, y) = S.splitAt i bs
leftover y
yield x
GT -> yield bs >> loop (i l)
where
l = S.length bs