{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Network.QUIC.Client.Reader (
readerClient
, recvClient
, ConnectionControl(..)
, controlConnection
) where
import Data.List (intersect)
import Network.Socket (getSocketName)
import Network.UDP
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Types
#if defined(mingw32_HOST_OS)
import Network.QUIC.Windows
#endif
readerClient :: UDPSocket -> Connection -> IO ()
readerClient :: UDPSocket -> Connection -> IO ()
readerClient cs0 :: UDPSocket
cs0@(UDPSocket Socket
s0 SockAddr
_ Bool
_) Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction forall a b. (a -> b) -> a -> b
$ do
IO ()
wait
IO ()
loop
where
wait :: IO ()
wait = do
Bool
bound <- forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
E.handleAny (\SomeException
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) forall a b. (a -> b) -> a -> b
$ do
SockAddr
_ <- Socket -> IO SockAddr
getSocketName Socket
s0
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
bound forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadIO m => m ()
yield
IO ()
wait
loop :: IO ()
loop = do
Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
Maybe ByteString
mbs <- forall a. Microseconds -> IO a -> IO (Maybe a)
timeout Microseconds
ito forall a b. (a -> b) -> a -> b
$
#if defined(mingw32_HOST_OS)
windowsThreadBlockHack $
#endif
UDPSocket -> IO ByteString
recv UDPSocket
cs0
case Maybe ByteString
mbs of
Maybe ByteString
Nothing -> UDPSocket -> IO ()
close UDPSocket
cs0
Just ByteString
bs -> do
TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
[PacketI]
pkts <- ByteString -> IO [PacketI]
decodePackets ByteString
bs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
now) [PacketI]
pkts
IO ()
loop
logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " forall a. Semigroup a => a -> a -> a
<> Builder
msg)
putQ :: TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
_ (PacketIB BrokenPacket
BrokenPacket) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
putQ TimeMicrosecond
t (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
let myVer :: Version
myVer = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
myVers0 :: [Version]
myVers0 = VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
myVer forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers Bool -> Bool -> Bool
&& Version
Negotiation forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers) forall a b. (a -> b) -> a -> b
$ do
Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (forall a b. a -> Either a b
Left CID
sCID)
let myVers :: [Version]
myVers = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) [Version]
myVers0
nextVerInfo :: VersionInfo
nextVerInfo = case [Version]
myVers forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
vers :: [Version]
vers@(Version
ver:[Version]
_) | Bool
ok -> Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
[Version]
_ -> VersionInfo
brokenVersionInfo
forall e (m :: * -> *).
(Exception e, MonadIO m) =>
ThreadId -> e -> m ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) forall a b. (a -> b) -> a -> b
$ VersionInfo -> Abort
VerNego VersionInfo
nextVerInfo
putQ TimeMicrosecond
t (PacketIC CryptPacket
pkt EncryptionLevel
lvl Int
siz) = RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
siz EncryptionLevel
lvl
putQ TimeMicrosecond
t (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok forall a b. (a -> b) -> a -> b
$ do
Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth { retrySrcCID :: Maybe CID
retrySrcCID = forall a. a -> Maybe a
Just CID
sCID }
forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
where
put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt
checkCIDs :: Connection -> CID -> Either CID (ByteString,ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0,ByteString
tag)) = do
CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
Version
ver <- Connection -> IO Version
getVersion Connection
conn
let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 forall a. Eq a => a -> a -> Bool
== ByteString
tag
forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)
recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ
data ConnectionControl = ChangeServerCID
| ChangeClientCID
| NATRebinding
| ActiveMigration
deriving (ConnectionControl -> ConnectionControl -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c== :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionControl] -> ShowS
$cshowList :: [ConnectionControl] -> ShowS
show :: ConnectionControl -> String
$cshow :: ConnectionControl -> String
showsPrec :: Int -> ConnectionControl -> ShowS
$cshowsPrec :: Int -> ConnectionControl -> ShowS
Show)
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
| forall a. Connector a => a -> Bool
isClient Connection
conn = do
Connection -> IO ()
waitEstablished Connection
conn
Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
| Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
Maybe CIDInfo
mn <- forall a. Microseconds -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn
case Maybe CIDInfo
mn of
Maybe CIDInfo
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Just (CIDInfo Int
n CID
_ StatelessResetToken
_) -> do
Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID Int
n]
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
Int
x <- (forall a. Num a => a -> a -> a
+Int
1) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
Connection -> Microseconds -> IO ()
rebind Connection
conn forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
Maybe CIDInfo
mn <- forall a. Microseconds -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn
case Maybe CIDInfo
mn of
Maybe CIDInfo
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Maybe CIDInfo
mcidinfo -> do
Connection -> Microseconds -> IO ()
rebind Connection
conn forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
UDPSocket
cs0 <- Connection -> IO UDPSocket
getSocket Connection
conn
UDPSocket
cs <- UDPSocket -> IO UDPSocket
natRebinding UDPSocket
cs0
UDPSocket
cs0' <- Connection -> UDPSocket -> IO UDPSocket
setSocket Connection
conn UDPSocket
cs
let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerClient UDPSocket
cs Connection
conn
forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds forall a b. (a -> b) -> a -> b
$ UDPSocket -> IO ()
close UDPSocket
cs0'