-------------------------------------------------------------------------------- -- SAML2 Middleware for WAI -- -------------------------------------------------------------------------------- -- This source code is licensed under the MIT license found in the LICENSE -- -- file in the root directory of this source tree. -- -------------------------------------------------------------------------------- -- | Implements WAI 'Middleware' for SAML2 service providers. Two different -- interfaces are supported (with equivalent functionality): one which simply -- stores the outcome of the validation process in the request vault and one -- which passes the outcome to a callback. module Network.Wai.SAML2 ( -- * Callback-based middleware -- -- $callbackBasedMiddleware Result(..), saml2Callback, -- * Vault-based middleware -- -- $vaultBasedMiddleware assertionKey, errorKey, saml2Vault, relayStateKey, -- * Re-exports module Network.Wai.SAML2.Config, module Network.Wai.SAML2.Error, module Network.Wai.SAML2.Assertion ) where -------------------------------------------------------------------------------- import qualified Data.ByteString as BS import Data.Functor ((<&>)) import Data.Maybe (fromMaybe) import qualified Data.Vault.Lazy as V import Network.Wai import Network.Wai.Parse import Network.Wai.SAML2.Config import Network.Wai.SAML2.Validation import Network.Wai.SAML2.Assertion import Network.Wai.SAML2.Error import qualified Network.Wai.SAML2.Response as SAML2 import System.IO.Unsafe (unsafePerformIO) -------------------------------------------------------------------------------- -- | Checks whether the request method of @request@ is @"POST"@. isPOST :: Request -> Bool isPOST :: Request -> Bool isPOST = (forall a. Eq a => a -> a -> Bool ==Method "POST") forall b c a. (b -> c) -> (a -> b) -> a -> c . Request -> Method requestMethod -------------------------------------------------------------------------------- -- $callbackBasedMiddleware -- -- This 'Middleware' provides a SAML2 service provider (SP) implementation -- that can be wrapped around an existing WAI 'Application'. The middleware is -- parameterised over the SAML2 configuration and a callback. If the middleware -- intercepts a request made to the endpoint given by the SAML2 configuration, -- the result of validating the SAML2 response contained in the request body -- will be passed to the callback. -- -- > saml2Callback cfg callback mainApp -- > where callback (Left err) app req sendResponse = do -- > -- a POST request was made to the assertion endpoint, but -- > -- something went wrong, details of which are provided by -- > -- the error: this should probably be logged as it may -- > -- indicate that an attack was attempted against the -- > -- endpoint, but you *must* not show the error -- > -- to the client as it would severely compromise -- > -- system security -- > -- -- > -- you may also want to return e.g. a HTTP 400 or 401 status -- > -- > callback (Right result) app req sendResponse = do -- > -- a POST request was made to the assertion endpoint and the -- > -- SAML2 response was successfully validated: -- > -- you *must* check that you have not encountered the -- > -- assertion ID before; we assume that there is a -- > -- computation tryRetrieveAssertion which looks up -- > -- assertions by ID in e.g. a database -- > result <- tryRetrieveAssertion (assertionId (assertion result)) -- > -- > case result of -- > Just something -> -- a replay attack has occurred -- > Nothing -> do -- > -- store the assertion id somewhere -- > storeAssertion (assertionId (assertion result)) -- > -- > -- the assertion is valid and you can now e.g. -- > -- retrieve user data from your database -- > -- before proceeding with the request by e.g. -- > -- redirecting them to the main view -- | 'saml2Callback' @config callback@ produces SAML2 'Middleware' for -- the given @config@. If the middleware intercepts a request to the -- endpoint given by @config@, the result will be passed to @callback@. saml2Callback :: SAML2Config -> (Either SAML2Error Result -> Middleware) -> Middleware saml2Callback :: SAML2Config -> (Either SAML2Error Result -> Middleware) -> Middleware saml2Callback SAML2Config cfg Either SAML2Error Result -> Middleware callback Application app Request req Response -> IO ResponseReceived sendResponse = do let path :: Method path = Request -> Method rawPathInfo Request req -- check if we need to handle this request if Method path forall a. Eq a => a -> a -> Bool == SAML2Config -> Method saml2AssertionPath SAML2Config cfg Bool -> Bool -> Bool && Request -> Bool isPOST Request req then do -- default request parse options, but do not allow files; -- we are not expecting any let bodyOpts :: ParseRequestBodyOptions bodyOpts = Int -> ParseRequestBodyOptions -> ParseRequestBodyOptions setMaxRequestNumFiles Int 0 forall a b. (a -> b) -> a -> b $ Int64 -> ParseRequestBodyOptions -> ParseRequestBodyOptions setMaxRequestFileSize Int64 0 forall a b. (a -> b) -> a -> b $ ParseRequestBodyOptions defaultParseRequestBodyOptions -- parse the request ([Param] body, [File ByteString] _) <- forall y. ParseRequestBodyOptions -> BackEnd y -> Request -> IO ([Param], [File y]) parseRequestBodyEx ParseRequestBodyOptions bodyOpts forall (m :: * -> *) ignored1 ignored2. Monad m => ignored1 -> ignored2 -> m Method -> m ByteString lbsBackEnd Request req case forall a b. Eq a => a -> [(a, b)] -> Maybe b lookup Method "SAMLResponse" [Param] body of Just Method val -> do let rs :: Maybe Method rs = forall a b. Eq a => a -> [(a, b)] -> Maybe b lookup Method "RelayState" [Param] body Either SAML2Error Result result <- SAML2Config -> Method -> IO (Either SAML2Error (Assertion, Response)) validateResponse SAML2Config cfg Method val forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b <&> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap (\(Assertion assertion, Response response) -> Result{ assertion :: Assertion assertion = Assertion assertion, relayState :: Maybe Method relayState = Maybe Method rs, response :: Response response = Response response }) -- call the callback Either SAML2Error Result -> Middleware callback Either SAML2Error Result result Application app Request req Response -> IO ResponseReceived sendResponse -- the request does not contain the expected payload Maybe Method Nothing -> Either SAML2Error Result -> Middleware callback (forall a b. a -> Either a b Left SAML2Error InvalidRequest) Application app Request req Response -> IO ResponseReceived sendResponse -- not one of the paths we need to handle, pass the request on to the -- inner application else Application app Request req Response -> IO ResponseReceived sendResponse -------------------------------------------------------------------------------- -- $vaultBasedMiddleware -- -- This is a simpler-to-use 'Middleware' which stores the outcome of a request -- made to the assertation endpoint in the request vault. The inner WAI -- application can then check of the presence of an assertion or an error with -- 'V.lookup' and 'assertionKey' or 'errorKey' respectively. At most one of -- the two locations will be populated for a given request, i.e. it is not -- possible for an assertion to be validated and an error to occur. -- -- > saml2Vault cfg $ \app req sendResponse -> do -- > case V.lookup errorKey (vault req) of -- > Just err -> -- > -- log the error, but you *must* not show the error -- > -- to the client as it would severely compromise -- > -- system security -- > Nothing -> pure () -- carry on -- > -- > case V.lookup assertionKey (vault req) of -- > Nothing -> pure () -- carry on -- > Just assertion -> do -- > -- a valid assertion was processed by the middleware, -- > -- you *must* check that you have not encountered the -- > -- assertion ID before; we assume that there is a -- > -- computation tryRetrieveAssertion which looks up -- > -- assertions by ID in e.g. a database -- > result <- tryRetrieveAssertion (assertionId assertion) -- > -- > case result of -- > Just something -> -- a replay attack has occurred -- > Nothing -> do -- > -- store the assertion id somewhere -- > storeAssertion (assertionId assertion) -- > -- > -- the assertion is valid -- | 'assertionKey' is a vault key for retrieving assertions from -- request vaults if the 'saml2Vault' 'Middleware' is used. assertionKey :: V.Key Assertion assertionKey :: Key Assertion assertionKey = forall a. IO a -> a unsafePerformIO forall a. IO (Key a) V.newKey -- | 'relayStateKey' is a vault key for retrieving the relay state -- from request vaults if the 'saml2Vault' 'Middleware' is used -- and the assertion is valid. relayStateKey :: V.Key BS.ByteString relayStateKey :: Key Method relayStateKey = forall a. IO a -> a unsafePerformIO forall a. IO (Key a) V.newKey -- | 'errorKey' is a vault key for retrieving SAML2 errors from request vaults -- if the 'saml2Vault' 'Middleware' is used. errorKey :: V.Key SAML2Error errorKey :: Key SAML2Error errorKey = forall a. IO a -> a unsafePerformIO forall a. IO (Key a) V.newKey -- | 'saml2Vault' @config@ produces SAML2 'Middleware' for the given @config@. saml2Vault :: SAML2Config -> Middleware saml2Vault :: SAML2Config -> Middleware saml2Vault SAML2Config cfg = SAML2Config -> (Either SAML2Error Result -> Middleware) -> Middleware saml2Callback SAML2Config cfg forall {t} {t}. Either SAML2Error Result -> (Request -> t -> t) -> Request -> t -> t callback -- if the middleware intercepts a request containing a SAML2 response at -- the configured endpoint, the outcome of processing response will be -- passed to this callback: we store the result in the corresponding -- entry in the request vault where callback :: Either SAML2Error Result -> (Request -> t -> t) -> Request -> t -> t callback (Left SAML2Error err) Request -> t -> t app Request req t sendResponse = do Request -> t -> t app Request req{ vault :: Vault vault = forall a. Key a -> a -> Vault -> Vault V.insert Key SAML2Error errorKey SAML2Error err (Request -> Vault vault Request req) } t sendResponse callback (Right Result result) Request -> t -> t app Request req t sendResponse = do let mRelayState :: Maybe Method mRelayState = Result -> Maybe Method relayState Result result let vlt :: Vault vlt = Request -> Vault vault Request req Request -> t -> t app Request req{ vault :: Vault vault = forall a. Key a -> a -> Vault -> Vault V.insert Key Assertion assertionKey (Result -> Assertion assertion Result result) forall a b. (a -> b) -> a -> b $ forall a. a -> Maybe a -> a fromMaybe Vault vlt forall a b. (a -> b) -> a -> b $ Maybe Method mRelayState forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b >>= \Method rs -> forall (f :: * -> *) a. Applicative f => a -> f a pure forall a b. (a -> b) -> a -> b $ forall a. Key a -> a -> Vault -> Vault V.insert Key Method relayStateKey Method rs Vault vlt } t sendResponse -------------------------------------------------------------------------------- -- | Represents the result of validating a SAML2 response. data Result = Result { -- | An optional relay state, as provided in the POST request. Result -> Maybe Method relayState :: !(Maybe BS.ByteString), -- | The assertion obtained from the response that has been validated. Result -> Assertion assertion :: !Assertion, -- | The full response obtained from the IdP. -- -- @since 0.4 Result -> Response response :: !SAML2.Response } deriving (Result -> Result -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a /= :: Result -> Result -> Bool $c/= :: Result -> Result -> Bool == :: Result -> Result -> Bool $c== :: Result -> Result -> Bool Eq, Int -> Result -> ShowS [Result] -> ShowS Result -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a showList :: [Result] -> ShowS $cshowList :: [Result] -> ShowS show :: Result -> String $cshow :: Result -> String showsPrec :: Int -> Result -> ShowS $cshowsPrec :: Int -> Result -> ShowS Show) --------------------------------------------------------------------------------