module Amazon.SNS.Verify.Validate
  ( validateSnsMessage
  , handleSubscription
  , SNSNotificationValidationError(..)
  , ValidSNSMessage(..)
  ) where

import Amazon.SNS.Verify.Prelude

import Amazon.SNS.Verify.Payload
import Control.Error (ExceptT, catMaybes, headMay, runExceptT, throwE)
import Control.Monad (when)
import Data.ByteArray.Encoding (Base(Base64), convertFromBase)
import Data.PEM (pemContent, pemParseLBS)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Data.X509
  ( HashALG(..)
  , PubKeyALG(..)
  , SignatureALG(..)
  , SignedCertificate
  , certPubKey
  , decodeSignedCertificate
  , getCertificate
  )
import Data.X509.Validation
  (SignatureFailure, SignatureVerification(..), verifySignature)
import Network.HTTP.Simple
  (getResponseBody, getResponseStatusCode, httpLbs, parseRequest_)

data ValidSNSMessage
  = SNSMessage Text
  | SNSSubscribe SNSSubscription
  | SNSUnsubscribe SNSSubscription
  deriving stock (Int -> ValidSNSMessage -> ShowS
[ValidSNSMessage] -> ShowS
ValidSNSMessage -> String
(Int -> ValidSNSMessage -> ShowS)
-> (ValidSNSMessage -> String)
-> ([ValidSNSMessage] -> ShowS)
-> Show ValidSNSMessage
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValidSNSMessage] -> ShowS
$cshowList :: [ValidSNSMessage] -> ShowS
show :: ValidSNSMessage -> String
$cshow :: ValidSNSMessage -> String
showsPrec :: Int -> ValidSNSMessage -> ShowS
$cshowsPrec :: Int -> ValidSNSMessage -> ShowS
Show, ValidSNSMessage -> ValidSNSMessage -> Bool
(ValidSNSMessage -> ValidSNSMessage -> Bool)
-> (ValidSNSMessage -> ValidSNSMessage -> Bool)
-> Eq ValidSNSMessage
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ValidSNSMessage -> ValidSNSMessage -> Bool
$c/= :: ValidSNSMessage -> ValidSNSMessage -> Bool
== :: ValidSNSMessage -> ValidSNSMessage -> Bool
$c== :: ValidSNSMessage -> ValidSNSMessage -> Bool
Eq)

-- | Validate SNS notification
--
-- SNS messages are validated through their signature. The algorithm is detailed
-- in the documentation below.
--
-- <https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html>
--
validateSnsMessage
  :: MonadIO m
  => SNSPayload
  -> m (Either SNSNotificationValidationError ValidSNSMessage)
validateSnsMessage :: SNSPayload
-> m (Either SNSNotificationValidationError ValidSNSMessage)
validateSnsMessage payload :: SNSPayload
payload@SNSPayload {Text
SNSType
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
..} = ExceptT SNSNotificationValidationError m ValidSNSMessage
-> m (Either SNSNotificationValidationError ValidSNSMessage)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT SNSNotificationValidationError m ValidSNSMessage
 -> m (Either SNSNotificationValidationError ValidSNSMessage))
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
-> m (Either SNSNotificationValidationError ValidSNSMessage)
forall a b. (a -> b) -> a -> b
$ do
  ByteString
signature <- (String -> SNSNotificationValidationError)
-> Either String ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall (m :: * -> *) e a b.
(MonadIO m, Exception e) =>
(a -> e) -> Either a b -> m b
unTry String -> SNSNotificationValidationError
BadSignature (Either String ByteString
 -> ExceptT SNSNotificationValidationError m ByteString)
-> Either String ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64 (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
encodeUtf8
    Text
snsSignature
  SignedCertificate
signedCert <- SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
forall (m :: * -> *).
MonadIO m =>
SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate SNSPayload
payload
  let
    valid :: SignatureVerification
valid = SignatureALG
-> PubKey -> ByteString -> ByteString -> SignatureVerification
verifySignature
      (HashALG -> PubKeyALG -> SignatureALG
SignatureALG HashALG
HashSHA1 PubKeyALG
PubKeyALG_RSA)
      (Certificate -> PubKey
certPubKey (Certificate -> PubKey) -> Certificate -> PubKey
forall a b. (a -> b) -> a -> b
$ SignedCertificate -> Certificate
getCertificate SignedCertificate
signedCert)
      (SNSPayload -> ByteString
unsignedSignature SNSPayload
payload)
      ByteString
signature
  case SignatureVerification
valid of
    SignatureVerification
SignaturePass -> ValidSNSMessage
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValidSNSMessage
 -> ExceptT SNSNotificationValidationError m ValidSNSMessage)
-> ValidSNSMessage
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall a b. (a -> b) -> a -> b
$ case SNSType
snsTypePayload of
      Notification{} -> Text -> ValidSNSMessage
SNSMessage Text
snsMessage
      SubscriptionConfirmation SNSSubscription
x -> SNSSubscription -> ValidSNSMessage
SNSSubscribe SNSSubscription
x
      UnsubscribeConfirmation SNSSubscription
x -> SNSSubscription -> ValidSNSMessage
SNSUnsubscribe SNSSubscription
x
    SignatureFailed SignatureFailure
e -> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (SNSNotificationValidationError
 -> ExceptT SNSNotificationValidationError m ValidSNSMessage)
-> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ValidSNSMessage
forall a b. (a -> b) -> a -> b
$ SignatureFailure -> SNSNotificationValidationError
InvalidPayload SignatureFailure
e

retrieveCertificate
  :: MonadIO m
  => SNSPayload
  -> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate :: SNSPayload
-> ExceptT SNSNotificationValidationError m SignedCertificate
retrieveCertificate SNSPayload {Text
SNSType
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
..} = do
  Response ByteString
response <- Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs (Request
 -> ExceptT SNSNotificationValidationError m (Response ByteString))
-> Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ String -> Request
parseRequest_ (String -> Request) -> String -> Request
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
snsSigningCertURL
  [PEM]
pems <- (String -> SNSNotificationValidationError)
-> Either String [PEM]
-> ExceptT SNSNotificationValidationError m [PEM]
forall (m :: * -> *) e a b.
(MonadIO m, Exception e) =>
(a -> e) -> Either a b -> m b
unTry String -> SNSNotificationValidationError
BadPem (Either String [PEM]
 -> ExceptT SNSNotificationValidationError m [PEM])
-> Either String [PEM]
-> ExceptT SNSNotificationValidationError m [PEM]
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String [PEM]
pemParseLBS (ByteString -> Either String [PEM])
-> ByteString -> Either String [PEM]
forall a b. (a -> b) -> a -> b
$ Response ByteString -> ByteString
forall a. Response a -> a
getResponseBody Response ByteString
response
  ByteString
cert <-
    ExceptT SNSNotificationValidationError m ByteString
-> Maybe ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall (m :: * -> *) a. Monad m => m a -> Maybe a -> m a
fromMaybeM (SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ByteString
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (SNSNotificationValidationError
 -> ExceptT SNSNotificationValidationError m ByteString)
-> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ByteString
forall a b. (a -> b) -> a -> b
$ String -> SNSNotificationValidationError
BadPem String
"Empty List") (Maybe ByteString
 -> ExceptT SNSNotificationValidationError m ByteString)
-> Maybe ByteString
-> ExceptT SNSNotificationValidationError m ByteString
forall a b. (a -> b) -> a -> b
$ PEM -> ByteString
pemContent (PEM -> ByteString) -> Maybe PEM -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PEM] -> Maybe PEM
forall a. [a] -> Maybe a
headMay [PEM]
pems
  (String -> SNSNotificationValidationError)
