{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}

module Network.HTTP2.Client.FrameConnection (
      Http2FrameConnection(..)
    , newHttp2FrameConnection
    , frameHttp2RawConnection
    -- * Interact at the Frame level.
    , Http2ServerStream(..)
    , Http2FrameClientStream(..)
    , makeFrameClientStream
    , sendOne
    , sendBackToBack
    , next
    , closeConnection
    ) where

import           Control.DeepSeq (deepseq)
import           Control.Exception.Lifted (bracket)
import           Control.Concurrent.MVar.Lifted (newMVar, takeMVar, putMVar)
import           Control.Monad ((>=>), void, when)
import qualified Data.ByteString as ByteString
import           Network.HTTP2.Frame (FrameHeader(..), FrameFlags, FramePayload, FrameDecodeError, encodeInfo, decodeFramePayload)
import qualified Network.HTTP2.Frame as HTTP2
import           Network.Socket (HostName, PortNumber)
import qualified Network.TLS as TLS

import           Network.HTTP2.Client.Exceptions
import           Network.HTTP2.Client.RawConnection

data Http2FrameConnection = Http2FrameConnection {
    Http2FrameConnection -> StreamId -> Http2FrameClientStream
_makeFrameClientStream :: HTTP2.StreamId -> Http2FrameClientStream
  -- ^ Starts a new client stream.
  , Http2FrameConnection -> Http2ServerStream
_serverStream     :: Http2ServerStream
  -- ^ Receives frames from a server.
  , Http2FrameConnection -> ClientIO ()
_closeConnection  :: ClientIO ()
  -- ^ Function that will close the network connection.
  }

-- | Closes the Http2FrameConnection abruptly.
closeConnection :: Http2FrameConnection -> ClientIO ()
closeConnection :: Http2FrameConnection -> ClientIO ()
closeConnection = Http2FrameConnection -> ClientIO ()
_closeConnection

-- | Creates a client stream.
makeFrameClientStream :: Http2FrameConnection
                      -> HTTP2.StreamId
                      -> Http2FrameClientStream
makeFrameClientStream :: Http2FrameConnection -> StreamId -> Http2FrameClientStream
makeFrameClientStream = Http2FrameConnection -> StreamId -> Http2FrameClientStream
_makeFrameClientStream

data Http2FrameClientStream = Http2FrameClientStream {
    Http2FrameClientStream
-> ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
-> ClientIO ()
_sendFrames :: ClientIO [(FrameFlags -> FrameFlags, FramePayload)] -> ClientIO ()
  -- ^ Sends a frame to the server.
  -- The first argument is a FrameFlags modifier (e.g., to sed the
  -- end-of-stream flag).
  , Http2FrameClientStream -> StreamId
_getStreamId :: HTTP2.StreamId -- TODO: hide me
  }

-- | Sends a frame to the server.
sendOne :: Http2FrameClientStream -> (FrameFlags -> FrameFlags) -> FramePayload -> ClientIO ()
sendOne :: Http2FrameClientStream
-> (FrameFlags -> FrameFlags) -> FramePayload -> ClientIO ()
sendOne Http2FrameClientStream
client FrameFlags -> FrameFlags
f FramePayload
payload = Http2FrameClientStream
-> ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
-> ClientIO ()
_sendFrames Http2FrameClientStream
client ([(FrameFlags -> FrameFlags, FramePayload)]
-> ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
forall a. a -> ExceptT ClientError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(FrameFlags -> FrameFlags
f, FramePayload
payload)])

-- | Sends multiple back-to-back frames to the server.
sendBackToBack :: Http2FrameClientStream -> [(FrameFlags -> FrameFlags, FramePayload)] -> ClientIO ()
sendBackToBack :: Http2FrameClientStream
-> [(FrameFlags -> FrameFlags, FramePayload)] -> ClientIO ()
sendBackToBack Http2FrameClientStream
client [(FrameFlags -> FrameFlags, FramePayload)]
payloads = Http2FrameClientStream
-> ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
-> ClientIO ()
_sendFrames Http2FrameClientStream
client ([(FrameFlags -> FrameFlags, FramePayload)]
-> ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
forall a. a -> ExceptT ClientError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(FrameFlags -> FrameFlags, FramePayload)]
payloads)

