{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE OverloadedStrings  #-}
{-# LANGUAGE CPP #-}

module Network.Mail.Mime.SES.Internal where

import           Crypto.Hash                 (SHA256, hmac, hmacGetDigest, hash)
import           Data.Bifunctor              (bimap)
import           Data.Byteable               (toBytes)
import           Data.ByteString             (ByteString)
import qualified Data.ByteString             as B
import           Data.ByteString.Base16      as Base16
import qualified Data.ByteString.Char8       as S8
import qualified Data.ByteString.Lazy        as L
import           Data.Char                   (toLower)
import           Data.CaseInsensitive        (CI)
import qualified Data.CaseInsensitive        as CI
import           Data.List                   (sort)
#if MIN_VERSION_base(4, 11, 0)
#else
import           Data.Monoid ((<>))
#endif

import           Data.Time                   (UTCTime)
import           Data.Time.Format            (formatTime)
import           Network.HTTP.Client         (Request, RequestBody(RequestBodyLBS, RequestBodyBS),
#if MIN_VERSION_http_client(0, 5, 0)
                                             parseRequest,
#else
                                             checkStatus,
                                             parseUrl,
#endif
                                             method, host, path, requestHeaders, queryString, requestBody
                                             )
#if MIN_VERSION_time(1,5,0)
import           Data.Time                   (defaultTimeLocale)
#else
import           System.Locale               (defaultTimeLocale)
#endif

-- | Create a canonical request according to <https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html>.
makeCanonicalRequest :: ByteString -> ByteString -> ByteString -> [(CI ByteString, ByteString)] -> ByteString -> ByteString
makeCanonicalRequest :: ByteString
-> ByteString
-> ByteString
-> [(CI ByteString, ByteString)]
-> ByteString
-> ByteString
makeCanonicalRequest ByteString
requesMethod ByteString
requestPath ByteString
requestQueryString [(CI ByteString, ByteString)]
headers ByteString
payload = ByteString -> [ByteString] -> ByteString
S8.intercalate ByteString
"\n"
  [ ByteString
requesMethod
  , ByteString
requestPath
  , ByteString
requestQueryString
  , [ByteString] -> ByteString
S8.concat ([ByteString] -> ByteString)
-> ([(CI ByteString, ByteString)] -> [ByteString])
-> [(CI ByteString, ByteString)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, ByteString) -> ByteString)
-> [(ByteString, ByteString)] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ (ByteString
name, ByteString
value) -> ByteString
name ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
":" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
value ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\n")
    ([(ByteString, ByteString)] -> [ByteString])
-> ([(CI ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(CI ByteString, ByteString)]
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. Ord a => [a] -> [a]
sort ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> ([(CI ByteString, ByteString)] -> [(ByteString, ByteString)])
-> [(CI ByteString, ByteString)]
-> [(ByteString, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((CI ByteString, ByteString) -> (ByteString, ByteString))
-> [(CI ByteString, ByteString)] -> [(ByteString, ByteString)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((CI ByteString -> ByteString)
-> (ByteString -> ByteString)
-> (CI ByteString, ByteString)
-> (ByteString, ByteString)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (ByteString -> ByteString
bytesToLowerCase (ByteString -> ByteString)
-> (CI ByteString -> ByteString) -> CI ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CI ByteString -> ByteString
forall s. CI s -> s
CI.original) ByteString -> ByteString
forall a. a -> a
id)
    ([(CI ByteString, ByteString)] -> ByteString)
-> [(CI ByteString, ByteString)] -> ByteString
forall a b. (a -> b) -> a -> b
$ [(CI ByteString, ByteString)]
headers
  , [(CI ByteString, ByteString)] -> ByteString
makeListOfHeaders ([(CI ByteString, ByteString)] -> ByteString)
-> [(CI ByteString, ByteString)] -> ByteString
forall a b. (a -> b) -> a -> b
$ [(CI ByteString, ByteString)]
headers
  , ByteString -> ByteString
unaryHashBase16 (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
payload
  ]

canonicalizeRequest :: Request -> ByteString
canonicalizeRequest :: Request -> ByteString
canonicalizeRequest Request
request
  = ByteString
-> ByteString
-> ByteString
-> [(CI ByteString, ByteString)]
-> ByteString
-> ByteString
makeCanonicalRequest
  (Request -> ByteString
method Request
request)
  (Request -> ByteString
path Request
request)
  (Request -> ByteString
queryString Request
request)
  (Request -> [(CI ByteString, ByteString)]
patchedRequestHeaders Request
request)
  (Request -> ByteString
requestBodyAsByteString Request
request)

-- | Create a string to sign according to <https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html>.
makeStringToSign :: ByteString -> UTCTime -> ByteString -> ByteString -> ByteString
makeStringToSign :: ByteString -> UTCTime -> ByteString -> ByteString -> ByteString
makeStringToSign ByteString
service UTCTime
time ByteString
region ByteString
canonicalRequest = ByteString -> [ByteString] -> ByteString
S8.intercalate ByteString
"\n"
  [ ByteString
"AWS4-HMAC-SHA256"
  , UTCTime -> ByteString
formatAmazonTime UTCTime
time
  , ByteString -> UTCTime -> ByteString -> ByteString
makeCredentialScope ByteString
service UTCTime
time ByteString
region
  , ByteString -> ByteString
unaryHashBase16 ByteString
canonicalRequest
  ]

-- | Create a signature according to <https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html>.
makeSig :: ByteString -> UTCTime -> ByteString -> ByteString -> ByteString -> ByteString
makeSig :: ByteString
-> UTCTime -> ByteString -> ByteString -> ByteString -> ByteString
makeSig ByteString
service UTCTime
time ByteString
region ByteString
secret ByteString
stringToSign =
  let f :: ByteString -> ByteString -> ByteString
f = (ByteString -> ByteString -> ByteString)
-> ByteString -> ByteString -> ByteString
forall a b c. (a -> b -> c) -> b -> a -> c
flip ByteString -> ByteString -> ByteString
keyedHash
  in ByteString -> ByteString
Base16.encode
     (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> ByteString
f ByteString
stringToSign
     (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> ByteString
f ByteString
"aws4_request"
     (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> ByteString
f ByteString
service
     (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> ByteString
f ByteString
region
     (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString -> ByteString
f (UTCTime -> ByteString
formatAmazonDate UTCTime
time)
     (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (ByteString
"AWS4" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
secret)

-- | Create an authorization string according to <https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html>.
makeAuthorizationString :: ByteString -> UTCTime -> ByteString -> [(CI ByteString, ByteString)] -> ByteString -> ByteString -> ByteString
makeAuthorizationString :: ByteString
-> UTCTime
-> ByteString
-> [(CI ByteString, ByteString)]
-> ByteString
-> ByteString
-> ByteString
makeAuthorizationString ByteString
service UTCTime
time ByteString
region [(CI ByteString, ByteString)]
headers ByteString
keyId ByteString
sig = [ByteString] -> ByteString
S8.concat
            [ ByteString
"AWS4-HMAC-SHA256 Credential="
                ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
keyId
                ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/"
                ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> UTCTime -> ByteString -> ByteString
makeCredentialScope ByteString
service UTCTime
time ByteString
region
            , ByteString
", SignedHeaders=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> [(CI ByteString, ByteString)] -> ByteString
makeListOfHeaders [(CI ByteString, ByteString)]
headers
            , ByteString
", Signature=" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
sig
            ]

formatAmazonTime :: UTCTime -> ByteString
formatAmazonTime :: UTCTime -> ByteString
formatAmazonTime = String -> ByteString
S8.pack (String -> ByteString)
-> (UTCTime -> String) -> UTCTime -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeLocale -> String -> UTCTime -> String
forall t. FormatTime t => TimeLocale -> String -> t -> String
formatTime TimeLocale
defaultTimeLocale String
"%Y%m%dT%H%M%SZ"

formatAmazonDate :: UTCTime -> ByteString
formatAmazonDate :: UTCTime -> ByteString
formatAmazonDate = String -> ByteString
S8.pack (String -> ByteString)
-> (UTCTime -> String) -> UTCTime -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeLocale -> String -> UTCTime -> String
forall t. FormatTime t => TimeLocale -> String -> t -> String
formatTime TimeLocale
defaultTimeLocale String
"%Y%m%d"

buildRequest :: String -> IO Request
buildRequest :: String -> IO Request
buildRequest String
url = do
#if MIN_VERSION_http_client(0, 5, 0)
  Request
requestBase <- (String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
url)
#else
  requestBase <- parseUrl url {checkStatus = \_ _ _ -> Nothing}
#endif
  Request -> IO Request
forall (m :: * -> *) a. Monad m => a -> m a
return Request
requestBase

requestBodyAsByteString :: Request -> ByteString
requestBodyAsByteString :: Request -> ByteString
requestBodyAsByteString Request
request = case Request -> RequestBody
requestBody Request
request of
                                    RequestBodyBS ByteString
x -> ByteString
x
                                    RequestBodyLBS ByteString
x -> ByteString -> ByteString
L.toStrict ByteString
x
                                    RequestBody
_ -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"Not implemented."

requestBodyLength :: Request -> Int
requestBodyLength :: Request -> Int
requestBodyLength = ByteString -> Int
B.length (ByteString -> Int) -> (Request -> ByteString) -> Request -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ByteString
requestBodyAsByteString

makeListOfHeaders :: [(CI ByteString, ByteString)] -> ByteString
makeListOfHeaders :: [(CI ByteString, ByteString)] -> ByteString
makeListOfHeaders = ByteString -> [ByteString] -> ByteString
S8.intercalate ByteString
";" ([ByteString] -> ByteString)
-> ([(CI ByteString, ByteString)] -> [ByteString])
-> [(CI ByteString, ByteString)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. Ord a => [a] -> [a]
sort ([ByteString] -> [ByteString])
-> ([(CI ByteString, ByteString)] -> [ByteString])
-> [(CI ByteString, ByteString)]
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((CI ByteString, ByteString) -> ByteString)
-> [(CI ByteString, ByteString)] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ByteString -> ByteString
bytesToLowerCase (ByteString -> ByteString)
-> ((CI ByteString, ByteString) -> ByteString)
-> (CI ByteString, ByteString)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CI ByteString -> ByteString
forall s. CI s -> s
CI.original (CI ByteString -> ByteString)
-> ((CI ByteString, ByteString) -> CI ByteString)
-> (CI ByteString, ByteString)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CI ByteString, ByteString) -> CI ByteString
forall a b. (a, b) -> a
fst)

patchedRequestHeaders :: Request -> [(CI ByteString, ByteString)]
patchedRequestHeaders :: Request -> [(CI ByteString, ByteString)]
patchedRequestHeaders Request
request = Request -> [(CI ByteString, ByteString)]
requestHeaders Request
request [(CI ByteString, ByteString)]
-> [(CI ByteString, ByteString)] -> [(CI ByteString, ByteString)]
forall a. [a] -> [a] -> [a]
++
      [ (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk ByteString
"Host", Request -> ByteString
host Request
request)
      , (ByteString -> CI ByteString
forall s. FoldCase s => s -> CI s
CI.mk ByteString
"Content-Length", String -> ByteString
S8.pack (String -> ByteString) -> (Int -> String) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show (Int -> ByteString) -> Int -> ByteString
forall a b. (a -> b) -> a -> b
$ Request -> Int
requestBodyLength Request
request)
      -- @http-client@ [adds the @Content-Length@ header automatically when sending the request](https://hackage.haskell.org/package/http-client-0.7.1/docs/Network-HTTP-Client.html#v:requestHeaders),
      -- so we have to reconstruct it by hand.
      ]

makeCredentialScope :: ByteString -> UTCTime -> ByteString -> ByteString
makeCredentialScope :: ByteString -> UTCTime -> ByteString -> ByteString
makeCredentialScope ByteString
service UTCTime
time ByteString
region = ByteString -> [ByteString] -> ByteString
S8.intercalate ByteString
"/" [UTCTime -> ByteString
formatAmazonDate UTCTime
time, ByteString
region, ByteString
service, ByteString
"aws4_request"]

bytesToLowerCase :: ByteString -> ByteString
bytesToLowerCase :: ByteString -> ByteString
bytesToLowerCase = String -> ByteString
S8.pack (String -> ByteString)
-> (ByteString -> String) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Char) -> String -> String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Char -> Char
toLower (String -> String)
-> (ByteString -> String) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
S8.unpack

unaryHashBase16 :: ByteString -> ByteString
unaryHashBase16 :: ByteString -> ByteString
unaryHashBase16 = ByteString -> ByteString
Base16.encode (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Digest SHA256 -> ByteString
forall a. Byteable a => a -> ByteString
toBytes (Digest SHA256 -> ByteString)
-> (ByteString -> Digest SHA256) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashAlgorithm SHA256 => ByteString -> Digest SHA256
forall a. HashAlgorithm a => ByteString -> Digest a
hash @SHA256

keyedHash :: ByteString -> ByteString -> ByteString
keyedHash :: ByteString -> ByteString -> ByteString
keyedHash ByteString
key ByteString
payload = Digest SHA256 -> ByteString
forall a. Byteable a => a -> ByteString
toBytes (Digest SHA256 -> ByteString)
-> (HMAC SHA256 -> Digest SHA256) -> HMAC SHA256 -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HMAC SHA256 -> Digest SHA256
forall a. HMAC a -> Digest a
hmacGetDigest (HMAC SHA256 -> ByteString) -> HMAC SHA256 -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> HMAC SHA256
forall a. HashAlgorithm a => ByteString -> ByteString -> HMAC a
hmac @SHA256 ByteString
key ByteString
payload