{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language MagicHash #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}

module Network.Unexceptional.ByteString
  ( send
  , sendInterruptible
  , receive
  , receiveExactly
  , receiveExactlyInterruptible
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent.STM (STM,TVar)
import GHC.Conc (threadWaitWrite,threadWaitWriteSTM)
import Control.Monad (when)
import Control.Monad ((<=<))
import Data.ByteString.Internal (ByteString(BS))
import Data.Bytes.Types (MutableBytes(MutableBytes))
import Data.Functor (($>))
import Data.Primitive (ByteArray(ByteArray))
import Data.Primitive.Addr (Addr(Addr),plusAddr)
import Foreign.C.Error (Errno)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN)
import GHC.ForeignPtr (ForeignPtr(ForeignPtr),ForeignPtrContents(PlainPtr))
import Network.Socket (Socket)
import System.Posix.Types (Fd(Fd))

import qualified Data.ByteString.Unsafe as ByteString
import qualified Data.Primitive as PM
import qualified GHC.Exts as Exts
import qualified Linux.Socket as X
import qualified Network.Socket as S
import qualified Network.Unexceptional.MutableBytes as MB
import qualified Posix.Socket as X
import qualified Control.Concurrent.STM as STM

-- | Send the entire byte sequence. This call POSIX @send@ in a loop
-- until all of the bytes have been sent.
send ::
     Socket
  -> ByteString
  -> IO (Either Errno ())
send :: Socket -> ByteString -> IO (Either Errno ())
send Socket
s !ByteString
b =
  Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> ByteString
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b ((CStringLen -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr,Int
len) ->
  -- We attempt the first send without testing if the socket is in
  -- ready for writes. This is because it is uncommon for the transmit
  -- buffer to already be full.
  Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len

-- does not wait for file descriptor to be ready
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop :: Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop !Fd
fd !Addr
addr !Int
len =
  Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then do
        Fd -> IO ()
threadWaitWrite Fd
fd
        Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd Addr
addr Int
len
      else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
            Ordering
LT -> Fd -> Addr -> Int -> IO (Either Errno ())
sendLoop Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendLoop: send claimed to send too many bytes"

-- | Send the entire byte sequence. This call POSIX @send@ in a loop
-- until all of the bytes have been sent.
sendInterruptible ::
     TVar Bool
  -> Socket
  -> ByteString
  -> IO (Either Errno ())
sendInterruptible :: TVar Bool -> Socket -> ByteString -> IO (Either Errno ())
sendInterruptible !TVar Bool
interrupt  Socket
s !ByteString
b =
  Socket -> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall r. Socket -> (CInt -> IO r) -> IO r
S.withFdSocket Socket
s ((CInt -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CInt -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \CInt
fd -> ByteString
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a. ByteString -> (CStringLen -> IO a) -> IO a
ByteString.unsafeUseAsCStringLen ByteString
b ((CStringLen -> IO (Either Errno ())) -> IO (Either Errno ()))
-> (CStringLen -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. (a -> b) -> a -> b
$ \(PM.Ptr Addr#
ptr,Int
len) ->
  -- We attempt the first send without testing if the socket is in
  -- ready for writes. This is because it is uncommon for the transmit
  -- buffer to already be full.
  TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) (Addr# -> Addr
Addr Addr#
ptr) Int
len

-- does not wait for file descriptor to be ready
sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop :: TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop !TVar Bool
interrupt !Fd
fd !Addr
addr !Int
len =
  Fd
-> Addr -> CSize -> MessageFlags 'Send -> IO (Either Errno CSize)
X.uninterruptibleSend Fd
fd Addr
addr (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (MessageFlags 'Send
X.noSignal MessageFlags 'Send -> MessageFlags 'Send -> MessageFlags 'Send
forall a. Semigroup a => a -> a -> a
<> MessageFlags 'Send
forall (m :: Message). MessageFlags m
X.dontWait) IO (Either Errno CSize)
-> (Either Errno CSize -> IO (Either Errno ()))
-> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> if Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e Errno -> Errno -> Bool
forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
      then TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt Fd
fd IO Outcome
-> (Outcome -> IO (Either Errno ())) -> IO (Either Errno ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Outcome
Ready -> TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd Addr
addr Int
len
        Outcome
Interrupted -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
EAGAIN)
      else Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ()
forall a b. a -> Either a b
Left Errno
e)
    Right CSize
sentSzC ->
      let sentSz :: Int
sentSz = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
sentSzC :: Int
       in case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
sentSz Int
len of
            Ordering
EQ -> Either Errno () -> IO (Either Errno ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (() -> Either Errno ()
forall a b. b -> Either a b
Right ())
            Ordering
LT -> TVar Bool -> Fd -> Addr -> Int -> IO (Either Errno ())
sendInterruptibleLoop TVar Bool
interrupt Fd
fd (Addr -> Int -> Addr
plusAddr Addr
addr Int
sentSz) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sentSz)
            Ordering
GT -> String -> IO (Either Errno ())
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.ByteString.sendInterruptibleLoop: send claimed to send too many bytes"

-- | If this returns zero bytes, it means that the peer has
-- performed an orderly shutdown.
receive ::
     Socket
  -> Int -- ^ Maximum number of bytes to receive
  -> IO (Either Errno ByteString)
receive :: Socket -> Int -> IO (Either Errno ByteString)
receive Socket
s Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno Int)
MB.receive Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno Int)
-> (Either Errno Int -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
    Right Int
m -> do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (MutableByteArray (PrimState IO) -> Int -> IO ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> m ()
PM.shrinkMutableByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst Int
m)
      ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
m))

-- | Blocks until an exact number of bytes has been received.
receiveExactly ::
     Socket
  -> Int -- ^ Exact number of bytes to receive, must be greater than zero
  -> IO (Either Errno ByteString)
receiveExactly :: Socket -> Int -> IO (Either Errno ByteString)
receiveExactly !Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactly Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
    Right ()
_ -> do
      ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
n))

