{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.Wai.Handler.Warp.Recv (
receive
, receiveBuf
, makeReceiveN
, makePlainReceiveN
, spell
) where
import qualified Control.Exception as E
import qualified Data.ByteString as BS
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.Conc (threadWaitRead)
import qualified GHC.IO.Exception as E
import Network.Socket (Socket)
import qualified System.IO.Error as E
#if MIN_VERSION_network(3,1,0)
import Network.Socket (withFdSocket)
#else
import Network.Socket (fdSocket)
#endif
import System.Posix.Types (Fd(..))
import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types
#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Wai.Handler.Warp.Windows
#endif
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> IO ByteString)
makeReceiveN bs0 recv recvBuf = do
ref <- newIORef bs0
return $ receiveN ref recv recvBuf
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> IO ByteString)
makePlainReceiveN s bs0 = do
ref <- newIORef bs0
pool <- newBufferPool
return $ receiveN ref (receive s pool) (receiveBuf s)
receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> IO ByteString
receiveN ref recv recvBuf size = E.handle handler $ do
cached <- readIORef ref
(bs, leftover) <- spell cached size recv recvBuf
writeIORef ref leftover
return bs
where
handler :: E.SomeException -> IO ByteString
handler _ = return ""
spell :: ByteString -> BufSize -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString)
spell init0 siz0 recv recvBuf
| siz0 <= len0 = return $ BS.splitAt siz0 init0
| siz0 <= 4096 = loop [init0] (siz0 - len0)
| otherwise = do
bs@(PS fptr _ _) <- mallocBS siz0
withForeignPtr fptr $ \ptr -> do
ptr' <- copy ptr init0
full <- recvBuf ptr' (siz0 - len0)
if full then
return (bs, "")
else
return ("", "")
where
len0 = BS.length init0
loop bss siz = do
bs <- recv
let len = BS.length bs
if len == 0 then
return ("", "")
else if len >= siz then do
let (consume, leftover) = BS.splitAt siz bs
ret = BS.concat $ reverse (consume : bss)
return (ret, leftover)
else do
let bss' = bs : bss
siz' = siz - len
loop bss' siz'
receive :: Socket -> BufferPool -> Recv
receive sock pool = E.handle handler $ withBufferPool pool $ \ (ptr, size) -> do
#if MIN_VERSION_network(3,1,0)
withFdSocket sock $ \fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
let size' = fromIntegral size
fromIntegral <$> receiveloop fd ptr size'
where
handler :: E.IOException -> IO ByteString
handler e
| E.ioeGetErrorType e == E.InvalidArgument = return ""
| otherwise = E.throwIO e
receiveBuf :: Socket -> RecvBuf
receiveBuf sock buf0 siz0 = do
#if MIN_VERSION_network(3,1,0)
withFdSocket sock $ \fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
loop fd buf0 siz0
where
loop _ _ 0 = return True
loop fd buf siz = do
n <- fromIntegral <$> receiveloop fd buf (fromIntegral siz)
if n == 0 then
return False
else
loop fd (buf `plusPtr` n) (siz - n)
receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop sock ptr size = do
#ifdef mingw32_HOST_OS
bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "recv" (FD sock 1) (castPtr ptr) 0 size
#else
bytes <- c_recv sock (castPtr ptr) size 0
#endif
if bytes == -1 then do
errno <- getErrno
if errno == eAGAIN then do
threadWaitRead (Fd sock)
receiveloop sock ptr size
else
throwErrno "receiveloop"
else
return bytes
#ifndef mingw32_HOST_OS
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif