{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Server.Run (
    run
  , stop
  ) where

import qualified Network.Socket as NS
import Network.UDP (UDPSocket(..), ListenSocket(..))
import qualified Network.UDP as UDP
import System.Log.FastLogger
import UnliftIO.Async
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E

import Network.QUIC.Closer
import Network.QUIC.Common
import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Crypto
import Network.QUIC.Exception
import Network.QUIC.Handshake
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Packet
import Network.QUIC.Parameters
import Network.QUIC.QLogger
import Network.QUIC.Qlog
import Network.QUIC.Receiver
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Server.Reader
import Network.QUIC.Types

----------------------------------------------------------------

-- | Running a QUIC server.
--   The action is executed with a new connection
--   in a new lightweight thread.
run :: ServerConfig -> (Connection -> IO ()) -> IO ()
run :: ServerConfig -> (Connection -> IO ()) -> IO ()
run ServerConfig
conf Connection -> IO ()
server = forall a. IO a -> IO a
NS.withSocketsDo forall a b. (a -> b) -> a -> b
$ DebugLogger -> IO () -> IO ()
handleLogUnit DebugLogger
debugLog forall a b. (a -> b) -> a -> b
$ do
    ThreadId
baseThreadId <- forall (m :: * -> *). MonadIO m => m ThreadId
myThreadId
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO (Dispatch, [ThreadId])
setup forall {t :: * -> *}. Foldable t => (Dispatch, t ThreadId) -> IO ()
teardown forall a b. (a -> b) -> a -> b
$ \(Dispatch
dispatch,[ThreadId]
_) -> forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
        Accept
acc <- Dispatch -> IO Accept
accept Dispatch
dispatch
        forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO (ServerConfig
-> (Connection -> IO ()) -> Dispatch -> ThreadId -> Accept -> IO ()
runServer ServerConfig
conf Connection -> IO ()
server Dispatch
dispatch ThreadId
baseThreadId Accept
acc)
  where
    doDebug :: Bool
doDebug = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ ServerConfig -> Maybe FilePath
scDebugLog ServerConfig
conf
    debugLog :: DebugLogger
debugLog Builder
msg | Bool
doDebug   = DebugLogger
stdoutLogger (Builder
"run: " forall a. Semigroup a => a -> a -> a
<> Builder
msg)
                 | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    setup :: IO (Dispatch, [ThreadId])
setup = do
        Dispatch
dispatch <- IO Dispatch
newDispatch
        -- fixme: the case where sockets cannot be created.
        [ListenSocket]
ssas <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (IP, PortNumber) -> IO ListenSocket
UDP.serverSocket forall a b. (a -> b) -> a -> b
$ ServerConfig -> [(IP, PortNumber)]
scAddresses ServerConfig
conf
        [ThreadId]
tids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Dispatch -> ServerConfig -> ListenSocket -> IO ThreadId
runDispatcher Dispatch
dispatch ServerConfig
conf) [ListenSocket]
ssas
        ThreadId
ttid <- forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
timeouter -- fixme
        forall (m :: * -> *) a. Monad m => a -> m a
return (Dispatch
dispatch, ThreadId
ttidforall a. a -> [a] -> [a]
:[ThreadId]
tids)
    teardown :: (Dispatch, t ThreadId) -> IO ()
teardown (Dispatch
dispatch, t ThreadId
tids) = do
        Dispatch -> IO ()
clearDispatch Dispatch
dispatch
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread t ThreadId
tids

-- Typically, ConnectionIsClosed breaks acceptStream.
-- And the exception should be ignored.
runServer :: ServerConfig -> (Connection -> IO ()) -> Dispatch -> ThreadId -> Accept -> IO ()
runServer :: ServerConfig
-> (Connection -> IO ()) -> Dispatch -> ThreadId -> Accept -> IO ()
runServer ServerConfig
conf Connection -> IO ()
server0 Dispatch
dispatch ThreadId
baseThreadId Accept
acc =
    forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO ConnRes
open ConnRes -> IO ()
clse forall a b. (a -> b) -> a -> b
$ \(ConnRes Connection
conn AuthCIDs
myAuthCIDs IO ()
_reader) ->
        DebugLogger -> IO () -> IO ()
handleLogUnit (Connection -> DebugLogger
debugLog Connection
conn) forall a b. (a -> b) -> a -> b
$ do
#if !defined(mingw32_HOST_OS)
            forall (m :: * -> *). MonadUnliftIO m => m () -> m ThreadId
forkIO IO ()
_reader forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> ThreadId -> IO ()
addReader Connection
conn
#endif
            let conf' :: ServerConfig
