{-# LANGUAGE DataKinds    #-}
{-# LANGUAGE TypeFamilies #-}

-- | Servant server authentication.

module Servant.Auth.Hmac.Server
       ( HmacAuth
       , HmacAuthContextHandlers
       , HmacAuthContext
       , HmacAuthHandler
       , hmacAuthServerContext
       , hmacAuthHandler
       , hmacAuthHandlerMap
       ) where

import Control.Monad.Except (throwError)
import Data.ByteString (ByteString)
import Data.Maybe (fromMaybe)
import Network.Wai (rawPathInfo, rawQueryString, requestHeaderHost, requestHeaders, requestMethod)
import Servant (Context ((:.), EmptyContext))
import Servant.API (AuthProtect)
import Servant.Server (Handler, err401, errBody)
import Servant.Server.Experimental.Auth (AuthHandler, AuthServerData, mkAuthHandler)

import Servant.Auth.Hmac.Crypto (RequestPayload (..), SecretKey, Signature, keepWhitelistedHeaders,
                                 verifySignatureHmac)

import qualified Network.Wai as Wai (Request)


type HmacAuth = AuthProtect "hmac-auth"

type instance AuthServerData HmacAuth = ()

type HmacAuthHandler = AuthHandler Wai.Request ()
type HmacAuthContextHandlers = '[HmacAuthHandler]
type HmacAuthContext = Context HmacAuthContextHandlers

hmacAuthServerContext
    :: (SecretKey -> ByteString -> Signature)  -- ^ Signing function
    -> SecretKey  -- ^ Secret key that was used for signing 'Request'
    -> HmacAuthContext
hmacAuthServerContext :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> HmacAuthContext
hmacAuthServerContext SecretKey -> ByteString -> Signature
signer SecretKey
sk = (SecretKey -> ByteString -> Signature)
-> SecretKey -> HmacAuthHandler
hmacAuthHandler SecretKey -> ByteString -> Signature
signer SecretKey
sk HmacAuthHandler -> Context '[] -> HmacAuthContext
forall x (xs :: [*]). x -> Context xs -> Context (x : xs)
:. Context '[]
EmptyContext

-- | Create 'HmacAuthHandler' from signing function and secret key.
hmacAuthHandler
    :: (SecretKey -> ByteString -> Signature)  -- ^ Signing function
    -> SecretKey  -- ^ Secret key that was used for signing 'Request'
    -> HmacAuthHandler
hmacAuthHandler :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> HmacAuthHandler
hmacAuthHandler = (Request -> Handler Request)
-> (SecretKey -> ByteString -> Signature)
-> SecretKey
-> HmacAuthHandler
hmacAuthHandlerMap Request -> Handler Request
forall (f :: * -> *) a. Applicative f => a -> f a
pure

{- | Like 'hmacAuthHandler' but allows to specify additional mapping function
for 'Wai.Request'. This can be useful if you want to print incoming request (for
logging purposes) or filter some headers (to match signature). Given function is
applied before signature verification.
-}
hmacAuthHandlerMap
    :: (Wai.Request -> Handler Wai.Request)  -- ^ Request mapper
    -> (SecretKey -> ByteString -> Signature)  -- ^ Signing function
    -> SecretKey  -- ^ Secret key that was used for signing 'Request'
    -> HmacAuthHandler
hmacAuthHandlerMap :: (Request -> Handler Request)
-> (SecretKey -> ByteString -> Signature)
-> SecretKey
-> HmacAuthHandler
hmacAuthHandlerMap Request -> Handler Request
mapper SecretKey -> ByteString -> Signature
signer SecretKey
sk = (Request -> Handler ()) -> HmacAuthHandler
forall r usr. (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler Request -> Handler ()
handler
  where
    handler :: Wai.Request -> Handler ()
    handler :: Request -> Handler ()
handler Request
req = do
        Request
newReq <- Request -> Handler Request
mapper Request
req
        let payload :: RequestPayload
payload = Request -> RequestPayload
waiRequestToPayload Request
newReq
        let verification :: Maybe ByteString
verification = (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Maybe ByteString
verifySignatureHmac SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
payload
        case Maybe ByteString
verification of
            Maybe ByteString
Nothing -> () -> Handler ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Just ByteString
bs -> ServerError -> Handler ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler ()) -> ServerError -> Handler ()
forall a b. (a -> b) -> a -> b
$ ServerError
err401 { errBody :: ByteString
errBody = ByteString
bs }

----------------------------------------------------------------------------
-- Internals
----------------------------------------------------------------------------

-- getWaiRequestBody :: Wai.Request -> IO ByteString
-- getWaiRequestBody request = BS.concat <$> getChunks
--   where
--     getChunks :: IO [ByteString]
--     getChunks = requestBody request >>= \chunk ->
--         if chunk == BS.empty
--         then pure []
--         else (chunk:) <$> getChunks

waiRequestToPayload :: Wai.Request -> RequestPayload
-- waiRequestToPayload req = getWaiRequestBody req >>= \body -> pure RequestPayload
waiRequestToPayload :: Request -> RequestPayload
waiRequestToPayload Request
req = RequestPayload :: ByteString
-> ByteString -> RequestHeaders -> ByteString -> RequestPayload
RequestPayload
    { rpMethod :: ByteString
rpMethod  = Request -> ByteString
requestMethod Request
req
    , rpContent :: ByteString
rpContent = ByteString
""
    , rpHeaders :: RequestHeaders
rpHeaders = RequestHeaders -> RequestHeaders
keepWhitelistedHeaders (RequestHeaders -> RequestHeaders)
-> RequestHeaders -> RequestHeaders
forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req
    , rpRawUrl :: ByteString
rpRawUrl  = ByteString -> Maybe ByteString -> ByteString
forall a. a -> Maybe a -> a
fromMaybe ByteString
forall a. Monoid a => a
mempty (Request -> Maybe ByteString
requestHeaderHost Request
req) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
rawQueryString Request
req
    }