data Http2ServerStream = Http2ServerStream {
    Http2ServerStream
-> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
_nextHeaderAndFrame :: ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
  }

-- | Waits for the next frame from the server.
next :: Http2FrameConnection -> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
next :: Http2FrameConnection
-> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
next = Http2ServerStream
-> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
_nextHeaderAndFrame (Http2ServerStream
 -> ClientIO (FrameHeader, Either FrameDecodeError FramePayload))
-> (Http2FrameConnection -> Http2ServerStream)
-> Http2FrameConnection
-> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Http2FrameConnection -> Http2ServerStream
_serverStream

-- | Adds framing around a 'RawHttp2Connection'.
frameHttp2RawConnection
  :: RawHttp2Connection
  -> ClientIO Http2FrameConnection
frameHttp2RawConnection :: RawHttp2Connection -> ClientIO Http2FrameConnection
frameHttp2RawConnection RawHttp2Connection
http2conn = do
    -- Prepare a local mutex, this mutex should never escape the
    -- function's scope. Else it might lead to bugs (e.g.,
    -- https://ro-che.info/articles/2014-07-30-bracket )
    MVar ()
writerMutex <- () -> ExceptT ClientError IO (MVar ())
forall (m :: * -> *) a. MonadBase IO m => a -> m (MVar a)
newMVar ()

    let writeProtect :: m c -> m c
writeProtect m c
io =
            m () -> (() -> m ()) -> (() -> m c) -> m c
forall (m :: * -> *) a b c.
MonadBaseControl IO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (MVar () -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> m a
takeMVar MVar ()
writerMutex) (MVar () -> () -> m ()
forall (m :: * -> *) a. MonadBase IO m => MVar a -> a -> m ()
putMVar MVar ()
writerMutex) (m c -> () -> m c
forall a b. a -> b -> a
const m c
io)

    -- Define handlers.
    let makeClientStream :: StreamId -> Http2FrameClientStream
makeClientStream StreamId
streamID =
            let putFrame :: (FrameFlags -> FrameFlags) -> FramePayload -> ByteString
putFrame FrameFlags -> FrameFlags
modifyFF FramePayload
frame =
                    let info :: EncodeInfo
info = (FrameFlags -> FrameFlags) -> StreamId -> EncodeInfo
encodeInfo FrameFlags -> FrameFlags
modifyFF StreamId
streamID
                    in EncodeInfo -> FramePayload -> ByteString
HTTP2.encodeFrame EncodeInfo
info FramePayload
frame
                putFrames :: ClientIO [(FrameFlags -> FrameFlags, FramePayload)] -> ClientIO ()
putFrames ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
f = ClientIO () -> ClientIO ()
forall {m :: * -> *} {c}. MonadBaseControl IO m => m c -> m c
writeProtect (ClientIO () -> ClientIO ())
-> (ClientIO () -> ClientIO ()) -> ClientIO () -> ClientIO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientIO () -> ClientIO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ClientIO () -> ClientIO ()) -> ClientIO () -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ do
                    [(FrameFlags -> FrameFlags, FramePayload)]
xs <- ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
f
                    let ys :: [ByteString]
ys = ((FrameFlags -> FrameFlags, FramePayload) -> ByteString)
-> [(FrameFlags -> FrameFlags, FramePayload)] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((FrameFlags -> FrameFlags) -> FramePayload -> ByteString)
-> (FrameFlags -> FrameFlags, FramePayload) -> ByteString
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (FrameFlags -> FrameFlags) -> FramePayload -> ByteString
putFrame) [(FrameFlags -> FrameFlags, FramePayload)]
xs
                    -- Force evaluation of frames serialization whilst
                    -- write-protected to avoid out-of-order errrors.
                    [ByteString] -> ClientIO () -> ClientIO ()
forall a b. NFData a => a -> b -> b
deepseq [ByteString]
ys (RawHttp2Connection -> [ByteString] -> ClientIO ()
_sendRaw RawHttp2Connection
http2conn [ByteString]
ys)
             in (ClientIO [(FrameFlags -> FrameFlags, FramePayload)]
 -> ClientIO ())
-> StreamId -> Http2FrameClientStream
Http2FrameClientStream ClientIO [(FrameFlags -> FrameFlags, FramePayload)] -> ClientIO ()
putFrames StreamId
streamID

        nextServerFrameChunk :: Http2ServerStream
