--------------------------------------------------------------------------------
-- SAML2 Middleware for WAI                                                   --
--------------------------------------------------------------------------------
-- This source code is licensed under the MIT license found in the LICENSE    --
-- file in the root directory of this source tree.                            --
--------------------------------------------------------------------------------

-- | SAML2 signatures.
module Network.Wai.SAML2.Signature (
    CanonicalisationMethod(..),
    SignatureMethod(..),
    DigestMethod(..),
    SignedInfo(..),
    Reference(..),
    Signature(..)
) where

--------------------------------------------------------------------------------

import qualified Data.ByteString as BS
import qualified Data.Text as T
import Data.Text.Encoding

import Text.XML.Cursor

import Network.Wai.SAML2.XML

--------------------------------------------------------------------------------

-- | Enumerates XML canonicalisation methods.
data CanonicalisationMethod
    -- | Original C14N 1.0 specification.
    = C14N_1_0
    -- | Exclusive C14N 1.0 specification.
    | C14N_EXC_1_0
    -- | C14N 1.1 specification.
    | C14N_1_1
    deriving (CanonicalisationMethod -> CanonicalisationMethod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
$c/= :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
== :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
$c== :: CanonicalisationMethod -> CanonicalisationMethod -> Bool
Eq, Int -> CanonicalisationMethod -> ShowS
[CanonicalisationMethod] -> ShowS
CanonicalisationMethod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CanonicalisationMethod] -> ShowS
$cshowList :: [CanonicalisationMethod] -> ShowS
show :: CanonicalisationMethod -> String
$cshow :: CanonicalisationMethod -> String
showsPrec :: Int -> CanonicalisationMethod -> ShowS
$cshowsPrec :: Int -> CanonicalisationMethod -> ShowS
Show)

instance FromXML CanonicalisationMethod where
    parseXML :: forall (m :: * -> *).
MonadFail m =>
Cursor -> m CanonicalisationMethod
parseXML Cursor
cursor =
        case [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor of
            Text
"http://www.w3.org/2001/10/xml-exc-c14n#" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure CanonicalisationMethod
C14N_EXC_1_0
            Text
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a valid CanonicalisationMethod"

-- | Enumerates signature methods.
data SignatureMethod
    -- | RSA with SHA256 digest
    = RSA_SHA256
    deriving (SignatureMethod -> SignatureMethod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SignatureMethod -> SignatureMethod -> Bool
$c/= :: SignatureMethod -> SignatureMethod -> Bool
== :: SignatureMethod -> SignatureMethod -> Bool
$c== :: SignatureMethod -> SignatureMethod -> Bool
Eq, Int -> SignatureMethod -> ShowS
[SignatureMethod] -> ShowS
SignatureMethod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SignatureMethod] -> ShowS
$cshowList :: [SignatureMethod] -> ShowS
show :: SignatureMethod -> String
$cshow :: SignatureMethod -> String
showsPrec :: Int -> SignatureMethod -> ShowS
$cshowsPrec :: Int -> SignatureMethod -> ShowS
Show)

instance FromXML SignatureMethod where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m SignatureMethod
parseXML Cursor
cursor = case [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor of
        Text
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SignatureMethod
RSA_SHA256
        Text
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a valid SignatureMethod"

--------------------------------------------------------------------------------

-- | Enumerates digest methods.
data DigestMethod
    -- | SHA256
    = DigestSHA256
    deriving (DigestMethod -> DigestMethod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DigestMethod -> DigestMethod -> Bool
$c/= :: DigestMethod -> DigestMethod -> Bool
== :: DigestMethod -> DigestMethod -> Bool
$c== :: DigestMethod -> DigestMethod -> Bool
Eq, Int -> DigestMethod -> ShowS
[DigestMethod] -> ShowS
DigestMethod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DigestMethod] -> ShowS
$cshowList :: [DigestMethod] -> ShowS
show :: DigestMethod -> String
$cshow :: DigestMethod -> String
showsPrec :: Int -> DigestMethod -> ShowS
$cshowsPrec :: Int -> DigestMethod -> ShowS
Show)

instance FromXML DigestMethod where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m DigestMethod
parseXML Cursor
cursor =  case [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"Algorithm" Cursor
cursor of
        Text
"http://www.w3.org/2001/04/xmlenc#sha256" -> forall (f :: * -> *) a. Applicative f => a -> f a
pure DigestMethod
DigestSHA256
        Text
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a valid DigestMethod"

-- | Represents a reference to some entity along with a digest of it.
data Reference = Reference {
    -- | The URI of the entity that is referenced.
    Reference -> Text
referenceURI :: !T.Text,
    -- | The method that was used to calculate the digest for the
    -- entity that is referenced.
    Reference -> DigestMethod
referenceDigestMethod :: !DigestMethod,
    -- | The digest of the entity that was calculated by the IdP.
    Reference -> ByteString
referenceDigestValue :: !BS.ByteString
} deriving (Reference -> Reference -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reference -> Reference -> Bool
$c/= :: Reference -> Reference -> Bool
== :: Reference -> Reference -> Bool
$c== :: Reference -> Reference -> Bool
Eq, Int -> Reference -> ShowS
[Reference] -> ShowS
Reference -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reference] -> ShowS
$cshowList :: [Reference] -> ShowS
show :: Reference -> String
$cshow :: Reference -> String
showsPrec :: Int -> Reference -> ShowS
$cshowsPrec :: Int -> Reference -> ShowS
Show)

