{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Reader (
    readerClient
  , recvClient
  , ConnectionControl(..)
  , controlConnection
  ) where

import Data.List (intersect)
import Network.Socket (getSocketName)
import Network.UDP
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.Types
#if defined(mingw32_HOST_OS)
import Network.QUIC.Windows
#endif

-- | readerClient dies when the socket is closed.
readerClient :: UDPSocket -> Connection -> IO ()
readerClient :: UDPSocket -> Connection -> IO ()
readerClient cs0 :: UDPSocket
cs0@(UDPSocket Socket
s0 SockAddr
_ Bool
_) Connection
conn = DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
logAction forall a b. (a -> b) -> a -> b
$ do
    IO ()
wait
    IO ()
loop
  where
    wait :: IO ()
wait = do
        Bool
bound <- forall (m :: * -> *) a.
MonadUnliftIO m =>
(SomeException -> m a) -> m a -> m a
E.handleAny (\SomeException
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False) forall a b. (a -> b) -> a -> b
$ do
            SockAddr
_ <- Socket -> IO SockAddr
getSocketName Socket
s0
            forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
bound forall a b. (a -> b) -> a -> b
$ do
            forall (m :: * -> *). MonadIO m => m ()
yield
            IO ()
wait
    loop :: IO ()
loop = do
        Microseconds
ito <- Connection -> IO Microseconds
readMinIdleTimeout Connection
conn
        Maybe ByteString
mbs <- forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
ito String
"readeClient" forall a b. (a -> b) -> a -> b
$
#if defined(mingw32_HOST_OS)
          windowsThreadBlockHack $
#endif
            UDPSocket -> IO ByteString
recv UDPSocket
cs0
        case Maybe ByteString
mbs of
          Maybe ByteString
Nothing -> UDPSocket -> IO ()
close UDPSocket
cs0
          Just ByteString
bs -> do
            TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
            [PacketI]
pkts <- ByteString -> IO [PacketI]
decodePackets ByteString
bs
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
now) [PacketI]
pkts
            IO ()
loop
    logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"debug: readerClient: " forall a. Semigroup a => a -> a -> a
<> Builder
msg)
    putQ :: TimeMicrosecond -> PacketI -> IO ()
putQ TimeMicrosecond
_ (PacketIB BrokenPacket
BrokenPacket) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    putQ TimeMicrosecond
t (PacketIV pkt :: VersionNegotiationPacket
pkt@(VersionNegotiationPacket CID
dCID CID
sCID [Version]
peerVers)) = do
        forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn VersionNegotiationPacket
pkt TimeMicrosecond
t
        VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
        let myVer :: Version
myVer   = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
            myVers0 :: [Version]
myVers0 = VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
        -- ignoring VN if the original version is included.
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
myVer forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers Bool -> Bool -> Bool
&& Version
Negotiation forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Version]
peerVers) forall a b. (a -> b) -> a -> b
$ do
            Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (forall a b. a -> Either a b
Left CID
sCID)
            let myVers :: [Version]
myVers = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) [Version]
myVers0
                nextVerInfo :: VersionInfo
nextVerInfo = case [Version]
myVers forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
                  vers :: [Version]
vers@(Version
ver:[Version]
_) | Bool
ok -> Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
                  [Version]
_                 -> VersionInfo
brokenVersionInfo
            forall e (m :: * -> *).
(Exception e, MonadIO m) =>
ThreadId -> e -> m ()
E.throwTo (Connection -> ThreadId
mainThreadId Connection
conn) forall a b. (a -> b) -> a -> b
$ VersionInfo -> Abort
VerNego VersionInfo
nextVerInfo
    putQ TimeMicrosecond
t (PacketIC CryptPacket
pkt EncryptionLevel
lvl Int
siz) = RecvQ -> ReceivedPacket -> IO ()
writeRecvQ (Connection -> RecvQ
connRecvQ Connection
conn) forall a b. (a -> b) -> a -> b
$ CryptPacket
-> TimeMicrosecond -> Int -> EncryptionLevel -> ReceivedPacket
mkReceivedPacket CryptPacket
pkt TimeMicrosecond
t Int
siz EncryptionLevel
lvl
    putQ TimeMicrosecond
t (PacketIR pkt :: RetryPacket
pkt@(RetryPacket Version
ver CID
dCID CID
sCID ByteString
token Either CID (ByteString, ByteString)
ex)) = do
        forall q a.
(KeepQlog q, Qlog a) =>
q -> a -> TimeMicrosecond -> IO ()
qlogReceived Connection
conn RetryPacket
pkt TimeMicrosecond
t
        Bool
ok <- Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID Either CID (ByteString, ByteString)
ex
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
ok forall a b. (a -> b) -> a -> b
$ do
            Connection -> CID -> IO ()
resetPeerCID Connection
conn CID
sCID
            Connection -> (AuthCIDs -> AuthCIDs) -> IO ()
setPeerAuthCIDs Connection
conn forall a b. (a -> b) -> a -> b
$ \AuthCIDs
auth -> AuthCIDs
auth { retrySrcCID :: Maybe CID
retrySrcCID  = forall a. a -> Maybe a
Just CID
sCID }
            forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
sCID
            Connection -> ByteString -> IO ()
