module Network.Socket.Eager
( Error (..)
, Milliseconds (..)
, Descriptor
, descriptor
, connect
, connect'
, send
, send'
, recv
, recv'
, waitRead
, waitRead'
, waitWrite
, waitWrite'
) where
import Control.Applicative
import Control.Exception
import Control.Monad
import Data.ByteString.Builder
import Data.ByteString.Lazy (ByteString)
import Data.Monoid
import Data.Typeable
import Foreign
import Foreign.C
import GHC.IO.FD
import Network
import Network.Socket hiding (connect, recv, send)
import Network.Socket.Internal (withSockAddr)
import qualified Data.ByteString.Lazy as B
import qualified Data.ByteString as BS
import qualified Network.Socket.ByteString as NS
import qualified Network.Socket.ByteString.Lazy as N
import qualified GHC.IO.Device as IODevice
data Error
= ConnectTimeout
| RecvTimeout
| SendTimeout
| ConnectError
| ConnectionClosed
deriving (Eq, Show, Typeable)
instance Exception Error
newtype Milliseconds = Milliseconds { ms :: Int }
newtype Descriptor = Descriptor { fd :: FD }
descriptor :: Socket -> Descriptor
descriptor s = Descriptor $ FD (fdSocket s) 1
waitRead :: Milliseconds -> Socket -> IO Bool
waitRead t s = waitRead' t (descriptor s)
waitRead' :: Milliseconds -> Descriptor -> IO Bool
waitRead' t s = IODevice.ready (fd s) False (ms t)
waitWrite :: Milliseconds -> Socket -> IO Bool
waitWrite t s = waitWrite' t (descriptor s)
waitWrite' :: Milliseconds -> Descriptor -> IO Bool
waitWrite' t s = IODevice.ready (fd s) True (ms t)
connect :: Milliseconds -> SockAddr -> Socket -> IO ()
connect t a s = connect' t a (descriptor s)
connect' :: Milliseconds -> SockAddr -> Descriptor -> IO ()
connect' t a d = withSockAddr a loop
where
loop ptr sz = do
r <- c_connect (fdFD (fd d)) ptr (fromIntegral sz)
if r == 1
then do
e <- getErrno
if | e == eINTR -> loop ptr sz
| e == eINPROGRESS -> wait
| otherwise -> throwIO ConnectError
else return ()
wait = do
isReady <- waitWrite' t d
unless isReady $
throwIO ConnectTimeout
e <- socketError d
when (e /= 0) $
throwIO ConnectError
send :: Milliseconds -> ByteString -> Socket -> IO ()
send t bs s = send' t bs (descriptor s) s
send' :: Milliseconds -> ByteString -> Descriptor -> Socket -> IO ()
send' _ "" _ _ = return ()
send' t !b d s = do
isReady <- waitWrite' t d
unless isReady $
throwIO SendTimeout
n <- N.send s b
send' t (B.drop n b) d s
recv :: Milliseconds -> Int -> Socket -> IO ByteString
recv t n s = recv' t n (descriptor s) s
recv' :: Milliseconds -> Int -> Descriptor -> Socket -> IO ByteString
recv' _ 0 _ _ = return B.empty
recv' t n d s = toLazyByteString <$> go 0 mempty
where
go k bytes = do
isReady <- waitRead' t d
unless isReady $
throwIO RecvTimeout
a <- NS.recv s (n k)
if BS.null a
then throwIO ConnectionClosed
else let b = bytes <> byteString a
m = BS.length a + k
in if m < n then go m b else return b
socketError :: Descriptor -> IO CInt
socketError d =
with (fromIntegral $ sizeOf (undefined :: CInt)) $ \plen ->
with 0 $ \pval -> do
throwErrnoIfMinus1Retry_ "getsockopt" $
c_getsockopt (fdFD (fd d)) sockLevel sockError pval plen
peek pval
sockLevel :: CInt
sockLevel = 1
sockError :: CInt
sockError = 4
foreign import ccall unsafe "connect"
c_connect :: CInt -> Ptr SockAddr -> CInt -> IO CInt
foreign import ccall unsafe "getsockopt"
c_getsockopt :: CInt
-> CInt
-> CInt
-> Ptr CInt
-> Ptr CInt
-> IO CInt