nextServerFrameChunk = ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
-> Http2ServerStream
Http2ServerStream (ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
 -> Http2ServerStream)
-> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
-> Http2ServerStream
forall a b. (a -> b) -> a -> b
$ do
            ByteString
b9 <- RawHttp2Connection -> StreamId -> ClientIO ByteString
_nextRaw RawHttp2Connection
http2conn StreamId
9
            Bool -> ClientIO () -> ClientIO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> StreamId
ByteString.length ByteString
b9 StreamId -> StreamId -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamId
9) (ClientIO () -> ClientIO ()) -> ClientIO () -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ ClientError -> ClientIO ()
forall a. ClientError -> ExceptT ClientError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ClientError -> ClientIO ()) -> ClientError -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ ClientError
EarlyEndOfStream
            let (FrameType
fTy, fh :: FrameHeader
fh@FrameHeader{StreamId
FrameFlags
payloadLength :: StreamId
flags :: FrameFlags
streamId :: StreamId
payloadLength :: FrameHeader -> StreamId
flags :: FrameHeader -> FrameFlags
streamId :: FrameHeader -> StreamId
..}) = ByteString -> (FrameType, FrameHeader)
HTTP2.decodeFrameHeader ByteString
b9
            let decoder :: FramePayloadDecoder
decoder = FrameType -> FramePayloadDecoder
decodeFramePayload FrameType
fTy
            ByteString
buf <- RawHttp2Connection -> StreamId -> ClientIO ByteString
_nextRaw RawHttp2Connection
http2conn StreamId
payloadLength
            Bool -> ClientIO () -> ClientIO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> StreamId
ByteString.length ByteString
buf StreamId -> StreamId -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamId
payloadLength) (ClientIO () -> ClientIO ()) -> ClientIO () -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ ClientError -> ClientIO ()
forall a. ClientError -> ExceptT ClientError IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ClientError -> ClientIO ()) -> ClientError -> ClientIO ()
forall a b. (a -> b) -> a -> b
$ ClientError
EarlyEndOfStream
            -- TODO: consider splitting the iteration here to give a chance to
            -- _not_ decode the frame, or consider lazyness enough.
            let nf :: Either FrameDecodeError FramePayload
nf = FramePayloadDecoder
decoder FrameHeader
fh ByteString
buf
            (FrameHeader, Either FrameDecodeError FramePayload)
-> ClientIO (FrameHeader, Either FrameDecodeError FramePayload)
forall a. a -> ExceptT ClientError IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FrameHeader
fh, Either FrameDecodeError FramePayload
nf)

        gtfo :: ClientIO ()
gtfo = RawHttp2Connection -> ClientIO ()
_close RawHttp2Connection
http2conn

    Http2FrameConnection -> ClientIO Http2FrameConnection
forall a. a -> ExceptT ClientError IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Http2FrameConnection -> ClientIO Http2FrameConnection)
-> Http2FrameConnection -> ClientIO Http2FrameConnection
forall a b. (a -> b) -> a -> b
$ (StreamId -> Http2FrameClientStream)
-> Http2ServerStream -> ClientIO () -> Http2FrameConnection
Http2FrameConnection StreamId -> Http2FrameClientStream
makeClientStream Http2ServerStream
nextServerFrameChunk ClientIO ()
gtfo

-- | Creates a new 'Http2FrameConnection' to a given host for a frame-to-frame communication.
newHttp2FrameConnection :: HostName
                        -> PortNumber
                        -> Maybe TLS.ClientParams
                        -> ClientIO Http2FrameConnection
newHttp2FrameConnection :: HostName
-> PortNumber
-> Maybe ClientParams
-> ClientIO Http2FrameConnection
newHttp2FrameConnection HostName
host PortNumber
port Maybe ClientParams
params = do
    RawHttp2Connection -> ClientIO Http2FrameConnection
frameHttp2RawConnection (RawHttp2Connection -> ClientIO Http2FrameConnection)
-> ExceptT ClientError IO RawHttp2Connection
-> ClientIO Http2FrameConnection
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HostName
-> PortNumber
-> Maybe ClientParams
-> ExceptT ClientError IO RawHttp2Connection
newRawHttp2Connection HostName
host PortNumber
port Maybe ClientParams
params