module Network.MiniHTTP.HTTPConnection
(
Source
, SourceResult(..)
, bsSource
, hSource
, nullSource
, sourceToLBS
, sourceToBS
, connSource
, connChunkedSource
, connEOFSource
, sourceDrain
, streamSource
, streamSourceChunked
, readIG
, sourceIG
, maybeRead
, sslToBaseConnection
) where
import Control.Concurrent.STM
import qualified Control.Exception as Exception
import Control.Monad (liftM)
import qualified Data.ByteString as B
import Data.ByteString.Char8 ()
import Data.ByteString.Internal (c2w, w2c)
import qualified Data.ByteString.Lazy.Internal as BL
import qualified Data.Binary.Strict.IncrementalGet as IG
import Data.IORef
import Data.Int (Int64)
import System.IO
import Text.Printf (printf)
import System.IO.Unsafe (unsafeInterleaveIO)
import qualified Network.Connection as C
import Network.Socket as Socket
import Network.MiniHTTP.Marshal (parseChunkHeader)
import qualified OpenSSL.Session as SSL
type Source = IO SourceResult
data SourceResult = SourceError
| SourceEOF
| SourceData B.ByteString
deriving (Show)
bsSource :: B.ByteString -> IO Source
bsSource bs = do
ref <- newIORef $ SourceData bs
return $ do
v <- readIORef ref
writeIORef ref SourceEOF
return v
hSource :: (Int64, Int64)
-> Handle
-> IO Source
hSource (from, to) handle = do
bytesSoFar <- newIORef (from :: Int64)
hSeek handle AbsoluteSeek (fromIntegral from)
return $ do
Exception.catch
(do done <- readIORef bytesSoFar
bytes <- B.hGet handle $ min (128 * 1024) (fromIntegral $ (to + 1) done)
if B.length bytes == 0
then do
if to + 1 == done
then return SourceEOF
else return SourceError
else do modifyIORef bytesSoFar ((+) (fromIntegral $ B.length bytes))
return $ SourceData bytes)
(const $ return SourceError)
nullSource :: Source
nullSource = return SourceEOF
connEOFSource :: C.Connection -> IO Source
connEOFSource conn = do
return $
catch (liftM SourceData $ C.read conn 1024) (const $ return SourceEOF)
connSource :: Int64
-> B.ByteString
-> C.Connection
-> IO Source
connSource n bs conn =
if fromIntegral (B.length bs) == n
then bsSource bs
else do
ref <- newIORef (False, 0 :: Int64)
return $ do
(doneBS, n') <- readIORef ref
if not doneBS
then do writeIORef ref (True, fromIntegral $ B.length bs)
return $ SourceData bs
else if n' == n
then return SourceEOF
else do bytes <- C.read conn $ min (32 * 1024) $ fromIntegral (n n')
if B.length bytes == 0
then return SourceError
else do writeIORef ref (doneBS, n' + (fromIntegral $ B.length bytes))
return $ SourceData bytes
connChunkedSource :: C.Connection -> IO Source
connChunkedSource conn = do
ref <- newIORef (0 :: Int64)
let f = do
remainingInThisChunk <- readIORef ref
case remainingInThisChunk of
0 -> do
m <- readIG conn 16 256 parseChunkHeader
case m of
Nothing -> return SourceError
Just n ->
if n == 0
then C.reada conn 2 >> writeIORef ref (1) >> return SourceEOF
else writeIORef ref n >> f
(1) -> return SourceEOF
remainingInThisChunk -> do
bytes <- C.read conn $ fromIntegral $ min remainingInThisChunk $ 32*1024
if B.null bytes
then return SourceError
else do
let stillRemaining = remainingInThisChunk (fromIntegral $ B.length bytes)
if stillRemaining == 0
then do
C.reada conn 2
writeIORef ref 0
else do
writeIORef ref stillRemaining
return $ SourceData bytes
return f
sourceDrain :: Source -> IO ()
sourceDrain s = do
v <- s
case v of
SourceEOF -> return ()
SourceError -> return ()
SourceData _ -> sourceDrain s
sourceToLBS :: Source -> IO BL.ByteString
sourceToLBS s = do
bytes <- s
case bytes of
SourceEOF -> return $ BL.Empty
SourceError -> fail "Error in reading from client"
SourceData bs -> do
rest <- unsafeInterleaveIO $ sourceToLBS s
return $ BL.Chunk bs rest
sourceToBS :: Int -> Source -> IO (Maybe B.ByteString)
sourceToBS n source = f 0 where
f soFar = do
s <- source
case s of
SourceEOF -> return $ Just B.empty
SourceError -> return Nothing
SourceData bs -> do
if B.length bs + soFar >= n
then return $ Just $ B.take (n soFar) bs
else do
rest <- f (soFar + B.length bs)
return $ (rest >>= return . B.append bs)
streamSource :: Int -> C.Connection -> Source -> IO Bool
streamSource lowWater conn source = do
next <- source
case next of
SourceEOF -> return True
SourceError -> return False
SourceData bs -> do
atomically $ C.writeAtLowWater lowWater conn bs
streamSource lowWater conn source
streamSourceChunked :: Int -> C.Connection -> Source -> IO Bool
streamSourceChunked lowWater conn source = do
next <- source
case next of
SourceEOF -> do
atomically $ C.writeAtLowWater lowWater conn "0\r\n\r\n"
return True
SourceError -> return False
SourceData bs -> do
atomically $ C.writeAtLowWater lowWater conn $ B.pack $ map c2w $
printf "%d\r\n\r\n" $ B.length bs
atomically $ C.writeAtLowWater lowWater conn bs
atomically $ C.writeAtLowWater lowWater conn "\r\n"
streamSourceChunked lowWater conn source
readIG :: C.Connection
-> Int
-> Int
-> IG.Get a a
-> IO (Maybe a)
readIG conn blockSize maxBytes parser = do
let f sofar result
| sofar >= maxBytes = return Nothing
| otherwise = do
case result of
IG.Failed _ -> return Nothing
IG.Partial cont -> C.read conn blockSize >>= \bs -> f (sofar + B.length bs) $ cont bs
IG.Finished rest result -> do
atomically $ C.pushBack conn rest
return $ Just result
C.read conn blockSize >>= f 0 . IG.runGet parser
sourceIG :: Source
-> Int
-> IG.Get a a
-> IO (Maybe a)
sourceIG source maxBytes parser = do
let f sofar result
| sofar >= maxBytes = return Nothing
| otherwise =
case result of
IG.Failed _ -> return Nothing
IG.Partial cont -> do
s <- source
case s of
SourceError -> return Nothing
SourceEOF -> return Nothing
SourceData bytes -> f (B.length bytes + sofar) $ cont bytes
IG.Finished _ result -> do
return $ Just result
f 0 $ IG.runGet parser B.empty
maybeRead :: Read a => B.ByteString -> Maybe a
maybeRead s = case reads $ map w2c $ B.unpack s of
[(x, "")] -> Just x
_ -> Nothing
sslToBaseConnection :: SSL.SSL -> C.BaseConnection
sslToBaseConnection ssl = C.BaseConnection r w c where
r n = do
bytes <- SSL.read ssl n
return bytes
w bs = SSL.write ssl bs >> return (B.length bs)
c = SSL.shutdown ssl SSL.Unidirectional >> sClose (SSL.sslSocket ssl)