module Z.IO.Network.TCP (
TCPClientConfig(..)
, UVStream
, defaultTCPClientConfig
, initTCPClient
, getTCPSockName
, TCPServerConfig(..)
, defaultTCPServerConfig
, startTCPServer
, getTCPPeerName
, helloWorldWorker
, echoWorker
, setTCPNoDelay
, setTCPKeepAlive
, initTCPStream
) where
import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.IO.Class
import Data.Primitive.PrimArray
import Foreign.Ptr
import Foreign.C
import GHC.Ptr
import Z.IO.Buffered
import Z.IO.Exception
import Z.IO.Network.SocketAddr
import Z.IO.Resource
import Z.IO.UV.FFI
import Z.IO.UV.Manager
import Z.Foreign
import Data.Coerce
data TCPClientConfig = TCPClientConfig
{ tcpClientAddr :: Maybe SocketAddr
, tcpRemoteAddr :: SocketAddr
, tcpClientNoDelay :: Bool
, tcpClientKeepAlive :: CUInt
}
defaultTCPClientConfig :: TCPClientConfig
defaultTCPClientConfig = TCPClientConfig Nothing (SocketAddrInet 8888 inetLoopback) True 30
initTCPClient :: HasCallStack => TCPClientConfig -> Resource UVStream
initTCPClient TCPClientConfig{..} = do
uvm <- liftIO getUVManager
client <- initTCPStream uvm
let hdl = uvsHandle client
liftIO $ do
forM_ tcpClientAddr $ \ tcpClientAddr' ->
withSocketAddrUnsafe tcpClientAddr' $ \ localPtr ->
throwUVIfMinus_ (uv_tcp_bind hdl localPtr 0)
when tcpClientNoDelay . throwUVIfMinus_ $ uv_tcp_nodelay hdl 1
when (tcpClientKeepAlive > 0) . throwUVIfMinus_ $
uv_tcp_keepalive hdl 1 tcpClientKeepAlive
withSocketAddrUnsafe tcpRemoteAddr $ \ targetPtr -> do
void . withUVRequest uvm $ \ _ -> hs_uv_tcp_connect hdl targetPtr
return client
data TCPServerConfig = TCPServerConfig
{ tcpListenAddr :: SocketAddr
, tcpListenBacklog :: Int
, tcpServerWorkerNoDelay :: Bool
, tcpServerWorkerKeepAlive :: CUInt
}
defaultTCPServerConfig :: TCPServerConfig
defaultTCPServerConfig = TCPServerConfig
(SocketAddrInet 8888 inetAny)
128
True
30
helloWorldWorker :: UVStream -> IO ()
helloWorldWorker uvs = writeOutput uvs (Ptr "hello world"#) 11
echoWorker :: UVStream -> IO ()
echoWorker uvs = do
i <- newBufferedInput uvs
o <- newBufferedOutput uvs
forever $ readBuffer i >>= writeBuffer o >> flushBuffer o
startTCPServer :: HasCallStack
=> TCPServerConfig
-> (UVStream -> IO ())
-> IO ()
startTCPServer TCPServerConfig{..} tcpServerWorker = do
let backLog = max tcpListenBacklog 128
serverUVManager <- getUVManager
withResource (initTCPStream serverUVManager) $ \ (UVStream serverHandle serverSlot _ _) -> do
withSocketAddrUnsafe tcpListenAddr $ \ addrPtr -> do
throwUVIfMinus_ (uv_tcp_bind serverHandle addrPtr 0)
bracket
(throwOOMIfNull $ hs_uv_accept_check_alloc serverHandle)
hs_uv_accept_check_close $
\ check -> do
acceptBuf <- newPinnedPrimArray backLog
let acceptBufPtr = coerce (mutablePrimArrayContents acceptBuf :: Ptr UVFD)
withUVManager' serverUVManager $ do
pokeBufferTable serverUVManager serverSlot acceptBufPtr (backLog-1)
throwUVIfMinus_ (hs_uv_listen serverHandle (fromIntegral backLog))
throwUVIfMinus_ $ hs_uv_accept_check_init check
m <- getBlockMVar serverUVManager serverSlot
forever $ do
_ <- takeMVar m
acceptBufCopy <- withUVManager' serverUVManager $ do
_ <- tryTakeMVar m
acceptCountDown <- peekBufferTable serverUVManager serverSlot
pokeBufferTable serverUVManager serverSlot acceptBufPtr (backLog-1)
when (acceptCountDown == -1) (hs_uv_listen_resume serverHandle)
let acceptCount = backLog - 1 - acceptCountDown
acceptBuf' <- newPrimArray acceptCount
copyMutablePrimArray acceptBuf' 0 acceptBuf (acceptCountDown+1) acceptCount
unsafeFreezePrimArray acceptBuf'
forM_ [0..sizeofPrimArray acceptBufCopy-1] $ \ i -> do
let fd = indexPrimArray acceptBufCopy i
if fd < 0
then throwUVIfMinus_ (return fd)
else void . forkBa $ do
uvm <- getUVManager
withResource (initUVStream (\ loop hdl -> do
throwUVIfMinus_ (uv_tcp_init loop hdl)
throwUVIfMinus_ (hs_uv_tcp_open hdl fd)) uvm) $ \ uvs -> do
when tcpServerWorkerNoDelay . throwUVIfMinus_ $
uv_tcp_nodelay (uvsHandle uvs) 1
when (tcpServerWorkerKeepAlive > 0) . throwUVIfMinus_ $
uv_tcp_keepalive (uvsHandle uvs) 1 tcpServerWorkerKeepAlive
tcpServerWorker uvs
initTCPStream :: HasCallStack => UVManager -> Resource UVStream
initTCPStream = initUVStream (\ loop hdl -> throwUVIfMinus_ (uv_tcp_init loop hdl))
setTCPNoDelay :: HasCallStack => UVStream -> Bool -> IO ()
setTCPNoDelay uvs nodelay =
throwUVIfMinus_ (uv_tcp_nodelay (uvsHandle uvs) (if nodelay then 1 else 0))
setTCPKeepAlive :: HasCallStack => UVStream -> CUInt -> IO ()
setTCPKeepAlive uvs delay
| delay > 0 = throwUVIfMinus_ (uv_tcp_keepalive (uvsHandle uvs) 1 delay)
| otherwise = throwUVIfMinus_ (uv_tcp_keepalive (uvsHandle uvs) 0 0)
getTCPSockName :: HasCallStack => UVStream -> IO SocketAddr
getTCPSockName uvs = do
withSocketAddrStorageUnsafe $ \ paddr ->
void $ withPrimUnsafe (fromIntegral sizeOfSocketAddrStorage :: CInt) $ \ plen ->
throwUVIfMinus_ (uv_tcp_getsockname (uvsHandle uvs) paddr plen)
getTCPPeerName :: HasCallStack => UVStream -> IO SocketAddr
getTCPPeerName uvs = do
withSocketAddrStorageUnsafe $ \ paddr ->
void $ withPrimUnsafe (fromIntegral sizeOfSocketAddrStorage :: CInt) $ \ plen ->
throwUVIfMinus_ (uv_tcp_getpeername (uvsHandle uvs) paddr plen)