module Hans.Tcp.SendWindow (
Window(),
emptyWindow,
sndNxt, setSndNxt,
sndUna,
sndWnd,
nullWindow,
fullWindow,
flushWindow,
TSClock(),
initialTSClock,
updateTSClock,
tsVal,
queueSegment,
retransmitTimeout,
ackSegment,
handleSack,
) where
import Hans.Config
import Hans.Lens
import Hans.Tcp.Packet
import Control.Monad (guard)
import qualified Data.ByteString.Lazy as L
import Data.List (sortBy)
import Data.Maybe (isJust)
import Data.Ord (comparing)
import Data.Time.Clock (UTCTime,NominalDiffTime,diffUTCTime)
import Data.Word (Word32)
data Segment = Segment { segHeader :: !TcpHeader
, segRightEdge :: !TcpSeqNum
, segBody :: !L.ByteString
, segSentAt :: !(Maybe UTCTime)
, segSACK :: !Bool
}
segHeaderL :: Lens' Segment TcpHeader
segHeaderL f Segment { .. } =
fmap (\h -> Segment { segHeader = h, .. }) (f segHeader)
instance HasTcpOptions Segment where
tcpOptions = segHeaderL . tcpOptions
mkSegment :: TcpHeader -> L.ByteString -> UTCTime -> Segment
mkSegment segHeader segBody now =
Segment { segRightEdge = tcpSegNextAckNum segHeader (fromIntegral (L.length segBody))
, segSACK = False
, segSentAt = Just now
, .. }
leftEdge :: Lens' Segment TcpSeqNum
leftEdge f seg@Segment { segHeader = hdr@TcpHeader { .. }, .. } =
fmap update (f tcpSeqNum)
where
update sn
| sn <= tcpSeqNum = seg
| otherwise =
let len = fromTcpSeqNum (sn tcpSeqNum)
(hdr',len') | view tcpSyn hdr = (set tcpSyn False hdr,len 1)
| otherwise = (hdr,len)
in Segment { segHeader = hdr' { tcpSeqNum = sn }
, segBody = L.drop len' segBody
, .. }
rightEdge :: Getting r Segment TcpSeqNum
rightEdge = to segRightEdge
sack :: Lens' Segment Bool
sack f seg@Segment { .. } =
fmap update (f segSACK)
where
update b | b == segSACK = seg
| otherwise = Segment { segSACK = b, .. }
data TSClock = TSClock { tscVal :: !Word32, tscLastUpdate :: !UTCTime }
initialTSClock :: Word32 -> UTCTime -> TSClock
initialTSClock tscVal tscLastUpdate = TSClock { .. }
updateTSClock :: Config -> UTCTime -> TSClock -> TSClock
updateTSClock Config { .. } now TSClock { .. } =
let diff = truncate (diffUTCTime now tscLastUpdate * cfgTcpTSClockFrequency)
in TSClock { tscVal = tscVal + diff, tscLastUpdate = now }
tsVal :: Getting r TSClock Word32
tsVal = to tscVal
measureRTT :: Config -> Word32 -> TSClock -> NominalDiffTime
measureRTT Config { .. } ecr clk =
fromIntegral (view tsVal clk ecr) / cfgTcpTSClockFrequency
type Segments = [Segment]
data Window = Window { wRetransmitQueue :: !Segments
, wSndAvail :: !Int
, wSndNxt :: !TcpSeqNum
, wSndWnd :: !TcpSeqNum
, wTSClock :: !TSClock
}
emptyWindow :: TcpSeqNum
-> TcpSeqNum
-> TSClock
-> Window
emptyWindow wSndNxt wSndWnd wTSClock =
Window { wRetransmitQueue = []
, wSndAvail = fromTcpSeqNum wSndWnd
, .. }
flushWindow :: Window -> (Window, ())
flushWindow Window { .. } = (Window { wRetransmitQueue = [], .. }, ())
nullWindow :: Window -> Bool
nullWindow Window { .. } = null wRetransmitQueue
fullWindow :: Window -> Bool
fullWindow Window { .. } = wSndAvail == 0
sndNxt :: Getting r Window TcpSeqNum
sndNxt = to wSndNxt
setSndNxt :: TcpSeqNum -> Window -> (Window, Bool)
setSndNxt nxt win
| null (wRetransmitQueue win) = (win { wSndNxt = nxt }, True)
| otherwise = (win, False)
sndWnd :: Lens' Window TcpSeqNum
sndWnd f Window { .. } =
fmap (\ wnd -> Window { wSndWnd = wnd
, wSndAvail = wSndAvail + fromTcpSeqNum (wnd wSndWnd)
, .. })
(f wSndWnd)
sndUna :: Getting r Window TcpSeqNum
sndUna = to $ \ Window { .. } ->
case wRetransmitQueue of
seg : _ -> view leftEdge seg
[] -> wSndNxt
queueSegment :: Config -> UTCTime -> (Word32 -> TcpSeqNum -> TcpHeader) -> L.ByteString
-> Window -> (Window,Maybe (Bool,TcpHeader,L.ByteString))
queueSegment cfg now mkHdr body win
| size == 0 = (win, Just (False,hdr,L.empty))
| wSndAvail win == 0 = (win,Nothing)
| otherwise = (win',Just (startRTO,hdr,trimmedBody))
where
clock' = updateTSClock cfg now (wTSClock win)
hdr = mkHdr (view tsVal clock') (wSndNxt win)
trimmedBody = L.take (fromIntegral (wSndAvail win)) body
seg = mkSegment hdr trimmedBody now
size = tcpSegLen hdr (fromIntegral (L.length trimmedBody))
win' = win { wRetransmitQueue = wRetransmitQueue win ++ [seg]
, wSndAvail = wSndAvail win size
, wSndNxt = wSndNxt win + fromIntegral size
, wTSClock = clock'
}
startRTO = null (wRetransmitQueue win)
retransmitTimeout :: Window -> (Window,Maybe (TcpHeader,L.ByteString))
retransmitTimeout win = (win { wRetransmitQueue = queue' }, mbSeg)
where
(mbSeg,queue') =
case wRetransmitQueue win of
Segment { .. } : rest ->
( Just (segHeader,segBody)
, map (set sack False) (Segment { segSentAt = Nothing, .. } : rest ) )
[] -> (Nothing,[])
ackSegment :: Config -> UTCTime -> TcpSeqNum -> Window
-> (Window, Maybe (Bool,Maybe NominalDiffTime))
ackSegment cfg now ack win
| view sndUna win <= ack && ack <= view sndNxt win =
( win', Just (null (wRetransmitQueue win'), mbMeasurement) )
| otherwise =
( win, Nothing )
where
win' = win { wRetransmitQueue = queue'
, wSndAvail = wSndAvail win + fromTcpSeqNum (ack view sndUna win)
, wTSClock = updateTSClock cfg now (wTSClock win)
}
go acks segs@(seg : rest)
| view rightEdge seg <= ack = go (seg:acks) rest
| view leftEdge seg <= ack = (seg:acks, set leftEdge ack seg : rest)
| otherwise = (acks,segs)
go acks [] = (acks,[])
(ackd,queue') = go [] (wRetransmitQueue win)
mbMeasurement =
case ackd of
seg : _
| Just (OptTimestamp val _) <- findTcpOption OptTagTimestamp seg ->
return (measureRTT cfg val (wTSClock win'))
| otherwise ->
do let samples = filter (isJust . segSentAt) ackd
guard (not (null samples))
let Segment { .. } = last samples
sent <- segSentAt
return $! diffUTCTime sent now
[] -> Nothing
handleSack :: [SackBlock] -> Window -> (Window,[(TcpHeader,L.ByteString)])
handleSack blocks win =
let win' = processSackBlocks blocks win
in (win', sackRetransmit win')
sackRetransmit :: Window -> [(TcpHeader,L.ByteString)]
sackRetransmit Window { .. } =
[ (segHeader,segBody) | Segment { .. } <- wRetransmitQueue, not segSACK ]
processSackBlocks :: [SackBlock] -> Window -> Window
processSackBlocks blocks Window { .. } =
Window { wRetransmitQueue = go wRetransmitQueue (sortBy (comparing sbLeft) blocks)
, .. }
where
go queue@(seg:segs) bs@(SackBlock { .. } :rest)
| segWithin seg sbLeft sbRight = set sack True seg : go segs bs
| view leftEdge seg >= sbRight = go queue rest
| otherwise = seg : go segs bs
go segs _ = segs
segWithin :: Segment -> TcpSeqNum -> TcpSeqNum -> Bool
segWithin seg l r = view leftEdge seg >= l && view rightEdge seg < r