{-# language BangPatterns #-} {-# language DuplicateRecordFields #-} {-# language MagicHash #-} {-# language PatternSynonyms #-} {-# language LambdaCase #-} {-# language NamedFieldPuns #-} module Network.Unexceptional.ByteString ( send , sendInterruptible , receive , receiveExactly , receiveExactlyInterruptible ) where import Control.Applicative ((<|>)) import Control.Concurrent.STM (STM,TVar) import GHC.Conc (threadWaitWrite,threadWaitWriteSTM) import Control.Monad (when) import Control.Monad ((<=<)) import Data.ByteString.Internal (ByteString(BS)) import Data.Bytes.Types (MutableBytes(MutableBytes)) import Data.Functor (($>)) import Data.Primitive (ByteArray(ByteArray)) import Data.Primitive.Addr (Addr(Addr),plusAddr) import Foreign.C.Error (Errno) import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN) import GHC.ForeignPtr (ForeignPtr(ForeignPtr),ForeignPtrContents(PlainPtr)) import Network.Socket (Socket) import System.Posix.Types (Fd(Fd)) import qualified Data.ByteString.Unsafe as ByteString import qualified Data.Primitive as PM import qualified GHC.Exts as Exts import qualified Linux.Socket as X import qualified Network.Socket as S import qualified Network.Unexceptional.MutableBytes as MB import qualified Posix.Socket as X import qualified Control.Concurrent.STM as STM -- | Send the entire byte sequence. This call POSIX @send@ in a loop -- until all of the bytes have been sent. send :: Socket -> ByteString -> IO (Either Errno ()) send s !b = S.withFdSocket s $ \fd -> ByteString.unsafeUseAsCStringLen b $ \(PM.Ptr ptr,len) -> -- We attempt the first send without testing if the socket is in -- ready for writes. This is because it is uncommon for the transmit -- buffer to already be full. sendLoop (Fd fd) (Addr ptr) len -- does not wait for file descriptor to be ready sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ()) sendLoop !fd !addr !len = X.uninterruptibleSend fd addr (fromIntegral len) (X.noSignal <> X.dontWait) >>= \case Left e -> if e == EAGAIN || e == EWOULDBLOCK then do threadWaitWrite fd sendLoop fd addr len else pure (Left e) Right sentSzC -> let sentSz = fromIntegral sentSzC :: Int in case compare sentSz len of EQ -> pure (Right ()) LT -> sendLoop fd (plusAddr addr sentSz) (len - sentSz) GT -> fail "Network.Unexceptional.ByteString.sendLoop: send claimed to send too many bytes" -- | Send the entire byte sequence. This call POSIX @send@ in a loop -- until all of the bytes have been sent. sendInterruptible :: TVar Bool -> Socket -> ByteString -> IO (Either Errno ()) sendInterruptible !interrupt s !b = S.withFdSocket s $ \fd -> ByteString.unsafeUseAsCStringLen b $ \(PM.Ptr ptr,len) -> -- We attempt the first send without testing if the socket is in -- ready for writes. This is because it is uncommon for the transmit -- buffer to already be full. sendInterruptibleLoop interrupt (Fd fd) (Addr ptr) len -- does not wait for file descriptor to be ready sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ()) sendInterruptibleLoop !interrupt !fd !addr !len = X.uninterruptibleSend fd addr (fromIntegral len) (X.noSignal <> X.dontWait) >>= \case Left e -> if e == EAGAIN || e == EWOULDBLOCK then waitUntilWriteable interrupt fd >>= \case Ready -> sendInterruptibleLoop interrupt fd addr len Interrupted -> pure (Left EAGAIN) else pure (Left e) Right sentSzC -> let sentSz = fromIntegral sentSzC :: Int in case compare sentSz len of EQ -> pure (Right ()) LT -> sendInterruptibleLoop interrupt fd (plusAddr addr sentSz) (len - sentSz) GT -> fail "Network.Unexceptional.ByteString.sendInterruptibleLoop: send claimed to send too many bytes" -- | If this returns zero bytes, it means that the peer has -- performed an orderly shutdown. receive :: Socket -> Int -- ^ Maximum number of bytes to receive -> IO (Either Errno ByteString) receive s n = do dst <- PM.newPinnedByteArray n MB.receive s (MutableBytes dst 0 n) >>= \case Left e -> pure (Left e) Right m -> do when (m < n) (PM.shrinkMutableByteArray dst m) ByteArray dst# <- PM.unsafeFreezeByteArray dst pure (Right (BS (ForeignPtr (Exts.byteArrayContents# dst#) (PlainPtr (Exts.unsafeCoerce# dst#))) m)) -- | Blocks until an exact number of bytes has been received. receiveExactly :: Socket -> Int -- ^ Exact number of bytes to receive, must be greater than zero -> IO (Either Errno ByteString) receiveExactly !s !n = do dst <- PM.newPinnedByteArray n MB.receiveExactly s (MutableBytes dst 0 n) >>= \case Left e -> pure (Left e) Right _ -> do ByteArray dst# <- PM.unsafeFreezeByteArray dst pure (Right (BS (ForeignPtr (Exts.byteArrayContents# dst#) (PlainPtr (Exts.unsafeCoerce# dst#))) n)) -- | Blocks until an exact number of bytes has been received. receiveExactlyInterruptible :: TVar Bool -> Socket -> Int -- ^ Exact number of bytes to receive, must be greater than zero -> IO (Either Errno ByteString) receiveExactlyInterruptible !intr !s !n = do dst <- PM.newPinnedByteArray n MB.receiveExactlyInterruptible intr s (MutableBytes dst 0 n) >>= \case Left e -> pure (Left e) Right _ -> do ByteArray dst# <- PM.unsafeFreezeByteArray dst pure (Right (BS (ForeignPtr (Exts.byteArrayContents# dst#) (PlainPtr (Exts.unsafeCoerce# dst#))) n)) waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome waitUntilWriteable !interrupt !fd = do (isReadyAction,deregister) <- threadWaitWriteSTM fd outcome <- STM.atomically $ (isReadyAction $> Ready) <|> (checkFinished interrupt $> Interrupted) deregister pure outcome data Outcome = Ready | Interrupted checkFinished :: TVar Bool -> STM () checkFinished = STM.check <=< STM.readTVar