{-# LANGUAGE OverloadedStrings #-}
module Network.Socket.ByteString.Lazy.Posix (
send,
sendAll,
) where
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Foreign.Marshal.Array (allocaArray)
import Network.Socket.ByteString.IO (waitWhen0)
import Network.Socket.ByteString.Internal (c_writev)
import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Posix.IOVec (IOVec (IOVec))
import Network.Socket.Types
send
:: Socket
-> L.ByteString
-> IO Int64
send :: Socket -> ByteString -> IO Int64
send Socket
s ByteString
lbs = do
let cs :: [ByteString]
cs = Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
maxNumChunks (ByteString -> [ByteString]
L.toChunks ByteString
lbs)
len :: Int
len = [ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
cs
CSsize
siz <- Socket -> (CInt -> IO CSsize) -> IO CSsize
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> Int -> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
len ((Ptr IOVec -> IO CSsize) -> IO CSsize)
-> (Ptr IOVec -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \Ptr IOVec
ptr ->
[ByteString] -> Ptr IOVec -> (CInt -> IO CSsize) -> IO CSsize
forall {t} {a}.
Num t =>
[ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
cs Ptr IOVec
ptr ((CInt -> IO CSsize) -> IO CSsize)
-> (CInt -> IO CSsize) -> IO CSsize
forall a b. (a -> b) -> a -> b
$ \CInt
niovs ->
Socket -> String -> IO CSsize -> IO CSsize
forall a. (Eq a, Num a) => Socket -> String -> IO a -> IO a
throwSocketErrorWaitWrite Socket
s String
"writev" (IO CSsize -> IO CSsize) -> IO CSsize -> IO CSsize
forall a b. (a -> b) -> a -> b
$ CInt -> Ptr IOVec -> CInt -> IO CSsize
c_writev CInt
fd Ptr IOVec
ptr CInt
niovs
Int64 -> IO Int64
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int64 -> IO Int64) -> Int64 -> IO Int64
forall a b. (a -> b) -> a -> b
$ CSsize -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSsize
siz
where
withPokes :: [ByteString] -> Ptr IOVec -> (t -> IO a) -> IO a
withPokes [ByteString]
ss Ptr IOVec
p t -> IO a
f = [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop [ByteString]
ss Ptr IOVec
p Int
0 t
0
where
loop :: [ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop (ByteString
c : [ByteString]
cs) Ptr IOVec
q Int
k t
niovs
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
maxNumBytes = ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
c ((CStringLen -> IO a) -> IO a) -> (CStringLen -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) -> do
Ptr IOVec -> IOVec -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr IOVec
q (IOVec -> IO ()) -> IOVec -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> IOVec
IOVec (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
ptr) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
[ByteString] -> Ptr IOVec -> Int -> t -> IO a
loop
[ByteString]
cs
(Ptr IOVec
q Ptr IOVec -> Int -> Ptr IOVec
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` IOVec -> Int
forall a. Storable a => a -> Int
sizeOf (Ptr Word8 -> CSize -> IOVec
IOVec Ptr Word8
forall a. Ptr a
nullPtr CSize
0))
(Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
(t
niovs t -> t -> t
forall a. Num a => a -> a -> a
+ t
1)
| Bool
otherwise = t -> IO a
f t
niovs
loop [ByteString]
_ Ptr IOVec
_ Int
_ t
niovs = t -> IO a
f t
niovs
maxNumBytes :: Int
maxNumBytes = Int
4194304 :: Int
maxNumChunks :: Int
maxNumChunks = Int
1024 :: Int
sendAll
:: Socket
-> L.ByteString
-> IO ()
sendAll :: Socket -> ByteString -> IO ()
sendAll Socket
_ ByteString
"" = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sendAll Socket
s ByteString
bs0 = ByteString -> IO ()
loop ByteString
bs0
where
loop :: ByteString -> IO ()
loop ByteString
bs = do
Int64
sent <- Socket -> ByteString -> IO Int64
send Socket
s ByteString
bs
Int -> Socket -> IO ()
waitWhen0 (Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
sent) Socket
s
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int64
sent Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString -> Int64
L.length ByteString
bs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> IO ()
loop (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
L.drop Int64
sent ByteString
bs