module Hans.Layer.Tcp.Handlers (
handleIncomingTcp
, outputSegments
) where
import Hans.Address.IP4
import Hans.Layer
import Hans.Layer.Tcp.Messages
import Hans.Layer.Tcp.Monad
import Hans.Layer.Tcp.Timers
import Hans.Layer.Tcp.Types
import Hans.Layer.Tcp.Window
import Hans.Message.Ip4
import Hans.Message.Tcp
import Control.Monad (guard,when,unless,join)
import Data.Bits (bit)
import Data.Int (Int64)
import Data.Maybe (fromMaybe,isJust,isNothing)
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import qualified Data.Foldable as F
handleIncomingTcp :: IP4Header -> S.ByteString -> Tcp ()
handleIncomingTcp ip4 bytes = do
let src = ip4SourceAddr ip4
dst = ip4DestAddr ip4
guard (validateTcpChecksumIP4 src dst bytes)
(hdr,body) <- liftRight (parseTcpPacket bytes)
withConnection' src hdr (segmentArrives src hdr body)
$ timeWaitConnection src hdr
$ noConnection src hdr body
noConnection :: IP4 -> TcpHeader -> S.ByteString -> Tcp ()
noConnection src hdr @ TcpHeader { .. } body =
do if | tcpRst -> return ()
| tcpAck -> sendSegment src (mkRst hdr) L.empty
| otherwise -> sendSegment src (mkRstAck hdr (S.length body)) L.empty
finish
timeWaitConnection :: IP4 -> TcpHeader -> Tcp () -> Tcp ()
timeWaitConnection src hdr noTimeWait =
do mb <- getTimeWait src hdr
case mb of
Just (sid,tcp) -> handleTimeWait sid tcp hdr
Nothing -> noTimeWait
handleTimeWait :: SocketId -> TimeWaitSock -> TcpHeader -> Tcp ()
handleTimeWait sid TimeWaitSock { .. } TcpHeader { .. }
| tcpRst || tcpSyn = removeTimeWait sid
| tcpAck = do resetTimeWait2MSL sid
let addTimestamp
| Just ts <- twTimestamp = setTcpOption (mkTimestamp ts)
| otherwise = id
hdr = addTimestamp emptyTcpHeader
{ tcpDestPort = tcpSourcePort
, tcpSourcePort = tcpDestPort
, tcpSeqNum = twSeqNum
, tcpAckNum = tcpSeqNum + 1
, tcpAck = True
}
when tcpFin (sendSegment (sidRemoteHost sid) hdr L.empty)
| otherwise = return ()
discardAndReturn :: Sock a
discardAndReturn = do outputSegments
escape
done :: Sock a
done = do outputSegments
escape
segmentArrives :: IP4 -> TcpHeader -> S.ByteString -> Sock ()
segmentArrives src hdr body =
do whenState Closed (inTcp (noConnection src hdr body))
whenState Listen $
do when (tcpAck hdr) (rst hdr)
when (tcpSyn hdr) $ do child <- createConnection src hdr
_ <- withChild child synAck
return ()
done
shouldDrop <- modifyTcpSocket (updateTimestamp hdr)
when shouldDrop discardAndReturn
whenState SynSent $
do tcp <- getTcpSocket
when (tcpAck hdr) $
do when (tcpAckNum hdr <= tcpIss tcp ||
tcpAckNum hdr > tcpSndNxt tcp) $
do unless (tcpRst hdr) (rst hdr)
discardAndReturn
unless (tcpSyn hdr || tcpRst hdr) discardAndReturn
let accAcceptable = tcpSndUna tcp <= tcpAckNum hdr &&
tcpAckNum hdr <= tcpSndNxt tcp
when (tcpRst hdr) $
do when accAcceptable $ do notify False
closeSocket
discardAndReturn
when (tcpSyn hdr) $
do advanceRcvNxt 1
modifyTcpSocket_ $ \ sock -> sock
{ tcpOutMSS = fromMaybe (tcpInMSS sock) (getMSS hdr)
, tcpOut = setSndWind (tcpWindow hdr)
$ setSndWindScale (windowScale hdr)
$ clearRetransmit
$ tcpOut sock
, tcpIn = emptyLocalWindow (tcpSeqNum hdr) 14600 0
, tcpSack = sackSupported hdr
, tcpWindowScale = isJust (findTcpOption OptTagWindowScaling hdr)
}
when (tcpAck hdr) $
do handleAck hdr
TcpSocket { .. } <- getTcpSocket
if tcpSndUna > tcpIss
then do ack
establishConnection
notify True
when (tcpUrg hdr) (proceedFromStep6 hdr body)
else do setState SynReceived
synAck
done
checkSequenceNumber hdr body
checkResetBit hdr
checkSynBit hdr
checkAckBit hdr
proceedFromStep6 hdr body
proceedFromStep6 :: TcpHeader -> S.ByteString -> Sock ()
proceedFromStep6 hdr body =
do
processSegmentText hdr body
checkFinBit hdr
checkSequenceNumber :: TcpHeader -> S.ByteString -> Sock ()
checkSequenceNumber hdr body =
do TcpSocket { .. } <- getTcpSocket
let canReceive off =
lwRcvNxt tcpIn <= segSeq &&
segSeq < lwRcvNxt tcpIn + fromIntegral (lwRcvWind tcpIn)
where
segSeq = tcpSeqNum hdr + off
len = fromIntegral (S.length body)
shouldDiscard
| len == 0 = if lwRcvWind tcpIn == 0
then tcpSeqNum hdr == lwRcvNxt tcpIn
else canReceive 0
| otherwise = canReceive 0 || canReceive (len 1)
when (shouldDiscard && all not [tcpAck hdr, tcpUrg hdr, tcpRst hdr]) $
do unless (tcpRst hdr) ack
discardAndReturn
checkResetBit :: TcpHeader -> Sock ()
checkResetBit hdr
| tcpRst hdr =
do TcpSocket { .. } <- getTcpSocket
whenState SynReceived $
do when (isNothing tcpParent) closeSocket
done
whenStates [Established,FinWait1,FinWait2,CloseWait] $
do closeSocket
done
whenStates [Closing,LastAck] $
do closeSocket
done
| otherwise = return ()
checkSynBit :: TcpHeader -> Sock ()
checkSynBit hdr
| tcpSyn hdr = whenStates [SynReceived,Established,FinWait1,FinWait2
,CloseWait,Closing,LastAck] $
do tcp <- getTcpSocket
when (tcpSeqNum hdr `inRcvWnd` tcp) $
do closeSocket
done
| otherwise = return ()
checkAckBit :: TcpHeader -> Sock ()
checkAckBit hdr
| tcpAck hdr =
do whenState SynReceived $
do tcp <- getTcpSocket
if tcpSndUna tcp <= tcpAckNum hdr && tcpAckNum hdr <= tcpSndNxt tcp
then establishConnection
else rst hdr
whenStates [Established,FinWait1,FinWait2,CloseWait,Closing] $
do let TcpHeader { .. } = hdr
TcpSocket { .. } <- getTcpSocket
when (tcpSndUna < tcpAckNum && tcpAckNum <= tcpSndNxt) (handleAck hdr)
when (tcpSndNxt < tcpAckNum) discardAndReturn
whenState FinWait1 $
do tcp <- getTcpSocket
when (nothingOutstanding tcp) (setState FinWait2)
whenState Closing $
do tcp <- getTcpSocket
when (nothingOutstanding tcp) enterTimeWait
whenState LastAck $
do handleAck hdr
tcp <- getTcpSocket
when (nothingOutstanding tcp) $ do closeSocket
done
| otherwise = discardAndReturn
processSegmentText :: TcpHeader -> S.ByteString -> Sock ()
processSegmentText hdr body =
whenStates [Established,FinWait1,FinWait2] $
do if S.null body
then when (tcpSyn hdr || tcpFin hdr) $ modifyTcpTimers_
$ \ tt -> tt { ttDelayedAck = True }
else do mb <- modifyTcpSocket (handleData hdr body)
case mb of
Just wakeup -> outputS (tryAgain wakeup)
Nothing -> return ()
checkFinBit :: TcpHeader -> Sock ()
checkFinBit hdr
| tcpFin hdr =
do
whenStates [Closed,Listen,SynSent]
discardAndReturn
advanceRcvNxt 1
flushQueues
ack
whenStates [SynReceived,Established] (setState CloseWait)
whenState FinWait1 $ do TcpSocket { .. } <- getTcpSocket
if tcpSndNxt <= tcpSndUna
then enterTimeWait
else setState Closing
whenState FinWait2 enterTimeWait
done
| otherwise = return ()
flushQueues :: Sock ()
flushQueues =
do finalizers <- modifyTcpSocket flush
outputS finalizers
where
flush tcp = (inFins >> outFins,tcp')
where
(outFins,out') = flushWaiting (tcpOutBuffer tcp)
(inFins, in') = flushWaiting (tcpInBuffer tcp)
tcp' = tcp { tcpOutBuffer = out'
, tcpInBuffer = in'
}
updateTimestamp :: TcpHeader -> TcpSocket -> (Bool,TcpSocket)
updateTimestamp hdr tcp
| shouldDrop = (True, tcp)
| otherwise = (False,tcp { tcpTimestamp = ts' })
where
shouldDrop = not (tcpSyn hdr || tcpRst hdr)
&& isJust (tcpTimestamp tcp) /= isJust ts'
ts' = do
ts <- tcpTimestamp tcp
OptTimestamp them echo <- findTcpOption OptTagTimestamp hdr
let rel = tsTimestamp ts echo
isGreater = 0 < rel && rel < bit 31
when (tcpAck hdr) (guard (tsTimestamp ts == echo || isGreater))
return ts { tsLastTimestamp = them }
handleData :: TcpHeader -> S.ByteString -> TcpSocket
-> (Maybe Wakeup, TcpSocket)
handleData hdr body tcp0 = fromMaybe (Nothing,tcp) $ do
(wakeup,buf') <- putBytes bytes (tcpInBuffer tcp)
let tcp' = tcp
{ tcpInBuffer = buf'
, tcpTimers = (tcpTimers tcp)
{ ttDelayedAck = or [ not (L.null bytes)
, tcpSyn hdr
, tcpFin hdr ]
}
}
return (wakeup, tcp')
where
(segs,win') = incomingPacket hdr body (tcpIn tcp0)
tcp = tcp0 { tcpIn = win' }
bytes = L.fromChunks (map inBody (F.toList segs))
handleAck :: TcpHeader -> Sock ()
handleAck hdr = do
now <- inTcp time
modifyTcpSocket_ (updateAck now)
where
updateAck now tcp = case receiveAck hdr (tcpOut tcp) of
Just (seg,out') ->
let calibrate | outFresh seg = calibrateRTO now (outTime seg)
| otherwise = id
in tcp { tcpOut = out'
, tcpSndUna = tcpAckNum hdr
, tcpTimers = calibrate (tcpTimers tcp)
}
Nothing -> tcp
enterTimeWait :: Sock ()
enterTimeWait = do
modifyTcpSocket_ $ \ tcp -> tcp { tcpOut = clearRetransmit (tcpOut tcp) }
set2MSL mslTimeout
setState TimeWait
createConnection :: IP4 -> TcpHeader -> Sock TcpSocket
createConnection ip4 hdr =
do let parent = listenSocketId (tcpDestPort hdr)
isn <- inTcp initialSeqNum
tcp <- getTcpSocket
return (emptyTcpSocket (tcpWindow hdr) (windowScale hdr))
{ tcpParent = Just parent
, tcpSocketId = incomingSocketId ip4 hdr
, tcpState = SynReceived
, tcpIss = isn
, tcpSndNxt = isn
, tcpSndUna = isn
, tcpIn = emptyLocalWindow (tcpSeqNum hdr) 14600 0
, tcpOutMSS = fromMaybe defaultMSS (getMSS hdr)
, tcpTimestamp = do
ts <- tcpTimestamp tcp
OptTimestamp val _ <- findTcpOption OptTagTimestamp hdr
return ts { tsLastTimestamp = val }
, tcpSack = sackSupported hdr
, tcpWindowScale = isJust (findTcpOption OptTagWindowScaling hdr)
}
establishConnection :: Sock ()
establishConnection =
do mb <- inParent popAcceptor
case join mb of
Just k -> do sid <- tcpSocketId `fmap` getTcpSocket
outputS (k sid)
setState Established
Nothing -> do finAck
setState FinWait1
done
outputSegments :: Sock ()
outputSegments = do
now <- inTcp time
(ws,segs) <- modifyTcpSocket (genSegments now)
F.mapM_ outputSegment segs
unless (null ws) (outputS (F.traverse_ tryAgain ws))
getMSS :: TcpHeader -> Maybe Int64
getMSS hdr = do
OptMaxSegmentSize n <- findTcpOption OptTagMaxSegmentSize hdr
return (fromIntegral n)
windowScale :: TcpHeader -> Int
windowScale hdr = fromMaybe 0 $ do
OptWindowScaling n <- findTcpOption OptTagWindowScaling hdr
return (fromIntegral n)
sackSupported :: TcpHeader -> Bool
sackSupported = isJust . findTcpOption OptTagSackPermitted