{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Network.GRPC.Client (
RPC(..)
, Authority
, Timeout(..)
, open
, RawReply
, singleRequest
, streamReply
, streamRequest
, steppedBiDiStream
, generalHandler
, CompressMode(..)
, StreamDone(..)
, BiDiStep(..)
, RunBiDiStep
, HandleMessageStep
, HandleTrailersStep
, IncomingEvent(..)
, OutgoingEvent(..)
, InvalidState(..)
, StreamReplyDecodingError(..)
, UnallowedPushPromiseReceived(..)
, InvalidParse(..)
, Compression
, gzip
, uncompressed
, HeaderList
) where
import Control.Concurrent.Async (concurrently)
import Control.Exception (SomeException, Exception(..), throwIO)
import Data.ByteString.Char8 (unpack)
import Data.ByteString.Lazy (toStrict)
import Data.Binary.Builder (toLazyByteString)
import Data.Binary.Get (Decoder(..), pushChunk, pushEndOfInput)
import qualified Data.ByteString.Char8 as ByteString
import Data.ByteString.Char8 (ByteString)
import Data.CaseInsensitive (CI)
import qualified Data.CaseInsensitive as CI
import Data.Monoid ((<>))
import Data.ProtoLens.Service.Types (Service(..), HasMethod, HasMethodImpl(..), StreamingType(..))
import GHC.TypeLits (Symbol)
import Network.GRPC.HTTP2.Types
import Network.GRPC.HTTP2.Encoding
import Network.HTTP2
import Network.HPACK
import Network.HTTP2.Client hiding (next)
import Network.HTTP2.Client.Helpers
type CIHeaderList = [(CI ByteString, ByteString)]
type RawReply a = Either ErrorCode (CIHeaderList, Maybe CIHeaderList, (Either String a))
data UnallowedPushPromiseReceived = UnallowedPushPromiseReceived deriving Show
instance Exception UnallowedPushPromiseReceived where
throwOnPushPromise :: PushPromiseHandler
throwOnPushPromise _ _ _ _ _ = throwIO UnallowedPushPromiseReceived
waitReply :: (Service s, HasMethod s m) => RPC s m -> Decoding -> Http2Stream -> IncomingFlowControl -> IO (RawReply (MethodOutput s m))
waitReply rpc decoding stream flowControl = do
format . fromStreamResult <$> waitStream stream flowControl throwOnPushPromise
where
decompress = _getDecodingCompression decoding
format rsp = do
(hdrs, dat, trls) <- rsp
let hdrs2 = headerstoCIHeaders hdrs
let trls2 = fmap headerstoCIHeaders trls
let res =
case lookup grpcMessageH hdrs2 of
Nothing -> fromDecoder $ pushEndOfInput $ flip pushChunk dat $ decodeOutput rpc decompress
Just errMsg -> Left $ unpack errMsg
return (hdrs2, trls2, res)
headerstoCIHeaders :: HeaderList -> CIHeaderList
headerstoCIHeaders hdrs = [(CI.mk k, v) | (k,v) <- hdrs]
data StreamReplyDecodingError = StreamReplyDecodingError String deriving Show
instance Exception StreamReplyDecodingError where
data InvalidState = InvalidState String
deriving Show
instance Exception InvalidState where
newtype RPCCall s (m ::Symbol) a = RPCCall {
runRPC :: Http2Client -> Http2Stream -> IncomingFlowControl -> OutgoingFlowControl -> Encoding -> Decoding -> IO a
}
rpcFromCall :: RPCCall s m a -> RPC s m
rpcFromCall _ = RPC
open :: (Service s, HasMethod s m)
=> Http2Client
-> Authority
-> HeaderList
-> Timeout
-> Encoding
-> Decoding
-> RPCCall s m a
-> IO (Either TooMuchConcurrency a)
open conn authority extraheaders timeout encoding decoding call = do
let rpc = rpcFromCall call
let compress = _getEncodingCompression encoding
let decompress = _getDecodingCompression decoding
let request = [ (":method", "POST")
, (":scheme", "http")
, (":authority", authority)
, (":path", path rpc)
, (CI.original grpcTimeoutH, showTimeout timeout)
, (CI.original grpcEncodingH, grpcCompressionHV compress)
, (CI.original grpcAcceptEncodingH, mconcat [grpcAcceptEncodingHVdefault, ",", grpcCompressionHV decompress])
, ("content-type", grpcContentTypeHV)
, ("te", "trailers")
] <> extraheaders
withHttp2Stream conn $ \stream ->
let
initStream = headers stream request (setEndHeader)
handler isfc osfc = do
(runRPC call) conn stream isfc osfc encoding decoding
in StreamDefinition initStream handler
streamReply
:: (Service s, HasMethod s m, MethodStreamingType s m ~ 'ServerStreaming)
=> RPC s m
-> a
-> MethodInput s m
-> (a -> HeaderList -> MethodOutput s m -> IO a)
-> RPCCall s m (a, HeaderList, HeaderList)
streamReply rpc v0 req handler = RPCCall $ \conn stream isfc osfc encoding decoding -> do
let {
loop v1 decode hdrs = _waitEvent stream >>= \case
(StreamPushPromiseEvent _ _ _) ->
throwIO UnallowedPushPromiseReceived
(StreamHeadersEvent _ trls) ->
return (v1, hdrs, trls)
(StreamErrorEvent _ _) ->
throwIO (InvalidState "stream error")
(StreamDataEvent _ dat) -> do
_addCredit isfc (ByteString.length dat)
_ <- _consumeCredit isfc (ByteString.length dat)
_ <- _updateWindow isfc
handleAllChunks decoding v1 hdrs decode dat loop
} in do
let ocfc = _outgoingFlowControl conn
let decompress = _getDecodingCompression decoding
sendSingleMessage rpc req encoding setEndStream conn ocfc stream osfc
_waitEvent stream >>= \case
StreamHeadersEvent _ hdrs ->
loop v0 (decodeOutput rpc decompress) hdrs
_ ->
throwIO (InvalidState "no headers")
where
handleAllChunks decoding v1 hdrs decode dat exitLoop =
case pushChunk decode dat of
(Done unusedDat _ (Right val)) -> do
v2 <- handler v1 hdrs val
let decompress = _getDecodingCompression decoding
handleAllChunks decoding v2 hdrs (decodeOutput rpc decompress) unusedDat exitLoop
(Done _ _ (Left err)) -> do
throwIO (StreamReplyDecodingError $ "done-error: " ++ err)
(Fail _ _ err) -> do
throwIO (StreamReplyDecodingError $ "fail-error: " ++ err)
partial@(Partial _) ->
exitLoop v1 partial hdrs
data StreamDone = StreamDone
data CompressMode = Compressed | Uncompressed
streamRequest
:: (Service s, HasMethod s m, MethodStreamingType s m ~ 'ClientStreaming)
=> RPC s m
-> a
-> (a -> IO (a, Either StreamDone (CompressMode, MethodInput s m)))
-> RPCCall s m (a, RawReply (MethodOutput s m))
streamRequest rpc v0 handler = RPCCall $ \conn stream isfc streamFlowControl encoding decoding ->
let ocfc = _outgoingFlowControl conn
go v1 = do
(v2, nextEvent) <- handler v1
case nextEvent of
Right (doCompress, msg) -> do
let compress = case doCompress of
Compressed -> _getEncodingCompression encoding
Uncompressed -> uncompressed
sendSingleMessage rpc msg (Encoding compress) id conn ocfc stream streamFlowControl
go v2
Left _ -> do
sendData conn stream setEndStream ""
reply <- waitReply rpc decoding stream isfc
pure (v2, reply)
in go v0
sendSingleMessage
:: (Service s, HasMethod s m)
=> RPC s m
-> MethodInput s m
-> Encoding
-> FlagSetter
-> Http2Client
-> OutgoingFlowControl
-> Http2Stream
-> OutgoingFlowControl
-> IO ()
sendSingleMessage rpc msg encoding flagMod conn connectionFlowControl stream streamFlowControl = do
let compress = _getEncodingCompression encoding
let goUpload dat = do
let !wanted = ByteString.length dat
gotStream <- _withdrawCredit streamFlowControl wanted
got <- _withdrawCredit connectionFlowControl gotStream
_receiveCredit streamFlowControl (gotStream - got)
if got == wanted
then
sendData conn stream flagMod dat
else do
sendData conn stream id (ByteString.take got dat)
goUpload (ByteString.drop got dat)
goUpload . toStrict . toLazyByteString . encodeInput rpc compress $ msg
singleRequest
:: (Service s, HasMethod s m)
=> RPC s m
-> MethodInput s m
-> RPCCall s m (RawReply (MethodOutput s m))
singleRequest rpc msg = RPCCall $ \conn stream isfc osfc encoding decoding -> do
let ocfc = _outgoingFlowControl conn
sendSingleMessage rpc msg encoding setEndStream conn ocfc stream osfc
waitReply rpc decoding stream isfc
type HandleMessageStep s m a = HeaderList -> a -> MethodOutput s m -> IO a
type HandleTrailersStep a = HeaderList -> a -> HeaderList -> IO a
data BiDiStep s m a =
Abort
| SendInput !CompressMode !(MethodInput s m)
| WaitOutput (HandleMessageStep s m a) (HandleTrailersStep a)
type RunBiDiStep s m a = a -> IO (a, BiDiStep s m a)
steppedBiDiStream
:: (Service s, HasMethod s m, MethodStreamingType s m ~ 'BiDiStreaming)
=> RPC s m
-> a
-> RunBiDiStep s m a
-> RPCCall s m a
steppedBiDiStream rpc v0 handler = RPCCall $ \conn stream isfc streamFlowControl encoding decoding ->
let ocfc = _outgoingFlowControl conn
decompress = _getDecodingCompression decoding
newDecoder = decodeOutput rpc decompress
goStep _ _ (v1, Abort) = do
sendData conn stream setEndStream ""
pure v1
goStep hdrs decode (v1, SendInput doCompress msg) = do
let compress = case doCompress of
Compressed -> _getEncodingCompression encoding
Uncompressed -> uncompressed
sendSingleMessage rpc msg (Encoding compress) id conn ocfc stream streamFlowControl
handler v1 >>= goStep hdrs decode
goStep jhdrs@(Just hdrs) decode unchanged@(v1, WaitOutput handleMsg handleEof) = do
_waitEvent stream >>= \case
(StreamPushPromiseEvent _ _ _) ->
throwIO UnallowedPushPromiseReceived
(StreamHeadersEvent _ trls) -> do
v2 <- handleEof hdrs v1 trls
handler v2 >>= goStep jhdrs newDecoder
(StreamErrorEvent _ _) ->
throwIO (InvalidState "stream error")
(StreamDataEvent _ dat) -> do
_addCredit isfc (ByteString.length dat)
_ <- _consumeCredit isfc (ByteString.length dat)
_ <- _updateWindow isfc
case pushChunk decode dat of
(Done unusedDat _ (Right val)) -> do
v2 <- handleMsg hdrs v1 val
handler v2 >>= goStep jhdrs (pushChunk newDecoder unusedDat)
(Done _ _ (Left err)) -> do
throwIO $ InvalidParse $ "done-err: " ++ err
(Fail _ _ err) ->
throwIO $ InvalidParse $ "done-fail: " ++ err
partial@(Partial _) -> do
goStep jhdrs partial unchanged
goStep Nothing decode unchanged = do
_waitEvent stream >>= \case
(StreamHeadersEvent _ hdrs) ->
goStep (Just hdrs) decode unchanged
_ ->
throwIO (InvalidState "no headers")
in handler v0 >>= goStep Nothing newDecoder
data InvalidParse = InvalidParse String deriving Show
instance Exception InvalidParse where
data IncomingEvent s m a =
Headers HeaderList
| RecvMessage (MethodOutput s m)
| Trailers HeaderList
| Invalid SomeException
data OutgoingEvent s m b =
Finalize
| SendMessage CompressMode (MethodInput s m)
generalHandler
:: (Service s, HasMethod s m)
=> RPC s m
-> a
-> (a -> IncomingEvent s m a -> IO a)
-> b
-> (b -> IO (b, OutgoingEvent s m b))
-> RPCCall s m (a,b)
generalHandler rpc v0 handle w0 next = RPCCall $ \conn stream isfc osfc encoding decoding ->
go conn stream isfc osfc encoding decoding
where
go conn stream isfc osfc encoding decoding =
concurrently (incomingLoop Nothing newDecoder v0) (outGoingLoop w0)
where
ocfc = _outgoingFlowControl conn
newDecoder = decodeOutput rpc decompress
decompress = _getDecodingCompression decoding
outGoingLoop v1 = do
(v2, event) <- next v1
case event of
Finalize -> do
sendData conn stream setEndStream ""
return v2
SendMessage doCompress msg -> do
let compress = case doCompress of
Compressed -> _getEncodingCompression encoding
Uncompressed -> uncompressed
sendSingleMessage rpc msg (Encoding compress) id conn ocfc stream osfc
outGoingLoop v2
incomingLoop Nothing decode v1 =
_waitEvent stream >>= \case
(StreamHeadersEvent _ hdrs) ->
handle v1 (Headers hdrs) >>= incomingLoop (Just hdrs) decode
_ ->
handle v1 (Invalid $ toException $ InvalidState "no headers")
incomingLoop jhdrs decode v1 =
_waitEvent stream >>= \case
(StreamHeadersEvent _ hdrs) ->
handle v1 (Trailers hdrs)
(StreamDataEvent _ dat) -> do
_addCredit isfc (ByteString.length dat)
_ <- _consumeCredit isfc (ByteString.length dat)
_ <- _updateWindow isfc
case pushChunk decode dat of
(Done unusedDat _ (Right val)) ->
handle v1 (RecvMessage val) >>= incomingLoop jhdrs (pushChunk newDecoder unusedDat)
partial@(Partial _) -> do
incomingLoop jhdrs partial v1
(Done _ _ (Left err)) -> do
handle v1 (Invalid $ toException $ InvalidParse $ "invalid-done-parse: " ++ err)
(Fail _ _ err) ->
handle v1 (Invalid $ toException $ InvalidParse $ "invalid-parse: " ++ err)
(StreamPushPromiseEvent _ _ _) ->
handle v1 (Invalid $ toException UnallowedPushPromiseReceived)
(StreamErrorEvent _ _) ->
handle v1 (Invalid $ toException $ InvalidState "stream error")