--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | Types to reprsent SAML2 responses.
module Network.Wai.SAML2.Response (
    -- * SAML2 responses
    Response(..),
    removeSignature,
    extractSignedInfo,
    extractPrefixList,

    -- * Re-exports
    module Network.Wai.SAML2.StatusCode,
    module Network.Wai.SAML2.Signature
) where

--------------------------------------------------------------------------------

import Data.Maybe (listToMaybe)
import qualified Data.Text as T
import Data.Time

import Text.XML
import Text.XML.Cursor

import Network.Wai.SAML2.Assertion
import Network.Wai.SAML2.XML
import Network.Wai.SAML2.XML.Encrypted
import Network.Wai.SAML2.StatusCode
import Network.Wai.SAML2.Signature

--------------------------------------------------------------------------------

-- | Represents SAML2 responses.

-- Reference [StatusResponseType]
data Response = Response {
    -- | The intended destination of this response.
    Response -> Text
responseDestination :: !T.Text,
    -- | The ID of the request this responds corresponds to, if any.
    --
    -- @since 0.4
    Response -> Maybe Text
responseInResponseTo :: !(Maybe T.Text),
    -- | The unique ID of the response.
    Response -> Text
responseId :: !T.Text,
    -- | The timestamp when the response was issued.
    Response -> UTCTime
responseIssueInstant :: !UTCTime,
    -- | The SAML version.
    Response -> Text
responseVersion :: !T.Text,
    -- | The name of the issuer.
    Response -> Text
responseIssuer :: !T.Text,
    -- | The status of the response.
    Response -> StatusCode
responseStatusCode :: !StatusCode,
    -- | The response signature.
    Response -> Signature
responseSignature :: !Signature,
    -- | The unencrypted assertion.
    --
    -- @since 0.4
    Response -> Maybe Assertion
responseAssertion :: !(Maybe Assertion),
    -- | The encrypted assertion.
    --
    -- @since 0.4
    Response -> Maybe EncryptedAssertion
responseEncryptedAssertion :: !(Maybe EncryptedAssertion)
} deriving (Response -> Response -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Response -> Response -> Bool
$c/= :: Response -> Response -> Bool
== :: Response -> Response -> Bool
$c== :: Response -> Response -> Bool
Eq, Int -> Response -> ShowS
[Response] -> ShowS
Response -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Response] -> ShowS
$cshowList :: [Response] -> ShowS
show :: Response -> String
$cshow :: Response -> String
showsPrec :: Int -> Response -> ShowS
$cshowsPrec :: Int -> Response -> ShowS
Show)

instance FromXML Response where
    -- Reference [StatusResponseType]
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m Response
parseXML Cursor
cursor = do
        UTCTime
issueInstant <- forall (m :: * -> *). MonadFail m => Text -> m UTCTime
parseUTCTime
                      forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat
                      forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"IssueInstant" Cursor
cursor

        StatusCode
statusCode <- case forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML Cursor
cursor of
            Maybe StatusCode
Nothing -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid status code"
            Just StatusCode
sc -> forall (f :: * -> *) a. Applicative f => a -> f a
pure StatusCode
sc

        let assertion :: Maybe Assertion
assertion = forall a. [a] -> Maybe a
listToMaybe
                    forall a b. (a -> b) -> a -> b
$ ( Cursor
cursor
                    forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Axis
element (Text -> Name
saml2Name Text
"Assertion")
                    ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        let encAssertion :: Maybe EncryptedAssertion
encAssertion = forall a. [a] -> Maybe a
listToMaybe
                    forall a b. (a -> b) -> a -> b
$ ( Cursor
cursor
                    forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Axis
element (Text -> Name
saml2Name Text
"EncryptedAssertion")
                    ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        Signature
signature <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Signature is required" (
            Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Signature") ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        forall (f :: * -> *) a. Applicative f => a -> f a
pure Response{
            responseDestination :: Text
responseDestination = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Destination" Cursor
cursor,
            responseId :: Text
responseId = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"ID" Cursor
cursor,
            responseInResponseTo :: Maybe Text
responseInResponseTo = forall a. [a] -> Maybe a
listToMaybe forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"InResponseTo" Cursor
cursor,
            responseIssueInstant :: UTCTime
responseIssueInstant = UTCTime
issueInstant,
            responseVersion :: Text
responseVersion = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Version" Cursor
cursor,
            responseIssuer :: Text
responseIssuer = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$
                Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
saml2Name Text
"Issuer") forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content,
            responseStatusCode :: StatusCode
responseStatusCode = StatusCode
statusCode,
            responseSignature :: Signature
responseSignature = Signature
signature,
            responseAssertion :: Maybe Assertion
responseAssertion = Maybe Assertion
assertion,
            responseEncryptedAssertion :: Maybe EncryptedAssertion
responseEncryptedAssertion = Maybe EncryptedAssertion
encAssertion
        }

--------------------------------------------------------------------------------

-- | Returns 'True' if the argument is not a @<Signature>@ element.
isNotSignature :: Node -> Bool
isNotSignature :: Node -> Bool
isNotSignature (NodeElement Element
e) = Element -> Name
elementName Element
e forall a. Eq a => a -> a -> Bool
/= Text -> Name
dsName Text
"Signature"
isNotSignature Node
_ = Bool
True

-- | 'removeSignature' @document@ removes all @<Signature>@ elements from
-- @document@ and returns the resulting document.
removeSignature :: Document -> Document
removeSignature :: Document -> Document
removeSignature (Document Prologue
prologue Element
root [Miscellaneous]
misc) =
    let Element Name
n Map Name Text
attr [Node]
ns = Element
root
    in Prologue -> Element -> [Miscellaneous] -> Document
Document Prologue
prologue (Name -> Map Name Text -> [Node] -> Element
Element Name
n Map Name Text
attr (forall a. (a -> Bool) -> [a] -> [a]
filter Node -> Bool
isNotSignature [Node]
ns)) [Miscellaneous]
misc

-- | Returns all nodes at @cursor@.
nodes :: MonadFail m => Cursor -> m Node
nodes :: forall (m :: * -> *). MonadFail m => Cursor -> m Node
nodes = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall node. Cursor node -> node
node

-- | 'extractSignedInfo' @cursor@ extracts the SignedInfo element from the
-- document reprsented by @cursor@.
extractSignedInfo :: MonadFail m => Cursor -> m Element
extractSignedInfo :: forall (m :: * -> *). MonadFail m => Cursor -> m Element
extractSignedInfo Cursor
cursor = do
    NodeElement Element
signedInfo <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignedInfo is required"
                            ( Cursor
cursor
                           forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Signature")
                           forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"SignedInfo")
                          ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *). MonadFail m => Cursor -> m Node
nodes
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Element
signedInfo

-- | Obtain a list of InclusiveNamespaces entries used for exclusive XML canonicalisation.
--
-- @since 0.5
extractPrefixList :: Cursor -> [T.Text]
extractPrefixList :: Cursor -> [Text]
extractPrefixList Cursor
cursor = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> [Text]
T.words
    forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Name -> Cursor -> [Text]
attribute Name
"PrefixList")
    forall a b. (a -> b) -> a -> b
$ Cursor
cursor
    forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Reference")
    forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"Transforms")
    forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"Transform")
    forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
ecName Text
"InclusiveNamespaces")

--------------------------------------------------------------------------------

-- Reference [StatusResponseType]
--   Source: https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38
--   Section: 3.2.2 Complex Type StatusResponseType