-- | Blocks until an exact number of bytes has been received.
receiveExactlyInterruptible ::
     TVar Bool
  -> Socket
  -> Int -- ^ Exact number of bytes to receive, must be greater than zero
  -> IO (Either Errno ByteString)
receiveExactlyInterruptible :: TVar Bool -> Socket -> Int -> IO (Either Errno ByteString)
receiveExactlyInterruptible !TVar Bool
intr !Socket
s !Int
n = do
  MutableByteArray RealWorld
dst <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newPinnedByteArray Int
n
  TVar Bool
-> Socket -> MutableBytes RealWorld -> IO (Either Errno ())
MB.receiveExactlyInterruptible TVar Bool
intr Socket
s (MutableByteArray RealWorld -> Int -> Int -> MutableBytes RealWorld
forall s. MutableByteArray s -> Int -> Int -> MutableBytes s
MutableBytes MutableByteArray RealWorld
dst Int
0 Int
n) IO (Either Errno ())
-> (Either Errno () -> IO (Either Errno ByteString))
-> IO (Either Errno ByteString)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left Errno
e -> Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Errno -> Either Errno ByteString
forall a b. a -> Either a b
Left Errno
e)
    Right ()
_ -> do
      ByteArray ByteArray#
dst# <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
dst
      Either Errno ByteString -> IO (Either Errno ByteString)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either Errno ByteString
forall a b. b -> Either a b
Right (ForeignPtr Word8 -> Int -> ByteString
BS (Addr# -> ForeignPtrContents -> ForeignPtr Word8
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
Exts.byteArrayContents# ByteArray#
dst#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
Exts.unsafeCoerce# ByteArray#
dst#))) Int
n))

waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
  (STM ()
isReadyAction,IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
  Outcome
outcome <- STM Outcome -> IO Outcome
forall a. STM a -> IO a
STM.atomically (STM Outcome -> IO Outcome) -> STM Outcome -> IO Outcome
forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) STM Outcome -> STM Outcome -> STM Outcome
forall a. STM a -> STM a -> STM a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt STM () -> Outcome -> STM Outcome
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
  IO ()
deregister
  Outcome -> IO Outcome
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome

data Outcome = Ready | Interrupted

checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check (Bool -> STM ()) -> (TVar Bool -> STM Bool) -> TVar Bool -> STM ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< TVar Bool -> STM Bool
forall a. TVar a -> STM a
STM.readTVar