{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.Recovery.Utils (
    retransmit,
    sendPing,
    mergeLostCandidates,
    mergeLostCandidatesAndClear,
    peerCompletedAddressValidation,
    countAckEli,
    inCongestionRecovery,
    delay,
) where

import Data.Sequence (Seq, ViewL (..), (<|))
import qualified Data.Sequence as Seq
import UnliftIO.Concurrent
import UnliftIO.STM

import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Recovery.Types
import Network.QUIC.Types

----------------------------------------------------------------

retransmit :: LDCC -> Seq SentPacket -> IO ()
retransmit :: LDCC -> Seq SentPacket -> IO ()
retransmit LDCC
ldcc Seq SentPacket
lostPackets
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq SentPacket
packetsToBeResent = forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel LDCC
ldcc forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= LDCC -> EncryptionLevel -> IO ()
sendPing LDCC
ldcc
    | Bool
otherwise = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SentPacket -> IO ()
put Seq SentPacket
packetsToBeResent
  where
    packetsToBeResent :: Seq SentPacket
packetsToBeResent = forall a. (a -> Bool) -> Seq a -> Seq a
Seq.filter SentPacket -> Bool
spAckEliciting Seq SentPacket
lostPackets
    put :: SentPacket -> IO ()
put = LDCC -> PlainPacket -> IO ()
putRetrans LDCC
ldcc forall b c a. (b -> c) -> (a -> b) -> a -> c
. SentPacket -> PlainPacket
spPlainPacket

----------------------------------------------------------------

sendPing :: LDCC -> EncryptionLevel -> IO ()
sendPing :: LDCC -> EncryptionLevel -> IO ()
sendPing LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef PacketNumber
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef PacketNumber
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef PacketNumber
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
putRetrans :: LDCC -> PlainPacket -> IO ()
..} EncryptionLevel
lvl = do
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    forall a. IORef a -> (a -> a) -> IO ()
atomicModifyIORef'' (Array EncryptionLevel (IORef LossDetection)
lossDetection forall i e. Ix i => Array i e -> i -> e
! EncryptionLevel
lvl) forall a b. (a -> b) -> a -> b
$ \LossDetection
ld ->
        LossDetection
ld
            { timeOfLastAckElicitingPacket :: TimeMicrosecond
timeOfLastAckElicitingPacket = TimeMicrosecond
now
            }
    forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe EncryptionLevel)
ptoPing forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just EncryptionLevel
lvl

----------------------------------------------------------------

mergeLostCandidates :: LDCC -> Seq SentPacket -> IO ()
mergeLostCandidates :: LDCC -> Seq SentPacket -> IO ()
mergeLostCandidates LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef PacketNumber
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef PacketNumber
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef PacketNumber
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
putRetrans :: LDCC -> PlainPacket -> IO ()
..} Seq SentPacket
lostPackets = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
    SentPackets Seq SentPacket
old <- forall a. TVar a -> STM a
readTVar TVar SentPackets
lostCandidates
    let new :: Seq SentPacket
new = Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
old Seq SentPacket
lostPackets
    forall a. TVar a -> a -> STM ()
writeTVar TVar SentPackets
lostCandidates forall a b. (a -> b) -> a -> b
$ Seq SentPacket -> SentPackets
SentPackets Seq SentPacket
new