-> Either String SignedCertificate
-> ExceptT SNSNotificationValidationError m SignedCertificate
forall (m :: * -> *) e a b.
(MonadIO m, Exception e) =>
(a -> e) -> Either a b -> m b
unTry String -> SNSNotificationValidationError
BadCert (Either String SignedCertificate
 -> ExceptT SNSNotificationValidationError m SignedCertificate)
-> Either String SignedCertificate
-> ExceptT SNSNotificationValidationError m SignedCertificate
forall a b. (a -> b) -> a -> b
$ ByteString -> Either String SignedCertificate
decodeSignedCertificate ByteString
cert

unsignedSignature :: SNSPayload -> ByteString
unsignedSignature :: SNSPayload -> ByteString
unsignedSignature SNSPayload {Text
SNSType
snsTypePayload :: SNSType
snsSigningCertURL :: Text
snsSignature :: Text
snsSignatureVersion :: Text
snsType :: Text
snsTopicArn :: Text
snsTimestamp :: Text
snsMessageId :: Text
snsMessage :: Text
snsTypePayload :: SNSPayload -> SNSType
snsSigningCertURL :: SNSPayload -> Text
snsSignature :: SNSPayload -> Text
snsSignatureVersion :: SNSPayload -> Text
snsType :: SNSPayload -> Text
snsTopicArn :: SNSPayload -> Text
snsTimestamp :: SNSPayload -> Text
snsMessageId :: SNSPayload -> Text
snsMessage :: SNSPayload -> Text
..} =
  Text -> ByteString
encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n") (Text -> Text) -> [Text] -> [Text]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Maybe Text] -> [Text]
forall a. [Maybe a] -> [a]
catMaybes
    [ Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Message"
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsMessage
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"MessageId"
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsMessageId
    , Text
"SubscrieURL" Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mSubscribeUrl
    , Maybe Text
mSubscribeUrl
    , Text
"Subject" Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mSubject
    , Maybe Text
mSubject
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Timestamp"
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsTimestamp
    , Text
"Token" Text -> Maybe Text -> Maybe Text
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Maybe Text
mToken
    , Maybe Text
mToken
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"TopicArn"
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsTopicArn
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Type"
    , Text -> Maybe Text
forall a. a -> Maybe a
Just Text
snsType
    ]
 where
  (Maybe Text
mSubject, Maybe Text
mToken, Maybe Text
mSubscribeUrl) = case SNSType
snsTypePayload of
    Notification SNSNotification
x -> (SNSNotification -> Maybe Text
snsSubject SNSNotification
x, Maybe Text
forall a. Maybe a
Nothing, Maybe Text
forall a. Maybe a
Nothing)
    SubscriptionConfirmation SNSSubscription
x ->
      (Maybe Text
forall a. Maybe a
Nothing, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsToken SNSSubscription
x, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsSubscribeURL SNSSubscription
x)
    UnsubscribeConfirmation SNSSubscription
x ->
      (Maybe Text
forall a. Maybe a
Nothing, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsToken SNSSubscription
x, Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> Text -> Maybe Text
forall a b. (a -> b) -> a -> b
$ SNSSubscription -> Text
snsSubscribeURL SNSSubscription
x)

handleSubscription
  :: MonadIO m
  => ValidSNSMessage
  -> m (Either SNSNotificationValidationError Text)
handleSubscription :: ValidSNSMessage -> m (Either SNSNotificationValidationError Text)
handleSubscription = ExceptT SNSNotificationValidationError m Text
-> m (Either SNSNotificationValidationError Text)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT SNSNotificationValidationError m Text
 -> m (Either SNSNotificationValidationError Text))
-> (ValidSNSMessage
    -> ExceptT SNSNotificationValidationError m Text)
