{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Network.Socket.BufferPool.Recv (
    receive
  , receiveBuf
  , makeReceiveN
  , makePlainReceiveN
  ) where

import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString(..))
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.Conc (threadWaitRead)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (Fd(..))

#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Socket.BufferPool.Windows
#endif

import Network.Socket.BufferPool.Types
import Network.Socket.BufferPool.Buffer

----------------------------------------------------------------

-- | The receiving function with a buffer pool.
--   The buffer pool is automatically managed.
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> BufferPool -> Recv
receive Socket
sock BufferPool
pool = BufferPool -> (Buffer -> BufSize -> IO BufSize) -> Recv
withBufferPool BufferPool
pool ((Buffer -> BufSize -> IO BufSize) -> Recv)
-> (Buffer -> BufSize -> IO BufSize) -> Recv
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr BufSize
size -> do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO BufSize) -> IO BufSize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO BufSize) -> IO BufSize)
-> (CInt -> IO BufSize) -> IO BufSize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    let size' :: CSize
size' = BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
size
    CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryRecv CInt
fd Buffer
ptr CSize
size'

----------------------------------------------------------------

-- | The receiving function with a buffer.
--   This tries to fill the buffer.
--   This returns when the buffer is filled or reaches EOF.
receiveBuf :: Socket -> RecvBuf
receiveBuf :: Socket -> RecvBuf
receiveBuf Socket
sock Buffer
buf0 BufSize
siz0 = do
#if MIN_VERSION_network(3,1,0)
  Socket -> (CInt -> IO Bool) -> IO Bool
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock ((CInt -> IO Bool) -> IO Bool) -> (CInt -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
    fd <- fdSocket sock
#else
    let fd = fdSocket sock
#endif
    CInt -> RecvBuf
loop CInt
fd Buffer
buf0 BufSize
siz0
  where
    loop :: CInt -> RecvBuf
loop CInt
_  Buffer
_   BufSize
0   = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    loop CInt
fd Buffer
buf BufSize
siz = do
        BufSize
n <- CInt -> BufSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> BufSize) -> IO CInt -> IO BufSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryRecv CInt
fd Buffer
buf (BufSize -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral BufSize
siz)
        -- fixme: what should we do in the case of n == 0
        if BufSize
n BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
            Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
          else
            CInt -> RecvBuf
loop CInt
fd (Buffer
buf Buffer -> BufSize -> Buffer
forall a b. Ptr a -> BufSize -> Ptr b
`plusPtr` BufSize
n) (BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
n)

----------------------------------------------------------------

tryRecv :: CInt -> Buffer -> CSize -> IO CInt
tryRecv :: CInt -> Buffer -> CSize -> IO CInt
tryRecv CInt
sock Buffer
ptr CSize
size = IO CInt
go
  where
    go :: IO CInt
go = do
#ifdef mingw32_HOST_OS
      bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "tryRecv" (FD sock 1) (castPtr ptr) 0 size
#else
      CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (Buffer -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Buffer
ptr) CSize
size CInt
0
#endif
      if CInt
bytes CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
          Errno
errno <- IO Errno
getErrno
          if Errno
errno Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
              Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
              IO CInt
go
            else
              String -> IO CInt
forall a. String -> IO a
throwErrno String
"tryRecv"
         else
          CInt -> IO CInt
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes

----------------------------------------------------------------

-- | This function returns a receiving function
--   based on two receiving functions.
--   The returned function receives exactly N bytes.
--   The first argument is an initial received data.
--   After consuming the initial data, the two functions is used.
--   When N is less than equal to 4096, the buffer pool is used.
--   Otherwise, a new buffer is allocated.
--   In this case, the global lock is taken.
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO RecvN
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO RecvN
makeReceiveN ByteString
bs0 Recv
recv RecvBuf
recvBuf = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    RecvN -> IO RecvN
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvN -> IO RecvN) -> RecvN -> IO RecvN
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf

-- | This function returns a receiving function with two receiving
--   functions is created internally.
--   The second argument is the lower limit of the buffer pool.
--   The third argument is the size of the allocated buffer in the pool.
--   The fourth argument is an initial received data.
--   The returned function behaves as described in 'makeReceiveN'.
makePlainReceiveN :: Socket -> Int -> Int -> ByteString -> IO RecvN
makePlainReceiveN :: Socket -> BufSize -> BufSize -> ByteString -> IO RecvN
makePlainReceiveN Socket
s BufSize
l BufSize
h ByteString
bs0 = do
    IORef ByteString
