--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

{-# LANGUAGE LambdaCase #-}

-- | This module provides a data type for IDP metadata containing certificate,
-- SSO URLs etc.
--
-- @since 0.4
module Network.Wai.SAML2.EntityDescriptor (
    IDPSSODescriptor(..),
    Binding(..)
) where

--------------------------------------------------------------------------------

import qualified Data.ByteString.Base64 as Base64
import qualified Data.X509 as X509
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T

import Network.Wai.SAML2.XML

import Text.XML.Cursor

--------------------------------------------------------------------------------

-- | Describes metadata of an identity provider.
-- See also section 2.4.3 of [Metadata for the OASIS Security Assertion Markup Language (SAML) V2.0](https://docs.oasis-open.org/security/saml/v2.0/saml-metadata-2.0-os.pdf).
data IDPSSODescriptor
    = IDPSSODescriptor {
        -- | IdP Entity ID. 'Network.Wai.SAML2.Config.saml2ExpectedIssuer' should be compared against this identifier
        IDPSSODescriptor -> Text
entityID :: Text
        -- | The X.509 certificate for signed assertions
    ,   IDPSSODescriptor -> SignedExact Certificate
x509Certificate :: X509.SignedExact X509.Certificate
        -- | Supported NameID formats
    ,   IDPSSODescriptor -> [Text]
nameIDFormats :: [Text]
        -- | List of SSO urls corresponding to 'Binding's
    ,   IDPSSODescriptor -> [(Binding, Text)]
singleSignOnServices :: [(Binding, Text)]
    } deriving Int -> IDPSSODescriptor -> ShowS
[IDPSSODescriptor] -> ShowS
IDPSSODescriptor -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IDPSSODescriptor] -> ShowS
$cshowList :: [IDPSSODescriptor] -> ShowS
show :: IDPSSODescriptor -> String
$cshow :: IDPSSODescriptor -> String
showsPrec :: Int -> IDPSSODescriptor -> ShowS
$cshowsPrec :: Int -> IDPSSODescriptor -> ShowS
Show

-- | urn:oasis:names:tc:SAML:2.0:bindings
-- https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
data Binding
    -- | SAML protocol messages are transmitted within the base64-encoded content of an HTML form control
    = HTTPPost
    -- | SAML protocol messages are transmitted within URL parameters
    | HTTPRedirect
    -- | The request and/or response are transmitted by reference using a small stand-in called an artifact
    | HTTPArtifact
    -- | Reverse HTTP Binding for SOAP specification
    | PAOS
    -- | SOAP is a lightweight protocol intended for exchanging structured information in a decentralized, distributed environment
    | SOAP
    -- | SAML protocol messages are encoded into a URL via the DEFLATE compression method
    | URLEncodingDEFLATE
    deriving (Int -> Binding -> ShowS
[Binding] -> ShowS
Binding -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Binding] -> ShowS
$cshowList :: [Binding] -> ShowS
show :: Binding -> String
$cshow :: Binding -> String
showsPrec :: Int -> Binding -> ShowS
$cshowsPrec :: Int -> Binding -> ShowS
Show, Binding -> Binding -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Binding -> Binding -> Bool
$c/= :: Binding -> Binding -> Bool
== :: Binding -> Binding -> Bool
$c== :: Binding -> Binding -> Bool
Eq)

instance FromXML IDPSSODescriptor where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m IDPSSODescriptor
parseXML Cursor
cursor = do
        let entityID :: Text
entityID = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"entityID" Cursor
cursor
        Cursor
descriptor <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"IDPSSODescriptor is required"
            forall a b. (a -> b) -> a -> b
$ Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
mdName Text
"IDPSSODescriptor")
        Text
rawCertificate <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"X509Certificate is required" forall a b. (a -> b) -> a -> b
$ Cursor
descriptor
            forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
mdName Text
"KeyDescriptor")
            forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"KeyInfo")
            forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"X509Data")
            forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"X509Certificate")
            forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content
        SignedExact Certificate
x509Certificate <- forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall (f :: * -> *) a. Applicative f => a -> f a
pure
            forall a b. (a -> b) -> a -> b
$ forall a.
(Show a, Eq a, ASN1Object a) =>
ByteString -> Either String (SignedExact a)
X509.decodeSignedObject
            forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Base64.decodeLenient
            forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 Text
rawCertificate
        let nameIDFormats :: [Text]
nameIDFormats = Cursor
descriptor
                forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
mdName Text
"NameIDFormat")
                forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content
        [(Binding, Text)]
singleSignOnServices <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadFail m => Cursor -> m (Binding, Text)
parseService
            forall a b. (a -> b) -> a -> b
$ Cursor
descriptor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
mdName Text
"SingleSignOnService")
        forall (f :: * -> *) a. Applicative f => a -> f a
pure IDPSSODescriptor{[(Binding, Text)]
[Text]
Text
SignedExact Certificate
singleSignOnServices :: [(Binding, Text)]
nameIDFormats :: [Text]
x509Certificate :: SignedExact Certificate
entityID :: Text
singleSignOnServices :: [(Binding, Text)]
nameIDFormats :: [Text]
x509Certificate :: SignedExact Certificate
entityID :: Text
..}

-- | `parseService` @cursor@ attempts to parse a pair of a `Binding` value
-- and a location given as a `Text` value from the XML @cursor@.
parseService :: MonadFail m => Cursor -> m (Binding, Text)
parseService :: forall (m :: * -> *). MonadFail m => Cursor -> m (Binding, Text)
parseService Cursor
cursor = do
    Binding
binding <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Binding is required" (Name -> Cursor -> [Text]
attribute Name
"Binding" Cursor
cursor)
        forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *). MonadFail m => Text -> m Binding
parseBinding
    Text
location <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Location is required" forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Location" Cursor
cursor
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binding
binding, Text
location)

-- | `parseBinding` @uri@ attempts to parse a `Binding` value from @uri@.
parseBinding :: MonadFail m => Text -> m Binding
parseBinding :: forall (m :: * -> *). MonadFail m => Text -> m Binding
parseBinding = \case
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
HTTPArtifact
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
HTTPPost
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
HTTPRedirect
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:PAOS" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
PAOS
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:SOAP" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
SOAP
    Text
"urn:oasis:names:tc:SAML:2.0:bindings:URL-Encoding:DEFLATE"
        -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Binding
URLEncodingDEFLATE
    Text
other -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Unknown Binding: " forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
other

--------------------------------------------------------------------------------