module Hans.Tcp.RecvWindow (
Window(),
emptyWindow,
recvSegment,
rcvWnd,
rcvNxt, setRcvNxt,
rcvRight,
moveRcvRight,
sequenceNumberValid,
) where
import Hans.Lens
import Hans.Tcp.Packet
import qualified Data.ByteString as S
import Data.Word (Word16)
data Segment = Segment { segStart :: !TcpSeqNum
, segEnd :: !TcpSeqNum
, segHdr :: !TcpHeader
, segBody :: !S.ByteString
} deriving (Show)
mkSegment :: TcpHeader -> S.ByteString -> Segment
mkSegment segHdr segBody =
Segment { segStart = tcpSeqNum segHdr
, segEnd = tcpSegLastSeqNum segHdr (S.length segBody)
, .. }
segNext :: Segment -> TcpSeqNum
segNext Segment { .. } = segEnd + 1
trimSeg :: Int -> Segment -> Maybe Segment
trimSeg len seg@Segment { .. }
| len' <= 0 =
Just seg
| len' >= S.length segBody =
Nothing
| otherwise =
Just $! Segment { segStart = segStart + fromIntegral len'
, segHdr = segHdr { tcpSeqNum = tcpSeqNum segHdr
+ fromIntegral len }
, segBody = S.drop len segBody
, .. }
where
flag l | view l segHdr = 1
| otherwise = 0
len' = len flag tcpSyn flag tcpFin
resolveOverlap :: Segment -> Segment -> [Segment]
resolveOverlap a b =
case trimSeg (fromTcpSeqNum (segEnd x segStart y)) y of
Just y' -> [x,y']
Nothing -> error "resolveOverlap: invariant violated"
where
(x,y) | segStart a < segStart b = (a,b)
| otherwise = (b,a)
data Window = Window { wSegments :: ![Segment]
, wRcvNxt :: !TcpSeqNum
, wRcvRight :: !TcpSeqNum
, wMax :: !TcpSeqNum
} deriving (Show)
emptyWindow :: TcpSeqNum -> Int -> Window
emptyWindow wRcvNxt maxWin =
Window { wSegments = []
, wRcvRight = wRcvNxt + wMax
, .. }
where
wMax = fromIntegral maxWin
rcvWnd :: Lens' Window Word16
rcvWnd f Window { .. } =
fmap (\ wnd -> Window { wRcvRight = wRcvNxt + fromIntegral wnd, .. })
(f (fromTcpSeqNum (wRcvRight wRcvNxt)))
rcvNxt :: Getting r Window TcpSeqNum
rcvNxt = to wRcvNxt
rcvRight :: Getting r Window TcpSeqNum
rcvRight = to wRcvRight
setRcvNxt :: TcpSeqNum -> Window -> (Window,Bool)
setRcvNxt nxt win
| null (wSegments win) = (win { wRcvNxt = nxt, wRcvRight = nxt + wMax win }, True)
| otherwise = (win, False)
recvSegment :: TcpHeader -> S.ByteString -> Window
-> (Window, Maybe [(TcpHeader,S.ByteString)])
recvSegment hdr body win
| Just seg <- sequenceNumberValid (wRcvNxt win) (wRcvRight win) hdr body =
let (win', segs) = addSegment seg win
in (win', Just [ (segHdr,segBody) | Segment { .. } <- segs ])
| otherwise =
(win, Nothing)
moveRcvRight :: Int -> Window -> (Window, ())
moveRcvRight n = \ win ->
let rcvRight' = view rcvRight win + min (max 0 (fromIntegral n)) (wMax win)
in (win { wRcvRight = rcvRight' }, ())
addSegment :: Segment -> Window -> (Window, [Segment])
addSegment seg win
| segStart seg == wRcvNxt win =
advanceLeft seg win
| otherwise =
(insertOutOfOrder seg win, [])
advanceLeft :: Segment -> Window -> (Window, [Segment])
advanceLeft seg win
| null (wSegments win) =
( win { wRcvNxt = segNext seg }, [seg])
| otherwise =
let win' = insertOutOfOrder seg win
(nxt,valid,rest) = splitContiguous (wSegments win')
in (win' { wSegments = rest, wRcvNxt = nxt }, valid)
insertOutOfOrder :: Segment -> Window -> Window
insertOutOfOrder seg Window { .. } = Window { wSegments = segs', .. }
where
segs' = loop seg wSegments
loop new segs@(x:xs)
| segEnd new < segStart x = new : segs
| segStart new > segEnd x =
x : loop new segs
| otherwise = resolveOverlap new x ++ xs
loop new [] = [new]
splitContiguous :: [Segment] -> (TcpSeqNum,[Segment],[Segment])
splitContiguous (seg:segs) = loop [seg] (segNext seg) segs
where
loop acc from (x:xs) | segStart x == from = loop (x:acc) (segNext seg) xs
loop acc from xs = (from, reverse acc, xs)
splitContiguous [] = error "splitContiguous: empty list"
sequenceNumberValid :: TcpSeqNum
-> TcpSeqNum
-> TcpHeader
-> S.ByteString
-> Maybe Segment
sequenceNumberValid nxt wnd hdr@TcpHeader { .. } payload
| payloadLen == 0 =
if nullWindow
then if tcpSeqNum == nxt then Just (mkSegment hdr S.empty) else Nothing
else if seqNumInWindow then Just (mkSegment hdr S.empty) else Nothing
| otherwise =
if nullWindow
then Nothing
else if | seqNumInWindow -> Just (mkSegment hdr seg')
| dataEndInWindow -> Just (mkSegment hdr' seg')
| otherwise -> Nothing
where
nullWindow = nxt == wnd
payloadLen = tcpSegLen hdr (fromIntegral (S.length payload))
segEnd = tcpSeqNum + fromIntegral (payloadLen 1)
hdr' = hdr { tcpSeqNum = nxt }
seg' = S.copy $ S.drop (fromTcpSeqNum (nxt tcpSeqNum))
$ S.take (fromTcpSeqNum (segEnd wnd)) payload
seqNumInWindow = nxt <= tcpSeqNum && tcpSeqNum < wnd
dataEndInWindow = nxt <= segEnd && segEnd < wnd