{-# LANGUAGE CPP #-}
{-# LANGUAGE UndecidableInstances #-}
module Servant.Auth.Server.Internal.ThrowAll where

#if !MIN_VERSION_servant_server(0,16,0)
#define ServerError ServantErr
#endif

import Control.Monad.Error.Class
import Data.Tagged               (Tagged (..))
import Servant                   ((:<|>) (..), ServerError(..))
import Network.HTTP.Types
import Network.Wai

import qualified Data.ByteString.Char8 as BS

class ThrowAll a where
  -- | 'throwAll' is a convenience function to throw errors across an entire
  -- sub-API
  --
  --
  -- > throwAll err400 :: Handler a :<|> Handler b :<|> Handler c
  -- >    == throwError err400 :<|> throwError err400 :<|> err400
  throwAll :: ServerError -> a

instance (ThrowAll a, ThrowAll b) => ThrowAll (a :<|> b) where
  throwAll :: ServerError -> a :<|> b
throwAll ServerError
e = ServerError -> a
forall a. ThrowAll a => ServerError -> a
throwAll ServerError
e a -> b -> a :<|> b
forall a b. a -> b -> a :<|> b
:<|> ServerError -> b
forall a. ThrowAll a => ServerError -> a
throwAll ServerError
e

-- Really this shouldn't be necessary - ((->) a) should be an instance of
-- MonadError, no?
instance {-# OVERLAPPING #-} ThrowAll b => ThrowAll (a -> b) where
  throwAll :: ServerError -> a -> b
throwAll ServerError
e = b -> a -> b
forall a b. a -> b -> a
const (b -> a -> b) -> b -> a -> b
forall a b. (a -> b) -> a -> b
$ ServerError -> b
forall a. ThrowAll a => ServerError -> a
throwAll ServerError
e

instance {-# OVERLAPPABLE #-} (MonadError ServerError m) => ThrowAll (m a) where
  throwAll :: ServerError -> m a
throwAll = ServerError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError

-- | for @servant <0.11@
instance {-# OVERLAPPING #-} ThrowAll Application where
  throwAll :: ServerError -> Application
throwAll ServerError
e Request
_req Response -> IO ResponseReceived
respond
      = Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS (Int -> ByteString -> Status
mkStatus (ServerError -> Int
errHTTPCode ServerError
e) (String -> ByteString
BS.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ ServerError -> String
errReasonPhrase ServerError
e))
                              (ServerError -> ResponseHeaders
errHeaders ServerError
e)
                              (ServerError -> ByteString
errBody ServerError
e)

-- | for @servant >=0.11@
instance {-# OVERLAPPING #-} MonadError ServerError m => ThrowAll (Tagged m Application) where
  throwAll :: ServerError -> Tagged m Application
throwAll ServerError
e = Application -> Tagged m Application
forall k (s :: k) b. b -> Tagged s b
Tagged (Application -> Tagged m Application)
-> Application -> Tagged m Application
forall a b. (a -> b) -> a -> b
$ \Request
_req Response -> IO ResponseReceived
respond ->
      Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseLBS (Int -> ByteString -> Status
mkStatus (ServerError -> Int
errHTTPCode ServerError
e) (String -> ByteString
BS.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ ServerError -> String
errReasonPhrase ServerError
e))
                              (ServerError -> ResponseHeaders
errHeaders ServerError
e)
                              (ServerError -> ByteString
errBody ServerError
e)