{-# LANGUAGE CPP #-}

module System.IO.Streams.Internal.Network
  ( socketToStreams
  , socketToStreamsWithBufferSize
  , socketToStreamsWithBufferSizeImpl
  ) where


------------------------------------------------------------------------------
import           Control.Exception          (catch)
import qualified Data.ByteString.Char8      as S
import qualified Data.ByteString.Internal   as S
import           Data.Word                  (Word8)
import           Foreign.ForeignPtr         (newForeignPtr, withForeignPtr)
import           Foreign.Marshal.Alloc      (finalizerFree, mallocBytes)
import           Foreign.Ptr                (Ptr)
import           Network.Socket             (Socket)
import qualified Network.Socket             as N
import qualified Network.Socket.ByteString  as NB
import           Prelude                    (IO, Int, Maybe (..), return, ($!), (<=), (>>=))
import           System.IO.Error            (ioError, isEOFError)
------------------------------------------------------------------------------
import           System.IO.Streams.Internal (InputStream, OutputStream)
import qualified System.IO.Streams.Internal as Streams


------------------------------------------------------------------------------
bUFSIZ :: Int
bUFSIZ :: Int
bUFSIZ = Int
4096


------------------------------------------------------------------------------
-- | Converts a 'Socket' to an 'InputStream' \/ 'OutputStream' pair. Note that,
-- as is usually the case in @io-streams@, writing a 'Nothing' to the generated
-- 'OutputStream' does not cause the underlying 'Socket' to be closed.
socketToStreams :: Socket
                -> IO (InputStream S.ByteString, OutputStream S.ByteString)
socketToStreams :: Socket -> IO (InputStream ByteString, OutputStream ByteString)
socketToStreams = Int
-> Socket -> IO (InputStream ByteString, OutputStream ByteString)
socketToStreamsWithBufferSize Int
bUFSIZ


------------------------------------------------------------------------------
-- | Converts a 'Socket' to an 'InputStream' \/ 'OutputStream' pair, with
-- control over the size of the receive buffers. Note that, as is usually the
-- case in @io-streams@, writing a 'Nothing' to the generated 'OutputStream'
-- does not cause the underlying 'Socket' to be closed.
socketToStreamsWithBufferSize
    :: Int                      -- ^ how large the receive buffer should be
    -> Socket                   -- ^ network socket
    -> IO (InputStream S.ByteString, OutputStream S.ByteString)
#if MIN_VERSION_network(2,4,0)
socketToStreamsWithBufferSize :: Int
-> Socket -> IO (InputStream ByteString, OutputStream ByteString)
socketToStreamsWithBufferSize = (Socket -> Ptr Word8 -> Int -> IO Int)
-> Int
-> Socket
-> IO (InputStream ByteString, OutputStream ByteString)
socketToStreamsWithBufferSizeImpl Socket -> Ptr Word8 -> Int -> IO Int
N.recvBuf
#else
socketToStreamsWithBufferSize bufsiz socket = do
    is <- Streams.makeInputStream input
    os <- Streams.makeOutputStream output
    return $! (is, os)

  where
    input = do
        s <- NB.recv socket bufsiz
        return $! if S.null s then Nothing else Just s

    output Nothing  = return $! ()
    output (Just s) = if S.null s then return $! () else NB.sendAll socket s
#endif


------------------------------------------------------------------------------
-- | Dependency-injected implementation of socketToStreamsWithBufferSize (for
-- testing)
socketToStreamsWithBufferSizeImpl
    :: (N.Socket -> Ptr Word8 -> Int -> IO Int)  -- ^ recvBuf
    -> Int                                       -- ^ how large the receive
                                                 --   buffer should be
    -> Socket                                    -- ^ network socket
    -> IO (InputStream S.ByteString, OutputStream S.ByteString)
socketToStreamsWithBufferSizeImpl :: (Socket -> Ptr Word8 -> Int -> IO Int)
-> Int
-> Socket
-> IO (InputStream ByteString, OutputStream ByteString)
socketToStreamsWithBufferSizeImpl Socket -> Ptr Word8 -> Int -> IO Int
_recvBuf Int
bufsiz Socket
socket = do
    InputStream ByteString
is <- IO (Maybe ByteString) -> IO (InputStream ByteString)
forall a. IO (Maybe a) -> IO (InputStream a)
Streams.makeInputStream IO (Maybe ByteString)
input
    OutputStream ByteString
os <- (Maybe ByteString -> IO ()) -> IO (OutputStream ByteString)
forall a. (Maybe a -> IO ()) -> IO (OutputStream a)
Streams.makeOutputStream Maybe ByteString -> IO ()
output
    (InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return ((InputStream ByteString, OutputStream ByteString)
 -> IO (InputStream ByteString, OutputStream ByteString))
-> (InputStream ByteString, OutputStream ByteString)
-> IO (InputStream ByteString, OutputStream ByteString)
forall a b. (a -> b) -> a -> b
$! (InputStream ByteString
is, OutputStream ByteString
os)

  where
    recv :: Ptr Word8 -> IO Int
recv Ptr Word8
buf = Socket -> Ptr Word8 -> Int -> IO Int
_recvBuf Socket
socket Ptr Word8
buf Int
bufsiz IO Int -> (IOError -> IO Int) -> IO Int
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \IOError
ioe ->
               if IOError -> Bool
isEOFError IOError
ioe then Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0 else IOError -> IO Int
forall a. IOError -> IO a
ioError IOError
ioe

    mkFp :: IO (ForeignPtr a)
mkFp = Int -> IO (Ptr a)
forall a. Int -> IO (Ptr a)
mallocBytes Int
bufsiz IO (Ptr a) -> (Ptr a -> IO (ForeignPtr a)) -> IO (ForeignPtr a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr a
forall a. FinalizerPtr a
finalizerFree

    input :: IO (Maybe ByteString)
input = do
        ForeignPtr Word8
fp <- IO (ForeignPtr Word8)
forall a. IO (ForeignPtr a)
mkFp
        Int
n  <- ForeignPtr Word8 -> (Ptr Word8 -> IO Int) -> IO Int
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp Ptr Word8 -> IO Int
recv
        Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$! if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
                    then Maybe ByteString
forall a. Maybe a
Nothing
                    else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$! ForeignPtr Word8 -> Int -> Int -> ByteString
S.fromForeignPtr ForeignPtr Word8
fp Int
0 Int
n

    output :: Maybe ByteString -> IO ()
output Maybe ByteString
Nothing  = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! ()
    output (Just ByteString
s) = if ByteString -> Bool
S.null ByteString
s then () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return (() -> IO ()) -> () -> IO ()
forall a b. (a -> b) -> a -> b
$! () else Socket -> ByteString -> IO ()
NB.sendAll Socket
socket ByteString
s