module Hans.Layer.Tcp.Socket (
Socket()
, sockRemoteHost
, sockRemotePort
, sockLocalPort
, connect, ConnectError(..)
, listen, ListenError(..)
, accept, AcceptError(..)
, close, CloseError(..)
, sendBytes, canSend
, recvBytes, canRecv
) where
import Hans.Address.IP4
import Hans.Channel
import Hans.Layer
import Hans.Layer.Tcp.Handlers
import Hans.Layer.Tcp.Messages
import Hans.Layer.Tcp.Monad
import Hans.Layer.Tcp.Types
import Hans.Layer.Tcp.Window
import Hans.Message.Tcp
import Control.Concurrent (MVar,newEmptyMVar,takeMVar,putMVar)
import Control.Exception (Exception,throwIO)
import Control.Monad (mplus)
import Data.Int (Int64)
import Data.Maybe (isJust)
import Data.Typeable (Typeable)
import qualified Data.ByteString.Lazy as L
data Socket = Socket
{ sockHandle :: !TcpHandle
, sockId :: !SocketId
}
sockRemoteHost :: Socket -> IP4
sockRemoteHost = sidRemoteHost . sockId
sockRemotePort :: Socket -> TcpPort
sockRemotePort = sidRemotePort . sockId
sockLocalPort :: Socket -> TcpPort
sockLocalPort = sidLocalPort . sockId
data SocketGenericError = SocketGenericError
deriving (Show,Typeable)
instance Exception SocketGenericError
blockResult :: TcpHandle -> (MVar (SocketResult a) -> Tcp ()) -> IO a
blockResult tcp action = do
res <- newEmptyMVar
let unblock = output (putMVar res (socketError SocketGenericError))
send tcp (action res `mplus` unblock)
sockRes <- takeMVar res
case sockRes of
SocketResult a -> return a
SocketError e -> throwIO e
data ConnectError = ConnectionRefused
deriving (Show,Typeable)
instance Exception ConnectError
connect :: TcpHandle -> IP4 -> TcpPort -> Maybe TcpPort -> IO Socket
connect tcp remote remotePort mbLocal = blockResult tcp $ \ res -> do
localPort <- maybe allocatePort return mbLocal
isn <- initialSeqNum
now <- time
let sid = SocketId
{ sidLocalPort = localPort
, sidRemoteHost = remote
, sidRemotePort = remotePort
}
sock = (emptyTcpSocket 0 0)
{ tcpSocketId = sid
, tcpNotify = Just $ \ success -> putMVar res $! if success
then SocketResult Socket
{ sockHandle = tcp
, sockId = sid
}
else socketError ConnectionRefused
, tcpState = Listen
, tcpSndNxt = isn
, tcpSndUna = isn
, tcpTimestamp = Just (emptyTimestamp now)
}
runSock_ sock $ do
syn
setState SynSent
data ListenError = ListenError
deriving (Show,Typeable)
instance Exception ListenError
listen :: TcpHandle -> IP4 -> TcpPort -> IO Socket
listen tcp _src port = blockResult tcp $ \ res -> do
let sid = listenSocketId port
mb <- lookupConnection sid
case mb of
Nothing -> do
now <- time
let con = (emptyTcpSocket 0 0)
{ tcpSocketId = sid
, tcpState = Listen
, tcpTimestamp = Just (emptyTimestamp now)
}
addConnection sid con
output $ putMVar res $ SocketResult Socket
{ sockHandle = tcp
, sockId = sid
}
Just _ -> output (putMVar res (socketError ListenError))
data AcceptError = AcceptError
deriving (Show,Typeable)
instance Exception AcceptError
accept :: Socket -> IO Socket
accept sock = blockResult (sockHandle sock) $ \ res ->
establishedConnection (sockId sock) $ do
state <- getState
case state of
Listen -> pushAcceptor $ \ sid -> putMVar res $ SocketResult $ Socket
{ sockHandle = sockHandle sock
, sockId = sid
}
_ -> outputS (putMVar res (socketError AcceptError))
data CloseError = CloseError
deriving (Show,Typeable)
instance Exception CloseError
close :: Socket -> IO ()
close sock = blockResult (sockHandle sock) $ \ res -> do
let unblock r = output (putMVar res r)
ok = inTcp (unblock (SocketResult ()))
closeError = inTcp (unblock (socketError CloseError))
connected = establishedConnection (sockId sock) $ do
userClose
state <- getState
case state of
Established ->
do finAck
setState FinWait1
ok
CloseWait ->
do finAck
setState LastAck
ok
SynSent ->
do closeSocket
ok
SynReceived ->
do finAck
setState FinWait1
ok
Listen ->
do setState Closed
closeSocket
ok
FinWait1 -> ok
FinWait2 -> ok
_ -> closeError
connected `mplus` unblock (socketError CloseError)
userClose :: Sock ()
userClose = modifyTcpSocket_ (\tcp -> tcp { tcpUserClosed = True })
canSend :: Socket -> IO Bool
canSend sock =
blockResult (sockHandle sock) $ \ res ->
establishedConnection (sockId sock) $
do tcp <- getTcpSocket
let avail = tcpState tcp == Established && not (isFull (tcpOutBuffer tcp))
outputS (putMVar res (SocketResult avail))
sendBytes :: Socket -> L.ByteString -> IO Int64
sendBytes sock bytes = blockResult (sockHandle sock) $ \ res ->
let result len = putMVar res (SocketResult len)
performSend = establishedConnection (sockId sock) $ do
let wakeup continue
| continue = send (sockHandle sock) performSend
| otherwise = result 0
mbWritten <- modifyTcpSocket (outputBytes bytes wakeup)
case mbWritten of
Just len -> outputS (result len)
Nothing -> return ()
outputSegments
in performSend
outputBytes :: L.ByteString -> Wakeup -> TcpSocket -> (Maybe Int64, TcpSocket)
outputBytes bytes wakeup tcp
| tcpState tcp == Established = ok
| tcpState tcp == CloseWait = ok
| tcpState tcp `elem` [ Closing, LastAck
, FinWait1, FinWait2
, TimeWait ] = bad
| otherwise = bad
where
(mbWritten,bufOut) = writeBytes bytes wakeup (tcpOutBuffer tcp)
ok = (mbWritten, tcp { tcpOutBuffer = bufOut })
bad = (Just 0, tcp)
canRecv :: Socket -> IO Bool
canRecv sock =
blockResult (sockHandle sock) $ \ res ->
establishedConnection (sockId sock) $
do tcp <- getTcpSocket
let avail = tcpState tcp == Established && not (isEmpty (tcpInBuffer tcp))
outputS (putMVar res (SocketResult avail))
recvBytes :: Socket -> Int64 -> IO L.ByteString
recvBytes sock len = blockResult (sockHandle sock) $ \ res ->
let result bytes = putMVar res (SocketResult bytes)
performRecv = establishedConnection (sockId sock) $ do
let wakeup continue
| continue = send (sockHandle sock) performRecv
| otherwise = result L.empty
mbRead <- modifyTcpSocket (inputBytes len wakeup)
case mbRead of
Just bytes -> outputS (result bytes)
Nothing -> return ()
in performRecv
inputBytes :: Int64 -> Wakeup -> TcpSocket -> (Maybe L.ByteString, TcpSocket)
inputBytes len wakeup tcp
| tcpState tcp == Established = ok
| tcpState tcp == CloseWait = if isJust mbRead
then ok
else bad
| tcpState tcp `elem` [ Closing, LastAck, TimeWait ] = bad
| otherwise = ok
where
(mbRead,bufIn) = readBytes len wakeup (tcpInBuffer tcp)
ok = (mbRead, tcp { tcpInBuffer = bufIn })
bad = (Just L.empty, tcp)