{-# LANGUAGE NamedFieldPuns #-}

module Network.HTTP2.Arch.Context where

import Data.IORef
import Network.HTTP.Types (Method)
import UnliftIO.STM

import Imports hiding (insert)
import Network.HPACK
import Network.HTTP2.Arch.Cache (Cache, emptyCache)
import qualified Network.HTTP2.Arch.Cache as Cache
import Network.HTTP2.Arch.Rate
import Network.HTTP2.Arch.Stream
import Network.HTTP2.Arch.Types
import Network.HTTP2.Frame

data Role = Client | Server deriving (Role -> Role -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Role -> Role -> Bool
$c/= :: Role -> Role -> Bool
== :: Role -> Role -> Bool
$c== :: Role -> Role -> Bool
Eq,Int -> Role -> ShowS
[Role] -> ShowS
Role -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Role] -> ShowS
$cshowList :: [Role] -> ShowS
show :: Role -> String
$cshow :: Role -> String
showsPrec :: Int -> Role -> ShowS
$cshowsPrec :: Int -> Role -> ShowS
Show)

----------------------------------------------------------------

data RoleInfo = RIS ServerInfo | RIC ClientInfo

data ServerInfo = ServerInfo {
    ServerInfo -> TQueue (Input Stream)
inputQ :: TQueue (Input Stream)
  }

data ClientInfo = ClientInfo {
    ClientInfo -> Method
scheme    :: ByteString
  , ClientInfo -> Method
authority :: ByteString
  , ClientInfo -> IORef (Cache (Method, Method) Stream)
cache     :: IORef (Cache (Method,ByteString) Stream)
  }

toServerInfo :: RoleInfo -> ServerInfo
toServerInfo :: RoleInfo -> ServerInfo
toServerInfo (RIS ServerInfo
x) = ServerInfo
x
toServerInfo RoleInfo
_       = forall a. HasCallStack => String -> a
error String
"toServerInfo"

toClientInfo :: RoleInfo -> ClientInfo
toClientInfo :: RoleInfo -> ClientInfo
toClientInfo (RIC ClientInfo
x) = ClientInfo
x
toClientInfo RoleInfo
_       = forall a. HasCallStack => String -> a
error String
"toClientInfo"

newServerInfo :: IO RoleInfo
newServerInfo :: IO RoleInfo
newServerInfo = ServerInfo -> RoleInfo
RIS forall b c a. (b -> c) -> (a -> b) -> a -> c
. TQueue (Input Stream) -> ServerInfo
ServerInfo forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadIO m => m (TQueue a)
newTQueueIO

newClientInfo :: ByteString -> ByteString -> Int -> IO RoleInfo
newClientInfo :: Method -> Method -> Int -> IO RoleInfo
newClientInfo Method
scm Method
auth Int
lim =  ClientInfo -> RoleInfo
RIC forall b c a. (b -> c) -> (a -> b) -> a -> c
. Method
-> Method -> IORef (Cache (Method, Method) Stream) -> ClientInfo
ClientInfo Method
scm Method
auth forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef (forall k v. Int -> Cache k v
emptyCache Int
lim)

insertCache :: Method -> ByteString -> Stream -> RoleInfo -> IO ()
insertCache :: Method -> Method -> Stream -> RoleInfo -> IO ()
insertCache Method
m Method
path Stream
v (RIC (ClientInfo Method
_ Method
_ IORef (Cache (Method, Method) Stream)
ref)) = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (Cache (Method, Method) Stream)
ref forall a b. (a -> b) -> a -> b
$ \Cache (Method, Method) Stream
c ->
  (forall k v. Ord k => k -> v -> Cache k v -> Cache k v
Cache.insert (Method
m,Method
path) Stream
v Cache (Method, Method) Stream
c, ())
insertCache Method
_ Method
_ Stream
_ RoleInfo
_ = forall a. HasCallStack => String -> a
error String
"insertCache"

lookupCache :: Method -> ByteString -> RoleInfo -> IO (Maybe Stream)
lookupCache :: Method -> Method -> RoleInfo -> IO (Maybe Stream)
lookupCache Method
m Method
path (RIC (ClientInfo Method
_ Method
_ IORef (Cache (Method, Method) Stream)
ref)) = forall k v. Ord k => k -> Cache k v -> Maybe v
Cache.lookup (Method
m,Method
path) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef (Cache (Method, Method) Stream)
ref
lookupCache Method
_ Method
_ RoleInfo
_ = forall a. HasCallStack => String -> a
error String
"lookupCache"

----------------------------------------------------------------

