--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | Types representing elements of the encrypted XML standard.
-- See https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html
module Network.Wai.SAML2.XML.Encrypted (
    CipherData(..),
    EncryptionMethod(..),
    EncryptedKey(..),
    EncryptedAssertion(..)
) where

--------------------------------------------------------------------------------

import qualified Data.Text as T
import Data.Text.Encoding
import qualified Data.ByteString as BS

import Text.XML.Cursor

import Network.Wai.SAML2.XML
import Network.Wai.SAML2.KeyInfo

--------------------------------------------------------------------------------

-- | Represents some ciphertext.
data CipherData = CipherData {
    CipherData -> ByteString
cipherValue :: !BS.ByteString
} deriving (CipherData -> CipherData -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CipherData -> CipherData -> Bool
$c/= :: CipherData -> CipherData -> Bool
== :: CipherData -> CipherData -> Bool
$c== :: CipherData -> CipherData -> Bool
Eq, Int -> CipherData -> ShowS
[CipherData] -> ShowS
CipherData -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CipherData] -> ShowS
$cshowList :: [CipherData] -> ShowS
show :: CipherData -> String
$cshow :: CipherData -> String
showsPrec :: Int -> CipherData -> ShowS
$cshowsPrec :: Int -> CipherData -> ShowS
Show)

instance FromXML CipherData where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m CipherData
parseXML Cursor
cursor = forall (f :: * -> *) a. Applicative f => a -> f a
pure CipherData{
        cipherValue :: ByteString
cipherValue = Text -> ByteString
encodeUtf8
                    forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat
                    forall a b. (a -> b) -> a -> b
$ Cursor
cursor
                    forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
xencName Text
"CipherValue")
                    forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content
    }

--------------------------------------------------------------------------------

-- | Describes an encryption method.
data EncryptionMethod = EncryptionMethod {
    -- | The name of the algorithm.
    EncryptionMethod -> Text
encryptionMethodAlgorithm :: !T.Text,
    -- | The name of the digest algorithm, if any.
    EncryptionMethod -> Maybe Text
encryptionMethodDigestAlgorithm :: !(Maybe T.Text)
} deriving (EncryptionMethod -> EncryptionMethod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncryptionMethod -> EncryptionMethod -> Bool
$c/= :: EncryptionMethod -> EncryptionMethod -> Bool
== :: EncryptionMethod -> EncryptionMethod -> Bool
$c== :: EncryptionMethod -> EncryptionMethod -> Bool
Eq, Int -> EncryptionMethod -> ShowS
[EncryptionMethod] -> ShowS
EncryptionMethod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncryptionMethod] -> ShowS
$cshowList :: [EncryptionMethod] -> ShowS
show :: EncryptionMethod -> String
$cshow :: EncryptionMethod -> String
showsPrec :: Int -> EncryptionMethod -> ShowS
$cshowsPrec :: Int -> EncryptionMethod -> ShowS
Show)

instance FromXML EncryptionMethod where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m EncryptionMethod
parseXML Cursor
cursor = forall (f :: * -> *) a. Applicative f => a -> f a
pure EncryptionMethod{
        encryptionMethodAlgorithm :: Text
encryptionMethodAlgorithm =
            [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor,
        encryptionMethodDigestAlgorithm :: Maybe Text
encryptionMethodDigestAlgorithm =
            [Text] -> Maybe Text
toMaybeText forall a b. (a -> b) -> a -> b
$ Cursor
cursor
                        forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"DigestMethod")
                       forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Name -> Cursor -> [Text]
attribute Name
"Algorithm"
    }

--------------------------------------------------------------------------------

-- | Represents an encrypted key.
data EncryptedKey = EncryptedKey {
    -- | The ID of the key.
    EncryptedKey -> Text
encryptedKeyId :: !T.Text,
    -- | The intended recipient of the key.
    EncryptedKey -> Text
encryptedKeyRecipient :: !T.Text,
    -- | The method used to encrypt the key.
    EncryptedKey -> EncryptionMethod
encryptedKeyMethod :: !EncryptionMethod,
    -- | The key data.
    EncryptedKey -> Maybe KeyInfo
encryptedKeyData :: !(Maybe KeyInfo),
    -- | The ciphertext.
    EncryptedKey -> CipherData
encryptedKeyCipher :: !CipherData
} deriving (EncryptedKey -> EncryptedKey -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncryptedKey -> EncryptedKey -> Bool
$c/= :: EncryptedKey -> EncryptedKey -> Bool
== :: EncryptedKey -> EncryptedKey -> Bool
$c== :: EncryptedKey -> EncryptedKey -> Bool
Eq, Int -> EncryptedKey -> ShowS
[EncryptedKey] -> ShowS
EncryptedKey -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncryptedKey] -> ShowS
$cshowList :: [EncryptedKey] -> ShowS
show :: EncryptedKey -> String
$cshow :: EncryptedKey -> String
showsPrec :: Int -> EncryptedKey -> ShowS
$cshowsPrec :: Int -> EncryptedKey -> ShowS
Show)

