-- This Source Code Form is subject to the terms of the Mozilla Public -- License, v. 2.0. If a copy of the MPL was not distributed with this -- file, You can obtain one at http://mozilla.org/MPL/2.0/. {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedStrings #-} -- | This module exports some basic functions to wait for some given -- milliseconds on 'Socket's to become available for reading or writing. -- -- On top some higher-level functions to 'connect', 'send' and 'recv' -- with timeouts are provided. -- -- The implementation is Unix-only and uses GHC-specific functionality -- from "GHC.IO.Device" and "GHC.IO.FD". In particular, the basic wait -- functions are just exports of 'GHC.IO.Device.ready'. module Network.Socket.Eager ( Error (..) , Milliseconds (..) , Descriptor , descriptor , connect , connect' , send , send' , recv , recv' -- * basic wait functions , waitRead , waitRead' , waitWrite , waitWrite' ) where import Control.Applicative import Control.Exception import Control.Monad import Data.ByteString.Builder import Data.ByteString.Lazy (ByteString) import Data.Monoid import Data.Typeable import Foreign import Foreign.C import GHC.IO.FD import Network import Network.Socket hiding (connect, recv, send) import Network.Socket.Internal (withSockAddr) import qualified Data.ByteString.Lazy as B import qualified Data.ByteString as BS import qualified Network.Socket.ByteString as NS import qualified Network.Socket.ByteString.Lazy as N import qualified GHC.IO.Device as IODevice data Error = ConnectTimeout | RecvTimeout | SendTimeout | ConnectError | ConnectionClosed deriving (Eq, Show, Typeable) instance Exception Error newtype Milliseconds = Milliseconds { ms :: Int } -- | A @Descriptor@ encapsulates an @GHC.IO.FD.FD@ -- Most of the wait functions operate on this descriptor which -- is implicitly created in functions which only accept a -- 'Socket' as parameter. newtype Descriptor = Descriptor { fd :: FD } descriptor :: Socket -> Descriptor descriptor s = Descriptor $ FD (fdSocket s) 1 waitRead :: Milliseconds -> Socket -> IO Bool waitRead t s = waitRead' t (descriptor s) {-# INLINEABLE waitRead #-} waitRead' :: Milliseconds -> Descriptor -> IO Bool waitRead' t s = IODevice.ready (fd s) False (ms t) {-# INLINEABLE waitRead' #-} waitWrite :: Milliseconds -> Socket -> IO Bool waitWrite t s = waitWrite' t (descriptor s) {-# INLINEABLE waitWrite #-} waitWrite' :: Milliseconds -> Descriptor -> IO Bool waitWrite' t s = IODevice.ready (fd s) True (ms t) {-# INLINEABLE waitWrite' #-} connect :: Milliseconds -> SockAddr -> Socket -> IO () connect t a s = connect' t a (descriptor s) connect' :: Milliseconds -> SockAddr -> Descriptor -> IO () connect' t a d = withSockAddr a loop where loop ptr sz = do r <- c_connect (fdFD (fd d)) ptr (fromIntegral sz) if r == -1 then do e <- getErrno if | e == eINTR -> loop ptr sz | e == eINPROGRESS -> wait | otherwise -> throwIO ConnectError else return () wait = do isReady <- waitWrite' t d unless isReady $ throwIO ConnectTimeout e <- socketError d when (e /= 0) $ throwIO ConnectError send :: Milliseconds -> ByteString -> Socket -> IO () send t bs s = send' t bs (descriptor s) s send' :: Milliseconds -> ByteString -> Descriptor -> Socket -> IO () send' _ "" _ _ = return () send' t !b d s = do isReady <- waitWrite' t d unless isReady $ throwIO SendTimeout n <- N.send s b send' t (B.drop n b) d s recv :: Milliseconds -> Int -> Socket -> IO ByteString recv t n s = recv' t n (descriptor s) s recv' :: Milliseconds -> Int -> Descriptor -> Socket -> IO ByteString recv' _ 0 _ _ = return B.empty recv' t n d s = toLazyByteString <$> go 0 mempty where go k bytes = do isReady <- waitRead' t d unless isReady $ throwIO RecvTimeout a <- NS.recv s (n - k) if BS.null a then throwIO ConnectionClosed else let b = bytes <> byteString a m = BS.length a + k in if m < n then go m b else return b ----------------------------------------------------------------------------- -- Internal: #include socketError :: Descriptor -> IO CInt socketError d = with (fromIntegral $ sizeOf (undefined :: CInt)) $ \plen -> with 0 $ \pval -> do throwErrnoIfMinus1Retry_ "getsockopt" $ c_getsockopt (fdFD (fd d)) sockLevel sockError pval plen peek pval sockLevel :: CInt sockLevel = #const SOL_SOCKET sockError :: CInt sockError = #const SO_ERROR foreign import ccall unsafe "connect" c_connect :: CInt -> Ptr SockAddr -> CInt -> IO CInt foreign import ccall unsafe "getsockopt" c_getsockopt :: CInt -- socket -> CInt -- protocol level -> CInt -- name -> Ptr CInt -- value -> Ptr CInt -- length -> IO CInt