-- | The context for HTTP/2 connection.
data Context = Context {
    Context -> Role
role               :: Role
  , Context -> RoleInfo
roleInfo           :: RoleInfo
  -- HTTP/2 settings received from a browser
  , Context -> IORef Settings
http2settings      :: IORef Settings
  , Context -> IORef Bool
firstSettings      :: IORef Bool
  , Context -> StreamTable
streamTable        :: StreamTable
  , Context -> IORef Int
concurrency        :: IORef Int
  -- | RFC 9113 says "Other frames (from any stream) MUST NOT
  --   occur between the HEADERS frame and any CONTINUATION
  --   frames that might follow". This field is used to implement
  --   this requirement.
  , Context -> IORef (Maybe Int)
continued          :: IORef (Maybe StreamId)
  , Context -> IORef Int
myStreamId         :: IORef StreamId
  , Context -> IORef Int
peerStreamId       :: IORef StreamId
  , Context -> TQueue (Output Stream)
outputQ            :: TQueue (Output Stream)
  , Context -> TQueue Control
controlQ           :: TQueue Control
  , Context -> DynamicTable
encodeDynamicTable :: DynamicTable
  , Context -> DynamicTable
decodeDynamicTable :: DynamicTable
  -- the connection window for data from a server to a browser.
  , Context -> TVar Int
connectionWindow   :: TVar WindowSize
  , Context -> Rate
pingRate           :: Rate
  , Context -> Rate
settingsRate       :: Rate
  , Context -> Rate
emptyFrameRate     :: Rate
  }

----------------------------------------------------------------

newContext :: RoleInfo -> IO Context
newContext :: RoleInfo -> IO Context
newContext RoleInfo
rinfo =
    Role
-> RoleInfo
-> IORef Settings
-> IORef Bool
-> StreamTable
-> IORef Int
-> IORef (Maybe Int)
-> IORef Int
-> IORef Int
-> TQueue (Output Stream)
-> TQueue Control
-> DynamicTable
-> DynamicTable
-> TVar Int
-> Rate
-> Rate
-> Rate
-> Context
Context Role
rl RoleInfo
rinfo
               forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef Settings
defaultSettings
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Bool
False
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO StreamTable
newStreamTable
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
sid0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Int
0
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => m (TQueue a)
newTQueueIO
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => m (TQueue a)
newTQueueIO
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> IO DynamicTable
newDynamicTableForEncoding Int
defaultDynamicTableSize
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Int -> IO DynamicTable
newDynamicTableForDecoding Int
defaultDynamicTableSize Int
4096
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Int
defaultInitialWindowSize
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Rate
newRate
   where
     rl :: Role
rl = case RoleInfo
rinfo of
       RIC{} -> Role
Client
       RoleInfo
_     -> Role
Server
     sid0 :: Int
sid0 | Role
rl forall a. Eq a => a -> a -> Bool
== Role
Client = Int
1
          | Bool
otherwise    = Int
2

----------------------------------------------------------------

isClient :: Context -> Bool
isClient :: Context -> Bool
isClient Context
ctx = Context -> Role
role Context
ctx forall a. Eq a => a -> a -> Bool
== Role
Client

isServer :: Context -> Bool
isServer :: Context -> Bool
isServer Context
ctx = Context -> Role
role Context
ctx forall a. Eq a => a -> a -> Bool
== Role
Server

----------------------------------------------------------------

getMyNewStreamId :: Context -> IO StreamId
getMyNewStreamId :: Context -> IO Int
getMyNewStreamId Context
ctx = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' (Context -> IORef Int
myStreamId Context
ctx) forall {b}. Num b => b -> (b, b)
inc2
  where
    inc2 :: b -> (b, b)
inc2 b
n = let n' :: b
n' = b
n forall a. Num a => a -> a -> a
+ b
2 in (b
n', b
n)

getPeerStreamID :: Context -> IO StreamId
getPeerStreamID :: Context -> IO Int
getPeerStreamID Context
ctx = forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ Context -> IORef Int
peerStreamId Context
ctx

setPeerStreamID :: Context -> StreamId -> IO ()
setPeerStreamID :: Context -> Int -> IO ()
setPeerStreamID Context
ctx Int
sid =  forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Int
peerStreamId Context
ctx) Int
sid

----------------------------------------------------------------

