{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Client.Run (
    run
  , migrate
  ) where

import qualified Network.Socket as NS
import Network.UDP (UDPSocket(..))
import qualified Network.UDP as UDP
import UnliftIO.Async
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Client.Reader
import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

----------------------------------------------------------------

-- | Running a QUIC client.
--   A UDP socket is created according to 'ccServerName' and 'ccPortName'.
--
--   If 'ccAutoMigration' is 'True', a unconnected socket is made.
--   Otherwise, a connected socket is made.
--   Use the 'migrate' API for the connected socket.
run :: ClientConfig -> (Connection -> IO a) -> IO a
-- Don't use handleLogUnit here because of a return value.
run :: forall a. ClientConfig -> (Connection -> IO a) -> IO a
run ClientConfig
conf Connection -> IO a
client = forall a. IO a -> IO a
NS.withSocketsDo forall a b. (a -> b) -> a -> b
$ do
  let resInfo :: ResumptionInfo
resInfo = ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
      verInfo :: VersionInfo
verInfo = case ResumptionInfo -> Maybe (ByteString, SessionData)
resumptionSession ResumptionInfo
resInfo of
        Maybe (ByteString, SessionData)
Nothing | ResumptionInfo -> ByteString
resumptionToken ResumptionInfo
resInfo forall a. Eq a => a -> a -> Bool
== ByteString
emptyToken ->
                  let vers :: [Version]
vers = ClientConfig -> [Version]
ccVersions ClientConfig
conf
                      ver :: Version
ver = forall a. [a] -> a
head [Version]
vers
                  in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version]
vers
        Maybe (ByteString, SessionData)
_  -> let ver :: Version
ver = ResumptionInfo -> Version
resumptionVersion ResumptionInfo
resInfo in Version -> [Version] -> VersionInfo
VersionInfo Version
ver [Version
ver]
  Either NextVersion a
ex <- forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.try forall a b. (a -> b) -> a -> b
$ forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
False VersionInfo
verInfo
  case Either NextVersion a
ex of
    Right a
v                     -> forall (m :: * -> *) a. Monad m => a -> m a
return a
v
    Left (NextVersion VersionInfo
nextVerInfo)
      | VersionInfo
verInfo forall a. Eq a => a -> a -> Bool
== VersionInfo
brokenVersionInfo -> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
VersionNegotiationFailed
      | Bool
otherwise                    -> forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client Bool
True VersionInfo
nextVerInfo

runClient :: ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient :: forall a.
ClientConfig -> (Connection -> IO a) -> Bool -> VersionInfo -> IO a
runClient ClientConfig
conf Connection -> IO a
client0 Bool
isICVN VersionInfo
verInfo = do
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader) -> do
        forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
reader    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
        forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
timeouter forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addTimeouter Connection
conn
        let conf' :: ClientConfig
conf' = ClientConfig
conf {
                ccParameters :: Parameters
ccParameters = (ClientConfig -> Parameters
ccParameters ClientConfig
conf) {
                      versionInformation :: Maybe VersionInfo
versionInformation = forall a. a -> Maybe a
Just VersionInfo
verInfo
                    }
              }
        Connection -> Bool -> IO ()
setIncompatibleVN Connection
conn Bool
isICVN -- must be before handshaker
        Connection -> ByteString -> IO ()
setToken Connection
conn forall a b. (a -> b) -> a -> b
$ ResumptionInfo -> ByteString
resumptionToken forall a b. (a -> b) -> a -> b
$ ClientConfig -> ResumptionInfo
ccResumption ClientConfig
conf
        IO ()
handshaker <- ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf' Connection
conn AuthCIDs
myAuthCIDs
        let client :: IO a
client = do
                if ClientConfig -> Bool
ccUse0RTT ClientConfig
conf then
                    Connection -> IO ()
wait0RTTReady Connection
conn
                  else
                    Connection -> IO ()
wait1RTTReady Connection
conn
                Connection -> IO a
client0 Connection
conn
            ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            supporters :: IO ()
supporters = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
concurrently_ [IO ()
handshaker
                                              ,Connection -> IO ()
sender   Connection
conn
                                              ,Connection -> IO ()
receiver Connection
conn
                                              ,LDCC -> IO ()
resender  LDCC
ldcc
                                              ,LDCC -> IO ()
ldccTimer LDCC
ldcc
                                              ]
            runThreads :: IO a
runThreads = do
                Either () a
er <- forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race IO ()
supporters IO a
client
                case Either () a
er of
                  Left () -> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
MustNotReached
                  Right a
r -> forall (m :: * -> *) a. Monad m => a -> m a
return a
r
        Either SomeException a
ex <- forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.trySyncOrAsync IO a
runThreads
        Connection -> IO ()
sendFinal Connection
conn
        forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc Either SomeException a
ex
  where
    open :: IO ConnRes
open = ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection ClientConfig
conf VersionInfo
verInfo
    clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
        let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
        Connection -> IO ()
setDead Connection
conn
        Connection -> IO ()
freeResources Connection
conn
        Connection -> IO ()
killReaders Connection
conn
        forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ Connection -> IO (IO ())
replaceKillTimeouter Connection
conn

createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection :: ClientConfig -> VersionInfo -> IO ConnRes
createClientConnection conf :: ClientConfig
conf@ClientConfig{Bool
ServiceName
[Cipher]
[Group]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [ByteString])
ccAutoMigration :: ClientConfig -> Bool
ccDebugLog :: ClientConfig -> Bool
ccPacketSize :: ClientConfig -> Maybe Int
ccValidate :: ClientConfig -> Bool
ccALPN :: ClientConfig -> Version -> IO (Maybe [ByteString])
ccPortName :: ClientConfig -> ServiceName
ccServerName :: ClientConfig -> ServiceName
ccHooks :: ClientConfig -> Hooks
ccCredentials :: ClientConfig -> Credentials
ccQLog :: ClientConfig -> Maybe ServiceName
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccGroups :: ClientConfig -> [Group]
ccCiphers :: ClientConfig -> [Cipher]
ccAutoMigration :: Bool
ccDebugLog :: Bool
ccPacketSize :: Maybe Int
ccResumption :: ResumptionInfo
ccValidate :: Bool
ccALPN :: Version -> IO (Maybe [ByteString])
ccPortName :: ServiceName
ccServerName :: ServiceName
ccUse0RTT :: Bool
ccHooks :: Hooks
ccCredentials :: Credentials
ccQLog :: Maybe ServiceName
ccKeyLog :: ServiceName -> IO ()
ccParameters :: Parameters
ccGroups :: [Group]
ccCiphers :: [Cipher]
ccVersions :: [Version]
ccUse0RTT :: ClientConfig -> Bool
ccParameters :: ClientConfig -> Parameters
ccVersions :: ClientConfig -> [Version]
ccResumption :: ClientConfig -> ResumptionInfo
..} VersionInfo
verInfo = do
    us :: UDPSocket
us@(UDPSocket Socket
_ SockAddr
sa Bool
_) <- ServiceName -> ServiceName -> Bool -> IO UDPSocket
UDP.clientSocket ServiceName
ccServerName ServiceName
ccPortName (Bool -> Bool
not Bool
ccAutoMigration)
    RecvQ
q <- IO RecvQ
newRecvQ
    IORef UDPSocket
sref <- forall a. a -> IO (IORef a)
newIORef UDPSocket
us
    let send :: Ptr Word8 -> Int -> IO ()
send = \Ptr Word8
buf Int
siz -> do
            UDPSocket
cs <- forall a. IORef a -> IO a
readIORef IORef UDPSocket
sref
            UDPSocket -> Ptr Word8 -> Int -> IO ()
UDP.sendBuf UDPSocket
cs Ptr Word8
buf Int
siz
        recv :: IO ReceivedPacket
recv = RecvQ -> IO ReceivedPacket
recvClient RecvQ
q
    CID
myCID   <- IO CID
newCID
    CID
peerCID <- IO CID
newCID
    TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
    (QLogger
qLog, IO ()
qclean) <- Maybe ServiceName
-> TimeMicrosecond -> CID -> ByteString -> IO (QLogger, IO ())
dirQLogger Maybe ServiceName
ccQLog TimeMicrosecond
now CID
peerCID ByteString
"client"
    let debugLog :: Builder -> IO ()
debugLog Builder
msg | Bool
ccDebugLog = Builder -> IO ()
stdoutLogger Builder
msg
                     | Bool
otherwise  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Builder -> IO ()
debugLog forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Builder
bhow CID
peerCID
    let myAuthCIDs :: AuthCIDs
myAuthCIDs   = AuthCIDs
defaultAuthCIDs { initSrcCID :: Maybe CID
initSrcCID = forall a. a -> Maybe a
Just CID
myCID }
        peerAuthCIDs :: AuthCIDs
peerAuthCIDs = AuthCIDs
defaultAuthCIDs { initSrcCID :: Maybe CID
initSrcCID = forall a. a -> Maybe a
Just CID
peerCID, origDstCID :: Maybe CID
origDstCID = forall a. a -> Maybe a
Just CID
peerCID }
    Connection
conn <- ClientConfig
-> VersionInfo
-> AuthCIDs
-> AuthCIDs
-> (Builder -> IO ())
-> QLogger
-> Hooks
-> IORef UDPSocket
-> RecvQ
-> (Ptr Word8 -> Int -> IO ())
-> IO ReceivedPacket
-> IO Connection
clientConnection ClientConfig
conf VersionInfo
verInfo AuthCIDs
myAuthCIDs AuthCIDs
peerAuthCIDs Builder -> IO ()
debugLog QLogger
qLog Hooks
ccHooks IORef UDPSocket
sref RecvQ
q Ptr Word8 -> Int -> IO ()
send IO ReceivedPacket
recv
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
    let ver :: Version
ver = VersionInfo -> Version
chosenVersion VersionInfo
verInfo
    forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
peerCID
    Connection -> IO ()
setupCryptoStreams Connection
conn -- fixme: cleanup
    let pktSiz0 :: Int
pktSiz0 = forall a. a -> Maybe a -> a
fromMaybe Int
0 Maybe Int
ccPacketSize
        pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
sa forall a. Ord a => a -> a -> a
`max` Int
pktSiz0) forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
sa
    Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
    LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
    Connection -> IO ()
setAddressValidated Connection
conn
    let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerClient UDPSocket
us Connection
conn -- dies when s0 is closed.
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Connection -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
reader

-- | Creating a new socket and execute a path validation
--   with a new connection ID. Typically, this is used
--   for migration in the case where 'ccAutoMigration' is 'False'.
--   But this can also be used even when the value is 'True'.
migrate :: Connection -> IO Bool
migrate :: Connection -> IO Bool
migrate Connection
conn = Connection -> ConnectionControl -> IO Bool
controlConnection Connection
conn ConnectionControl
ActiveMigration