{-# language BangPatterns #-}
{-# language RankNTypes #-}
{-# language DuplicateRecordFields #-}
{-# language LambdaCase #-}
{-# language NamedFieldPuns #-}
{-# language MagicHash #-}
module Socket.Stream.IPv4
(
Listener
, Connection
, Endpoint(..)
, withListener
, withAccepted
, withConnection
, forkAccepted
, forkAcceptedUnmasked
, sendByteArray
, sendByteArraySlice
, sendMutableByteArray
, sendMutableByteArraySlice
, receiveByteArray
, receiveBoundedByteArray
, receiveMutableByteArray
, SocketException(..)
, Context(..)
, Reason(..)
) where
import Control.Concurrent (ThreadId,threadWaitWrite,threadWaitRead)
import Control.Concurrent (forkIO,forkIOWithUnmask)
import Control.Exception (mask,onException)
import Data.Bifunctor (bimap)
import Data.Primitive (ByteArray,MutableByteArray(..))
import Data.Word (Word16)
import Foreign.C.Error (Errno(..),eAGAIN,eWOULDBLOCK,eINPROGRESS)
import Foreign.C.Types (CInt,CSize)
import GHC.Exts (RealWorld,Int(I#),shrinkMutableByteArray#)
import Socket (SocketException(..),Context(..),Reason(..))
import Socket.Debug (debug)
import Socket.IPv4 (Endpoint(..))
import System.Posix.Types (Fd)
import Net.Types (IPv4(..))
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Linux.Socket as L
import qualified Posix.Socket as S
newtype Listener = Listener Fd
newtype Connection = Connection Fd
withListener ::
Endpoint
-> (Listener -> Word16 -> IO a)
-> IO (Either SocketException a)
withListener endpoint@Endpoint{port = specifiedPort} f = mask $ \restore -> do
debug ("withSocket: opening listener " ++ show endpoint)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("withSocket: opened listener " ++ show endpoint)
case e1 of
Left err -> pure (Left (errorCode Open err))
Right fd -> do
e2 <- S.uninterruptibleBind fd
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint))
debug ("withSocket: requested binding for listener " ++ show endpoint)
case e2 of
Left err -> do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode Bind err))
Right _ -> S.uninterruptibleListen fd 16 >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
debug "withSocket: listen failed with error code"
pure (Left (errorCode Listen err))
Right _ -> do
eactualPort <- if specifiedPort == 0
then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode GetName err))
Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just S.SocketAddressInternet{port = actualPort} -> do
let cleanActualPort = S.networkToHostShort actualPort
debug ("withSocket: successfully bound listener " ++ show endpoint ++ " and got port " ++ show cleanActualPort)
pure (Right cleanActualPort)
Nothing -> do
_ <- S.uninterruptibleClose fd
pure (Left (exception GetName SocketAddressFamily))
else do
_ <- S.uninterruptibleClose fd
pure (Left (exception GetName SocketAddressSize))
else pure (Right specifiedPort)
case eactualPort of
Left err -> pure (Left err)
Right actualPort -> do
a <- onException (restore (f (Listener fd) actualPort)) (S.uninterruptibleClose fd)
S.uninterruptibleClose fd >>= \case
Left err -> pure (Left (errorCode Close err))
Right _ -> pure (Right a)
withAccepted ::
Listener
-> (Connection -> Endpoint -> IO a)
-> IO (Either SocketException a)
withAccepted lst cb = internalAccepted
( \restore action -> do
action restore
) lst cb
internalAccepted ::
((forall x. IO x -> IO x) -> ((IO a -> IO b) -> IO (Either SocketException b)) -> IO (Either SocketException c))
-> Listener
-> (Connection -> Endpoint -> IO a)
-> IO (Either SocketException c)
internalAccepted wrap (Listener !lst) f = do
threadWaitRead lst
mask $ \restore -> do
S.uninterruptibleAccept lst S.sizeofSocketAddressInternet >>= \case
Left err -> pure (Left (errorCode Accept err))
Right (sockAddrRequiredSz,sockAddr,acpt) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
let acceptedEndpoint = socketAddressInternetToEndpoint sockAddrInet
debug ("withAccepted: successfully accepted connection from " ++ show acceptedEndpoint)
wrap restore $ \restore' -> do
a <- onException (restore' (f (Connection acpt) acceptedEndpoint)) (S.uninterruptibleClose acpt)
gracefulClose acpt a
Nothing -> do
_ <- S.uninterruptibleClose acpt
pure (Left (exception GetName SocketAddressFamily))
else do
_ <- S.uninterruptibleClose acpt
pure (Left (exception GetName SocketAddressSize))
gracefulClose :: Fd -> a -> IO (Either SocketException a)
gracefulClose fd a = S.uninterruptibleShutdown fd S.write >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode Shutdown err))
Right _ -> do
buf <- PM.newByteArray 1
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
threadWaitRead fd
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode Shutdown err))
Right sz -> if sz == 0
then fmap (bimap (errorCode Close) (const a)) (S.uninterruptibleClose fd)
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown A")
_ <- S.uninterruptibleClose fd
pure (Left (exception Shutdown RemoteNotShutdown))
else do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode Shutdown err1))
Right sz -> if sz == 0
then fmap (bimap (errorCode Close) (const a)) (S.uninterruptibleClose fd)
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown B")
_ <- S.uninterruptibleClose fd
pure (Left (exception Shutdown RemoteNotShutdown))
forkAccepted ::
Listener
-> (Either SocketException a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either SocketException ThreadId)
forkAccepted lst consumeException cb = internalAccepted
( \restore action -> do
tid <- forkIO $ do
x <- action restore
restore (consumeException x)
pure (Right tid)
) lst cb
forkAcceptedUnmasked ::
Listener
-> (Either SocketException a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either SocketException ThreadId)
forkAcceptedUnmasked lst consumeException cb = internalAccepted
( \_ action -> do
tid <- forkIOWithUnmask $ \unmask -> do
x <- action unmask
unmask (consumeException x)
pure (Right tid)
) lst cb
withConnection ::
Endpoint
-> (Connection -> IO a)
-> IO (Either SocketException a)
withConnection !remote f = mask $ \restore -> do
debug ("withSocket: opening connection " ++ show remote)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("withSocket: opened connection " ++ show remote)
case e1 of
Left err1 -> pure (Left (errorCode Open err1))
Right fd -> do
let sockAddr = id
$ S.encodeSocketAddressInternet
$ endpointToSocketAddressInternet
$ remote
merr <- S.uninterruptibleConnect fd sockAddr >>= \case
Left err2 -> if err2 == eINPROGRESS
then do
threadWaitWrite fd
pure Nothing
else pure (Just (errorCode Connect err2))
Right _ -> pure Nothing
case merr of
Just err -> do
_ <- S.uninterruptibleClose fd
pure (Left err)
Nothing -> do
e <- S.uninterruptibleGetSocketOption fd
S.levelSocket S.optionError (intToCInt (PM.sizeOf (undefined :: CInt)))
case e of
Left err -> do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode Option err))
Right (sz,S.OptionValue val) -> if sz == intToCInt (PM.sizeOf (undefined :: CInt))
then
let err = PM.indexByteArray val 0 :: CInt in
if err == 0
then do
a <- onException (restore (f (Connection fd))) (S.uninterruptibleClose fd)
gracefulClose fd a
else do
_ <- S.uninterruptibleClose fd
pure (Left (errorCode Connect (Errno err)))
else do
_ <- S.uninterruptibleClose fd
pure (Left (exception Option OptionValueSize))
sendByteArray ::
Connection
-> ByteArray
-> IO (Either SocketException ())
sendByteArray conn arr =
sendByteArraySlice conn arr 0 (PM.sizeofByteArray arr)
sendByteArraySlice ::
Connection
-> ByteArray
-> Int
-> Int
-> IO (Either SocketException ())
sendByteArraySlice !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalSend conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else pure (Right ())
sendMutableByteArray ::
Connection
-> MutableByteArray RealWorld
-> IO (Either SocketException ())
sendMutableByteArray conn arr =
sendMutableByteArraySlice conn arr 0 =<< PM.getSizeofMutableByteArray arr
sendMutableByteArraySlice ::
Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either SocketException ())
sendMutableByteArraySlice !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalSendMutable conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else pure (Right ())
internalSendMutable ::
Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either SocketException CSize)
internalSendMutable (Connection !s) !payload !off !len = do
e1 <- S.uninterruptibleSendMutableByteArray s payload
(intToCInt off)
(intToCSize len)
mempty
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
threadWaitWrite s
e2 <- S.uninterruptibleSendMutableByteArray s payload
(intToCInt off)
(intToCSize len)
mempty
case e2 of
Left err2 -> pure (Left (errorCode Send err2))
Right sz -> pure (Right sz)
else pure (Left (errorCode Send err1))
Right sz -> pure (Right sz)
internalSend ::
Connection
-> ByteArray
-> Int
-> Int
-> IO (Either SocketException CSize)
internalSend (Connection !s) !payload !off !len = do
debug ("send: about to send chunk on stream socket, offset " ++ show off ++ " and length " ++ show len)
e1 <- S.uninterruptibleSendByteArray s payload
(intToCInt off)
(intToCSize len)
mempty
debug "send: just sent chunk on stream socket"
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug "send: waiting to for write ready on stream socket"
threadWaitWrite s
e2 <- S.uninterruptibleSendByteArray s payload
(intToCInt off)
(intToCSize len)
mempty
case e2 of
Left err2 -> do
debug "send: encountered error after sending chunk on stream socket"
pure (Left (errorCode Send err2))
Right sz -> pure (Right sz)
else pure (Left (errorCode Send err1))
Right sz -> pure (Right sz)
internalReceiveMaximally ::
Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either SocketException Int)
internalReceiveMaximally (Connection !fd) !maxSz !buf !off = do
debug "receive: stream socket about to wait"
threadWaitRead fd
debug ("receive: stream socket is now readable, receiving up to " ++ show maxSz ++ " bytes at offset " ++ show off)
e <- S.uninterruptibleReceiveMutableByteArray fd buf (intToCInt off) (intToCSize maxSz) mempty
debug "receive: finished reading from stream socket"
case e of
Left err -> pure (Left (errorCode Receive err))
Right recvSz -> pure (Right (csizeToInt recvSz))
receiveByteArray ::
Connection
-> Int
-> IO (Either SocketException ByteArray)
receiveByteArray !conn0 !total = do
marr <- PM.newByteArray total
go conn0 marr 0 total
where
go !conn !marr !off !remaining = case compare remaining 0 of
GT -> internalReceiveMaximally conn remaining marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then go conn marr (off + sz) (remaining - sz)
else pure (Left (exception Receive RemoteShutdown))
EQ -> do
arr <- PM.unsafeFreezeByteArray marr
pure (Right arr)
LT -> pure (Left (exception Receive NegativeBytesRequested))
receiveMutableByteArray ::
Connection
-> MutableByteArray RealWorld
-> IO (Either SocketException ())
receiveMutableByteArray !conn0 !marr0 = do
total <- PM.getSizeofMutableByteArray marr0
go conn0 marr0 0 total
where
go !conn !marr !off !remaining = if remaining > 0
then internalReceiveMaximally conn remaining marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then go conn marr (off + sz) (remaining - sz)
else pure (Left (exception Receive RemoteShutdown))
else pure (Right ())
receiveBoundedByteArray ::
Connection
-> Int
-> IO (Either SocketException ByteArray)
receiveBoundedByteArray !conn !total
| total > 0 = do
m <- PM.newByteArray total
internalReceiveMaximally conn total m 0 >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then do
shrinkMutableByteArray m sz
fmap Right (PM.unsafeFreezeByteArray m)
else pure (Left (exception Receive RemoteShutdown))
| total == 0 = pure (Right mempty)
| otherwise = pure (Left (exception Receive NegativeBytesRequested))
endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Endpoint
socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Endpoint
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
errorCode :: Context -> Errno -> SocketException
errorCode func (Errno x) = SocketException func (ErrorCode x)
exception :: Context -> Reason -> SocketException
exception func reason = SocketException func reason
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)