conf' = ServerConfig
conf {
                    scParameters :: Parameters
scParameters = (ServerConfig -> Parameters
scParameters ServerConfig
conf) {
                          versionInformation :: Maybe VersionInfo
versionInformation = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Accept -> VersionInfo
accVersionInfo Accept
acc
                        }
                  }
            IO ()
handshaker <- ServerConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeServer ServerConfig
conf' Connection
conn AuthCIDs
myAuthCIDs
            let server :: IO ()
server = do
                    Connection -> IO ()
wait1RTTReady Connection
conn
                    Connection -> IO ()
afterHandshakeServer Connection
conn
                    Connection -> IO ()
server0 Connection
conn
                ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
                supporters :: IO ()
supporters = forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m ()
concurrently_ [IO ()
handshaker
                                                  ,Connection -> IO ()
sender   Connection
conn
                                                  ,Connection -> IO ()
receiver Connection
conn
                                                  ,LDCC -> IO ()
resender  LDCC
ldcc
                                                  ,LDCC -> IO ()
ldccTimer LDCC
ldcc
                                                  ]
                runThreads :: IO ()
runThreads = do
                    Either () ()
er <- forall (m :: * -> *) a b.
MonadUnliftIO m =>
m a -> m b -> m (Either a b)
race IO ()
supporters IO ()
server
                    case Either () ()
er of
                      Left () -> forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
MustNotReached
                      Right ()
r -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
r
            forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.trySyncOrAsync IO ()
runThreads forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc
  where
    open :: IO ConnRes
open = ServerConfig -> Dispatch -> Accept -> ThreadId -> IO ConnRes
createServerConnection ServerConfig
conf Dispatch
dispatch Accept
acc ThreadId
baseThreadId
    clse :: ConnRes -> IO ()
clse ConnRes
connRes = do
        let conn :: Connection
conn = ConnRes -> Connection
connResConnection ConnRes
connRes
        Connection -> IO ()
setDead Connection
conn
        Connection -> IO ()
freeResources Connection
conn
#if !defined(mingw32_HOST_OS)
        Connection -> IO ()
killReaders Connection
conn
#endif
        Connection -> IO UDPSocket
getSocket Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= UDPSocket -> IO ()
UDP.close
    debugLog :: Connection -> DebugLogger
debugLog Connection
conn Builder
msg = do
        Connection -> DebugLogger
connDebugLog Connection
conn (Builder
"runServer: " forall a. Semigroup a => a -> a -> a
<> Builder
msg)
        forall q. KeepQlog q => q -> Debug -> IO ()
qlogDebug Connection
conn forall a b. (a -> b) -> a -> b
$ LogStr -> Debug
Debug forall a b. (a -> b) -> a -> b
$ forall msg. ToLogStr msg => msg -> LogStr
toLogStr Builder
msg

createServerConnection :: ServerConfig -> Dispatch -> Accept -> ThreadId
                       -> IO ConnRes
createServerConnection :: ServerConfig -> Dispatch -> Accept -> ThreadId -> IO ConnRes
createServerConnection conf :: ServerConfig
conf@ServerConfig{Bool
[(IP, PortNumber)]
[Cipher]
[Group]
[Version]
Maybe FilePath
Maybe (Version -> [ByteString] -> IO ByteString)
Credentials
SessionManager
Parameters
Hooks
FilePath -> IO ()
scSessionManager :: ServerConfig -> SessionManager
scRequireRetry :: ServerConfig -> Bool
scALPN :: ServerConfig -> Maybe (Version -> [ByteString] -> IO ByteString)
scUse0RTT :: ServerConfig -> Bool
scHooks :: ServerConfig -> Hooks
scCredentials :: ServerConfig -> Credentials
scQLog :: ServerConfig -> Maybe FilePath
scKeyLog :: ServerConfig -> FilePath -> IO ()
scGroups :: ServerConfig -> [Group]
scCiphers :: ServerConfig -> [Cipher]
scVersions :: ServerConfig -> [Version]
scDebugLog :: Maybe FilePath
scSessionManager :: SessionManager
scRequireRetry :: Bool
scALPN :: Maybe (Version -> [ByteString] -> IO ByteString)
scAddresses :: [(IP, PortNumber)]
scUse0RTT :: Bool
scHooks :: Hooks
scCredentials :: Credentials
scQLog :: Maybe FilePath
scKeyLog :: FilePath -> IO ()
scParameters :: Parameters
scGroups :: [Group]
scCiphers :: [Cipher]
scVersions :: [Version]
scParameters :: ServerConfig -> Parameters
scAddresses :: ServerConfig -> [(IP, PortNumber)]
scDebugLog :: ServerConfig -> Maybe FilePath
..} Dispatch
dispatch Accept{Bool
Int
ListenSocket
ClientSockAddr
TimeMicrosecond
VersionInfo
RecvQ
AuthCIDs
CID -> IO ()
CID -> Connection -> IO ()
accTime :: Accept -> TimeMicrosecond
accAddressValidated :: Accept -> Bool
accUnregister :: Accept -> CID -> IO ()
accRegister :: Accept -> CID -> Connection -> IO ()
accPacketSize :: Accept -> Int
accRecvQ :: Accept -> RecvQ
accPeerSockAddr :: Accept -> ClientSockAddr
accMySocket :: Accept -> ListenSocket
accPeerAuthCIDs :: Accept -> AuthCIDs
accMyAuthCIDs :: Accept -> AuthCIDs
accTime :: TimeMicrosecond
accAddressValidated :: Bool
accUnregister :: CID -> IO ()
accRegister :: CID -> Connection -> IO ()
accPacketSize :: Int
accRecvQ :: RecvQ
accPeerSockAddr :: ClientSockAddr
accMySocket :: ListenSocket
accPeerAuthCIDs :: AuthCIDs
accMyAuthCIDs :: AuthCIDs
accVersionInfo :: VersionInfo
accVersionInfo :: Accept -> VersionInfo
..} ThreadId
baseThreadId = do
    UDPSocket
