{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Network.Wai.Middleware.Servant.Errors
(
errorMw
, errorMwDefJson
, HasErrorBody (..)
, ErrorMsg (..)
, StatusCode (..)
, ErrorLabels (..)
, getErrorLabels
)where
import Data.Aeson (Value (..), encode)
import qualified Data.ByteString as B
import Data.ByteString.Builder (toLazyByteString)
import qualified Data.ByteString.Lazy as LB
import qualified Data.HashMap.Strict as H
import Data.IORef (modifyIORef', newIORef, readIORef)
import Data.Kind (Type)
import Data.List (find)
import Data.Monoid ((<>))
import Data.Proxy (Proxy (..))
import Data.Scientific (Scientific)
import Data.String.Conversions (cs)
import qualified Data.Text as T
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)
import qualified Network.HTTP.Media as M
import Network.HTTP.Types (Header, Status (..), hContentType)
import Network.Wai (Response, Middleware, responseHeaders, responseLBS, responseStatus,
responseToStream)
import Servant.API.ContentTypes (Accept (..), JSON, PlainText)
newtype StatusCode = StatusCode { unStatusCode :: Int }
deriving (Eq, Ord, Show)
newtype ErrorMsg = ErrorMsg { unErrorMsg :: T.Text }
deriving Show
data ErrorLabels = ErrorLabels
{ errName :: T.Text
, errStatusName :: T.Text
}
class Accept ctyp => HasErrorBody (ctyp :: Type) (opts :: [Symbol]) where
encodeError :: StatusCode -> ErrorMsg -> LB.ByteString
instance (KnownSymbol errLabel, KnownSymbol statusLabel)
=> HasErrorBody JSON '[errLabel, statusLabel] where
encodeError = encodeAsJsonError (getErrorLabels @errLabel @statusLabel)
instance HasErrorBody JSON '[] where
encodeError = encodeError @JSON @["error", "status"]
instance (KnownSymbol errLabel, KnownSymbol statusLabel)
=> HasErrorBody PlainText '[errLabel, statusLabel] where
encodeError = encodeAsPlainText (getErrorLabels @errLabel @statusLabel)
instance HasErrorBody PlainText '[] where
encodeError = encodeError @JSON @["error", "status"]
errorMwDefJson :: Middleware
errorMwDefJson = errorMw @JSON @'[]
errorMw :: forall ctyp opts. HasErrorBody ctyp opts => Middleware
errorMw baseApp req respond =
baseApp req $ \ response -> do
let status = responseStatus response
mcontentType = getContentTypeHeader response
processResponse = newResponse @ctyp @opts status response >>= respond
case (status, mcontentType) of
(Status 200 _, _) -> respond response
(Status code _, Nothing) | code > 200 -> processResponse
_ -> respond response
where
getContentTypeHeader :: Response -> Maybe Header
getContentTypeHeader = find ((hContentType ==) . fst) . responseHeaders
newResponse
:: forall ctyp opts . HasErrorBody ctyp opts
=> Status
-> Response
-> IO Response
newResponse status@(Status code statusMsg) response = do
body <- responseBody response
let header = (hContentType, M.renderHeader $ contentType (Proxy @JSON) )
content = ErrorMsg . cs $ if body == mempty then statusMsg else body
newContent = encodeError @ctyp @opts (StatusCode code) content
return $ responseLBS status [header] newContent
responseBody :: Response -> IO B.ByteString
responseBody res =
let (_status, _headers, streamBody) = responseToStream res in
streamBody $ \f -> do
content <- newIORef mempty
f (\chunk -> modifyIORef' content (<> chunk)) (return ())
cs . toLazyByteString <$> readIORef content
encodeAsJsonError :: ErrorLabels -> StatusCode -> ErrorMsg -> LB.ByteString
encodeAsJsonError ErrorLabels {..} code content =
encode $ Object
$ H.fromList
[ (errName, String $ unErrorMsg content)
, (errStatusName, Number $ toScientific code )
]
where
toScientific :: StatusCode -> Scientific
toScientific = fromInteger . fromIntegral @_ @Integer . unStatusCode
encodeAsPlainText :: ErrorLabels -> StatusCode -> ErrorMsg -> LB.ByteString
encodeAsPlainText ErrorLabels {..} code content =
cs $ errName
<> unErrorMsg content
<> errStatusName
<> cs (show $ unStatusCode code)
getErrorLabels
:: forall errLabel statusLabel .(KnownSymbol errLabel, KnownSymbol statusLabel)
=> ErrorLabels
getErrorLabels = ErrorLabels (label (Proxy @errLabel)) (label (Proxy @statusLabel))
where
label :: KnownSymbol t => Proxy t -> T.Text
label proxy = cs $ symbolVal proxy