module Amazon.SNS.Verify.Validate
( validateSnsMessage
, handleSubscription
, SNSNotificationValidationError(..)
, ValidSNSMessage(..)
) where
import Amazon.SNS.Verify.Prelude
import Amazon.SNS.Verify.Payload
import Control.Error (ExceptT, catMaybes, headMay, runExceptT, throwE)
import Control.Monad (when)
import Data.ByteArray.Encoding (Base(Base64), convertFromBase)
import Data.PEM (pemContent, pemParseLBS)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.X509
( HashALG(..)
, PubKeyALG(..)
, SignatureALG(..)
, SignedCertificate
, certPubKey
, decodeSignedCertificate
, getCertificate
)
import Data.X509.Validation
(SignatureFailure, SignatureVerification(..), verifySignature)
import Network.HTTP.Simple
(getResponseBody, getResponseStatusCode, httpLbs, parseRequest_)
data ValidSNSMessage
= SNSMessage Text
| SNSSubscribe SNSSubscription
| SNSUnsubscribe SNSSubscription
deriving stock (Int -> ValidSNSMessage -> ShowS
[ValidSNSMessage] -> ShowS
ValidSNSMessage -> String
(Int -> ValidSNSMessage -> ShowS)
-> (ValidSNSMessage -> String)
-> ([ValidSNSMessage] -> ShowS)
-> Show ValidSNSMessage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValidSNSMessage] -> ShowS
$cshowList :: [ValidSNSMessage] -> ShowS
show :: ValidSNSMessage -> String
$cshow :: ValidSNSMessage -> String
showsPrec :: Int -> ValidSNSMessage -> ShowS
$cshowsPrec :: Int -> ValidSNSMessage -> ShowS
Show, ValidSNSMessage -> ValidSNSMessage -> Bool
(ValidSNSMessage -> ValidSNSMessage -> Bool)
-> (ValidSNSMessage -> ValidSNSMessage -> Bool)
-> Eq ValidSNSMessage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ValidSNSMessage -> ValidSNSMessage -> Bool
$c/= :: ValidSNSMessage -> ValidSNSMessage -> Bool
== :: ValidSNSMessage -> ValidSNSMessage -> Bool
$c== :: ValidSNSMessage -> ValidSNSMessage -> Bool
Eq)
validateSnsMessage
:: MonadIO m
=> SNSPayload
-> m (Either SNSNotificationValidationError ValidSNSMessage)
validateSnsMessage :: SNSPayload
-> m (Either SNSNotificationValidationError ValidSNSMessage)
validateSnsMessage payload :: SNSPayload
payload@SNSPayload {Text
SNSType
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
..} = ExceptT SNSNotificationValidationError m ValidSNSMessage
-> m (Either SNSNotificationValidationError ValidSNSMessage)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT SNSNotificationValidationError m ValidSNSMessage
-> m (Either SNSNotificationValidationError ValidSNSMessage))
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
-> m (Either SNSNotificationValidationError ValidSNSMessage)
forall a b. (a -> b) -> a -> b
$ do
ByteString
signature <- (String -> SNSNotificationValidationError)
-> Either String ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall (m :: * -> *) e a b.
(MonadIO m, Exception e) =>
(a -> e) -> Either a b -> m b
unTry String -> SNSNotificationValidationError
BadSignature (Either String ByteString
-> ExceptT SNSNotificationValidationError m ByteString)
-> Either String ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64 (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
encodeUtf8
Text
snsSignature
SignedCertificate
signedCert <- SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
forall (m :: * -> *).
MonadIO m =>
SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate SNSPayload
payload
let
valid :: SignatureVerification
valid = SignatureALG
-> PubKey -> ByteString -> ByteString -> SignatureVerification
verifySignature
(HashALG -> PubKeyALG -> SignatureALG
SignatureALG HashALG
HashSHA1 PubKeyALG
PubKeyALG_RSA)
(Certificate -> PubKey
certPubKey (Certificate -> PubKey) -> Certificate -> PubKey
forall a b. (a -> b) -> a -> b
$ SignedCertificate -> Certificate
getCertificate SignedCertificate
signedCert)
(SNSPayload -> ByteString
unsignedSignature SNSPayload
payload)
ByteString
signature
case SignatureVerification
valid of
SignatureVerification
SignaturePass -> ValidSNSMessage
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValidSNSMessage
-> ExceptT SNSNotificationValidationError m ValidSNSMessage)
-> ValidSNSMessage
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall a b. (a -> b) -> a -> b
$ case SNSType
snsTypePayload of
Notification{} -> Text -> ValidSNSMessage
SNSMessage Text
snsMessage
SubscriptionConfirmation SNSSubscription
x -> SNSSubscription -> ValidSNSMessage
SNSSubscribe SNSSubscription
x
UnsubscribeConfirmation SNSSubscription
x -> SNSSubscription -> ValidSNSMessage
SNSUnsubscribe SNSSubscription
x
SignatureFailed SignatureFailure
e -> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ValidSNSMessage)
-> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall a b. (a -> b) -> a -> b
$ SignatureFailure -> SNSNotificationValidationError
InvalidPayload SignatureFailure
e
retrieveCertificate
:: MonadIO m
=> SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate :: SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate SNSPayload {Text
SNSType
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
..} = do
Response ByteString
response <- Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs (Request
-> ExceptT SNSNotificationValidationError m (Response ByteString))
-> Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ String -> Request
parseRequest_ (String -> Request) -> String -> Request
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
snsSigningCertURL
[PEM]
pems <- (String -> SNSNotificationValidationError)
-> Either String [PEM]
-> ExceptT SNSNotificationValidationError m [PEM]
forall (m :: * -> *) e a b.
(MonadIO m, Exception e) =>
(a -> e) -> Either a b -> m b
unTry String -> SNSNotificationValidationError
BadPem (Either String [PEM]
-> ExceptT SNSNotificationValidationError m [PEM])
-> Either String [PEM]
-> ExceptT SNSNotificationValidationError m [PEM]
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String [PEM]
pemParseLBS (ByteString -> Either String [PEM])
-> ByteString -> Either String [PEM]
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall a. Response a -> a
getResponseBody Response ByteString
response
ByteString
cert <-
ExceptT SNSNotificationValidationError m ByteString
-> Maybe ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall (m :: * -> *) a. Monad m => m a -> Maybe a -> m a
fromMaybeM (SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ByteString
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ByteString)
-> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ByteString
forall a b. (a -> b) -> a -> b
$ String -> SNSNotificationValidationError
BadPem String
"Empty List") (Maybe ByteString
-> ExceptT SNSNotificationValidationError m ByteString)
-> Maybe ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall a b. (a -> b) -> a -> b
$ PEM -> ByteString
pemContent (PEM -> ByteString) -> Maybe PEM -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PEM] -> Maybe PEM
forall a. [a] -> Maybe a
headMay [PEM]
pems
(String -> SNSNotificationValidationError)
-> Either String SignedCertificate
-> ExceptT SNSNotificationValidationError m SignedCertificate
forall (m :: * -> *) e a b.
(MonadIO m, Exception e) =>
(a -> e) -> Either a b -> m b
unTry String -> SNSNotificationValidationError
BadCert (Either String SignedCertificate
-> ExceptT SNSNotificationValidationError m SignedCertificate)
-> Either String SignedCertificate
-> ExceptT SNSNotificationValidationError m SignedCertificate
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String SignedCertificate
decodeSignedCertificate ByteString
cert
unsignedSignature :: SNSPayload -> ByteString
unsignedSignature :: SNSPayload -> ByteString
unsignedSignature SNSPayload {Text
SNSType
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
..} =
Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n") (Text -> Text) -> [Text] -> [Text]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Maybe Text] -> [Text]
forall a. [Maybe a] -> [a]
catMaybes
[ Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Message"
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsMessage
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"MessageId"
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsMessageId
, Text
"SubscrieURL" Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mSubscribeUrl
, Maybe Text
mSubscribeUrl
, Text
"Subject" Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mSubject
, Maybe Text
mSubject
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Timestamp"
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsTimestamp
, Text
"Token" Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mToken
, Maybe Text
mToken
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"TopicArn"
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsTopicArn
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Type"
, Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsType
]
where
(Maybe Text
mSubject, Maybe Text
mToken, Maybe Text
mSubscribeUrl) = case SNSType
snsTypePayload of
Notification SNSNotification
x -> (SNSNotification -> Maybe Text
snsSubject SNSNotification
x, Maybe Text
forall a. Maybe a
Nothing, Maybe Text
forall a. Maybe a
Nothing)
SubscriptionConfirmation SNSSubscription
x ->
(Maybe Text
forall a. Maybe a
Nothing, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsToken SNSSubscription
x, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsSubscribeURL SNSSubscription
x)
UnsubscribeConfirmation SNSSubscription
x ->
(Maybe Text
forall a. Maybe a
Nothing, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsToken SNSSubscription
x, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsSubscribeURL SNSSubscription
x)
handleSubscription
:: MonadIO m
=> ValidSNSMessage
-> m (Either SNSNotificationValidationError Text)
handleSubscription :: ValidSNSMessage -> m (Either SNSNotificationValidationError Text)
handleSubscription = ExceptT SNSNotificationValidationError m Text
-> m (Either SNSNotificationValidationError Text)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT SNSNotificationValidationError m Text
-> m (Either SNSNotificationValidationError Text))
-> (ValidSNSMessage
-> ExceptT SNSNotificationValidationError m Text)
-> ValidSNSMessage
-> m (Either SNSNotificationValidationError Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
SNSMessage Text
t -> Text -> ExceptT SNSNotificationValidationError m Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
t
SNSSubscribe SNSSubscription {Text
snsSubscribeURL :: Text
snsToken :: Text
snsSubscribeURL :: SNSSubscription -> Text
snsToken :: SNSSubscription -> Text
..} -> do
Response ByteString
response <- Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs (Request
-> ExceptT SNSNotificationValidationError m (Response ByteString))
-> Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ String -> Request
parseRequest_ (String -> Request) -> String -> Request
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
snsSubscribeURL
Bool
-> ExceptT SNSNotificationValidationError m ()
-> ExceptT SNSNotificationValidationError m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Response ByteString -> Int
forall a. Response a -> Int
getResponseStatusCode Response ByteString
response Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
300) (ExceptT SNSNotificationValidationError m ()
-> ExceptT SNSNotificationValidationError m ())
-> ExceptT SNSNotificationValidationError m ()
-> ExceptT SNSNotificationValidationError m ()
forall a b. (a -> b) -> a -> b
$ do
SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ())
-> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ()
forall a b. (a -> b) -> a -> b
$ () -> SNSNotificationValidationError
BadSubscription ()
SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m Text
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
SubscribeMessageResponded
SNSUnsubscribe{} -> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m Text
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
UnsubscribeMessage
data SNSNotificationValidationError
= BadPem String
| BadSignature String
| BadCert String
| BadJSONParse String
| BadSubscription ()
| InvalidPayload SignatureFailure
|
| UnsubscribeMessage
| SubscribeMessageResponded
deriving stock (Int -> SNSNotificationValidationError -> ShowS
[SNSNotificationValidationError] -> ShowS
SNSNotificationValidationError -> String
(Int -> SNSNotificationValidationError -> ShowS)
-> (SNSNotificationValidationError -> String)
-> ([SNSNotificationValidationError] -> ShowS)
-> Show SNSNotificationValidationError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SNSNotificationValidationError] -> ShowS
$cshowList :: [SNSNotificationValidationError] -> ShowS
show :: SNSNotificationValidationError -> String
$cshow :: SNSNotificationValidationError -> String
showsPrec :: Int -> SNSNotificationValidationError -> ShowS
$cshowsPrec :: Int -> SNSNotificationValidationError -> ShowS
Show, SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
(SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool)
-> (SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool)
-> Eq SNSNotificationValidationError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
$c/= :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
== :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
$c== :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
Eq)
deriving anyclass Show SNSNotificationValidationError
Typeable SNSNotificationValidationError
Typeable SNSNotificationValidationError
-> Show SNSNotificationValidationError
-> (SNSNotificationValidationError -> SomeException)
-> (SomeException -> Maybe SNSNotificationValidationError)
-> (SNSNotificationValidationError -> String)
-> Exception SNSNotificationValidationError
SomeException -> Maybe SNSNotificationValidationError
SNSNotificationValidationError -> String
SNSNotificationValidationError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: SNSNotificationValidationError -> String
$cdisplayException :: SNSNotificationValidationError -> String
fromException :: SomeException -> Maybe SNSNotificationValidationError
$cfromException :: SomeException -> Maybe SNSNotificationValidationError
toException :: SNSNotificationValidationError -> SomeException
$ctoException :: SNSNotificationValidationError -> SomeException
$cp2Exception :: Show SNSNotificationValidationError
$cp1Exception :: Typeable SNSNotificationValidationError
Exception