module Network.QUIC.Connector where

import Data.IORef
import Network.QUIC.Types
import UnliftIO.STM

class Connector a where
    getRole :: a -> Role
    getEncryptionLevel :: a -> IO EncryptionLevel
    getMaxPacketSize :: a -> IO Int
    getConnectionState :: a -> IO ConnectionState
    getPacketNumber :: a -> IO PacketNumber
    getAlive :: a -> IO Bool

----------------------------------------------------------------

data ConnState = ConnState
    { ConnState -> Role
role :: Role
    , ConnState -> TVar ConnectionState
connectionState :: TVar ConnectionState
    , ConnState -> IORef PacketNumber
packetNumber :: IORef PacketNumber -- squeezing three to one
    , ConnState -> TVar EncryptionLevel
encryptionLevel :: TVar EncryptionLevel -- to synchronize
    , ConnState -> IORef PacketNumber
maxPacketSize :: IORef Int
    , -- Explicitly separated from 'ConnectionState'
      -- It seems that STM triggers a dead-lock if
      -- it is used in the close function of bracket.
      ConnState -> IORef Bool
connectionAlive :: IORef Bool
    }

newConnState :: Role -> IO ConnState
newConnState :: Role -> IO ConnState
newConnState Role
rl =
    Role
-> TVar ConnectionState
-> IORef PacketNumber
-> TVar EncryptionLevel
-> IORef PacketNumber
-> IORef Bool
-> ConnState
ConnState Role
rl
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO ConnectionState
Handshaking
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef PacketNumber
0
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO EncryptionLevel
InitialLevel
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef PacketNumber
defaultQUICPacketSize
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Bool
True

----------------------------------------------------------------

data Role = Client | Server deriving (Role -> Role -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Role -> Role -> Bool
$c/= :: Role -> Role -> Bool
== :: Role -> Role -> Bool
$c== :: Role -> Role -> Bool
Eq, PacketNumber -> Role -> ShowS
[Role] -> ShowS
Role -> String
forall a.
(PacketNumber -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Role] -> ShowS
$cshowList :: [Role] -> ShowS
show :: Role -> String
$cshow :: Role -> String
showsPrec :: PacketNumber -> Role -> ShowS
$cshowsPrec :: PacketNumber -> Role -> ShowS
Show)

isClient :: Connector a => a -> Bool
isClient :: forall a. Connector a => a -> Bool
isClient a
conn = forall a. Connector a => a -> Role
getRole a
conn forall a. Eq a => a -> a -> Bool
== Role
Client

isServer :: Connector a => a -> Bool
isServer :: forall a. Connector a => a -> Bool
isServer a
conn = forall a. Connector a => a -> Role
getRole a
conn forall a. Eq a => a -> a -> Bool
== Role
Server

----------------------------------------------------------------

data ConnectionState
    = Handshaking
    | ReadyFor0RTT
    | ReadyFor1RTT
    | Established
    deriving (ConnectionState -> ConnectionState -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionState -> ConnectionState -> Bool
$c/= :: ConnectionState -> ConnectionState -> Bool
== :: ConnectionState -> ConnectionState -> Bool
$c== :: ConnectionState -> ConnectionState -> Bool
Eq, Eq ConnectionState
ConnectionState -> ConnectionState -> Bool
ConnectionState -> ConnectionState -> Ordering
ConnectionState -> ConnectionState -> ConnectionState
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ConnectionState -> ConnectionState -> ConnectionState
$cmin :: ConnectionState -> ConnectionState -> ConnectionState
max :: ConnectionState -> ConnectionState -> ConnectionState
$cmax :: ConnectionState -> ConnectionState -> ConnectionState
>= :: ConnectionState -> ConnectionState -> Bool
$c>= :: ConnectionState -> ConnectionState -> Bool
> :: ConnectionState -> ConnectionState -> Bool
$c> :: ConnectionState -> ConnectionState -> Bool
<= :: ConnectionState -> ConnectionState -> Bool
$c<= :: ConnectionState -> ConnectionState -> Bool
< :: ConnectionState -> ConnectionState -> Bool
$c< :: ConnectionState -> ConnectionState -> Bool
compare :: ConnectionState -> ConnectionState -> Ordering
$ccompare :: ConnectionState -> ConnectionState -> Ordering
Ord, PacketNumber -> ConnectionState -> ShowS
[ConnectionState] -> ShowS
ConnectionState -> String
forall a.
(PacketNumber -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionState] -> ShowS
$cshowList :: [ConnectionState] -> ShowS
show :: ConnectionState -> String
$cshow :: ConnectionState -> String
showsPrec :: PacketNumber -> ConnectionState -> ShowS
$cshowsPrec :: PacketNumber -> ConnectionState -> ShowS
Show)

isConnectionEstablished :: Connector a => a -> IO Bool
isConnectionEstablished :: forall a. Connector a => a -> IO Bool
isConnectionEstablished a
conn = do
    ConnectionState
st <- forall a. Connector a => a -> IO ConnectionState
getConnectionState a
conn
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case ConnectionState
st of
        ConnectionState
Established -> Bool
True
        ConnectionState
_ -> Bool
False