{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.Wai.Handler.Warp.Recv (
receive
, receiveBuf
, makeReceiveN
, makePlainReceiveN
, spell
) where
import qualified Control.Exception as E
import qualified Data.ByteString as BS
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.Conc (threadWaitRead)
import qualified GHC.IO.Exception as E
import Network.Socket (Socket)
import qualified System.IO.Error as E
#if MIN_VERSION_network(3,1,0)
import Network.Socket (withFdSocket)
#else
import Network.Socket (fdSocket)
#endif
import System.Posix.Types (Fd(..))
import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types
#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Wai.Handler.Warp.Windows
#endif
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> IO ByteString)
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> Recv)
makeReceiveN ByteString
bs0 Recv
recv RecvBuf
recvBuf = do
IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
(BufSize -> Recv) -> IO (BufSize -> Recv)
forall (m :: * -> *) a. Monad m => a -> m a
return ((BufSize -> Recv) -> IO (BufSize -> Recv))
-> (BufSize -> Recv) -> IO (BufSize -> Recv)
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> IO ByteString)
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> Recv)
makePlainReceiveN Socket
s ByteString
bs0 = do
IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
IORef ByteString
pool <- IO (IORef ByteString)
newBufferPool
(BufSize -> Recv) -> IO (BufSize -> Recv)
forall (m :: * -> *) a. Monad m => a -> m a
return ((BufSize -> Recv) -> IO (BufSize -> Recv))
-> (BufSize -> Recv) -> IO (BufSize -> Recv)
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref (Socket -> IORef ByteString -> Recv
receive Socket
s IORef ByteString
pool) (Socket -> RecvBuf
receiveBuf Socket
s)
receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> IO ByteString
receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf BufSize
size = (SomeException -> Recv) -> Recv -> Recv
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle SomeException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ do
ByteString
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
(ByteString
bs, ByteString
leftover) <- ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
spell ByteString
cached BufSize
size Recv
recv RecvBuf
recvBuf
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
where
handler :: E.SomeException -> IO ByteString
handler :: SomeException -> Recv
handler SomeException
_ = ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
spell :: ByteString -> BufSize -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString)
spell :: ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
spell ByteString
init0 BufSize
siz0 Recv
recv RecvBuf
recvBuf
| BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz0 ByteString
init0
| BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
4096 = [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString
init0] (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
| Bool
otherwise = do
bs :: ByteString
bs@(PS ForeignPtr Word8
fptr BufSize
_ BufSize
_) <- BufSize -> Recv
mallocBS BufSize
siz0
ForeignPtr Word8
-> (Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString))
-> (Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
Ptr Word8
ptr' <- Ptr Word8 -> ByteString -> IO (Ptr Word8)
copy Ptr Word8
ptr ByteString
init0
Bool
full <- RecvBuf
recvBuf Ptr Word8
ptr' (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
if Bool
full then
(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, ByteString
"")
else
(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
where
len0 :: BufSize
len0 = ByteString -> BufSize
BS.length ByteString
init0
loop :: [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString]
bss BufSize
siz = do
ByteString
bs <- Recv
recv
let len :: BufSize
len = ByteString -> BufSize
BS.length ByteString
bs
if BufSize
len BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
else if BufSize
len BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
>= BufSize
siz then do
let (ByteString
consume, ByteString
leftover) = BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz ByteString
bs
ret :: ByteString
ret = [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
consume ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss)
(ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
else do
let bss' :: [ByteString]
bss' = ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss
siz' :: BufSize
siz' = BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len
[ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString]
bss' BufSize
siz'
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> IORef ByteString -> Recv
receive Socket
sock IORef ByteString
pool = (IOException -> Recv) -> Recv -> Recv
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle IOException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> ((Ptr Word8, BufSize) -> IO BufSize) -> Recv
withBufferPool IORef ByteString
pool (((Ptr Word8, BufSize) -> IO BufSize) -> Recv)
-> ((Ptr Word8, BufSize) -> IO BufSize) -> Recv
forall a b. (a -> b) -> a -> b
$ \ (Ptr Word8
ptr, BufSize
size) -> do
#if MIN_VERSION_network(3,1,0)
Socket -> (CInt -> IO BufSize) -> IO BufSize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO BufSize) -> IO BufSize)
-> (CInt -> IO BufSize) -> IO BufSize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
let size' :: CSize
size' = BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
size
CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
fd Ptr Word8
ptr CSize
size'
where
handler :: E.IOException -> IO ByteString
handler :: IOException -> Recv
handler IOException
e
| IOException -> IOErrorType
E.ioeGetErrorType IOException
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
E.InvalidArgument = ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
| Bool
otherwise = IOException -> Recv
forall e a. Exception e => e -> IO a
E.throwIO IOException
e
receiveBuf :: Socket -> RecvBuf
receiveBuf :: Socket -> RecvBuf
receiveBuf Socket
sock Ptr Word8
buf0 BufSize
siz0 = do
#if MIN_VERSION_network(3,1,0)
Socket -> (CInt -> IO Bool) -> IO Bool
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO Bool) -> IO Bool) -> (CInt -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
CInt -> RecvBuf
loop CInt
fd Ptr Word8
buf0 BufSize
siz0
where
loop :: CInt -> RecvBuf
loop CInt
_ Ptr Word8
_ BufSize
0 = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
loop CInt
fd Ptr Word8
buf BufSize
siz = do
BufSize
n <- CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
fd Ptr Word8
buf (BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
siz)
if BufSize
n BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
else
CInt -> RecvBuf
loop CInt
fd (Ptr Word8
buf Ptr Word8 -> BufSize -> Ptr Word8
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
n) (BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
n)
receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
sock Ptr Word8
ptr CSize
size = do
#ifdef mingw32_HOST_OS
bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "recv" (FD sock 1) (castPtr ptr) 0 size
#else
CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) CSize
size CInt
0
#endif
if CInt
bytes CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
Errno
errno <- IO Errno
getErrno
if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
sock Ptr Word8
ptr CSize
size
else
String -> IO CInt
forall a. String -> IO a
throwErrno String
"receiveloop"
else
CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes
#ifndef mingw32_HOST_OS
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif