{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE OverloadedStrings #-} module Network.HTTP2.Client.Servant ( H2ClientM , runH2ClientM , H2ClientEnv (..) -- * generate functions , h2client ) where import Data.IORef (newIORef, readIORef, writeIORef) import Data.Foldable (traverse_) import Control.Exception (throwIO) import Control.Monad (unless, when, (>=>)) import Control.Monad.Catch (MonadCatch, MonadThrow) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Except (ExceptT, runExceptT) import Control.Monad.Error.Class (MonadError (..)) import Control.Monad.Reader (MonadReader, ReaderT, ask, runReaderT) import Data.Binary.Builder (toLazyByteString) import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as ByteString import Data.ByteString.Lazy (fromStrict, toStrict, toChunks) import Data.Foldable (toList) import Data.Proxy (Proxy(..)) import Data.Sequence (fromList) import qualified Data.Text as Text import GHC.Generics (Generic) import Network.HPACK (HeaderList) import Network.HTTP.Media.RenderHeader (renderHeader) import Network.HTTP2 (flags, setEndStream, testEndStream, payloadLength, toErrorCodeId, ErrorCodeId(RefusedStream)) import Network.HTTP2.Client.Helpers (upload, waitStream, fromStreamResult) import Network.HTTP.Types.Status (Status(..)) import Network.HTTP.Types.Version (http20) import qualified Data.CaseInsensitive as CI import Text.Read (readMaybe) import Network.HTTP2.Client import Servant.Client.Core newtype H2ClientM a = H2ClientM { unH2ClientM :: ReaderT H2ClientEnv (ExceptT ServantError IO) a } deriving ( Functor, Applicative, Monad, MonadIO, Generic , MonadReader H2ClientEnv, MonadError ServantError, MonadThrow , MonadCatch) runH2ClientM :: H2ClientM a -> H2ClientEnv -> IO (Either ServantError a) runH2ClientM cm env = runExceptT $ flip runReaderT env $ unH2ClientM cm instance RunClient H2ClientM where runRequest :: Request -> H2ClientM Response runRequest = performRequest streamingRequest :: Request -> H2ClientM StreamingResponse streamingRequest = performStreamingRequest throwServantError :: ServantError -> H2ClientM a throwServantError = throwError data H2ClientEnv = H2ClientEnv ByteString Http2Client type ByteSegments = (IO ByteString -> IO ()) -> IO () -- | Construct a ByteSegments from a single ByteString. onlySegment :: ByteString -> ByteSegments onlySegment bs handle = handle (pure bs) -- | Construct a ByteSegments from a lazy list of ByteString. -- -- Since we expect the handler passed to ByteSegments to do some IO, empty -- chunks are discarded to save on IOs. multiSegments :: [ByteString] -> ByteSegments multiSegments bss handle = traverse_ (handle . pure) (filter (not . ByteString.null) bss) -- | Pulls data segments from ByteSegments and calls 'upload' on it. sendSegments :: Http2Client -> Http2Stream -> OutgoingFlowControl -- ^ Connection -> OutgoingFlowControl -- ^ Stream -> ByteSegments -> IO () sendSegments http2client stream ocfc osfc segments = segments go where go getChunk = do dat <- getChunk upload dat id http2client ocfc stream osfc -- | Prepare an HTTP2 request to a given server. makeRequest :: ByteString -- ^ Server's Authority. -> Request -- ^ The HTTP request. -> IO (HeaderList, ByteSegments) makeRequest authority req = do let go ct obj = case obj of (RequestBodyBS bs) -> pure $ (onlySegment bs, [ ("Content-Type", renderHeader ct) , ("Content-Length", ByteString.pack $ show $ ByteString.length bs) ]) (RequestBodyLBS lbs) -> pure $ (multiSegments $ toChunks lbs, [ ("Content-Type", renderHeader ct) ]) (RequestBodyBuilder n builder) -> let lbs = toLazyByteString builder in pure $ (multiSegments $ toChunks lbs, [ ("Content-Type", renderHeader ct) , ("Content-Length", ByteString.pack $ show n) ]) (RequestBodyStream n act) -> pure $ (act, [ ("Content-Type", renderHeader ct) , ("Content-Length", ByteString.pack $ show n) ]) (RequestBodyStreamChunked act) -> pure $ (act, [ ("Content-Type", renderHeader ct) , ("Transfer-Encoding", "chunked") ]) (RequestBodyIO again) -> again >>= go ct (bodyIO,bodyheaders) <- case requestBody req of Nothing -> pure (onlySegment "", []) (Just (r,ct)) -> go ct r let headersPairs = baseHeaders <> reqHeaders <> bodyheaders pure (headersPairs, bodyIO) where baseHeaders = [ (":method", requestMethod req) , (":scheme", "https") , (":path", toStrict $ toLazyByteString $ requestPath req) , (":authority", authority) , ("Accept", ByteString.intercalate "," $ toList $ fmap renderHeader $ requestAccept req) , ("User-Agent", "servant-http2-client/dev") ] reqHeaders = [(CI.original h, hv) | (h,hv) <- toList (requestHeaders req)] resetPushPromises :: PushPromiseHandler resetPushPromises _ pps _ _ _ = _rst pps RefusedStream -- | Implementation of simple requests. performRequest :: Request -> H2ClientM Response performRequest req = do H2ClientEnv authority http2client <- ask let icfc = _incomingFlowControl http2client let ocfc = _outgoingFlowControl http2client let headersFlags = id (headersPairs, bodyIO) <- liftIO $ makeRequest authority req http2rsp <- liftIO $ withHttp2Stream http2client $ \stream -> let initStream = headers stream headersPairs headersFlags handler _ osfc = do sendSegments http2client stream ocfc osfc bodyIO sendData http2client stream setEndStream "" streamResult <- waitStream stream icfc resetPushPromises pure $ fromStreamResult streamResult in (StreamDefinition initStream handler) case http2rsp of Left (TooMuchConcurrency _) -> throwError $ Servant.Client.Core.ConnectionError "too many concurrent streams" Right (Right (hdrs,body,_)) | let Just status = lookupStatus hdrs -> do let response = mkResponse status hdrs body unless (status >= 200 && status < 300) $ throwError $ FailureResponse response pure response | otherwise -> do let response = mkResponse 0 hdrs body throwError $ DecodeFailure "no :status header" response Right (Left err) -> throwError $ Servant.Client.Core.ConnectionError $ "connection error: " <> (Text.pack $ show $ toErrorCodeId err) mkResponse :: Int -> HeaderList -> ByteString -> Response mkResponse status hdrs body = Response (Status status "") (fromList [ (CI.mk h, hv) | (h,hv) <- hdrs ]) http20 (fromStrict body) lookupStatus :: HeaderList -> Maybe Int lookupStatus = lookup ":status" >=> readMaybe . ByteString.unpack replenishFlowControls :: IncomingFlowControl -> IncomingFlowControl -> Int -> IO () replenishFlowControls icfc isfc len = do _ <- _consumeCredit isfc len _addCredit isfc len _ <- _updateWindow isfc _ <- _consumeCredit icfc len _addCredit icfc len _ <- _updateWindow icfc pure () -- | Implementation of requests with streaming replies. performStreamingRequest :: Request -> H2ClientM StreamingResponse performStreamingRequest req = do H2ClientEnv authority http2client <- ask let icfc = _incomingFlowControl http2client let ocfc = _outgoingFlowControl http2client let headersFlags = id (headersPairs, bodyIO) <- liftIO $ makeRequest authority req ret <- liftIO $ withHttp2Stream http2client $ \stream -> let initStream = headers stream headersPairs headersFlags handler isfc osfc = do -- Send the request sendSegments http2client stream ocfc osfc bodyIO sendData http2client stream setEndStream "" -- Waits for headers and returns the response object to the -- caller. pure $ StreamingResponse (\handleGenResponse -> do ev <- _waitEvent stream case ev of (StreamHeadersEvent fh hdrs) -> handleHeaders stream icfc isfc fh hdrs handleGenResponse _ -> throwIO $ Servant.Client.Core.ConnectionError $ "unwanted event received in data stream" <> Text.pack (show ev) ) in (StreamDefinition initStream handler) case ret of Right streamingResp -> pure streamingResp Left (TooMuchConcurrency _) -> throwError $ Servant.Client.Core.ConnectionError "too many concurrent streams" where handleHeaders stream icfc isfc fh hdrs handleGenResponse | let Just status = lookupStatus hdrs = do isFinished <- newIORef False when (testEndStream $ flags fh) $ writeIORef isFinished True let response = mkStreamResponse status hdrs isFinished stream icfc isfc unless (status >= 200 && status < 300) $ do wholeBody <- consumeBody stream icfc isfc let failResponse = mkResponse status hdrs wholeBody throwIO $ FailureResponse failResponse handleGenResponse response | otherwise = do wholeBody <- consumeBody stream icfc isfc let response = mkResponse 0 hdrs wholeBody throwIO $ DecodeFailure "no :status header" response -- | Helper to consume the whole response body when the status is not a 2xx. -- This consumption can itself fail. consumeBody stream icfc isfc = do (revBss,_) <- waitDataFrames stream icfc isfc [] let bs = mconcat $ reverse revBss pure bs -- | Helper to iteratively eat all data frames. waitDataFrames stream icfc isfc xs = do ev <- _waitEvent stream case ev of StreamDataEvent fh x | testEndStream (flags fh) -> return (x:xs, Nothing) | otherwise -> do replenishFlowControls icfc isfc (payloadLength fh) waitDataFrames stream icfc isfc (x:xs) StreamPushPromiseEvent _ ppSid ppHdrs -> do _handlePushPromise stream ppSid ppHdrs resetPushPromises waitDataFrames stream icfc isfc xs StreamHeadersEvent _ hdrs -> return (xs, Just hdrs) _ -> throwIO $ Servant.Client.Core.ConnectionError $ "unwanted event received in data stream" <> Text.pack (show ev) -- | Function to get the next DATA chunk on a stream this function is -- stateful and modifies an IORef to remember that the stream is closed on -- the server side. -- Do not copy-paste this utility method unless you understand that the -- IORef must be entirely owned by this function. nextChunk isFinished stream icfc isfc = do done <- readIORef isFinished if done then pure "" else do ev <- _waitEvent stream case ev of (StreamDataEvent fh dat) -> do replenishFlowControls icfc isfc (payloadLength fh) when (testEndStream $ flags fh) $ writeIORef isFinished True pure dat _ -> throwIO $ Servant.Client.Core.ConnectionError $ "unwanted event received in data stream" <> Text.pack (show ev) mkStreamResponse status hdrs isFinished stream icfc isfc = Response (Status status "") (fromList [ (CI.mk h, hv) | (h,hv) <- hdrs ]) http20 (nextChunk isFinished stream icfc isfc) h2client :: HasClient H2ClientM api => Proxy api -> Client H2ClientM api h2client api = api `clientIn` (Proxy :: Proxy H2ClientM)