{-# LANGUAGE OverloadedStrings #-}

-- | Implements HTTP Bearer Token Authentication.
--
-- This module is based on 'Network.Wai.Middleware.HttpAuth'.

module Network.Wai.Middleware.BearerTokenAuth
  ( tokenAuth
  , tokenAuth'
  , tokenListAuth
  , CheckToken
  ) where

import Data.ByteString (ByteString)
import qualified Data.ByteString as S
import Data.Word8 (isSpace, toLower)
import Network.HTTP.Types (hAuthorization, hContentType, status401)
import Network.Wai (Middleware, Request(requestHeaders), Response, responseLBS)

-- | Check if a given token is valid.
type CheckToken = ByteString -> IO Bool

-- | Perform token authentication.
--
-- If the token is accepted, leave the Application unchanged.
-- Otherwise, send a @401 Unauthorized@ HTTP response.
--
-- > tokenAuth (\tok -> return $ tok == "abcd" )
tokenAuth :: CheckToken -> Middleware
tokenAuth :: CheckToken -> Middleware
tokenAuth CheckToken
checker = (Request -> CheckToken) -> Middleware
tokenAuth' (CheckToken -> Request -> CheckToken
forall a b. a -> b -> a
const CheckToken
checker)

-- | Like 'tokenAuth', but also passes a request to the authentication function.
--
tokenAuth' :: (Request -> CheckToken) -> Middleware
tokenAuth' :: (Request -> CheckToken) -> Middleware
tokenAuth' Request -> CheckToken
checkByReq Application
app Request
req Response -> IO ResponseReceived
sendRes = do
  let checker :: CheckToken
checker = Request -> CheckToken
checkByReq Request
req
  let pass :: IO ResponseReceived
pass = Application
app Request
req Response -> IO ResponseReceived
sendRes
  Bool
authorized <- CheckToken -> Request -> IO Bool
check CheckToken
checker Request
req
  if Bool
authorized
    then IO ResponseReceived
pass -- Pass the Application on successful auth
    else Response -> IO ResponseReceived
sendRes Response
rspUnauthorized -- Send a @401 Unauthorized@ response on failed auth

-- | Perform token authentication
-- based on a list of allowed tokens.
--
-- > tokenListAuth ["secret1", "secret2"]
tokenListAuth :: [ByteString] -> Middleware
tokenListAuth :: [ByteString] -> Middleware
tokenListAuth [ByteString]
tokens = CheckToken -> Middleware
tokenAuth (\ByteString
tok -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ ByteString
tok ByteString -> [ByteString] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
tokens)

check :: CheckToken -> Request -> IO Bool
check :: CheckToken -> Request -> IO Bool
check CheckToken
checkCreds Request
req =
  case Request -> Maybe ByteString
extractBearerFromRequest Request
req of
    Maybe ByteString
Nothing -> Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    Just ByteString
token -> CheckToken
checkCreds ByteString
token

rspUnauthorized :: Response
rspUnauthorized :: Response
rspUnauthorized =
  Status -> ResponseHeaders -> ByteString -> Response
responseLBS
    Status
status401
    [(HeaderName
hContentType, ByteString
"text/plain"), (HeaderName
"WWW-Authenticate", ByteString
"Bearer")]
    ByteString
"Bearer token authentication is required"

extractBearerFromRequest :: Request -> Maybe ByteString
extractBearerFromRequest :: Request -> Maybe ByteString
extractBearerFromRequest Request
req = do
  ByteString
authHeader <- HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hAuthorization (Request -> ResponseHeaders
requestHeaders Request
req)
  ByteString -> Maybe ByteString
extractBearerAuth ByteString
authHeader

-- | Extract bearer authentication data from __Authorization__ header
-- value. Returns bearer token
--
-- Source: https://hackage.haskell.org/package/wai-extra-3.1.11/docs/Network-Wai-Middleware-HttpAuth.html
extractBearerAuth :: ByteString -> Maybe ByteString
extractBearerAuth :: ByteString -> Maybe ByteString
extractBearerAuth ByteString
bs =
  let (ByteString
x, ByteString
y) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
S.break Word8 -> Bool
isSpace ByteString
bs
   in if (Word8 -> Word8) -> ByteString -> ByteString
S.map Word8 -> Word8
toLower ByteString
x ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"bearer"
        then ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ (Word8 -> Bool) -> ByteString -> ByteString
S.dropWhile Word8 -> Bool
isSpace ByteString
y
        else Maybe ByteString
forall a. Maybe a
Nothing