module Hans.Layer.Tcp.Monad where
import Hans.Address.IP4
import Hans.Channel
import Hans.Layer
import Hans.Layer.IP4
import Hans.Layer.Tcp.Types
import Hans.Layer.Tcp.Window
import Hans.Message.Ip4
import Hans.Message.Tcp
import Control.Applicative(Applicative(..))
import Control.Monad (MonadPlus(..),guard,when)
import Data.Time.Clock.POSIX (POSIXTime)
import MonadLib (get,set)
import qualified Data.ByteString.Lazy as L
import qualified Data.Map.Strict as Map
import qualified Data.Sequence as Seq
import qualified Data.Traversable as T
type TcpHandle = Channel (Tcp ())
type Tcp = Layer TcpState
data TcpState = TcpState
{ tcpSelf :: TcpHandle
, tcpIP4 :: IP4Handle
, tcpHost :: Host
}
emptyTcpState :: TcpHandle -> IP4Handle -> POSIXTime -> TcpState
emptyTcpState tcp ip4 start = TcpState
{ tcpSelf = tcp
, tcpIP4 = ip4
, tcpHost = emptyHost start
}
self :: Tcp TcpHandle
self = tcpSelf `fmap` get
ip4Handle :: Tcp IP4Handle
ip4Handle = tcpIP4 `fmap` get
getHost :: Tcp Host
getHost = tcpHost `fmap` get
setHost :: Host -> Tcp ()
setHost host = do
rw <- get
set $! rw { tcpHost = host }
modifyHost :: (Host -> Host) -> Tcp ()
modifyHost f = do
host <- getHost
setHost $! f host
resetTimeWait2MSL :: SocketId -> Tcp ()
resetTimeWait2MSL sid = modifyHost $ \ host ->
host { hostTimeWaits = Map.adjust twReset2MSL sid (hostTimeWaits host) }
getTimeWait :: IP4 -> TcpHeader -> Tcp (Maybe (SocketId,TimeWaitSock))
getTimeWait remote hdr =
do host <- getHost
let sid = incomingSocketId remote hdr
return $ do tw <- Map.lookup sid (hostTimeWaits host)
return (sid,tw)
removeTimeWait :: SocketId -> Tcp ()
removeTimeWait sid =
modifyHost $ \ host ->
host { hostTimeWaits = Map.delete sid (hostTimeWaits host) }
getConnections :: Tcp Connections
getConnections = hostConnections `fmap` getHost
setConnections :: Connections -> Tcp ()
setConnections cons = modifyHost (\host -> host { hostConnections = cons })
lookupConnection :: SocketId -> Tcp (Maybe TcpSocket)
lookupConnection sid = do
cons <- getConnections
return (Map.lookup sid cons)
getConnection :: SocketId -> Tcp TcpSocket
getConnection sid = do
cs <- getConnections
case Map.lookup sid cs of
Just tcp -> return tcp
Nothing -> mzero
setConnection :: SocketId -> TcpSocket -> Tcp ()
setConnection ident con
| tcpState con == TimeWait =
modifyHost $ \ host ->
host { hostTimeWaits = addTimeWait con (hostTimeWaits host)
, hostConnections = Map.delete ident (hostConnections host)
}
| otherwise =
do cons <- getConnections
setConnections (Map.insert ident con cons)
addConnection :: SocketId -> TcpSocket -> Tcp ()
addConnection = setConnection
modifyConnection :: SocketId -> (TcpSocket -> TcpSocket) -> Tcp ()
modifyConnection sid k = do
cons <- getConnections
setConnections (Map.adjust k sid cons)
remConnection :: SocketId -> Tcp ()
remConnection sid = do
cons <- getConnections
setConnections (Map.delete sid cons)
sendSegment :: IP4 -> TcpHeader -> L.ByteString -> Tcp ()
sendSegment dst hdr body = do
ip4 <- ip4Handle
output $ withIP4Source ip4 dst $ \ src -> do
let ip4Hdr = emptyIP4Header
{ ip4DestAddr = dst
, ip4Protocol = tcpProtocol
, ip4DontFragment = False
}
pkt = renderWithTcpChecksumIP4 src dst hdr body
sendIP4Packet ip4 ip4Hdr pkt
initialSeqNum :: Tcp TcpSeqNum
initialSeqNum = hostInitialSeqNum `fmap` getHost
addInitialSeqNum :: TcpSeqNum -> Tcp ()
addInitialSeqNum sn =
modifyHost (\host -> host { hostInitialSeqNum = hostInitialSeqNum host + sn })
allocatePort :: Tcp TcpPort
allocatePort = do
host <- getHost
case takePort host of
Just (p,host') -> do
setHost host'
return p
Nothing -> mzero
closePort :: TcpPort -> Tcp ()
closePort port = modifyHost (releasePort port)
newtype Sock a = Sock
{ unSock :: forall r. TcpSocket -> Escape r -> Next a r
-> Tcp (TcpSocket,Maybe r)
}
type Escape r = TcpSocket -> Tcp (TcpSocket,Maybe r)
type Next a r = TcpSocket -> a -> Tcp (TcpSocket, Maybe r)
instance Functor Sock where
fmap f m = Sock $ \s x k -> unSock m s x
$ \s' a -> k s' (f a)
instance Applicative Sock where
pure x = Sock $ \ s _ k -> k s x
f <*> a = Sock $ \ s x k -> unSock f s x
$ \ s' g -> unSock a s' x
$ \ s'' b -> k s'' (g b)
instance Monad Sock where
return = pure
m >>= f = Sock $ \ s x k -> unSock m s x
$ \ s' a -> unSock (f a) s' x k
inTcp :: Tcp a -> Sock a
inTcp m = Sock $ \ s _ k -> do a <- m
k s a
escape :: Sock a
escape = Sock $ \ s x _ -> x s
runSock_ :: TcpSocket -> Sock a -> Tcp ()
runSock_ tcp sm =
do _ <- runSock tcp sm
return ()
runSock' :: TcpSocket -> Sock a -> Tcp TcpSocket
runSock' tcp sm =
do (tcp',_) <- runSock tcp sm
return tcp'
runSock :: TcpSocket -> Sock a -> Tcp (TcpSocket,Maybe a)
runSock tcp sm = do
now <- time
let steppedTcp = tcp { tcpTimestamp = stepTimestamp now `fmap` tcpTimestamp tcp }
r@(tcp',_) <- unSock sm steppedTcp escapeK nextK
addConnection (tcpSocketId tcp') tcp'
return r
where
escapeK s = return (s,Nothing)
nextK s a = return (s,Just a)
eachConnection :: Sock () -> Tcp ()
eachConnection m =
setConnections . removeClosed =<< T.mapM sandbox =<< getConnections
where
sandbox tcp = runSock' tcp m `mplus` return tcp
withConnection :: IP4 -> TcpHeader -> Sock a -> Tcp ()
withConnection remote hdr m = withConnection' remote hdr m mzero
withConnection' :: IP4 -> TcpHeader -> Sock a -> Tcp () -> Tcp ()
withConnection' remote hdr m noConn = do
cs <- getConnections
case Map.lookup estId cs `mplus` Map.lookup listenId cs of
Just con -> runSock_ con m
Nothing -> noConn
where
estId = incomingSocketId remote hdr
listenId = listenSocketId (tcpDestPort hdr)
listeningConnection :: SocketId -> Sock a -> Tcp (Maybe a)
listeningConnection sid m = do
tcp <- getConnection sid
guard (tcpState tcp == Listen && isAccepting tcp)
(_,mb) <- runSock tcp m
return mb
establishedConnection :: SocketId -> Sock a -> Tcp ()
establishedConnection sid m = do
tcp <- getConnection sid
runSock_ tcp m
getParent :: Sock (Maybe SocketId)
getParent = tcpParent `fmap` getTcpSocket
inParent :: Sock a -> Sock (Maybe a)
inParent m = do
mbPid <- getParent
case mbPid of
Just pid -> inTcp $ do p <- getConnection pid
(_,mb) <- runSock p m
return mb
Nothing -> return Nothing
withChild :: TcpSocket -> Sock a -> Sock (Maybe a)
withChild tcp m = inTcp $ do (_,mb) <- runSock tcp m
return mb
getTcpSocket :: Sock TcpSocket
getTcpSocket = Sock (\s _ k -> k s s)
setTcpSocket :: TcpSocket -> Sock ()
setTcpSocket tcp = Sock (\ _ _ k -> k tcp ())
getTcpTimers :: Sock TcpTimers
getTcpTimers = tcpTimers `fmap` getTcpSocket
modifyTcpSocket :: (TcpSocket -> (a,TcpSocket)) -> Sock a
modifyTcpSocket f = Sock $ \ s _ k -> let (a,s') = f s
in k s' a
modifyTcpSocket_ :: (TcpSocket -> TcpSocket) -> Sock ()
modifyTcpSocket_ k = modifyTcpSocket (\tcp -> ((), k tcp))
modifyTcpTimers :: (TcpTimers -> (a,TcpTimers)) -> Sock a
modifyTcpTimers k = modifyTcpSocket $ \ tcp ->
let (a,t') = k (tcpTimers tcp)
in (a,tcp { tcpTimers = t' })
modifyTcpTimers_ :: (TcpTimers -> TcpTimers) -> Sock ()
modifyTcpTimers_ k = modifyTcpTimers (\t -> ((), k t))
setState :: ConnState -> Sock ()
setState state = modifyTcpSocket_ (\tcp -> tcp { tcpState = state })
getState :: Sock ConnState
getState = tcpState `fmap` getTcpSocket
whenState :: ConnState -> Sock () -> Sock ()
whenState state body = do
curState <- getState
when (state == curState) body
whenStates :: [ConnState] -> Sock () -> Sock ()
whenStates states body = do
curState <- getState
when (curState `elem` states) body
pushAcceptor :: Acceptor -> Sock ()
pushAcceptor k = modifyTcpSocket_ $ \ tcp -> tcp
{ tcpAcceptors = tcpAcceptors tcp Seq.|> k
}
popAcceptor :: Sock (Maybe Acceptor)
popAcceptor = do
tcp <- getTcpSocket
case Seq.viewl (tcpAcceptors tcp) of
a Seq.:< as -> do setTcpSocket $! tcp { tcpAcceptors = as }
return (Just a)
Seq.EmptyL -> return Nothing
notify :: Bool -> Sock ()
notify success = do
mbNotify <- modifyTcpSocket $ \ tcp ->
let tcp' = tcp { tcpNotify = Nothing }
in if success
then (tcpNotify tcp, tcp')
else (tcpNotify tcp, tcp' { tcpUserClosed = True })
case mbNotify of
Just f -> outputS (f success)
Nothing -> return ()
outputS :: IO () -> Sock ()
outputS = inTcp . output
advanceRcvNxt :: TcpSeqNum -> Sock ()
advanceRcvNxt n =
modifyTcpSocket_ (\tcp -> tcp { tcpIn = addRcvNxt n (tcpIn tcp) })
advanceSndNxt :: TcpSeqNum -> Sock ()
advanceSndNxt n =
modifyTcpSocket_ (\tcp -> tcp { tcpSndNxt = tcpSndNxt tcp + n })
remoteHost :: Sock IP4
remoteHost = (sidRemoteHost . tcpSocketId) `fmap` getTcpSocket
tcpOutput :: TcpHeader -> L.ByteString -> Sock ()
tcpOutput hdr body = do
dst <- remoteHost
inTcp (sendSegment dst hdr body)
shutdown :: Sock ()
shutdown = do
finalize <- modifyTcpSocket $ \ tcp ->
let (wOut,bufOut) = shutdownWaiting (tcpOutBuffer tcp)
(wIn,bufIn) = shutdownWaiting (tcpInBuffer tcp)
in (wOut >> wIn,tcp { tcpOut = clearRetransmit (tcpOut tcp)
, tcpOutBuffer = bufOut
, tcpInBuffer = bufIn
})
outputS finalize
closeSocket :: Sock ()
closeSocket = do
shutdown
setState Closed