mergeLostCandidatesAndClear :: LDCC -> Seq SentPacket -> IO (Seq SentPacket)
mergeLostCandidatesAndClear :: LDCC -> Seq SentPacket -> IO (Seq SentPacket)
mergeLostCandidatesAndClear LDCC{Array EncryptionLevel (IORef Bool)
Array EncryptionLevel (IORef PeerPacketNumbers)
Array EncryptionLevel (IORef LossDetection)
Array EncryptionLevel (IORef SentPackets)
TVar (Maybe EncryptionLevel)
TVar TimerInfoQ
TVar CC
TVar SentPackets
IORef Bool
IORef PacketNumber
IORef (Maybe TimeoutKey)
IORef (Maybe TimerInfo)
IORef PeerPacketNumbers
IORef RTT
ConnState
PlainPacket -> IO ()
QLogger
timerInfoQ :: TVar TimerInfoQ
previousRTT1PPNs :: IORef PeerPacketNumbers
peerPacketNumbers :: Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: IORef PacketNumber
speedingUp :: IORef Bool
ptoPing :: TVar (Maybe EncryptionLevel)
lostCandidates :: TVar SentPackets
timerInfo :: IORef (Maybe TimerInfo)
timerKey :: IORef (Maybe TimeoutKey)
lossDetection :: Array EncryptionLevel (IORef LossDetection)
sentPackets :: Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: Array EncryptionLevel (IORef Bool)
recoveryCC :: TVar CC
recoveryRTT :: IORef RTT
putRetrans :: PlainPacket -> IO ()
ldccQlogger :: QLogger
ldccState :: ConnState
timerInfoQ :: LDCC -> TVar TimerInfoQ
previousRTT1PPNs :: LDCC -> IORef PeerPacketNumbers
peerPacketNumbers :: LDCC -> Array EncryptionLevel (IORef PeerPacketNumbers)
pktNumPersistent :: LDCC -> IORef PacketNumber
speedingUp :: LDCC -> IORef Bool
ptoPing :: LDCC -> TVar (Maybe EncryptionLevel)
lostCandidates :: LDCC -> TVar SentPackets
timerInfo :: LDCC -> IORef (Maybe TimerInfo)
timerKey :: LDCC -> IORef (Maybe TimeoutKey)
lossDetection :: LDCC -> Array EncryptionLevel (IORef LossDetection)
sentPackets :: LDCC -> Array EncryptionLevel (IORef SentPackets)
spaceDiscarded :: LDCC -> Array EncryptionLevel (IORef Bool)
recoveryCC :: LDCC -> TVar CC
recoveryRTT :: LDCC -> IORef RTT
ldccQlogger :: LDCC -> QLogger
ldccState :: LDCC -> ConnState
putRetrans :: LDCC -> PlainPacket -> IO ()
..} Seq SentPacket
lostPackets = forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically forall a b. (a -> b) -> a -> b
$ do
    SentPackets Seq SentPacket
old <- forall a. TVar a -> STM a
readTVar TVar SentPackets
lostCandidates
    forall a. TVar a -> a -> STM ()
writeTVar TVar SentPackets
lostCandidates SentPackets
emptySentPackets
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
old Seq SentPacket
lostPackets

merge :: Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge :: Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
s1 Seq SentPacket
s2 = case forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
s1 of
    ViewL SentPacket
EmptyL -> Seq SentPacket
s2
    SentPacket
x :< Seq SentPacket
s1' -> case forall a. Seq a -> ViewL a
Seq.viewl Seq SentPacket
s2 of
        ViewL SentPacket
EmptyL -> Seq SentPacket
s1
        SentPacket
y :< Seq SentPacket
s2'
            | SentPacket -> PacketNumber
spPacketNumber SentPacket
x forall a. Ord a => a -> a -> Bool
< SentPacket -> PacketNumber
spPacketNumber SentPacket
y -> SentPacket
x forall a. a -> Seq a -> Seq a
<| Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
s1' Seq SentPacket
s2
            | Bool
otherwise -> SentPacket
y forall a. a -> Seq a -> Seq a
<| Seq SentPacket -> Seq SentPacket -> Seq SentPacket
merge Seq SentPacket
s1 Seq SentPacket
s2'

----------------------------------------------------------------

-- Sec 6.2.1. Computing PTO
-- "That is, a client does not reset the PTO backoff factor on
--  receiving acknowledgements until it receives a HANDSHAKE_DONE
--  frame or an acknowledgement for one of its Handshake or 1-RTT
--  packets."
peerCompletedAddressValidation :: LDCC -> IO Bool
-- For servers: assume clients validate the server's address implicitly.
peerCompletedAddressValidation :: LDCC -> IO Bool
peerCompletedAddressValidation LDCC
ldcc
    | forall a. Connector a => a -> Bool
isServer LDCC
ldcc = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
-- For clients: servers complete address validation when a protected
-- packet is received.
peerCompletedAddressValidation LDCC
ldcc = forall a. Connector a => a -> IO Bool
isConnectionEstablished LDCC
ldcc

----------------------------------------------------------------

countAckEli :: SentPacket -> Int
countAckEli :: SentPacket -> PacketNumber
countAckEli SentPacket
sentPacket
    | SentPacket -> Bool
spAckEliciting SentPacket
sentPacket = PacketNumber
1
    | Bool
otherwise = PacketNumber
0

----------------------------------------------------------------

inCongestionRecovery :: TimeMicrosecond -> Maybe TimeMicrosecond -> Bool
inCongestionRecovery :: TimeMicrosecond -> Maybe TimeMicrosecond -> Bool
inCongestionRecovery TimeMicrosecond
_ Maybe TimeMicrosecond
Nothing = Bool
False
inCongestionRecovery TimeMicrosecond
sentTime (Just TimeMicrosecond
crst) = TimeMicrosecond
sentTime forall a. Ord a => a -> a -> Bool
<= TimeMicrosecond
crst

----------------------------------------------------------------

delay :: Microseconds -> IO ()
delay :: Microseconds -> IO ()
delay (Microseconds PacketNumber
microseconds) = forall (m :: * -> *). MonadIO m => PacketNumber -> m ()
threadDelay PacketNumber
microseconds