instance FromXML Reference where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m Reference
parseXML Cursor
cursor = do
        -- the reference starts with a #, drop it
        let uri :: Text
uri = Int -> Text -> Text
T.drop Int
1 forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$ Name -> Cursor -> [Text]
attribute Name
"URI" Cursor
cursor

        DigestMethod
digestMethod <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"DigestMethod is required" (
            Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"DigestMethod")
            ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        let digestValue :: ByteString
digestValue = Text -> ByteString
encodeUtf8 forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$
                Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"DigestValue") forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content

        forall (f :: * -> *) a. Applicative f => a -> f a
pure Reference{
            referenceURI :: Text
referenceURI = Text
uri,
            referenceDigestMethod :: DigestMethod
referenceDigestMethod = DigestMethod
digestMethod,
            referenceDigestValue :: ByteString
referenceDigestValue = ByteString
digestValue
        }

--------------------------------------------------------------------------------

-- | Represents references to some entities for which the IdP has calculated
-- digests. The 'SignedInfo' component is then signed by the IdP.
data SignedInfo = SignedInfo {
    -- | The XML canonicalisation method used.
    SignedInfo -> CanonicalisationMethod
signedInfoCanonicalisationMethod :: !CanonicalisationMethod,
    -- | The method used to compute the signature for the referenced entity.
    SignedInfo -> SignatureMethod
signedInfoSignatureMethod :: !SignatureMethod,
    -- | The reference to some entity, along with a digest.
    SignedInfo -> Reference
signedInfoReference :: !Reference
} deriving (SignedInfo -> SignedInfo -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SignedInfo -> SignedInfo -> Bool
$c/= :: SignedInfo -> SignedInfo -> Bool
== :: SignedInfo -> SignedInfo -> Bool
$c== :: SignedInfo -> SignedInfo -> Bool
Eq, Int -> SignedInfo -> ShowS
[SignedInfo] -> ShowS
SignedInfo -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SignedInfo] -> ShowS
$cshowList :: [SignedInfo] -> ShowS
show :: SignedInfo -> String
$cshow :: SignedInfo -> String
showsPrec :: Int -> SignedInfo -> ShowS
$cshowsPrec :: Int -> SignedInfo -> ShowS
Show)

instance FromXML SignedInfo where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m SignedInfo
parseXML Cursor
cursor = do
        CanonicalisationMethod
canonicalisationMethod <-
                forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"CanonicalizationMethod is required"
              ( Cursor
cursor
             forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"CanonicalizationMethod")
              ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        SignatureMethod
signatureMethod <-
                forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignatureMethod is required"
              ( Cursor
cursor
             forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"SignatureMethod")
            ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        Reference
reference <-
                forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"Reference is required"
              ( Cursor
cursor
             forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"Reference")
            ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        forall (f :: * -> *) a. Applicative f => a -> f a
pure SignedInfo{
            signedInfoCanonicalisationMethod :: CanonicalisationMethod
signedInfoCanonicalisationMethod = CanonicalisationMethod
canonicalisationMethod,
            signedInfoSignatureMethod :: SignatureMethod
signedInfoSignatureMethod = SignatureMethod
signatureMethod,
            signedInfoReference :: Reference
signedInfoReference = Reference
reference
        }

-- | Represents response signatures.
data Signature = Signature {
    -- | Information about the data for which the IdP has computed digests.
    Signature -> SignedInfo
signatureInfo :: !SignedInfo,
    -- | The signature of the 'SignedInfo' value.
    Signature -> ByteString
signatureValue :: !BS.ByteString
} deriving (Signature -> Signature -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Signature -> Signature -> Bool
$c/= :: Signature -> Signature -> Bool
== :: Signature -> Signature -> Bool
$c== :: Signature -> Signature -> Bool
Eq, Int -> Signature -> ShowS
[Signature] -> ShowS
Signature -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Signature] -> ShowS
$cshowList :: [Signature] -> ShowS
show :: Signature -> String
$cshow :: Signature -> String
showsPrec :: Int -> Signature -> ShowS
$cshowsPrec :: Int -> Signature -> ShowS
Show)

instance FromXML Signature where
    parseXML :: forall (m :: * -> *). MonadFail m => Cursor -> m Signature
parseXML Cursor
cursor = do
        SignedInfo
info <- forall (m :: * -> *) a. MonadFail m => String -> [a] -> m a
oneOrFail String
"SignedInfo is required" (
            Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"SignedInfo") ) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a (m :: * -> *). (FromXML a, MonadFail m) => Cursor -> m a
parseXML

        let value :: ByteString
value = Text -> ByteString
encodeUtf8 forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat forall a b. (a -> b) -> a -> b
$
                Cursor
cursor forall node a. Cursor node -> (Cursor node -> [a]) -> [a]
$/ Name -> Axis
element (Text -> Name
dsName Text
"SignatureValue") forall node a.
Axis node -> (Cursor node -> [a]) -> Cursor node -> [a]
&/ Cursor -> [Text]
content

        forall (f :: * -> *) a. Applicative f => a -> f a
pure Signature{
            signatureInfo :: SignedInfo
signatureInfo = SignedInfo
info,
            signatureValue :: ByteString
signatureValue = ByteString
value
        }

--------------------------------------------------------------------------------