{-# INLINE setStreamState #-}
setStreamState :: Context -> Stream -> StreamState -> IO ()
setStreamState :: Context -> Stream -> StreamState -> IO ()
setStreamState Context
_ Stream{IORef StreamState
streamState :: Stream -> IORef StreamState
streamState :: IORef StreamState
streamState} StreamState
val = forall a. IORef a -> a -> IO ()
writeIORef IORef StreamState
streamState StreamState
val

opened :: Context -> Stream -> IO ()
opened :: Context -> Stream -> IO ()
opened ctx :: Context
ctx@Context{IORef Int
concurrency :: IORef Int
concurrency :: Context -> IORef Int
concurrency} Stream
strm = do
    forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
concurrency (\Int
x -> (Int
xforall a. Num a => a -> a -> a
+Int
1,()))
    Context -> Stream -> StreamState -> IO ()
setStreamState Context
ctx Stream
strm (OpenState -> StreamState
Open OpenState
JustOpened)

halfClosedRemote :: Context -> Stream -> IO ()
halfClosedRemote :: Context -> Stream -> IO ()
halfClosedRemote Context
ctx stream :: Stream
stream@Stream{IORef StreamState
streamState :: IORef StreamState
streamState :: Stream -> IORef StreamState
streamState} = do
    Maybe ClosedCode
closingCode <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef StreamState
streamState StreamState -> (StreamState, Maybe ClosedCode)
closeHalf
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Context -> Stream -> ClosedCode -> IO ()
closed Context
ctx Stream
stream) Maybe ClosedCode
closingCode
  where
    closeHalf :: StreamState -> (StreamState, Maybe ClosedCode)
    closeHalf :: StreamState -> (StreamState, Maybe ClosedCode)
closeHalf x :: StreamState
x@(Closed ClosedCode
_)         = (StreamState
x, forall a. Maybe a
Nothing)
    closeHalf (HalfClosedLocal ClosedCode
cc) = (ClosedCode -> StreamState
Closed ClosedCode
cc, forall a. a -> Maybe a
Just ClosedCode
cc)
    closeHalf StreamState
_                    = (StreamState
HalfClosedRemote, forall a. Maybe a
Nothing)

halfClosedLocal :: Context -> Stream -> ClosedCode -> IO ()
halfClosedLocal :: Context -> Stream -> ClosedCode -> IO ()
halfClosedLocal Context
ctx stream :: Stream
stream@Stream{IORef StreamState
streamState :: IORef StreamState
streamState :: Stream -> IORef StreamState
streamState} ClosedCode
cc = do
    Bool
shouldFinalize <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef StreamState
streamState StreamState -> (StreamState, Bool)
closeHalf
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
shouldFinalize forall a b. (a -> b) -> a -> b
$
        Context -> Stream -> ClosedCode -> IO ()
closed Context
ctx Stream
stream ClosedCode
cc
  where
    closeHalf :: StreamState -> (StreamState, Bool)
    closeHalf :: StreamState -> (StreamState, Bool)
closeHalf x :: StreamState
x@(Closed ClosedCode
_)     = (StreamState
x, Bool
False)
    closeHalf StreamState
HalfClosedRemote = (ClosedCode -> StreamState
Closed ClosedCode
cc, Bool
True)
    closeHalf StreamState
_                = (ClosedCode -> StreamState
HalfClosedLocal ClosedCode
cc, Bool
False)

closed :: Context -> Stream -> ClosedCode -> IO ()
closed :: Context -> Stream -> ClosedCode -> IO ()
closed ctx :: Context
ctx@Context{IORef Int
concurrency :: IORef Int
concurrency :: Context -> IORef Int
concurrency,StreamTable
streamTable :: StreamTable
streamTable :: Context -> StreamTable
streamTable} strm :: Stream
strm@Stream{Int
streamNumber :: Stream -> Int
streamNumber :: Int
streamNumber} ClosedCode
cc = do
    StreamTable -> Int -> IO ()
remove StreamTable
streamTable Int
streamNumber
    -- TODO: prevent double-counting
    forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
concurrency (\Int
x -> (Int
xforall a. Num a => a -> a -> a
-Int
1,()))
    Context -> Stream -> StreamState -> IO ()
setStreamState Context
ctx Stream
strm (ClosedCode -> StreamState
Closed ClosedCode
cc) -- anyway

openStream :: Context -> StreamId -> FrameType -> IO Stream
openStream :: Context -> Int -> FrameType -> IO Stream
openStream ctx :: Context
ctx@Context{StreamTable
streamTable :: StreamTable
streamTable :: Context -> StreamTable
streamTable, IORef Settings
http2settings :: IORef Settings
http2settings :: Context -> IORef Settings
http2settings} Int
sid FrameType
ftyp = do
    Int
ws <- Settings -> Int
initialWindowSize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. IORef a -> IO a
readIORef IORef Settings
http2settings
    Stream
newstrm <- Int -> Int -> IO Stream
newStream Int
sid forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ws
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (FrameType
ftyp forall a. Eq a => a -> a -> Bool
== FrameType
FrameHeaders Bool -> Bool -> Bool
|| FrameType
ftyp forall a. Eq a => a -> a -> Bool
== FrameType
FramePushPromise) forall a b. (a -> b) -> a -> b
$ Context -> Stream -> IO ()
opened Context
ctx Stream
newstrm
    StreamTable -> Int -> Stream -> IO ()
insert StreamTable
streamTable Int
sid Stream
newstrm
    forall (m :: * -> *) a. Monad m => a -> m a
return Stream
newstrm