instance FromXML EncryptedKey where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m EncryptedKey
parseXML Cursor
cursor =  do
        EncryptionMethod
method <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"EncryptionMethod is required" (
            Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
xencName Text
"EncryptionMethod")
                ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        Maybe KeyInfo
keyData <- case Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"KeyInfo") of
                     [] -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                     (Cursor
keyInfo :[Cursor]
_) -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML Cursor
keyInfo

        CipherData
cipher <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"CipherData is required" (
            Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
xencName Text
"CipherData")
                ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        forall (f :: * -> *) a. Applicative f => a -> f a
pure EncryptedKey{
            encryptedKeyId :: Text
encryptedKeyId = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Id" Cursor
cursor,
            encryptedKeyRecipient :: Text
encryptedKeyRecipient = [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Recipient" Cursor
cursor,
            encryptedKeyMethod :: EncryptionMethod
encryptedKeyMethod = EncryptionMethod
method,
            encryptedKeyData :: Maybe KeyInfo
encryptedKeyData = Maybe KeyInfo
keyData,
            encryptedKeyCipher :: CipherData
encryptedKeyCipher = CipherData
cipher
        }

--------------------------------------------------------------------------------

-- | Represents an encrypted SAML assertion.
data EncryptedAssertion = EncryptedAssertion {
    -- | Information about the encryption method used.
    EncryptedAssertion -> EncryptionMethod
encryptedAssertionAlgorithm :: !EncryptionMethod,
    -- | The encrypted key.
    EncryptedAssertion -> EncryptedKey
encryptedAssertionKey :: !EncryptedKey,
    -- | The ciphertext.
    EncryptedAssertion -> CipherData
encryptedAssertionCipher :: !CipherData
} deriving (EncryptedAssertion -> EncryptedAssertion -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncryptedAssertion -> EncryptedAssertion -> Bool
$c/= :: EncryptedAssertion -> EncryptedAssertion -> Bool
== :: EncryptedAssertion -> EncryptedAssertion -> Bool
$c== :: EncryptedAssertion -> EncryptedAssertion -> Bool
Eq, Int -> EncryptedAssertion -> ShowS
[EncryptedAssertion] -> ShowS
EncryptedAssertion -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncryptedAssertion] -> ShowS
$cshowList :: [EncryptedAssertion] -> ShowS
show :: EncryptedAssertion -> String
$cshow :: EncryptedAssertion -> String
showsPrec :: Int -> EncryptedAssertion -> ShowS
$cshowsPrec :: Int -> EncryptedAssertion -> ShowS
Show)

instance FromXML EncryptedAssertion where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m EncryptedAssertion
parseXML Cursor
cursor = do
        Cursor
encryptedData <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"EncryptedData is required"
            forall a b. (a -> b) -> a -> b
$   Cursor
cursor
            forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Axis
element (Text -> Name
xencName Text
"EncryptedData")

        EncryptionMethod
algorithm <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Algorithm is required"
            forall a b. (a -> b) -> a -> b
$   Cursor
encryptedData
            forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Axis
element (Text -> Name
xencName Text
"EncryptionMethod")
            forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        EncryptedKey
keyInfo <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"EncryptedKey is required" forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat
            [ Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
xencName Text
"EncryptedKey")
            forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML
            , Cursor
cursor
                forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
xencName Text
"EncryptedData")
                forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
dsName Text
"KeyInfo")
                forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Name -> Axis
element (Text -> Name
xencName Text
"EncryptedKey")
            forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML
            ]

        CipherData
cipher <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"CipherData is required"
               (  Cursor
encryptedData
              forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/  Name -> Axis
element (Text -> Name
xencName Text
"CipherData")
              ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        forall (f :: * -> *) a. Applicative f => a -> f a
pure EncryptedAssertion{
            encryptedAssertionAlgorithm :: EncryptionMethod
encryptedAssertionAlgorithm = EncryptionMethod
algorithm,
            encryptedAssertionKey :: EncryptedKey
encryptedAssertionKey = EncryptedKey
keyInfo,
            encryptedAssertionCipher :: CipherData
encryptedAssertionCipher = CipherData
cipher
        }

--------------------------------------------------------------------------------