setToken Connection
conn ByteString
token
            Connection -> Bool -> IO ()
setRetried Connection
conn Bool
True
            LDCC -> IO (Seq PlainPacket)
releaseByRetry (Connection -> LDCC
connLDCC Connection
conn) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PlainPacket -> IO ()
put
      where
        put :: PlainPacket -> IO ()
put PlainPacket
ppkt = Connection -> Output -> IO ()
putOutput Connection
conn forall a b. (a -> b) -> a -> b
$ PlainPacket -> Output
OutRetrans PlainPacket
ppkt

checkCIDs :: Connection -> CID -> Either CID (ByteString,ByteString) -> IO Bool
checkCIDs :: Connection -> CID -> Either CID (ByteString, ByteString) -> IO Bool
checkCIDs Connection
conn CID
dCID (Left CID
sCID) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& CID
sCID forall a. Eq a => a -> a -> Bool
== CID
remoteCID)
checkCIDs Connection
conn CID
dCID (Right (ByteString
pseudo0,ByteString
tag)) = do
    CID
localCID <- Connection -> IO CID
getMyCID Connection
conn
    CID
remoteCID <- Connection -> IO CID
getPeerCID Connection
conn
    Version
ver <- Connection -> IO Version
getVersion Connection
conn
    let ok :: Bool
ok = Version -> CID -> ByteString -> ByteString
calculateIntegrityTag Version
ver CID
remoteCID ByteString
pseudo0 forall a. Eq a => a -> a -> Bool
== ByteString
tag
    forall (m :: * -> *) a. Monad m => a -> m a
return (CID
dCID forall a. Eq a => a -> a -> Bool
== CID
localCID Bool -> Bool -> Bool
&& Bool
ok)

recvClient :: RecvQ -> IO ReceivedPacket
recvClient :: RecvQ -> IO ReceivedPacket
recvClient = RecvQ -> IO ReceivedPacket
readRecvQ

----------------------------------------------------------------

-- | How to control a connection.
data ConnectionControl = ChangeServerCID
                       | ChangeClientCID
                       | NATRebinding
                       | ActiveMigration
                       deriving (ConnectionControl -> ConnectionControl -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionControl -> ConnectionControl -> Bool
$c/= :: ConnectionControl -> ConnectionControl -> Bool
== :: ConnectionControl -> ConnectionControl -> Bool
$c== :: ConnectionControl -> ConnectionControl -> Bool
Eq, Int -> ConnectionControl -> ShowS
[ConnectionControl] -> ShowS
ConnectionControl -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionControl] -> ShowS
$cshowList :: [ConnectionControl] -> ShowS
show :: ConnectionControl -> String
$cshow :: ConnectionControl -> String
showsPrec :: Int -> ConnectionControl -> ShowS
$cshowsPrec :: Int -> ConnectionControl -> ShowS
Show)

controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection :: Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
typ
  | forall a. Connector a => a -> Bool
isClient Connection
conn = do
        Connection -> IO ()
waitEstablished Connection
conn
        Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
typ
  | Bool
otherwise     = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' :: Connection -> ConnectionControl -> IO Bool
controlConnection' Connection
conn ConnectionControl
ChangeServerCID = do
    Maybe CIDInfo
mn <- forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 1" forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
      Maybe CIDInfo
Nothing              -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      Just (CIDInfo Int
n CID
_ StatelessResetToken
_) -> do
          Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [Int -> Frame
RetireConnectionID Int
n]
          forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ChangeClientCID = do
    CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
    Int
x <- (forall a. Num a => a -> a -> a
+Int
1) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Int
getMyCIDSeqNum Connection
conn
    Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
x]
    forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
NATRebinding = do
    Connection -> Microseconds -> IO ()
rebind Connection
conn forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000 -- nearly 0
    forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
controlConnection' Connection
conn ConnectionControl
ActiveMigration = do
    Maybe CIDInfo
mn <- forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds Int
1000000) String
"controlConnection' 2" forall a b. (a -> b) -> a -> b
$ Connection -> IO CIDInfo
waitPeerCID Connection
conn -- fixme
    case Maybe CIDInfo
mn of
      Maybe CIDInfo
Nothing  -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      Maybe CIDInfo
mcidinfo -> do
          Connection -> Microseconds -> IO ()
rebind Connection
conn forall a b. (a -> b) -> a -> b
$ Int -> Microseconds
Microseconds Int
5000000
          Connection -> Maybe CIDInfo -> IO ()
validatePath Connection
conn Maybe CIDInfo
mcidinfo
          forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

rebind :: Connection -> Microseconds -> IO ()
rebind :: Connection -> Microseconds -> IO ()
rebind Connection
conn Microseconds
microseconds = do
    UDPSocket
cs0 <- Connection -> IO UDPSocket
getSocket Connection
conn
    UDPSocket
cs <- UDPSocket -> IO UDPSocket
natRebinding UDPSocket
cs0
    UDPSocket
cs0' <- Connection -> UDPSocket -> IO UDPSocket
setSocket Connection
conn UDPSocket
cs
    let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerClient UDPSocket
cs Connection
conn
    forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
    -- Using cs0' just in case.
    Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn Microseconds
microseconds forall a b. (a -> b) -> a -> b
$ UDPSocket -> IO ()
close UDPSocket
cs0'