ref <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
bs0
    BufferPool
pool <- BufSize -> BufSize -> IO BufferPool
newBufferPool BufSize
l BufSize
h
    RecvN -> IO RecvN
forall (m :: * -> *) a. Monad m => a -> m a
return (RecvN -> IO RecvN) -> RecvN -> IO RecvN
forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN IORef ByteString
ref (Socket -> BufferPool -> Recv
receive Socket
s BufferPool
pool) (Socket -> RecvBuf
receiveBuf Socket
s)

-- | The receiving function which receives exactly N bytes
--   (the fourth argument).
receiveN :: IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN :: IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf BufSize
size = do
    ByteString
cached <- IORef ByteString -> Recv
forall a. IORef a -> IO a
readIORef IORef ByteString
ref
    (ByteString
bs, ByteString
leftover) <- ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
tryRecvN ByteString
cached BufSize
size Recv
recv RecvBuf
recvBuf
    IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
    ByteString -> Recv
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

----------------------------------------------------------------

tryRecvN :: ByteString -> Int -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString)
tryRecvN :: ByteString
-> BufSize -> Recv -> RecvBuf -> IO (ByteString, ByteString)
tryRecvN ByteString
init0 BufSize
siz0 Recv
recv RecvBuf
recvBuf
  | BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
len0 = (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString, ByteString) -> IO (ByteString, ByteString))
-> (ByteString, ByteString) -> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz0 ByteString
init0
  -- fixme: hard coding 4096
  | BufSize
siz0 BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
<= BufSize
4096 = [ByteString] -> BufSize -> IO (ByteString, ByteString)
recvWithPool [ByteString
init0] (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
  | Bool
otherwise    = IO (ByteString, ByteString)
recvWithNewBuf
  where
    len0 :: BufSize
len0 = ByteString -> BufSize
BS.length ByteString
init0
    recvWithPool :: [ByteString] -> BufSize -> IO (ByteString, ByteString)
recvWithPool [ByteString]
bss BufSize
siz = do
        ByteString
bs <- Recv
recv
        let len :: BufSize
len = ByteString -> BufSize
BS.length ByteString
bs
        if BufSize
len BufSize -> BufSize -> Bool
forall a. Eq a => a -> a -> Bool
== BufSize
0 then
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
          else if BufSize
len BufSize -> BufSize -> Bool
forall a. Ord a => a -> a -> Bool
>= BufSize
siz then do
            let (ByteString
consume, ByteString
leftover) = BufSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufSize
siz ByteString
bs
                ret :: ByteString
ret = [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse (ByteString
consume ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss)
            (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
          else do
            let bss' :: [ByteString]
bss' = ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss
                siz' :: BufSize
siz' = BufSize
siz BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len
            [ByteString] -> BufSize -> IO (ByteString, ByteString)
recvWithPool [ByteString]
bss' BufSize
siz'
    recvWithNewBuf :: IO (ByteString, ByteString)
recvWithNewBuf = do
      bs :: ByteString
bs@(PS ForeignPtr Word8
fptr BufSize
_ BufSize
_) <- RecvN
mallocBS BufSize
siz0
      ForeignPtr Word8
-> (Buffer -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Buffer -> IO (ByteString, ByteString))
 -> IO (ByteString, ByteString))
-> (Buffer -> IO (ByteString, ByteString))
-> IO (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ \Buffer
ptr -> do
          Buffer
ptr' <- Buffer -> ByteString -> IO Buffer
copy Buffer
ptr ByteString
init0
          Bool
full <- RecvBuf
recvBuf Buffer
ptr' (BufSize
siz0 BufSize -> BufSize -> BufSize
forall a. Num a => a -> a -> a
- BufSize
len0)
          if Bool
full then
              (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, ByteString
"")
            else
              (ByteString, ByteString) -> IO (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"") -- fixme

#ifndef mingw32_HOST_OS
-- fixme: the type of the return value
foreign import ccall unsafe "recv"
    c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif