{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language MagicHash #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
module Network.Unexceptional.ByteString
( send
, receive
) where
import Control.Monad (when)
import Data.ByteString.Internal (ByteString(BS))
import Data.Bytes.Types (MutableBytes(MutableBytes))
import Data.Primitive (ByteArray(ByteArray))
import Data.Primitive.Addr (Addr(Addr),plusAddr)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.Conc (threadWaitWrite)
import GHC.ForeignPtr (ForeignPtr(ForeignPtr),ForeignPtrContents(PlainPtr))
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))
import qualified Data.ByteString.Unsafe as ByteString
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Posix.Socket as X
send ::
Socket
-> ByteString
-> IO (Either Errno ())
send :: Socket -> ByteString -> IO (Either Errno ())
send Socket
s !ByteString
b =
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd -> forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr,Int
len) ->
Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !Addr
addr !Int
len =
Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal forall a. Semigroup a => a -> a -> a
<> forall (m :: Message). MessageFlags m
X.dontWait) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
then do
Fd -> IO ()
threadWaitWrite Fd
fd
Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd Addr
addr Int
len
else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
Right CSize
sentSzC ->
let sentSz :: Int
sentSz = forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
in case forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
Ordering
EQ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
Ordering
LT -> Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len forall a. Num a => a -> a -> a
- Int
sentSz)
Ordering
GT -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendLoop: send claimed to send too many bytes"
receive ::
Socket
-> Int
-> IO (Either Errno ByteString)
receive :: Socket -> Int -> IO (Either Errno ByteString)
receive Socket
s Int
n = do
MutableByteArray RealWorld
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
Right Int
m -> do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m forall a. Ord a => a -> a -> Bool
< Int
n) (forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
dst Int
m)
ByteArray ByteArray#
dst# <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
dst
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (unsafeCoerce# :: forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
m))