us <- ListenSocket -> ClientSockAddr -> IO UDPSocket
UDP.accept ListenSocket
accMySocket ClientSockAddr
accPeerSockAddr
    let ListenSocket Socket
_ SockAddr
mysa Bool
_ = ListenSocket
accMySocket
    IORef UDPSocket
sref <- forall a. a -> IO (IORef a)
newIORef UDPSocket
us
    let send :: Ptr Word8 -> Int -> IO ()
send Ptr Word8
buf Int
siz = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ do
            UDPSocket{Bool
SockAddr
Socket
udpSocket :: UDPSocket -> Socket
peerSockAddr :: UDPSocket -> SockAddr
connected :: UDPSocket -> Bool
connected :: Bool
peerSockAddr :: SockAddr
udpSocket :: Socket
..} <- forall a. IORef a -> IO a
readIORef IORef UDPSocket
sref
            Socket -> Ptr Word8 -> Int -> IO Int
NS.sendBuf Socket
udpSocket Ptr Word8
buf Int
siz
        recv :: IO ReceivedPacket
recv = RecvQ -> IO ReceivedPacket
recvServer RecvQ
accRecvQ
    let myCID :: CID
myCID = forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
initSrcCID AuthCIDs
accMyAuthCIDs
        ocid :: CID
ocid  = forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
origDstCID AuthCIDs
accMyAuthCIDs
    (QLogger
qLog, IO ()
qclean)     <- Maybe FilePath
-> TimeMicrosecond -> CID -> ByteString -> IO (QLogger, IO ())
dirQLogger Maybe FilePath
scQLog TimeMicrosecond
accTime CID
ocid ByteString
"server"
    (DebugLogger
debugLog, IO ()
dclean) <- Maybe FilePath -> CID -> IO (DebugLogger, IO ())
dirDebugLogger Maybe FilePath
scDebugLog CID
ocid
    DebugLogger
debugLog forall a b. (a -> b) -> a -> b
$ Builder
"Original CID: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Builder
bhow CID
ocid
    Connection
conn <- ServerConfig
-> VersionInfo
-> AuthCIDs
-> AuthCIDs
-> DebugLogger
-> QLogger
-> Hooks
-> IORef UDPSocket
-> RecvQ
-> (Ptr Word8 -> Int -> IO ())
-> IO ReceivedPacket
-> IO Connection
serverConnection ServerConfig
conf VersionInfo
accVersionInfo AuthCIDs
accMyAuthCIDs AuthCIDs
accPeerAuthCIDs DebugLogger
debugLog QLogger
qLog Hooks
scHooks IORef UDPSocket
sref RecvQ
accRecvQ Ptr Word8 -> Int -> IO ()
send IO ReceivedPacket
recv
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
qclean
    Connection -> IO () -> IO ()
addResource Connection
conn IO ()
dclean
    let cid :: CID
cid = forall a. a -> Maybe a -> a
fromMaybe CID
ocid forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
retrySrcCID AuthCIDs
accMyAuthCIDs
        ver :: Version
