{-# language BangPatterns #-}
{-# language DuplicateRecordFields #-}
{-# language PatternSynonyms #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
module Network.Unexceptional
( accept_
, socket
, connect
, connectInterruptible
) where
import Control.Exception (mask_)
import Control.Monad ((<=<))
import Control.Applicative ((<|>))
import Data.Functor (($>))
import Network.Socket (Socket,SockAddr,mkSocket,withFdSocket,SocketOption(SoError),getSocketOption)
import Network.Socket.Address (SocketAddress,pokeSocketAddress,sizeOfSocketAddress)
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.C.Error (Errno(Errno))
import GHC.Conc (threadWaitRead,threadWaitWrite,threadWaitWriteSTM)
import Foreign.C.Error.Pattern (pattern EWOULDBLOCK,pattern EAGAIN,pattern EINPROGRESS,pattern EINTR)
import System.Posix.Types (Fd(Fd))
import Foreign.Ptr (castPtr,nullPtr)
import GHC.Exts (Ptr)
import Control.Concurrent.STM (STM,TVar)
import qualified Control.Concurrent.STM as STM
import qualified Linux.Socket as X
import qualified Posix.Socket as X
import qualified Network.Socket as N
accept_ :: Socket -> IO (Either Errno Socket)
accept_ :: Socket -> IO (Either Errno Socket)
accept_ Socket
listing_sock = forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
listing_sock forall a b. (a -> b) -> a -> b
$ \CInt
listing_fd -> do
let acceptLoop :: IO (Either Errno Socket)
acceptLoop = do
Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
listing_fd)
Fd -> SocketFlags -> IO (Either Errno Fd)
X.uninterruptibleAccept4_ (CInt -> Fd
Fd CInt
listing_fd) (SocketFlags
X.nonblocking forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.closeOnExec) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
e -> if Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EAGAIN Bool -> Bool -> Bool
|| Errno
e forall a. Eq a => a -> a -> Bool
== Errno
EWOULDBLOCK
then IO (Either Errno Socket)
acceptLoop
else forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
e)
Right (Fd CInt
fd) -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. b -> Either a b
Right (CInt -> IO Socket
mkSocket CInt
fd)
IO (Either Errno Socket)
acceptLoop
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect :: Socket -> SockAddr -> IO (Either Errno ())
connect Socket
s SockAddr
sa = forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
let loop :: IO (Either Errno ())
loop = do
Either Errno ()
r <- forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
case Either Errno ()
r of
Right ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
Left Errno
err -> case Errno
err of
Errno
EINTR -> IO (Either Errno ())
loop
Errno
EINPROGRESS -> do
Fd -> IO ()
threadWaitWrite (CInt -> Fd
Fd CInt
fd)
Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
case Int
errB of
Int
0 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
Int
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left (CInt -> Errno
Errno (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
Errno
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
err)
in IO (Either Errno ())
loop
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible :: TVar Bool -> Socket -> SockAddr -> IO (Either Errno ())
connectInterruptible !TVar Bool
interrupt Socket
s SockAddr
sa = forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress SockAddr
sa forall a b. (a -> b) -> a -> b
$ \Ptr SockAddr
p_sa Int
sz -> forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
s forall a b. (a -> b) -> a -> b
$ \CInt
fd ->
let loop :: IO (Either Errno ())
loop = do
Either Errno ()
r <- forall a. Fd -> Ptr a -> Int -> IO (Either Errno ())
X.uninterruptibleConnectPtr (CInt -> Fd
Fd CInt
fd) Ptr SockAddr
p_sa Int
sz
case Either Errno ()
r of
Right ()
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
Left Errno
err -> case Errno
err of
Errno
EINTR -> IO (Either Errno ())
loop
Errno
EINPROGRESS -> TVar Bool -> Fd -> IO Outcome
waitUntilWriteable TVar Bool
interrupt (CInt -> Fd
Fd CInt
fd) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Outcome
Interrupted -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
EAGAIN)
Outcome
Ready -> do
Int
errB <- Socket -> SocketOption -> IO Int
getSocketOption Socket
s SocketOption
SoError
case Int
errB of
Int
0 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right ())
Int
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left (CInt -> Errno
Errno (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
errB)))
Errno
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
err)
in IO (Either Errno ())
loop
withSocketAddress :: SocketAddress sa => sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress :: forall sa a.
SocketAddress sa =>
sa -> (Ptr sa -> Int -> IO a) -> IO a
withSocketAddress sa
addr Ptr sa -> Int -> IO a
f = do
let sz :: Int
sz = forall sa. SocketAddress sa => sa -> Int
sizeOfSocketAddress sa
addr
if Int
sz forall a. Eq a => a -> a -> Bool
== Int
0
then Ptr sa -> Int -> IO a
f forall a. Ptr a
nullPtr Int
0
else forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
sz forall a b. (a -> b) -> a -> b
$ \Ptr Any
p -> forall sa a. SocketAddress sa => Ptr a -> sa -> IO ()
pokeSocketAddress Ptr Any
p sa
addr forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr sa -> Int -> IO a
f (forall a b. Ptr a -> Ptr b
castPtr Ptr Any
p) Int
sz
data Outcome = Ready | Interrupted
checkFinished :: TVar Bool -> STM ()
checkFinished :: TVar Bool -> STM ()
checkFinished = Bool -> STM ()
STM.check forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a. TVar a -> STM a
STM.readTVar
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable :: TVar Bool -> Fd -> IO Outcome
waitUntilWriteable !TVar Bool
interrupt !Fd
fd = do
(STM ()
isReadyAction,IO ()
deregister) <- Fd -> IO (STM (), IO ())
threadWaitWriteSTM Fd
fd
Outcome
outcome <- forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ (STM ()
isReadyAction forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Ready) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (TVar Bool -> STM ()
checkFinished TVar Bool
interrupt forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Outcome
Interrupted)
IO ()
deregister
forall (f :: * -> *) a. Applicative f => a -> f a
pure Outcome
outcome
socket ::
N.Family
-> N.SocketType
-> N.ProtocolNumber
-> IO (Either Errno Socket)
socket :: Family -> SocketType -> CInt -> IO (Either Errno Socket)
socket !Family
fam !SocketType
stype !CInt
protocol = case SocketType
stype of
SocketType
N.Stream -> Type -> IO (Either Errno Socket)
finish Type
X.stream
SocketType
N.Datagram -> Type -> IO (Either Errno Socket)
finish Type
X.datagram
SocketType
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Network.Unexceptional.socket: Currently only supports stream and datagram types"
where
finish :: Type -> IO (Either Errno Socket)
finish !Type
sockTy = forall a. IO a -> IO a
mask_ forall a b. (a -> b) -> a -> b
$ do
Family -> Type -> Protocol -> IO (Either Errno Fd)
X.uninterruptibleSocket (CInt -> Family
X.Family (Family -> CInt
N.packFamily Family
fam)) (SocketFlags -> Type -> Type
X.applySocketFlags (SocketFlags
X.closeOnExec forall a. Semigroup a => a -> a -> a
<> SocketFlags
X.nonblocking) Type
sockTy) (CInt -> Protocol
X.Protocol CInt
protocol) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Left Errno
err -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. a -> Either a b
Left Errno
err)
Right (Fd CInt
fd) -> do
Socket
s <- CInt -> IO Socket
mkSocket CInt
fd
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right Socket
s)