{-# LANGUAGE OverloadedStrings #-} module Network.QUIC.Closer (closure) where import Foreign.Marshal.Alloc import qualified Network.Socket as NS import UnliftIO.Concurrent import qualified UnliftIO.Exception as E import Foreign.Ptr import Network.QUIC.Config import Network.QUIC.Connection import Network.QUIC.Connector import Network.QUIC.Imports import Network.QUIC.Packet import Network.QUIC.Recovery import Network.QUIC.Sender import Network.QUIC.Types closure :: Connection -> LDCC -> Either E.SomeException a -> IO a closure :: Connection -> LDCC -> Either SomeException a -> IO a closure Connection conn LDCC ldcc (Right a x) = do Connection -> LDCC -> Frame -> IO () closure' Connection conn LDCC ldcc (Frame -> IO ()) -> Frame -> IO () forall a b. (a -> b) -> a -> b $ TransportError -> FrameType -> ReasonPhrase -> Frame ConnectionClose TransportError NoError FrameType 0 ReasonPhrase "" a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return a x closure Connection conn LDCC ldcc (Left SomeException se) | Just e :: QUICException e@(TransportErrorIsSent TransportError err ReasonPhrase desc) <- SomeException -> Maybe QUICException forall e. Exception e => SomeException -> Maybe e E.fromException SomeException se = do Connection -> LDCC -> Frame -> IO () closure' Connection conn LDCC ldcc (Frame -> IO ()) -> Frame -> IO () forall a b. (a -> b) -> a -> b $ TransportError -> FrameType -> ReasonPhrase -> Frame ConnectionClose TransportError err FrameType 0 ReasonPhrase desc QUICException -> IO a forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a E.throwIO QUICException e | Just e :: QUICException e@(ApplicationProtocolErrorIsSent ApplicationProtocolError err ReasonPhrase desc) <- SomeException -> Maybe QUICException forall e. Exception e => SomeException -> Maybe e E.fromException SomeException se = do Connection -> LDCC -> Frame -> IO () closure' Connection conn LDCC ldcc (Frame -> IO ()) -> Frame -> IO () forall a b. (a -> b) -> a -> b $ ApplicationProtocolError -> ReasonPhrase -> Frame ConnectionCloseApp ApplicationProtocolError err ReasonPhrase desc QUICException -> IO a forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a E.throwIO QUICException e | Just (Abort ApplicationProtocolError err ReasonPhrase desc) <- SomeException -> Maybe Abort forall e. Exception e => SomeException -> Maybe e E.fromException SomeException se = do Connection -> LDCC -> Frame -> IO () closure' Connection conn LDCC ldcc (Frame -> IO ()) -> Frame -> IO () forall a b. (a -> b) -> a -> b $ ApplicationProtocolError -> ReasonPhrase -> Frame ConnectionCloseApp ApplicationProtocolError err ReasonPhrase desc QUICException -> IO a forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a E.throwIO (QUICException -> IO a) -> QUICException -> IO a forall a b. (a -> b) -> a -> b $ ApplicationProtocolError -> ReasonPhrase -> QUICException ApplicationProtocolErrorIsSent ApplicationProtocolError err ReasonPhrase desc | Just (VerNego Maybe Version ver) <- SomeException -> Maybe Abort forall e. Exception e => SomeException -> Maybe e E.fromException SomeException se = do NextVersion -> IO a forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a E.throwIO (NextVersion -> IO a) -> NextVersion -> IO a forall a b. (a -> b) -> a -> b $ Maybe Version -> NextVersion NextVersion Maybe Version ver | Bool otherwise = SomeException -> IO a forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a E.throwIO SomeException se closure' :: Connection -> LDCC -> Frame -> IO () closure' :: Connection -> LDCC -> Frame -> IO () closure' Connection conn LDCC ldcc Frame frame = do Connection -> IO () killReaders Connection conn IO () killTimeouter <- Connection -> IO (IO ()) replaceKillTimeouter Connection conn socks :: [Socket] socks@(Socket s:[Socket] _) <- Connection -> IO [Socket] clearSockets Connection conn let bufsiz :: FrameType bufsiz = FrameType maximumUdpPayloadSize Ptr Word8 sendBuf <- FrameType -> IO (Ptr Word8) forall a. FrameType -> IO (Ptr a) mallocBytes (FrameType bufsiz FrameType -> FrameType -> FrameType forall a. Num a => a -> a -> a * FrameType 3) FrameType siz <- Connection -> Frame -> Ptr Word8 -> FrameType -> IO FrameType encodeCC Connection conn Frame frame Ptr Word8 sendBuf FrameType bufsiz let recvBuf :: Ptr b recvBuf = Ptr Word8 sendBuf Ptr Word8 -> FrameType -> Ptr b forall a b. Ptr a -> FrameType -> Ptr b `plusPtr` (FrameType bufsiz FrameType -> FrameType -> FrameType forall a. Num a => a -> a -> a * FrameType 2) recv :: IO FrameType recv = Socket -> Ptr Word8 -> FrameType -> IO FrameType NS.recvBuf Socket s Ptr Word8 forall b. Ptr b recvBuf FrameType bufsiz hook :: IO () hook = Hooks -> IO () onCloseCompleted (Hooks -> IO ()) -> Hooks -> IO () forall a b. (a -> b) -> a -> b $ Connection -> Hooks connHooks Connection conn IO FrameType send <- if Connection -> Bool forall a. Connector a => a -> Bool isClient Connection conn then do Maybe SockAddr msa <- Connection -> IO (Maybe SockAddr) getServerAddr Connection conn IO FrameType -> IO (IO FrameType) forall (m :: * -> *) a. Monad m => a -> m a return (IO FrameType -> IO (IO FrameType)) -> IO FrameType -> IO (IO FrameType) forall a b. (a -> b) -> a -> b $ case Maybe SockAddr msa of Maybe SockAddr Nothing -> Socket -> Ptr Word8 -> FrameType -> IO FrameType NS.sendBuf Socket s Ptr Word8 sendBuf FrameType siz Just SockAddr sa -> Socket -> Ptr Word8 -> FrameType -> SockAddr -> IO FrameType forall a. Socket -> Ptr a -> FrameType -> SockAddr -> IO FrameType NS.sendBufTo Socket s Ptr Word8 sendBuf FrameType siz SockAddr sa else IO FrameType -> IO (IO FrameType) forall (m :: * -> *) a. Monad m => a -> m a return (IO FrameType -> IO (IO FrameType)) -> IO FrameType -> IO (IO FrameType) forall a b. (a -> b) -> a -> b $ Socket -> Ptr Word8 -> FrameType -> IO FrameType NS.sendBuf Socket s Ptr Word8 sendBuf FrameType siz Microseconds pto <- LDCC -> IO Microseconds getPTO LDCC ldcc IO ThreadId -> IO () forall (f :: * -> *) a. Functor f => f a -> f () void (IO ThreadId -> IO ()) -> IO ThreadId -> IO () forall a b. (a -> b) -> a -> b $ IO () -> (Either SomeException () -> IO ()) -> IO ThreadId forall (m :: * -> *) a. MonadUnliftIO m => m a -> (Either SomeException a -> m ()) -> m ThreadId forkFinally (Microseconds -> IO FrameType -> IO FrameType -> IO () -> IO () closer Microseconds pto IO FrameType send IO FrameType recv IO () hook) ((Either SomeException () -> IO ()) -> IO ThreadId) -> (Either SomeException () -> IO ()) -> IO ThreadId forall a b. (a -> b) -> a -> b $ \Either SomeException () _ -> do Ptr Word8 -> IO () forall a. Ptr a -> IO () free Ptr Word8 sendBuf (Socket -> IO ()) -> [Socket] -> IO () forall (t :: * -> *) (m :: * -> *) a b. (Foldable t, Monad m) => (a -> m b) -> t a -> m () mapM_ Socket -> IO () NS.close [Socket] socks IO () killTimeouter encodeCC :: Connection -> Frame -> Buffer -> BufferSize -> IO Int encodeCC :: Connection -> Frame -> Ptr Word8 -> FrameType -> IO FrameType encodeCC Connection conn Frame frame Ptr Word8 sendBuf0 FrameType bufsiz0 = do EncryptionLevel lvl0 <- Connection -> IO EncryptionLevel forall a. Connector a => a -> IO EncryptionLevel getEncryptionLevel Connection conn let lvl :: EncryptionLevel lvl | EncryptionLevel lvl0 EncryptionLevel -> EncryptionLevel -> Bool forall a. Eq a => a -> a -> Bool == EncryptionLevel RTT0Level = EncryptionLevel InitialLevel | Bool otherwise = EncryptionLevel lvl0 if EncryptionLevel lvl EncryptionLevel -> EncryptionLevel -> Bool forall a. Eq a => a -> a -> Bool == EncryptionLevel HandshakeLevel then do FrameType siz0 <- Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType encCC Ptr Word8 sendBuf0 FrameType bufsiz0 EncryptionLevel InitialLevel let sendBuf1 :: Ptr b sendBuf1 = Ptr Word8 sendBuf0 Ptr Word8 -> FrameType -> Ptr b forall a b. Ptr a -> FrameType -> Ptr b `plusPtr` FrameType siz0 bufsiz1 :: FrameType bufsiz1 = FrameType bufsiz0 FrameType -> FrameType -> FrameType forall a. Num a => a -> a -> a - FrameType siz0 FrameType siz1 <- Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType encCC Ptr Word8 forall b. Ptr b sendBuf1 FrameType bufsiz1 EncryptionLevel HandshakeLevel FrameType -> IO FrameType forall (m :: * -> *) a. Monad m => a -> m a return (FrameType siz0 FrameType -> FrameType -> FrameType forall a. Num a => a -> a -> a + FrameType siz1) else Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType encCC Ptr Word8 sendBuf0 FrameType bufsiz0 EncryptionLevel lvl where encCC :: Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType encCC Ptr Word8 sendBuf FrameType bufsiz EncryptionLevel lvl = do Header header <- Connection -> EncryptionLevel -> IO Header mkHeader Connection conn EncryptionLevel lvl FrameType mypn <- Connection -> IO FrameType nextPacketNumber Connection conn let plain :: Plain plain = Flags Raw -> FrameType -> [Frame] -> FrameType -> Plain Plain (Word8 -> Flags Raw forall a. Word8 -> Flags a Flags Word8 0) FrameType mypn [Frame frame] FrameType 0 ppkt :: PlainPacket ppkt = Header -> Plain -> PlainPacket PlainPacket Header header Plain plain FrameType siz <- (FrameType, FrameType) -> FrameType forall a b. (a, b) -> a fst ((FrameType, FrameType) -> FrameType) -> IO (FrameType, FrameType) -> IO FrameType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Connection -> Ptr Word8 -> FrameType -> PlainPacket -> Maybe FrameType -> IO (FrameType, FrameType) encodePlainPacket Connection conn Ptr Word8 sendBuf FrameType bufsiz PlainPacket ppkt Maybe FrameType forall a. Maybe a Nothing if FrameType siz FrameType -> FrameType -> Bool forall a. Ord a => a -> a -> Bool >= FrameType 0 then do TimeMicrosecond now <- IO TimeMicrosecond getTimeMicrosecond Connection -> PlainPacket -> TimeMicrosecond -> IO () forall q pkt. (KeepQlog q, Qlog pkt) => q -> pkt -> TimeMicrosecond -> IO () qlogSent Connection conn PlainPacket ppkt TimeMicrosecond now FrameType -> IO FrameType forall (m :: * -> *) a. Monad m => a -> m a return FrameType siz else FrameType -> IO FrameType forall (m :: * -> *) a. Monad m => a -> m a return FrameType 0 closer :: Microseconds -> IO Int -> IO Int -> IO () -> IO () closer :: Microseconds -> IO FrameType -> IO FrameType -> IO () -> IO () closer (Microseconds FrameType pto) IO FrameType send IO FrameType recv IO () hook = FrameType -> IO () forall t. (Eq t, Num t) => t -> IO () loop (FrameType 3 :: Int) where loop :: t -> IO () loop t 0 = () -> IO () forall (m :: * -> *) a. Monad m => a -> m a return () loop t n = do FrameType _ <- IO FrameType send IO TimeMicrosecond getTimeMicrosecond IO TimeMicrosecond -> (TimeMicrosecond -> IO ()) -> IO () forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= Microseconds -> TimeMicrosecond -> IO () skip (FrameType -> Microseconds Microseconds FrameType pto) Maybe FrameType mx <- Microseconds -> IO FrameType -> IO (Maybe FrameType) forall a. Microseconds -> IO a -> IO (Maybe a) timeout (FrameType -> Microseconds Microseconds (FrameType pto FrameType -> FrameType -> FrameType forall a. Bits a => a -> FrameType -> a .>>. FrameType 1)) IO FrameType recv case Maybe FrameType mx of Maybe FrameType Nothing -> IO () hook Just FrameType 0 -> () -> IO () forall (m :: * -> *) a. Monad m => a -> m a return () Just FrameType _ -> t -> IO () loop (t n t -> t -> t forall a. Num a => a -> a -> a - t 1) skip :: Microseconds -> TimeMicrosecond -> IO () skip tmo :: Microseconds tmo@(Microseconds FrameType duration) TimeMicrosecond base = do Maybe FrameType mx <- Microseconds -> IO FrameType -> IO (Maybe FrameType) forall a. Microseconds -> IO a -> IO (Maybe a) timeout Microseconds tmo IO FrameType recv case Maybe FrameType mx of Maybe FrameType Nothing -> () -> IO () forall (m :: * -> *) a. Monad m => a -> m a return () Just FrameType 0 -> () -> IO () forall (m :: * -> *) a. Monad m => a -> m a return () Just FrameType _ -> do Microseconds FrameType elapsed <- TimeMicrosecond -> IO Microseconds getElapsedTimeMicrosecond TimeMicrosecond base let duration' :: FrameType duration' = FrameType duration FrameType -> FrameType -> FrameType forall a. Num a => a -> a -> a - FrameType elapsed Bool -> IO () -> IO () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (FrameType duration' FrameType -> FrameType -> Bool forall a. Ord a => a -> a -> Bool >= FrameType 5000) (IO () -> IO ()) -> IO () -> IO () forall a b. (a -> b) -> a -> b $ Microseconds -> TimeMicrosecond -> IO () skip (FrameType -> Microseconds Microseconds FrameType duration') TimeMicrosecond base