{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Servant.HttpStreams.Internal where
import Prelude ()
import Prelude.Compat
import Control.DeepSeq
(NFData, force)
import Control.Exception
(IOException, SomeException (..), catch, evaluate, throwIO)
import Control.Monad
(unless)
import Control.Monad.Base
(MonadBase (..))
import Control.Monad.Codensity
(Codensity (..))
import Control.Monad.Error.Class
(MonadError (..))
import Control.Monad.IO.Class
(MonadIO (..))
import Control.Monad.Reader
(MonadReader, ReaderT, ask, runReaderT)
import Control.Monad.Trans.Class
(lift)
import Control.Monad.Trans.Except
(ExceptT, runExceptT)
import Data.Bifunctor
(bimap, first)
import Data.ByteString.Builder
(toLazyByteString)
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Lazy as BSL
import qualified Data.CaseInsensitive as CI
import Data.Foldable
(for_, toList)
import Data.Functor.Alt
(Alt (..))
import Data.Maybe
(maybeToList)
import Data.Proxy
(Proxy (..))
import Data.Semigroup
((<>))
import Data.Sequence
(fromList)
import Data.String
(fromString)
import GHC.Generics
import Network.HTTP.Media
(renderHeader)
import Network.HTTP.Types
(Status (..), hContentType, http11, renderQuery)
import Servant.Client.Core
import qualified Network.Http.Client as Client
import qualified Network.Http.Types as Client
import qualified Servant.Types.SourceT as S
import qualified System.IO.Streams as Streams
data ClientEnv
= ClientEnv
{ baseUrl :: BaseUrl
, connection :: Client.Connection
}
mkClientEnv :: BaseUrl -> Client.Connection -> ClientEnv
mkClientEnv = ClientEnv
withClientEnvIO :: BaseUrl -> (ClientEnv -> IO r) -> IO r
withClientEnvIO burl k = Client.withConnection open $ \conn ->
k (mkClientEnv burl conn)
where
open = Client.openConnection (fromString $ baseUrlHost burl) (fromIntegral $ baseUrlPort burl)
client :: HasClient ClientM api => Proxy api -> Client ClientM api
client api = api `clientIn` (Proxy :: Proxy ClientM)
hoistClient
:: HasClient ClientM api
=> Proxy api
-> (forall a. m a -> n a)
-> Client m api
-> Client n api
hoistClient = hoistClientMonad (Proxy :: Proxy ClientM)
newtype ClientM a = ClientM
{ unClientM :: ReaderT ClientEnv (ExceptT ClientError (Codensity IO)) a }
deriving ( Functor, Applicative, Monad, MonadIO, Generic
, MonadReader ClientEnv, MonadError ClientError)
instance MonadBase IO ClientM where
liftBase = ClientM . liftIO
instance Alt ClientM where
a <!> b = a `catchError` \_ -> b
instance RunClient ClientM where
runRequest = performRequest
throwClientError = throwError
instance RunStreamingClient ClientM where
withStreamingRequest = performWithStreamingRequest
runClientM :: NFData a => ClientM a -> ClientEnv -> IO (Either ClientError a)
runClientM cm env = withClientM cm env (evaluate . force)
withClientM :: ClientM a -> ClientEnv -> (Either ClientError a -> IO b) -> IO b
withClientM cm env k =
let Codensity f = runExceptT $ flip runReaderT env $ unClientM cm
in f k
performRequest :: Request -> ClientM Response
performRequest req = do
ClientEnv burl conn <- ask
let (req', body) = requestToClientRequest burl req
x <- ClientM $ lift $ lift $ Codensity $ \k -> do
Client.sendRequest conn req' body
Client.receiveResponse conn $ \res' body' -> do
let sc = Client.getStatusCode res'
lbs <- BSL.fromChunks <$> Streams.toList body'
let res'' = clientResponseToResponse res' lbs
if sc >= 200 && sc < 300
then k (Right res'')
else k (Left (mkFailureResponse burl req res''))
either throwError pure x
performWithStreamingRequest :: Request -> (StreamingResponse -> IO a) -> ClientM a
performWithStreamingRequest req k = do
ClientEnv burl conn <- ask
let (req', body) = requestToClientRequest burl req
ClientM $ lift $ lift $ Codensity $ \k1 -> do
Client.sendRequest conn req' body
Client.receiveResponseRaw conn $ \res' body' -> do
let sc = Client.getStatusCode res'
unless (sc >= 200 && sc < 300) $ do
lbs <- BSL.fromChunks <$> Streams.toList body'
throwIO $ mkFailureResponse burl req (clientResponseToResponse res' lbs)
x <- k (clientResponseToResponse res' (fromInputStream body'))
k1 x
mkFailureResponse :: BaseUrl -> Request -> ResponseF BSL.ByteString -> ClientError
mkFailureResponse burl request =
FailureResponse (bimap (const ()) f request)
where
f b = (burl, BSL.toStrict $ toLazyByteString b)
clientResponseToResponse :: Client.Response -> body -> ResponseF body
clientResponseToResponse r body = Response
{ responseStatusCode = Status (Client.getStatusCode r) (Client.getStatusMessage r)
, responseBody = body
, responseHeaders = fromList $ map (first CI.mk) $ Client.retrieveHeaders $ Client.getHeaders r
, responseHttpVersion = http11
}
requestToClientRequest :: BaseUrl -> Request -> (Client.Request, Streams.OutputStream B.Builder -> IO ())
requestToClientRequest burl r = (request, body)
where
request = Client.buildRequest1 $ do
Client.http (Client.Method $ requestMethod r)
$ fromString (baseUrlPath burl)
<> BSL.toStrict (toLazyByteString (requestPath r))
<> renderQuery True (toList (requestQueryString r))
Client.setHostname (fromString $ baseUrlHost burl) (fromIntegral $ baseUrlPort burl)
for_ (maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers) $ \(hn, hv) ->
Client.setHeader (CI.original hn) hv
Client.setTransferEncoding
headers = filter (\(h, _) -> h /= "Accept" && h /= "Content-Type") $
toList $ requestHeaders r
acceptHdr
| null hs = Nothing
| otherwise = Just ("Accept", renderHeader hs)
where
hs = toList $ requestAccept r
convertBody bd os = case bd of
RequestBodyLBS body' ->
Streams.writeTo os (Just (B.lazyByteString body'))
RequestBodyBS body' ->
Streams.writeTo os (Just (B.byteString body'))
RequestBodySource sourceIO ->
toOutputStream sourceIO os
(body, contentTypeHdr) = case requestBody r of
Nothing -> (Client.emptyBody, Nothing)
Just (body', typ) -> (convertBody body', Just (hContentType, renderHeader typ))
catchConnectionError :: IO a -> IO (Either ClientError a)
catchConnectionError action =
catch (Right <$> action) $ \e ->
pure . Left . ConnectionError $ SomeException (e :: IOException)
fromInputStream :: Streams.InputStream b -> S.SourceT IO b
fromInputStream is = S.SourceT $ \k -> k loop where
loop = S.Effect $ maybe S.Stop (flip S.Yield loop) <$> Streams.read is
toOutputStream :: S.SourceT IO BSL.ByteString -> Streams.OutputStream B.Builder -> IO ()
toOutputStream (S.SourceT k) os = k loop where
loop S.Stop = return ()
loop (S.Error err) = fail err
loop (S.Skip s) = loop s
loop (S.Effect mx) = mx >>= loop
loop (S.Yield x s) = Streams.write (Just (B.lazyByteString x)) os >> loop s