{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Recv (
    receive
  , receiveBuf
  , makeReceiveN
  , makePlainReceiveN
  , spell
  ) where

import qualified Control.Exception as E
import qualified Data.ByteString as BS
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.Conc (threadWaitRead)
import qualified GHC.IO.Exception as E
import Network.Socket (Socket)
import qualified System.IO.Error as E
#if MIN_VERSION_network(3,1,0)
import Network.Socket (withFdSocket)
#else
import Network.Socket (fdSocket)
#endif
import System.Posix.Types (Fd(..))

import Network.Wai.Handler.Warp.Buffer
import Network.Wai.Handler.Warp.Imports
import Network.Wai.Handler.Warp.Types

#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Wai.Handler.Warp.Windows
#endif

----------------------------------------------------------------

makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> IO ByteString)
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO (BufSize -> Recv)
makeReceiveN ByteString
bs0 Recv
recv RecvBuf
recvBuf = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    (BufSize -> Recv) -> IO (BufSize -> Recv)
forall (m :: * -> *) a. Monad m => a -> m a
return ((BufSize -> Recv) -> IO (BufSize -> Recv))
-> (BufSize -> Recv) -> IO (BufSize -> Recv)
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf

-- | This function returns a receiving function
--   based on two receiving functions.
--   The returned function efficiently manages received data
--   which is initialized by the first argument.
--   The returned function may allocate a byte string with malloc().
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> IO ByteString)
makePlainReceiveN :: Socket -> ByteString -> IO (BufSize -> Recv)
makePlainReceiveN Socket
s ByteString
bs0 = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    IORef ByteString
pool <- IO (IORef ByteString)
newBufferPool
    (BufSize -> Recv) -> IO (BufSize -> Recv)
forall (m :: * -> *) a. Monad m => a -> m a
return ((BufSize -> Recv) -> IO (BufSize -> Recv))
-> (BufSize -> Recv) -> IO (BufSize -> Recv)
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref (Socket -> IORef ByteString -> Recv
receive Socket
s IORef ByteString
pool) (Socket -> RecvBuf
receiveBuf Socket
s)

receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> IO ByteString
receiveN :: IORef ByteString -> Recv -> RecvBuf -> BufSize -> Recv
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf BufSize
size = (SomeException -> Recv) -> Recv -> Recv
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle SomeException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ do
    ByteString
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    (ByteString
bs, ByteString
leftover) <- ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
spell ByteString
cached BufSize
size Recv
recv RecvBuf
recvBuf
    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
    ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
 where
   handler :: E.SomeException -> IO ByteString
   handler :: SomeException -> Recv
handler SomeException
_ = ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""

----------------------------------------------------------------

spell :: ByteString -> BufSize -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString)
spell :: ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
spell ByteString
init0 BufSize
siz0 Recv
recv RecvBuf
recvBuf
  | BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz0 ByteString
init0
  -- fixme: hard coding 4096
  | BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
4096 = [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString
init0] (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
  | Bool
otherwise    = do
      bs :: ByteString
bs@(PS ForeignPtr Word8
fptr BufSize
_ BufSize
_) <- BufSize -> Recv
mallocBS BufSize
siz0
      ForeignPtr Word8
-> (Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO (ByteString, ByteString))
 -> IO (ByteString, ByteString))
-> (Ptr Word8 -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
          Ptr Word8
ptr' <- Ptr Word8 -> ByteString -> IO (Ptr Word8)
copy Ptr Word8
ptr ByteString
init0
          Bool
full <- RecvBuf
recvBuf Ptr Word8
ptr' (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
          if Bool
full then
              (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, ByteString
"")
            else
              (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"") -- fixme
  where
    len0 :: BufSize
len0 = ByteString -> BufSize
BS.length ByteString
init0
    loop :: [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString]
bss BufSize
siz = do
        ByteString
bs <- Recv
recv
        let len :: BufSize
len = ByteString -> BufSize
BS.length ByteString
bs
        if BufSize
len BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
          else if BufSize
len BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
>= BufSize
siz then do
            let (ByteString
consume, ByteString
leftover) = BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz ByteString
bs
                ret :: ByteString
ret = [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
consume ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss)
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
          else do
            let bss' :: [ByteString]
bss' = ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss
                siz' :: BufSize
siz' = BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len
            [ByteString] -> BufSize -> IO (ByteString, ByteString)
loop [ByteString]
bss' BufSize
siz'

-- The timeout manager may close the socket.
-- In that case, an error of "Bad file descriptor" occurs.
-- We ignores it because we expect TimeoutThread.
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> IORef ByteString -> Recv
receive Socket
sock IORef ByteString
pool = (IOException -> Recv) -> Recv -> Recv
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle IOException -> Recv
handler (Recv -> Recv) -> Recv -> Recv
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> ((Ptr Word8, BufSize) -> IO BufSize) -> Recv
withBufferPool IORef ByteString
pool (((Ptr Word8, BufSize) -> IO BufSize) -> Recv)
-> ((Ptr Word8, BufSize) -> IO BufSize) -> Recv
forall a b. (a -> b) -> a -> b
$ \ (Ptr Word8
ptr, BufSize
size) -> do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO BufSize) -> IO BufSize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO BufSize) -> IO BufSize)
-> (CInt -> IO BufSize) -> IO BufSize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    let size' :: CSize
size' = BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
size
    CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
fd Ptr Word8
ptr CSize
size'
  where
    handler :: E.IOException -> IO ByteString
    handler :: IOException -> Recv
handler IOException
e
      | IOException -> IOErrorType
E.ioeGetErrorType IOException
e IOErrorType -> IOErrorType -> Bool
forall a. Eq a => a -> a -> Bool
== IOErrorType
E.InvalidArgument = ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
""
      | Bool
otherwise                                = IOException -> Recv
forall e a. Exception e => e -> IO a
E.throwIO IOException
e

receiveBuf :: Socket -> RecvBuf
receiveBuf :: Socket -> RecvBuf
receiveBuf Socket
sock Ptr Word8
buf0 BufSize
siz0 = do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO Bool) -> IO Bool
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO Bool) -> IO Bool) -> (CInt -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    CInt -> RecvBuf
loop CInt
fd Ptr Word8
buf0 BufSize
siz0
  where
    loop :: CInt -> RecvBuf
loop CInt
_  Ptr Word8
_   BufSize
0   = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    loop CInt
fd Ptr Word8
buf BufSize
siz = do
        BufSize
n <- CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
fd Ptr Word8
buf (BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
siz)
        -- fixme: what should we do in the case of n == 0
        if BufSize
n BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
            Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          else
            CInt -> RecvBuf
loop CInt
fd (Ptr Word8
buf Ptr Word8 -> BufSize -> Ptr Word8
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
n) (BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
n)

receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop :: CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
sock Ptr Word8
ptr CSize
size = do
#ifdef mingw32_HOST_OS
    bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "recv" (FD sock 1) (castPtr ptr) 0 size
#else
    CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) CSize
size CInt
0
#endif
    if CInt
bytes CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
        Errno
errno <- IO Errno
getErrno
        if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
            Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
            CInt -> Ptr Word8 -> CSize -> IO CInt
receiveloop CInt
sock Ptr Word8
ptr CSize
size
          else
            String -> IO CInt
forall a. String -> IO a
throwErrno String
"receiveloop"
       else
        CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes

#ifndef mingw32_HOST_OS
-- fixme: the type of the return value
foreign import ccall unsafe "recv"
    c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif