{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Servant.Client.Internal.HttpClient where
import Prelude ()
import Prelude.Compat
import Control.Concurrent.MVar
(modifyMVar, newMVar)
import qualified Data.ByteString as BS
import Control.Concurrent.STM.TVar
import Control.Exception
import Control.Monad
import Control.Monad.Base
(MonadBase (..))
import Control.Monad.Catch
(MonadCatch, MonadThrow)
import Control.Monad.Error.Class
(MonadError (..))
import Control.Monad.IO.Class
(liftIO)
import Control.Monad.Reader
import Control.Monad.STM
(STM, atomically)
import Control.Monad.Trans.Control
(MonadBaseControl (..))
import Control.Monad.Trans.Except
import Data.Bifunctor
(bimap)
import Data.ByteString.Builder
(toLazyByteString)
import qualified Data.ByteString.Lazy as BSL
import Data.Either
(either)
import Data.Foldable
(toList)
import Data.Functor.Alt
(Alt (..))
import Data.Maybe
(maybe, maybeToList)
import Data.Proxy
(Proxy (..))
import Data.Semigroup
((<>))
import Data.Sequence
(fromList)
import Data.String
(fromString)
import Data.Time.Clock
(UTCTime, getCurrentTime)
import GHC.Generics
import Network.HTTP.Media
(renderHeader)
import Network.HTTP.Types
(hContentType, renderQuery, statusCode)
import Servant.Client.Core
import qualified Servant.Types.SourceT as S
import qualified Network.HTTP.Client as Client
data ClientEnv
= ClientEnv
{ manager :: Client.Manager
, baseUrl :: BaseUrl
, cookieJar :: Maybe (TVar Client.CookieJar)
}
mkClientEnv :: Client.Manager -> BaseUrl -> ClientEnv
mkClientEnv mgr burl = ClientEnv mgr burl Nothing
client :: HasClient ClientM api => Proxy api -> Client ClientM api
client api = api `clientIn` (Proxy :: Proxy ClientM)
hoistClient
:: HasClient ClientM api
=> Proxy api
-> (forall a. m a -> n a)
-> Client m api
-> Client n api
hoistClient = hoistClientMonad (Proxy :: Proxy ClientM)
newtype ClientM a = ClientM
{ unClientM :: ReaderT ClientEnv (ExceptT ClientError IO) a }
deriving ( Functor, Applicative, Monad, MonadIO, Generic
, MonadReader ClientEnv, MonadError ClientError, MonadThrow
, MonadCatch)
instance MonadBase IO ClientM where
liftBase = ClientM . liftBase
instance MonadBaseControl IO ClientM where
type StM ClientM a = Either ClientError a
liftBaseWith f = ClientM (liftBaseWith (\g -> f (g . unClientM)))
restoreM st = ClientM (restoreM st)
instance Alt ClientM where
a <!> b = a `catchError` \_ -> b
instance RunClient ClientM where
runRequest = performRequest
throwClientError = throwError
runClientM :: ClientM a -> ClientEnv -> IO (Either ClientError a)
runClientM cm env = runExceptT $ flip runReaderT env $ unClientM cm
performRequest :: Request -> ClientM Response
performRequest req = do
ClientEnv m burl cookieJar' <- ask
let clientRequest = requestToClientRequest burl req
request <- case cookieJar' of
Nothing -> pure clientRequest
Just cj -> liftIO $ do
now <- getCurrentTime
atomically $ do
oldCookieJar <- readTVar cj
let (newRequest, newCookieJar) =
Client.insertCookiesIntoRequest
(requestToClientRequest burl req)
oldCookieJar
now
writeTVar cj newCookieJar
pure newRequest
response <- maybe (requestWithoutCookieJar m request) (requestWithCookieJar m request) cookieJar'
let status = Client.responseStatus response
status_code = statusCode status
ourResponse = clientResponseToResponse id response
unless (status_code >= 200 && status_code < 300) $
throwError $ mkFailureResponse burl req ourResponse
return ourResponse
where
requestWithoutCookieJar :: Client.Manager -> Client.Request -> ClientM (Client.Response BSL.ByteString)
requestWithoutCookieJar m' request' = do
eResponse <- liftIO . catchConnectionError $ Client.httpLbs request' m'
either throwError return eResponse
requestWithCookieJar :: Client.Manager -> Client.Request -> TVar Client.CookieJar -> ClientM (Client.Response BSL.ByteString)
requestWithCookieJar m' request' cj = do
eResponse <- liftIO . catchConnectionError . Client.withResponseHistory request' m' $ updateWithResponseCookies cj
either throwError return eResponse
updateWithResponseCookies :: TVar Client.CookieJar -> Client.HistoriedResponse Client.BodyReader -> IO (Client.Response BSL.ByteString)
updateWithResponseCookies cj responses = do
now <- getCurrentTime
bss <- Client.brConsume $ Client.responseBody fRes
let fRes' = fRes { Client.responseBody = BSL.fromChunks bss }
allResponses = Client.hrRedirects responses <> [(fReq, fRes')]
atomically $ mapM_ (updateCookieJar now) allResponses
return fRes'
where
updateCookieJar :: UTCTime -> (Client.Request, Client.Response BSL.ByteString) -> STM ()
updateCookieJar now' (req', res') = modifyTVar' cj (fst . Client.updateCookieJar res' req' now')
fReq = Client.hrFinalRequest responses
fRes = Client.hrFinalResponse responses
mkFailureResponse :: BaseUrl -> Request -> ResponseF BSL.ByteString -> ClientError
mkFailureResponse burl request =
FailureResponse (bimap (const ()) f request)
where
f b = (burl, BSL.toStrict $ toLazyByteString b)
clientResponseToResponse :: (a -> b) -> Client.Response a -> ResponseF b
clientResponseToResponse f r = Response
{ responseStatusCode = Client.responseStatus r
, responseBody = f (Client.responseBody r)
, responseHeaders = fromList $ Client.responseHeaders r
, responseHttpVersion = Client.responseVersion r
}
requestToClientRequest :: BaseUrl -> Request -> Client.Request
requestToClientRequest burl r = Client.defaultRequest
{ Client.method = requestMethod r
, Client.host = fromString $ baseUrlHost burl
, Client.port = baseUrlPort burl
, Client.path = BSL.toStrict
$ fromString (baseUrlPath burl)
<> toLazyByteString (requestPath r)
, Client.queryString = renderQuery True . toList $ requestQueryString r
, Client.requestHeaders =
maybeToList acceptHdr ++ maybeToList contentTypeHdr ++ headers
, Client.requestBody = body
, Client.secure = isSecure
}
where
headers = filter (\(h, _) -> h /= "Accept" && h /= "Content-Type") $
toList $requestHeaders r
acceptHdr
| null hs = Nothing
| otherwise = Just ("Accept", renderHeader hs)
where
hs = toList $ requestAccept r
convertBody bd = case bd of
RequestBodyLBS body' -> Client.RequestBodyLBS body'
RequestBodyBS body' -> Client.RequestBodyBS body'
RequestBodySource sourceIO -> Client.RequestBodyStreamChunked givesPopper
where
givesPopper :: (IO BS.ByteString -> IO ()) -> IO ()
givesPopper needsPopper = S.unSourceT sourceIO $ \step0 -> do
ref <- newMVar step0
let popper :: IO BS.ByteString
popper = modifyMVar ref nextBs
needsPopper popper
nextBs S.Stop = return (S.Stop, BS.empty)
nextBs (S.Error err) = fail err
nextBs (S.Skip s) = nextBs s
nextBs (S.Effect ms) = ms >>= nextBs
nextBs (S.Yield lbs s) = case BSL.toChunks lbs of
[] -> nextBs s
(x:xs) | BS.null x -> nextBs step'
| otherwise -> return (step', x)
where
step' = S.Yield (BSL.fromChunks xs) s
(body, contentTypeHdr) = case requestBody r of
Nothing -> (Client.RequestBodyBS "", Nothing)
Just (body', typ) -> (convertBody body', Just (hContentType, renderHeader typ))
isSecure = case baseUrlScheme burl of
Http -> False
Https -> True
catchConnectionError :: IO a -> IO (Either ClientError a)
catchConnectionError action =
catch (Right <$> action) $ \e ->
pure . Left . ConnectionError $ SomeException (e :: Client.HttpException)