{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
#include "HsNetDef.h"
module Network.Socket.Syscall where
import Foreign.Marshal.Utils (with)
import qualified Control.Exception as E
#if defined(mingw32_HOST_OS)
import Foreign (FunPtr)
import GHC.Conc (asyncDoProc)
#else
import Foreign.C.Error (getErrno, eINTR, eINPROGRESS)
import GHC.Conc (threadWaitWrite)
#endif
#ifdef HAVE_ADVANCED_SOCKET_FLAGS
import Network.Socket.Cbits
#else
import Network.Socket.Fcntl
#endif
import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Options
import Network.Socket.Types
socket :: Family
-> SocketType
-> ProtocolNumber
-> IO Socket
socket family stype protocol = E.bracketOnError create c_close $ \fd -> do
setNonBlock fd
s <- mkSocket fd
unsetIPv6Only s
return s
where
create = do
c_stype <- modifyFlag <$> packSocketTypeOrThrow "socket" stype
throwSocketErrorIfMinus1Retry "Network.Socket.socket" $
c_socket (packFamily family) c_stype protocol
#ifdef HAVE_ADVANCED_SOCKET_FLAGS
modifyFlag c_stype = c_stype .|. sockNonBlock
#else
modifyFlag c_stype = c_stype
#endif
#ifdef HAVE_ADVANCED_SOCKET_FLAGS
setNonBlock _ = return ()
#else
setNonBlock fd = setNonBlockIfNeeded fd
#endif
#if HAVE_DECL_IPV6_V6ONLY
unsetIPv6Only s = when (family == AF_INET6 && stype `elem` [Stream, Datagram]) $
# if defined(mingw32_HOST_OS)
E.catch (setSocketOption s IPv6Only 0) $ (\(_ :: E.IOException) -> return ())
# elif defined(__OpenBSD__)
return ()
# else
setSocketOption s IPv6Only 0
# endif
#else
unsetIPv6Only _ = return ()
#endif
bind :: SocketAddress sa => Socket -> sa -> IO ()
bind s sa = withSocketAddress sa $ \p_sa siz -> void $ withFdSocket s $ \fd -> do
let sz = fromIntegral siz
throwSocketErrorIfMinus1Retry "Network.Socket.bind" $ c_bind fd p_sa sz
connect :: SocketAddress sa => Socket -> sa -> IO ()
connect s sa = withSocketsDo $ withSocketAddress sa $ \p_sa sz ->
connectLoop s p_sa (fromIntegral sz)
connectLoop :: SocketAddress sa => Socket -> Ptr sa -> CInt -> IO ()
connectLoop s p_sa sz = withFdSocket s $ \fd -> loop fd
where
errLoc = "Network.Socket.connect: " ++ show s
loop fd = do
r <- c_connect fd p_sa sz
when (r == -1) $ do
#if defined(mingw32_HOST_OS)
throwSocketError errLoc
#else
err <- getErrno
case () of
_ | err == eINTR -> loop fd
_ | err == eINPROGRESS -> connectBlocked
_otherwise -> throwSocketError errLoc
connectBlocked = do
withFdSocket s $ threadWaitWrite . fromIntegral
err <- getSocketOption s SoError
when (err /= 0) $ throwSocketErrorCode errLoc (fromIntegral err)
#endif
listen :: Socket -> Int -> IO ()
listen s backlog = withFdSocket s $ \fd -> do
throwSocketErrorIfMinus1Retry_ "Network.Socket.listen" $
c_listen fd $ fromIntegral backlog
accept :: SocketAddress sa => Socket -> IO (Socket, sa)
accept listing_sock = withNewSocketAddress $ \new_sa sz ->
withFdSocket listing_sock $ \listing_fd -> do
new_sock <- callAccept listing_fd new_sa sz >>= mkSocket
new_addr <- peekSocketAddress new_sa
return (new_sock, new_addr)
where
#if defined(mingw32_HOST_OS)
callAccept fd sa sz
| threaded = with (fromIntegral sz) $ \ ptr_len ->
throwSocketErrorIfMinus1Retry "Network.Socket.accept" $
c_accept_safe fd sa ptr_len
| otherwise = do
paramData <- c_newAcceptParams fd (fromIntegral sz) sa
rc <- asyncDoProc c_acceptDoProc paramData
new_fd <- c_acceptNewSock paramData
c_free paramData
when (rc /= 0) $
throwSocketErrorCode "Network.Socket.accept" (fromIntegral rc)
return new_fd
#else
callAccept fd sa sz = with (fromIntegral sz) $ \ ptr_len -> do
# ifdef HAVE_ADVANCED_SOCKET_FLAGS
throwSocketErrorWaitRead listing_sock "Network.Socket.accept"
(c_accept4 fd sa ptr_len (sockNonBlock .|. sockCloexec))
# else
new_fd <- throwSocketErrorWaitRead listing_sock "Network.Socket.accept"
(c_accept fd sa ptr_len)
setNonBlockIfNeeded new_fd
setCloseOnExecIfNeeded new_fd
return new_fd
# endif /* HAVE_ADVANCED_SOCKET_FLAGS */
#endif
foreign import CALLCONV unsafe "socket"
c_socket :: CInt -> CInt -> CInt -> IO CInt
foreign import CALLCONV unsafe "bind"
c_bind :: CInt -> Ptr sa -> CInt -> IO CInt
foreign import CALLCONV SAFE_ON_WIN "connect"
c_connect :: CInt -> Ptr sa -> CInt -> IO CInt
foreign import CALLCONV unsafe "listen"
c_listen :: CInt -> CInt -> IO CInt
#ifdef HAVE_ADVANCED_SOCKET_FLAGS
foreign import CALLCONV unsafe "accept4"
c_accept4 :: CInt -> Ptr sa -> Ptr CInt -> CInt -> IO CInt
#else
foreign import CALLCONV unsafe "accept"
c_accept :: CInt -> Ptr sa -> Ptr CInt -> IO CInt
#endif
#if defined(mingw32_HOST_OS)
foreign import CALLCONV safe "accept"
c_accept_safe :: CInt -> Ptr sa -> Ptr CInt -> IO CInt
foreign import ccall unsafe "rtsSupportsBoundThreads"
threaded :: Bool
foreign import ccall unsafe "HsNet.h acceptNewSock"
c_acceptNewSock :: Ptr () -> IO CInt
foreign import ccall unsafe "HsNet.h newAcceptParams"
c_newAcceptParams :: CInt -> CInt -> Ptr a -> IO (Ptr ())
foreign import ccall unsafe "HsNet.h &acceptDoProc"
c_acceptDoProc :: FunPtr (Ptr () -> IO Int)
foreign import ccall unsafe "free"
c_free:: Ptr a -> IO ()
#endif