ver = VersionInfo -> Version
chosenVersion VersionInfo
accVersionInfo
    forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
ver CID
cid
    Connection -> IO ()
setupCryptoStreams Connection
conn -- fixme: cleanup
    let pktSiz :: Int
pktSiz = (SockAddr -> Int
defaultPacketSize SockAddr
mysa forall a. Ord a => a -> a -> a
`max` Int
accPacketSize) forall a. Ord a => a -> a -> a
`min` SockAddr -> Int
maximumPacketSize SockAddr
mysa
    Connection -> Int -> IO ()
setMaxPacketSize Connection
conn Int
pktSiz
    LDCC -> Int -> IO ()
setInitialCongestionWindow (Connection -> LDCC
connLDCC Connection
conn) Int
pktSiz
    DebugLogger
debugLog forall a b. (a -> b) -> a -> b
$ Builder
"Packet size: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Builder
bhow Int
pktSiz forall a. Semigroup a => a -> a -> a
<> Builder
" (" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Builder
bhow Int
accPacketSize forall a. Semigroup a => a -> a -> a
<> Builder
")"
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
accAddressValidated forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
setAddressValidated Connection
conn
    --
    let retried :: Bool
retried = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
retrySrcCID AuthCIDs
accMyAuthCIDs
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
retried forall a b. (a -> b) -> a -> b
$ do
        forall q. KeepQlog q => q -> IO ()
qlogRecvInitial Connection
conn
        forall q. KeepQlog q => q -> IO ()
qlogSentRetry Connection
conn
    --
    let mgr :: TokenManager
mgr = Dispatch -> TokenManager
tokenMgr Dispatch
dispatch
    Connection -> TokenManager -> IO ()
setTokenManager Connection
conn TokenManager
mgr
    --
    Connection -> ThreadId -> IO ()
setBaseThreadId Connection
conn ThreadId
baseThreadId
    --
    Connection
-> (CID -> Connection -> IO ()) -> (CID -> IO ()) -> IO ()
setRegister Connection
conn CID -> Connection -> IO ()
accRegister CID -> IO ()
accUnregister
    CID -> Connection -> IO ()
accRegister CID
myCID Connection
conn
    Connection -> IO () -> IO ()
addResource Connection
conn forall a b. (a -> b) -> a -> b
$ do
        [CID]
myCIDs <- Connection -> IO [CID]
getMyCIDs Connection
conn
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ CID -> IO ()
accUnregister [CID]
myCIDs
    --
#if defined(mingw32_HOST_OS)
    return $ ConnRes conn accMyAuthCIDs undefined
#else
    let reader :: IO ()
reader = UDPSocket -> Connection -> IO ()
readerServer UDPSocket
us Connection
conn -- dies when us is closed.
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Connection -> AuthCIDs -> IO () -> ConnRes
ConnRes Connection
conn AuthCIDs
accMyAuthCIDs IO ()
reader
#endif

afterHandshakeServer :: Connection -> IO ()
afterHandshakeServer :: Connection -> IO ()
afterHandshakeServer Connection
conn = forall a. DebugLogger -> IO a -> IO a
handleLogT DebugLogger
logAction forall a b. (a -> b) -> a -> b
$ do
    --
    CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
    CID -> Connection -> IO ()
register <- Connection -> IO (CID -> Connection -> IO ())
getRegister Connection
conn
    CID -> Connection -> IO ()
register (CIDInfo -> CID
cidInfoCID CIDInfo
cidInfo) Connection
conn
    --
    CryptoToken
cryptoToken <- Version -> IO CryptoToken
generateToken forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Connection -> IO Version
getVersion Connection
conn
    TokenManager
mgr <- Connection -> IO TokenManager
getTokenManager Connection
conn
    ByteString
token <- TokenManager -> CryptoToken -> IO ByteString
encryptToken TokenManager
mgr CryptoToken
cryptoToken
    let ncid :: Frame
ncid = CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
0
    Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [ByteString -> Frame
NewToken ByteString
token,Frame
ncid,Frame
HandshakeDone]
  where
    logAction :: DebugLogger
logAction Builder
msg = Connection -> DebugLogger
connDebugLog Connection
conn forall a b. (a -> b) -> a -> b
$ Builder
"afterHandshakeServer: " forall a. Semigroup a => a -> a -> a
<> Builder
msg

-- | Stopping the base thread of the server.
stop :: Connection -> IO ()
stop :: Connection -> IO ()
stop Connection
conn = Connection -> IO ThreadId
getBaseThreadId Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread