{-# LANGUAGE OverloadedStrings #-}

module Network.Socket.ByteString.Lazy.Posix (
    -- * Send data to a socket
    send
  , sendAll
  ) where

import qualified Data.ByteString.Lazy               as L
import           Data.ByteString.Unsafe             (unsafeUseAsCStringLen)
import           Foreign.Marshal.Array              (allocaArray)

import           Network.Socket.ByteString.IO       (waitWhen0)
import           Network.Socket.ByteString.Internal (c_writev)
import           Network.Socket.Imports
import           Network.Socket.Internal
import           Network.Socket.Posix.IOVec    (IOVec (IOVec))
import           Network.Socket.Types

-- -----------------------------------------------------------------------------
-- Sending
send
    :: Socket -- ^ Connected socket
    -> L.ByteString -- ^ Data to send
    -> IO Int64 -- ^ Number of bytes sent
send :: Socket -> ByteString -> IO Int64
send Socket
s ByteString
lbs = do
    let cs :: [ByteString]
cs  = Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
maxNumChunks (ByteString -> [ByteString]
L.toChunks ByteString
lbs)
        len :: Int
len = [ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
cs
    CSsize
siz <- Socket -> (CInt -> IO CSsize) -> IO CSsize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> Int -> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
len ((Ptr IOVec -> IO CSsize) -> IO CSsize)
-> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \Ptr IOVec
ptr ->
             [ByteString] -> Ptr IOVec -> (CInt -> IO CSsize) -> IO CSsize
forall {t} {a}.
Num t =>
[ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
cs Ptr IOVec
ptr ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
niovs ->
               Socket -> String -> IO CSsize -> IO CSsize
forall a. (Eq a, Num a) => Socket -> String -> IO a -> IO a
throwSocketErrorWaitWrite Socket
s String
"writev" (IO CSsize -> IO CSsize) -> IO CSsize -> IO CSsize
forall a b. (a -> b) -> a -> b
$ CInt -> Ptr IOVec -> CInt -> IO CSsize
c_writev CInt
fd Ptr IOVec
ptr CInt
niovs
    Int64 -> IO Int64
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> IO Int64) -> Int64 -> IO Int64
forall a b. (a -> b) -> a -> b
$ CSsize -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSsize
siz
  where
    withPokes :: [ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
ss Ptr IOVec
p t -> IO a
f = [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
ss Ptr IOVec
p Int
0 t
0
      where
        loop :: [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop (ByteString
c:[ByteString]
cs) Ptr IOVec
q Int
k t
niovs
            | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxNumBytes = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
c ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) -> do
                Ptr IOVec -> IOVec -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr IOVec
q (IOVec -> IO ()) -> IOVec -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> IOVec
IOVec (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
cs
                     (Ptr IOVec
q Ptr IOVec -> Int -> Ptr IOVec
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` IOVec -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr Word8 -> CSize -> IOVec
IOVec Ptr Word8
forall a. Ptr a
nullPtr CSize
0))
                     (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
                     (t
niovs t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)
            | Bool
otherwise = t -> IO a
f t
niovs
        loop [ByteString]
_ Ptr IOVec
_ Int
_ t
niovs = t -> IO a
f t
niovs
    maxNumBytes :: Int
maxNumBytes  = Int
4194304 :: Int -- maximum number of bytes to transmit in one system call
    maxNumChunks :: Int
maxNumChunks = Int
1024 :: Int -- maximum number of chunks to transmit in one system call

sendAll
    :: Socket -- ^ Connected socket
    -> L.ByteString -- ^ Data to send
    -> IO ()
sendAll :: Socket -> ByteString -> IO ()
sendAll Socket
_ ByteString
"" = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sendAll Socket
s ByteString
bs0 = ByteString -> IO ()
loop ByteString
bs0
  where
    loop :: ByteString -> IO ()
loop ByteString
bs = do
        -- "send" throws an exception.
        Int64
sent <- Socket -> ByteString -> IO Int64
send Socket
s ByteString
bs
        Int -> Socket -> IO ()
waitWhen0 (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
sent) Socket
s
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int64
sent Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int64
L.length ByteString
bs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ()
loop (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
L.drop Int64
sent ByteString
bs