{-# language BangPatterns #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language DuplicateRecordFields #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
module Socket.Datagram.IPv4.Undestined
(
Socket(..)
, Endpoint(..)
, withSocket
, send
, receive
, receiveMutableByteArraySlice_
, SocketException(..)
, Context(..)
, Reason(..)
) where
import Control.Concurrent (threadWaitWrite,threadWaitRead)
import Control.Exception (mask,onException)
import Data.Primitive (ByteArray,MutableByteArray(..))
import Data.Word (Word16)
import Foreign.C.Error (Errno(..),eWOULDBLOCK,eAGAIN)
import Foreign.C.Types (CInt,CSize)
import GHC.Exts (Int(I#),RealWorld,shrinkMutableByteArray#)
import Net.Types (IPv4(..))
import Socket (SocketException(..),Context(..),Reason(..))
import Socket.Debug (debug)
import Socket.IPv4 (Endpoint(..))
import System.Posix.Types (Fd)
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Linux.Socket as L
import qualified Posix.Socket as S
newtype Socket = Socket Fd
deriving (Eq,Ord)
withSocket ::
Endpoint
-> (Socket -> Word16 -> IO a)
-> IO (Either SocketException a)
withSocket endpoint@Endpoint{port = specifiedPort} f = mask $ \restore -> do
debug ("withSocket: opening socket " ++ show endpoint)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.datagram)
S.defaultProtocol
debug ("withSocket: opened socket " ++ show endpoint)
case e1 of
Left err -> pure (Left (errorCode Open err))
Right fd -> do
e2 <- S.uninterruptibleBind fd
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint))
debug ("withSocket: requested binding for " ++ show endpoint)
case e2 of
Left err -> do
S.uninterruptibleErrorlessClose fd
pure (Left (errorCode Bind err))
Right _ -> do
eactualPort <- if specifiedPort == 0
then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case
Left err -> do
S.uninterruptibleErrorlessClose fd
pure (Left (errorCode GetName err))
Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just S.SocketAddressInternet{port = actualPort} -> do
let cleanPort = S.networkToHostShort actualPort
debug ("withSocket: successfully bound " ++ show endpoint ++ " and got port " ++ show cleanPort)
pure (Right cleanPort)
Nothing -> do
S.uninterruptibleErrorlessClose fd
pure (Left (exception GetName SocketAddressFamily))
else do
S.uninterruptibleErrorlessClose fd
pure (Left (exception GetName SocketAddressSize))
else pure (Right specifiedPort)
case eactualPort of
Left err -> pure (Left err)
Right actualPort -> do
a <- onException (restore (f (Socket fd) actualPort)) (S.uninterruptibleErrorlessClose fd)
S.uninterruptibleClose fd >>= \case
Left err -> pure (Left (errorCode Close err))
Right _ -> pure (Right a)
send ::
Socket
-> Endpoint
-> ByteArray
-> Int
-> Int
-> IO (Either SocketException ())
send (Socket !s) !remote !payload !off !len = do
debug ("send: about to send to " ++ show remote)
e1 <- S.uninterruptibleSendToByteArray s payload
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet remote))
debug ("send: just sent to " ++ show remote)
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug ("send: waiting to for write ready to send to " ++ show remote)
threadWaitWrite s
e2 <- S.uninterruptibleSendToByteArray s payload
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet remote))
case e2 of
Left err2 -> do
debug ("send: encountered error after sending")
pure (Left (errorCode Send err2))
Right sz -> if csizeToInt sz == len
then pure (Right ())
else pure (Left (exception Send (MessageTruncated (csizeToInt sz) len)))
else pure (Left (errorCode Send err1))
Right sz -> if csizeToInt sz == len
then do
debug ("send: success")
pure (Right ())
else pure (Left (exception Send (MessageTruncated (csizeToInt sz) len)))
receive ::
Socket
-> Int
-> IO (Either SocketException (Endpoint,ByteArray))
receive (Socket !fd) !maxSz = do
debug "receive: about to wait"
threadWaitRead fd
debug "receive: socket is now readable"
marr <- PM.newByteArray maxSz
e <- S.uninterruptibleReceiveFromMutableByteArray fd marr 0
(intToCSize maxSz) (L.truncate) S.sizeofSocketAddressInternet
debug "receive: finished reading from socket"
case e of
Left err -> pure (Left (errorCode Receive err))
Right (sockAddrRequiredSz,sockAddr,recvSz) -> if csizeToInt recvSz <= maxSz
then if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
shrinkMutableByteArray marr (csizeToInt recvSz)
arr <- PM.unsafeFreezeByteArray marr
pure $ Right
( socketAddressInternetToEndpoint sockAddrInet
, arr
)
Nothing -> pure (Left (exception Receive SocketAddressFamily))
else pure (Left (exception Receive SocketAddressSize))
else pure (Left (exception Receive (MessageTruncated maxSz (csizeToInt recvSz))))
receiveMutableByteArraySlice_ ::
Socket
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either SocketException Int)
receiveMutableByteArraySlice_ (Socket !fd) !buf !off !maxSz = do
threadWaitRead fd
e <- S.uninterruptibleReceiveFromMutableByteArray_ fd buf (intToCInt off) (intToCSize maxSz) (L.truncate)
case e of
Left err -> pure (Left (errorCode Receive err))
Right recvSz -> if csizeToInt recvSz <= maxSz
then pure (Right (csizeToInt recvSz))
else pure (Left (exception Receive (MessageTruncated maxSz (csizeToInt recvSz))))
endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Endpoint
socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Endpoint
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
errorCode :: Context -> Errno -> SocketException
errorCode func (Errno x) = SocketException func (ErrorCode x)
exception :: Context -> Reason -> SocketException
exception func reason = SocketException func reason
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)