-> ValidSNSMessage
-> m (Either SNSNotificationValidationError Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
  SNSMessage Text
t -> Text -> ExceptT SNSNotificationValidationError m Text
forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
t
  SNSSubscribe SNSSubscription {Text
snsSubscribeURL :: Text
snsToken :: Text
snsSubscribeURL :: SNSSubscription -> Text
snsToken :: SNSSubscription -> Text
..} -> do
    Response ByteString
response <- Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall (m :: * -> *).
MonadIO m =>
Request -> m (Response ByteString)
httpLbs (Request
 -> ExceptT SNSNotificationValidationError m (Response ByteString))
-> Request
-> ExceptT SNSNotificationValidationError m (Response ByteString)
forall a b. (a -> b) -> a -> b
$ String -> Request
parseRequest_ (String -> Request) -> String -> Request
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
snsSubscribeURL
    Bool
-> ExceptT SNSNotificationValidationError m ()
-> ExceptT SNSNotificationValidationError m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Response ByteString -> Int
forall a. Response a -> Int
getResponseStatusCode Response ByteString
response Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
300) (ExceptT SNSNotificationValidationError m ()
 -> ExceptT SNSNotificationValidationError m ())
-> ExceptT SNSNotificationValidationError m ()
-> ExceptT SNSNotificationValidationError m ()
forall a b. (a -> b) -> a -> b
$ do
      SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (SNSNotificationValidationError
 -> ExceptT SNSNotificationValidationError m ())
-> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m ()
forall a b. (a -> b) -> a -> b
$ () -> SNSNotificationValidationError
BadSubscription ()
    SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m Text
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
SubscribeMessageResponded
  SNSUnsubscribe{} -> SNSNotificationValidationError
-> ExceptT SNSNotificationValidationError m Text
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE SNSNotificationValidationError
UnsubscribeMessage

data SNSNotificationValidationError
  = BadPem String
  | BadSignature String
  | BadCert String
  | BadJSONParse String
  | BadSubscription ()
  | InvalidPayload SignatureFailure
  | MissingMessageTypeHeader
  | UnsubscribeMessage
  | SubscribeMessageResponded
  deriving stock (Int -> SNSNotificationValidationError -> ShowS
[SNSNotificationValidationError] -> ShowS
SNSNotificationValidationError -> String
(Int -> SNSNotificationValidationError -> ShowS)
-> (SNSNotificationValidationError -> String)
-> ([SNSNotificationValidationError] -> ShowS)
-> Show SNSNotificationValidationError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SNSNotificationValidationError] -> ShowS
$cshowList :: [SNSNotificationValidationError] -> ShowS
show :: SNSNotificationValidationError -> String
$cshow :: SNSNotificationValidationError -> String
showsPrec :: Int -> SNSNotificationValidationError -> ShowS
$cshowsPrec :: Int -> SNSNotificationValidationError -> ShowS
Show, SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
(SNSNotificationValidationError
 -> SNSNotificationValidationError -> Bool)
-> (SNSNotificationValidationError
    -> SNSNotificationValidationError -> Bool)
-> Eq SNSNotificationValidationError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
$c/= :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
== :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
$c== :: SNSNotificationValidationError
-> SNSNotificationValidationError -> Bool
Eq)
  deriving anyclass Show SNSNotificationValidationError
Typeable SNSNotificationValidationError
Typeable SNSNotificationValidationError
-> Show SNSNotificationValidationError
-> (SNSNotificationValidationError -> SomeException)
-> (SomeException -> Maybe SNSNotificationValidationError)
-> (SNSNotificationValidationError -> String)
-> Exception SNSNotificationValidationError
SomeException -> Maybe SNSNotificationValidationError
SNSNotificationValidationError -> String
SNSNotificationValidationError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: SNSNotificationValidationError -> String
$cdisplayException :: SNSNotificationValidationError -> String
fromException :: SomeException -> Maybe SNSNotificationValidationError
$cfromException :: SomeException -> Maybe SNSNotificationValidationError
toException :: SNSNotificationValidationError -> SomeException
$ctoException :: SNSNotificationValidationError -> SomeException
$cp2Exception :: Show SNSNotificationValidationError
$cp1Exception :: Typeable SNSNotificationValidationError
Exception