{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Network.Wai.Middleware.Servant.Errors
(
errorMw
, errorMwDefJson
, HasErrorBody (..)
, ErrorMsg (..)
, StatusCode (..)
, ErrorLabels (..)
, getErrorLabels
, encodeAsJsonError
, encodeAsPlainText
)where
import Prelude.Compat
import Data.Aeson (Value (..), encode)
import qualified Data.ByteString as B
import Data.ByteString.Builder (toLazyByteString)
import qualified Data.ByteString.Lazy as LB
import qualified Data.HashMap.Strict as H
import Data.IORef (modifyIORef', newIORef, readIORef)
import Data.Kind (Type)
import Data.List (find)
import Data.Proxy (Proxy (..))
import Data.Scientific (Scientific)
import Data.String.Conversions (cs)
import qualified Data.Text as T
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)
import qualified Network.HTTP.Media as M
import Network.HTTP.Types (Header, Status (..), hContentType)
import Network.Wai (Middleware, Response, responseHeaders, responseLBS, responseStatus,
responseToStream)
import Servant.API.ContentTypes (Accept (..), JSON, PlainText)
newtype StatusCode = StatusCode { StatusCode -> Int
unStatusCode :: Int }
deriving (StatusCode -> StatusCode -> Bool
(StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool) -> Eq StatusCode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StatusCode -> StatusCode -> Bool
$c/= :: StatusCode -> StatusCode -> Bool
== :: StatusCode -> StatusCode -> Bool
$c== :: StatusCode -> StatusCode -> Bool
Eq, Eq StatusCode
Eq StatusCode =>
(StatusCode -> StatusCode -> Ordering)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> Bool)
-> (StatusCode -> StatusCode -> StatusCode)
-> (StatusCode -> StatusCode -> StatusCode)
-> Ord StatusCode
StatusCode -> StatusCode -> Bool
StatusCode -> StatusCode -> Ordering
StatusCode -> StatusCode -> StatusCode
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StatusCode -> StatusCode -> StatusCode
$cmin :: StatusCode -> StatusCode -> StatusCode
max :: StatusCode -> StatusCode -> StatusCode
$cmax :: StatusCode -> StatusCode -> StatusCode
>= :: StatusCode -> StatusCode -> Bool
$c>= :: StatusCode -> StatusCode -> Bool
> :: StatusCode -> StatusCode -> Bool
$c> :: StatusCode -> StatusCode -> Bool
<= :: StatusCode -> StatusCode -> Bool
$c<= :: StatusCode -> StatusCode -> Bool
< :: StatusCode -> StatusCode -> Bool
$c< :: StatusCode -> StatusCode -> Bool
compare :: StatusCode -> StatusCode -> Ordering
$ccompare :: StatusCode -> StatusCode -> Ordering
$cp1Ord :: Eq StatusCode
Ord, Int -> StatusCode -> ShowS
[StatusCode] -> ShowS
StatusCode -> String
(Int -> StatusCode -> ShowS)
-> (StatusCode -> String)
-> ([StatusCode] -> ShowS)
-> Show StatusCode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StatusCode] -> ShowS
$cshowList :: [StatusCode] -> ShowS
show :: StatusCode -> String
$cshow :: StatusCode -> String
showsPrec :: Int -> StatusCode -> ShowS
$cshowsPrec :: Int -> StatusCode -> ShowS
Show)
newtype ErrorMsg = ErrorMsg { ErrorMsg -> Text
unErrorMsg :: T.Text }
deriving Int -> ErrorMsg -> ShowS
[ErrorMsg] -> ShowS
ErrorMsg -> String
(Int -> ErrorMsg -> ShowS)
-> (ErrorMsg -> String) -> ([ErrorMsg] -> ShowS) -> Show ErrorMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ErrorMsg] -> ShowS
$cshowList :: [ErrorMsg] -> ShowS
show :: ErrorMsg -> String
$cshow :: ErrorMsg -> String
showsPrec :: Int -> ErrorMsg -> ShowS
$cshowsPrec :: Int -> ErrorMsg -> ShowS
Show
data ErrorLabels = ErrorLabels
{ ErrorLabels -> Text
errName :: T.Text
, ErrorLabels -> Text
errStatusName :: T.Text
}
class Accept ctyp => HasErrorBody (ctyp :: Type) (opts :: [Symbol]) where
encodeError :: StatusCode -> ErrorMsg -> LB.ByteString
instance (KnownSymbol errLabel, KnownSymbol statusLabel)
=> HasErrorBody JSON '[errLabel, statusLabel] where
encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsJsonError ((KnownSymbol errLabel, KnownSymbol statusLabel) => ErrorLabels
forall (errLabel :: Symbol) (statusLabel :: Symbol).
(KnownSymbol errLabel, KnownSymbol statusLabel) =>
ErrorLabels
getErrorLabels @errLabel @statusLabel)
instance HasErrorBody JSON '[] where
encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = HasErrorBody JSON '["error", "status"] =>
StatusCode -> ErrorMsg -> ByteString
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
StatusCode -> ErrorMsg -> ByteString
encodeError @JSON @["error", "status"]
instance (KnownSymbol errLabel, KnownSymbol statusLabel)
=> HasErrorBody PlainText '[errLabel, statusLabel] where
encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsPlainText ((KnownSymbol errLabel, KnownSymbol statusLabel) => ErrorLabels
forall (errLabel :: Symbol) (statusLabel :: Symbol).
(KnownSymbol errLabel, KnownSymbol statusLabel) =>
ErrorLabels
getErrorLabels @errLabel @statusLabel)
instance HasErrorBody PlainText '[] where
encodeError :: StatusCode -> ErrorMsg -> ByteString
encodeError = HasErrorBody PlainText '["error", "status"] =>
StatusCode -> ErrorMsg -> ByteString
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
StatusCode -> ErrorMsg -> ByteString
encodeError @PlainText @["error", "status"]
errorMwDefJson :: Middleware
errorMwDefJson :: Middleware
errorMwDefJson = HasErrorBody JSON '[] => Middleware
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
Middleware
errorMw @JSON @'[]
errorMw :: forall ctyp opts. HasErrorBody ctyp opts => Middleware
errorMw :: Middleware
errorMw baseApp :: Application
baseApp req :: Request
req respond :: Response -> IO ResponseReceived
respond =
Application
baseApp Request
req ((Response -> IO ResponseReceived) -> IO ResponseReceived)
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \ response :: Response
response -> do
let status :: Status
status = Response -> Status
responseStatus Response
response
mcontentType :: Maybe Header
mcontentType = Response -> Maybe Header
getContentTypeHeader Response
response
case (Status
status, Maybe Header
mcontentType) of
(Status code :: Int
code _, Nothing) | Int
code Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= 400 Bool -> Bool -> Bool
&& Int
code Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< 600 ->
Status -> Response -> IO Response
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
Status -> Response -> IO Response
newResponse @ctyp @opts Status
status Response
response IO Response
-> (Response -> IO ResponseReceived) -> IO ResponseReceived
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Response -> IO ResponseReceived
respond
_ -> Response -> IO ResponseReceived
respond Response
response
where
getContentTypeHeader :: Response -> Maybe Header
getContentTypeHeader :: Response -> Maybe Header
getContentTypeHeader = (Header -> Bool) -> [Header] -> Maybe Header
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((HeaderName
hContentType HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
==) (HeaderName -> Bool) -> (Header -> HeaderName) -> Header -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> HeaderName
forall a b. (a, b) -> a
fst) ([Header] -> Maybe Header)
-> (Response -> [Header]) -> Response -> Maybe Header
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Response -> [Header]
responseHeaders
newResponse
:: forall ctyp opts . HasErrorBody ctyp opts
=> Status
-> Response
-> IO Response
newResponse :: Status -> Response -> IO Response
newResponse status :: Status
status@(Status code :: Int
code statusMsg :: ByteString
statusMsg) response :: Response
response = do
ByteString
body <- Response -> IO ByteString
responseBody Response
response
let header :: Header
header = (HeaderName
hContentType, MediaType -> ByteString
forall h. RenderHeader h => h -> ByteString
M.renderHeader (MediaType -> ByteString) -> MediaType -> ByteString
forall a b. (a -> b) -> a -> b
$ Proxy ctyp -> MediaType
forall k (ctype :: k). Accept ctype => Proxy ctype -> MediaType
contentType (Proxy ctyp
forall k (t :: k). Proxy t
Proxy @ctyp) )
content :: ErrorMsg
content = Text -> ErrorMsg
ErrorMsg (Text -> ErrorMsg)
-> (ByteString -> Text) -> ByteString -> ErrorMsg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> ErrorMsg) -> ByteString -> ErrorMsg
forall a b. (a -> b) -> a -> b
$ if ByteString
body ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
forall a. Monoid a => a
mempty then ByteString
statusMsg else ByteString
body
newContent :: ByteString
newContent = StatusCode -> ErrorMsg -> ByteString
forall ctyp (opts :: [Symbol]).
HasErrorBody ctyp opts =>
StatusCode -> ErrorMsg -> ByteString
encodeError @ctyp @opts (Int -> StatusCode
StatusCode Int
code) ErrorMsg
content
Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$ Status -> [Header] -> ByteString -> Response
responseLBS Status
status [Header
header] ByteString
newContent
responseBody :: Response -> IO B.ByteString
responseBody :: Response -> IO ByteString
responseBody res :: Response
res =
let (_status :: Status
_status, _headers :: [Header]
_headers, streamBody :: (StreamingBody -> IO a) -> IO a
streamBody) = Response -> (Status, [Header], (StreamingBody -> IO a) -> IO a)
forall a.
Response -> (Status, [Header], (StreamingBody -> IO a) -> IO a)
responseToStream Response
res in
(StreamingBody -> IO ByteString) -> IO ByteString
forall a. (StreamingBody -> IO a) -> IO a
streamBody ((StreamingBody -> IO ByteString) -> IO ByteString)
-> (StreamingBody -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \f :: StreamingBody
f -> do
IORef Builder
content <- Builder -> IO (IORef Builder)
forall a. a -> IO (IORef a)
newIORef Builder
forall a. Monoid a => a
mempty
StreamingBody
f (\chunk :: Builder
chunk -> IORef Builder -> (Builder -> Builder) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef Builder
content (Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
chunk)) (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
ByteString -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (ByteString -> ByteString)
-> (Builder -> ByteString) -> Builder -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> ByteString
toLazyByteString (Builder -> ByteString) -> IO Builder -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef Builder -> IO Builder
forall a. IORef a -> IO a
readIORef IORef Builder
content
encodeAsJsonError :: ErrorLabels -> StatusCode -> ErrorMsg -> LB.ByteString
encodeAsJsonError :: ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsJsonError ErrorLabels {..} code :: StatusCode
code content :: ErrorMsg
content =
Value -> ByteString
forall a. ToJSON a => a -> ByteString
encode (Value -> ByteString) -> Value -> ByteString
forall a b. (a -> b) -> a -> b
$ Object -> Value
Object
(Object -> Value) -> Object -> Value
forall a b. (a -> b) -> a -> b
$ [(Text, Value)] -> Object
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
H.fromList
[ (Text
errName, Text -> Value
String (Text -> Value) -> Text -> Value
forall a b. (a -> b) -> a -> b
$ ErrorMsg -> Text
unErrorMsg ErrorMsg
content)
, (Text
errStatusName, Scientific -> Value
Number (Scientific -> Value) -> Scientific -> Value
forall a b. (a -> b) -> a -> b
$ StatusCode -> Scientific
toScientific StatusCode
code )
]
where
toScientific :: StatusCode -> Scientific
toScientific :: StatusCode -> Scientific
toScientific = Integer -> Scientific
forall a. Num a => Integer -> a
fromInteger (Integer -> Scientific)
-> (StatusCode -> Integer) -> StatusCode -> Scientific
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integral Int, Num Integer) => Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @Integer (Int -> Integer) -> (StatusCode -> Int) -> StatusCode -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StatusCode -> Int
unStatusCode
encodeAsPlainText :: ErrorLabels -> StatusCode -> ErrorMsg -> LB.ByteString
encodeAsPlainText :: ErrorLabels -> StatusCode -> ErrorMsg -> ByteString
encodeAsPlainText ErrorLabels {..} code :: StatusCode
code content :: ErrorMsg
content =
Text -> ByteString
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text
errName
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> ErrorMsg -> Text
unErrorMsg ErrorMsg
content
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
errStatusName
Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (Int -> String
forall a. Show a => a -> String
show (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ StatusCode -> Int
unStatusCode StatusCode
code)
getErrorLabels
:: forall errLabel statusLabel .(KnownSymbol errLabel, KnownSymbol statusLabel)
=> ErrorLabels
getErrorLabels :: ErrorLabels
getErrorLabels = Text -> Text -> ErrorLabels
ErrorLabels (Proxy errLabel -> Text
forall (t :: Symbol). KnownSymbol t => Proxy t -> Text
label (Proxy errLabel
forall k (t :: k). Proxy t
Proxy @errLabel)) (Proxy statusLabel -> Text
forall (t :: Symbol). KnownSymbol t => Proxy t -> Text
label (Proxy statusLabel
forall k (t :: k). Proxy t
Proxy @statusLabel))
where
label :: KnownSymbol t => Proxy t -> T.Text
label :: Proxy t -> Text
label proxy :: Proxy t
proxy = String -> Text
forall a b. ConvertibleStrings a b => a -> b
cs (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Proxy t -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal Proxy t
proxy