{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Closer (closure) where

import Foreign.Marshal.Alloc
import qualified Network.UDP as UDP
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import Foreign.Ptr

import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Logger
import Network.QUIC.Packet
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

closure :: Connection -> LDCC -> Either E.SomeException a -> IO a
closure :: forall a. Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc (Right a
x) = do
    Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc forall a b. (a -> b) -> a -> b
$ TransportError -> Int -> ReasonPhrase -> Frame
ConnectionClose TransportError
NoError Int
0 ReasonPhrase
""
    forall (m :: * -> *) a. Monad m => a -> m a
return a
x
closure Connection
conn LDCC
ldcc (Left SomeException
se)
  | Just e :: QUICException
e@(TransportErrorIsSent TransportError
err ReasonPhrase
desc) <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc forall a b. (a -> b) -> a -> b
$ TransportError -> Int -> ReasonPhrase -> Frame
ConnectionClose TransportError
err Int
0 ReasonPhrase
desc
        forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
e
  | Just e :: QUICException
e@(ApplicationProtocolErrorIsSent ApplicationProtocolError
err ReasonPhrase
desc) <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> Frame
ConnectionCloseApp ApplicationProtocolError
err ReasonPhrase
desc
        forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
e
  | Just (Abort ApplicationProtocolError
err ReasonPhrase
desc) <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> Frame
ConnectionCloseApp ApplicationProtocolError
err ReasonPhrase
desc
        forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> QUICException
ApplicationProtocolErrorIsSent ApplicationProtocolError
err ReasonPhrase
desc
  | Just (VerNego VersionInfo
vers) <- forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO forall a b. (a -> b) -> a -> b
$ VersionInfo -> NextVersion
NextVersion VersionInfo
vers
  | Bool
otherwise = forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO SomeException
se

closure' :: Connection -> LDCC -> Frame -> IO ()
closure' :: Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc Frame
frame = do
    Connection -> IO ()
killReaders Connection
conn
    let bufsiz :: Int
bufsiz = Int
maximumUdpPayloadSize
    Ptr Word8
sendBuf <- forall a. Int -> IO (Ptr a)
mallocBytes Int
bufsiz
    Ptr Word8
recvBuf <- forall a. Int -> IO (Ptr a)
mallocBytes Int
bufsiz
    Int
siz <- Connection -> SizedBuffer -> Frame -> IO Int
encodeCC Connection
conn (Ptr Word8 -> Int -> SizedBuffer
SizedBuffer Ptr Word8
sendBuf Int
bufsiz) Frame
frame
    UDPSocket
us <- Connection -> IO UDPSocket
getSocket Connection
conn
    let clos :: IO ()
clos = do
            UDPSocket -> IO ()
UDP.close UDPSocket
us
            -- This is just in case.
            -- UDP.close never throw exceptions.
            Connection -> IO UDPSocket
getSocket Connection
conn forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= UDPSocket -> IO ()
UDP.close
        send :: IO ()
send = UDPSocket -> Ptr Word8 -> Int -> IO ()
UDP.sendBuf UDPSocket
us Ptr Word8
sendBuf Int
siz
        recv :: IO Int
recv = UDPSocket -> Ptr Word8 -> Int -> IO Int
UDP.recvBuf UDPSocket
us Ptr Word8
recvBuf Int
bufsiz
        hook :: IO ()
hook = Hooks -> IO ()
onCloseCompleted forall a b. (a -> b) -> a -> b
$ Connection -> Hooks
connHooks Connection
conn
    Microseconds
pto <- LDCC -> IO Microseconds
getPTO LDCC
ldcc
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> (Either SomeException a -> m ()) -> m ThreadId
forkFinally (Connection -> Microseconds -> IO () -> IO Int -> IO () -> IO ()
closer Connection
conn Microseconds
pto IO ()
send IO Int
recv IO ()
hook) forall a b. (a -> b) -> a -> b
$ \Either SomeException ()
e -> do
        case Either SomeException ()
e of
          Left SomeException
e' ->  Connection -> DebugLogger
connDebugLog Connection
conn forall a b. (a -> b) -> a -> b
$ Builder
"closure' " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> Builder
bhow SomeException
e'
          Right ()
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
        forall a. Ptr a -> IO ()
free Ptr Word8
sendBuf
        forall a. Ptr a -> IO ()
free Ptr Word8
recvBuf
        IO ()
clos

encodeCC :: Connection -> SizedBuffer -> Frame -> IO Int
encodeCC :: Connection -> SizedBuffer -> Frame -> IO Int
encodeCC Connection
conn res0 :: SizedBuffer
res0@(SizedBuffer Ptr Word8
sendBuf0 Int
bufsiz0) Frame
frame = do
    EncryptionLevel
lvl0 <- forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel Connection
conn
    let lvl :: EncryptionLevel
lvl | EncryptionLevel
lvl0 forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT0Level = EncryptionLevel
InitialLevel
            | Bool
otherwise         = EncryptionLevel
lvl0
    if EncryptionLevel
lvl forall a. Eq a => a -> a -> Bool
== EncryptionLevel
HandshakeLevel then do
        Int
siz0 <- SizedBuffer -> EncryptionLevel -> IO Int
encCC SizedBuffer
res0 EncryptionLevel
InitialLevel
        let sendBuf1 :: Ptr b
sendBuf1 = Ptr Word8
sendBuf0 forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
siz0
            bufsiz1 :: Int
bufsiz1 = Int
bufsiz0 forall a. Num a => a -> a -> a
- Int
siz0
            res1 :: SizedBuffer
res1 = Ptr Word8 -> Int -> SizedBuffer
SizedBuffer forall {b}. Ptr b
sendBuf1 Int
bufsiz1
        Int
siz1 <- SizedBuffer -> EncryptionLevel -> IO Int
encCC SizedBuffer
res1 EncryptionLevel
HandshakeLevel
        forall (m :: * -> *) a. Monad m => a -> m a
return (Int
siz0 forall a. Num a => a -> a -> a
+ Int
siz1)
      else
        SizedBuffer -> EncryptionLevel -> IO Int
encCC SizedBuffer
res0 EncryptionLevel
lvl
  where
    encCC :: SizedBuffer -> EncryptionLevel -> IO Int
encCC SizedBuffer
res EncryptionLevel
lvl = do
        Header
header <- Connection -> EncryptionLevel -> IO Header
mkHeader Connection
conn EncryptionLevel
lvl
        Int
mypn <- Connection -> IO Int
nextPacketNumber Connection
conn
        let plain :: Plain
plain = Flags Raw -> Int -> [Frame] -> Int -> Plain
Plain (forall a. Word8 -> Flags a
Flags Word8
0) Int
mypn [Frame
frame] Int
0
            ppkt :: PlainPacket
ppkt = Header -> Plain -> PlainPacket
PlainPacket Header
header Plain
plain
        Int
siz <- forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection
-> SizedBuffer -> PlainPacket -> Maybe Int -> IO (Int, Int)
encodePlainPacket Connection
conn SizedBuffer
res PlainPacket
ppkt forall a. Maybe a
Nothing
        if Int
siz forall a. Ord a => a -> a -> Bool
>= Int
0 then do
            TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
            forall q pkt.
(KeepQlog q, Qlog pkt) =>
q -> pkt -> TimeMicrosecond -> IO ()
qlogSent Connection
conn PlainPacket
ppkt TimeMicrosecond
now
            forall (m :: * -> *) a. Monad m => a -> m a
return Int
siz
          else
            forall (m :: * -> *) a. Monad m => a -> m a
return Int
0

closer :: Connection -> Microseconds -> IO () -> IO Int -> IO () -> IO ()
closer :: Connection -> Microseconds -> IO () -> IO Int -> IO () -> IO ()
closer Connection
_conn (Microseconds Int
pto) IO ()
send IO Int
recv IO ()
hook
#if defined(mingw32_HOST_OS)
  | isServer _conn = send
#endif
  | Bool
otherwise      = forall {t}. (Eq t, Num t) => t -> IO ()
loop (Int
3 :: Int)
  where
    loop :: t -> IO ()
loop t
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loop t
n = do
        IO ()
send
        IO TimeMicrosecond
getTimeMicrosecond forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Microseconds -> TimeMicrosecond -> IO ()
skip (Int -> Microseconds
Microseconds Int
pto)
        Maybe Int
mx <- forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Int -> Microseconds
Microseconds (Int
pto forall a. Bits a => a -> Int -> a
!>>. Int
1)) String
"closer 1" IO Int
recv
        case Maybe Int
mx of
          Maybe Int
Nothing -> IO ()
hook
          Just Int
0  -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just Int
_  -> t -> IO ()
loop (t
n forall a. Num a => a -> a -> a
- t
1)
    skip :: Microseconds -> TimeMicrosecond -> IO ()
skip tmo :: Microseconds
tmo@(Microseconds Int
duration) TimeMicrosecond
base = do
        Maybe Int
mx <- forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout Microseconds
tmo String
"closer 2" IO Int
recv
        case Maybe Int
mx of
          Maybe Int
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just Int
0  -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just Int
_  -> do
              Microseconds Int
elapsed <- TimeMicrosecond -> IO Microseconds
getElapsedTimeMicrosecond TimeMicrosecond
base
              let duration' :: Int
duration' = Int
duration forall a. Num a => a -> a -> a
- Int
elapsed
              forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
duration' forall a. Ord a => a -> a -> Bool
>= Int
5000) forall a b. (a -> b) -> a -> b
$ Microseconds -> TimeMicrosecond -> IO ()
skip (Int -> Microseconds
Microseconds Int
duration') TimeMicrosecond
base