{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE OverloadedStrings #-} #if __GLASGOW_HASKELL__ < 710 {-# LANGUAGE OverlappingInstances #-} #endif {-# LANGUAGE ScopedTypeVariables #-} module Servant.Server.Internal.BasicAuth where #if __GLASGOW_HASKELL__ < 710 import Data.Functor ((<$>)) #endif import Control.Monad (guard) import qualified Data.ByteString as BS import Data.ByteString.Base64 (decodeLenient) import Data.CaseInsensitive (CI(..)) import Data.Monoid ((<>)) import Data.Typeable (Typeable) import Data.Word8 (isSpace, toLower, _colon) import GHC.Generics import Snap.Core -- import Network.HTTP.Types (Header) -- import Network.Wai (Request, requestHeaders) import Servant.API.BasicAuth (BasicAuthData(BasicAuthData)) import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.ServantErr -- * Basic Auth -- | servant-server's current implementation of basic authentication is not -- immune to certian kinds of timing attacks. Decoding payloads does not take -- a fixed amount of time. -- | The result of authentication/authorization data BasicAuthResult usr = Unauthorized | BadPassword | NoSuchUser | Authorized usr deriving (Eq, Show, Read, Generic, Typeable, Functor) -- | Datatype wrapping a function used to check authentication. newtype BasicAuthCheck m usr = BasicAuthCheck { unBasicAuthCheck :: BasicAuthData -> m (BasicAuthResult usr) } deriving (Generic, Typeable, Functor) -- | Internal method to make a basic-auth challenge mkBAChallengerHdr :: BS.ByteString -> (CI BS.ByteString, BS.ByteString) mkBAChallengerHdr realm = ("WWW-Authenticate", "Basic realm=\"" <> realm <> "\"") -- | Find and decode an 'Authorization' header from the request as Basic Auth decodeBAHdr :: Request -> Maybe BasicAuthData decodeBAHdr req = do ah <- getHeader "Authorization" req let (b, rest) = BS.break isSpace ah guard (BS.map toLower b == "basic") let decoded = decodeLenient (BS.dropWhile isSpace rest) let (username, passWithColonAtHead) = BS.break (== _colon) decoded (_, password) <- BS.uncons passWithColonAtHead return (BasicAuthData username password) -- | Run and check basic authentication, returning the appropriate http error per -- the spec. runBasicAuth :: MonadSnap m => Request -> BS.ByteString -> BasicAuthCheck m usr -> DelayedM m usr runBasicAuth req realm (BasicAuthCheck ba) = case decodeBAHdr req of Nothing -> plzAuthenticate Just e -> DelayedM (const $ Route <$> ba e) >>= \res -> case res of BadPassword -> plzAuthenticate NoSuchUser -> plzAuthenticate Unauthorized -> delayedFailFatal err403 Authorized usr -> return usr where plzAuthenticate = delayedFailFatal err401 { errHeaders = [mkBAChallengerHdr realm] }