{-# LANGUAGE AllowAmbiguousTypes #-}
module Servant.Auth.Hmac.Crypto
(
SecretKey (..)
, Signature (..)
, sign
, signSHA256
, RequestPayload (..)
, requestSignature
, verifySignatureHmac
, whitelistHeaders
, keepWhitelistedHeaders
, authHeaderName
) where
import Crypto.Hash (hash)
import Crypto.Hash.Algorithms (MD5, SHA256)
import Crypto.Hash.IO (HashAlgorithm)
import Crypto.MAC.HMAC (HMAC (hmacGetDigest), hmac)
import Data.ByteString (ByteString)
import Data.CaseInsensitive (foldedCase)
import Data.List (sort, uncons)
import Network.HTTP.Types (Header, HeaderName, Method, RequestHeaders)
import qualified Data.ByteArray as BA (convert)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Lazy as LBS
newtype SecretKey = SecretKey
{ SecretKey -> ByteString
unSecretKey :: ByteString
}
newtype Signature = Signature
{ Signature -> ByteString
unSignature :: ByteString
} deriving (Signature -> Signature -> Bool
(Signature -> Signature -> Bool)
-> (Signature -> Signature -> Bool) -> Eq Signature
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Signature -> Signature -> Bool
$c/= :: Signature -> Signature -> Bool
== :: Signature -> Signature -> Bool
$c== :: Signature -> Signature -> Bool
Eq)
sign :: forall algo . (HashAlgorithm algo)
=> SecretKey
-> ByteString
-> Signature
sign :: SecretKey -> ByteString -> Signature
sign (SecretKey ByteString
sk) ByteString
msg = ByteString -> Signature
Signature
(ByteString -> Signature) -> ByteString -> Signature
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
Base64.encode
(ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Digest algo -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert
(Digest algo -> ByteString) -> Digest algo -> ByteString
forall a b. (a -> b) -> a -> b
$ HMAC algo -> Digest algo
forall a. HMAC a -> Digest a
hmacGetDigest
(HMAC algo -> Digest algo) -> HMAC algo -> Digest algo
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> HMAC algo
forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
hmac @_ @_ @algo ByteString
sk ByteString
msg
{-# INLINE sign #-}
signSHA256 :: SecretKey -> ByteString -> Signature
signSHA256 :: SecretKey -> ByteString -> Signature
signSHA256 = HashAlgorithm SHA256 => SecretKey -> ByteString -> Signature
forall algo.
HashAlgorithm algo =>
SecretKey -> ByteString -> Signature
sign @SHA256
{-# INLINE signSHA256 #-}
data RequestPayload = RequestPayload
{ RequestPayload -> ByteString
rpMethod :: !Method
, RequestPayload -> ByteString
rpContent :: !ByteString
, :: !RequestHeaders
, RequestPayload -> ByteString
rpRawUrl :: !ByteString
} deriving (Int -> RequestPayload -> ShowS
[RequestPayload] -> ShowS
RequestPayload -> String
(Int -> RequestPayload -> ShowS)
-> (RequestPayload -> String)
-> ([RequestPayload] -> ShowS)
-> Show RequestPayload
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RequestPayload] -> ShowS
$cshowList :: [RequestPayload] -> ShowS
show :: RequestPayload -> String
$cshow :: RequestPayload -> String
showsPrec :: Int -> RequestPayload -> ShowS
$cshowsPrec :: Int -> RequestPayload -> ShowS
Show)
requestSignature
:: (SecretKey -> ByteString -> Signature)
-> SecretKey
-> RequestPayload
-> Signature
requestSignature :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Signature
requestSignature SecretKey -> ByteString -> Signature
signer SecretKey
sk = SecretKey -> ByteString -> Signature
signer SecretKey
sk (ByteString -> Signature)
-> (RequestPayload -> ByteString) -> RequestPayload -> Signature
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RequestPayload -> ByteString
createStringToSign
where
createStringToSign :: RequestPayload -> ByteString
createStringToSign :: RequestPayload -> ByteString
createStringToSign RequestPayload{RequestHeaders
ByteString
rpRawUrl :: ByteString
rpHeaders :: RequestHeaders
rpContent :: ByteString
rpMethod :: ByteString
rpRawUrl :: RequestPayload -> ByteString
rpHeaders :: RequestPayload -> RequestHeaders
rpContent :: RequestPayload -> ByteString
rpMethod :: RequestPayload -> ByteString
..} = ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"\n"
[ ByteString
rpMethod
, ByteString -> ByteString
hashMD5 ByteString
rpContent
, RequestHeaders -> ByteString
normalizeHeaders RequestHeaders
rpHeaders
, ByteString
rpRawUrl
]
normalizeHeaders :: [Header] -> ByteString
normalizeHeaders :: RequestHeaders -> ByteString
normalizeHeaders = ByteString -> [ByteString] -> ByteString
BS.intercalate ByteString
"\n" ([ByteString] -> ByteString)
-> (RequestHeaders -> [ByteString]) -> RequestHeaders -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> [ByteString]
forall a. Ord a => [a] -> [a]
sort ([ByteString] -> [ByteString])
-> (RequestHeaders -> [ByteString])
-> RequestHeaders
-> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header -> ByteString) -> RequestHeaders -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map Header -> ByteString
normalize
where
normalize :: Header -> ByteString
normalize :: Header -> ByteString
normalize (HeaderName
name, ByteString
value) = HeaderName -> ByteString
forall s. CI s -> s
foldedCase HeaderName
name ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
value
whitelistHeaders :: [HeaderName]
=
[ HeaderName
authHeaderName
, HeaderName
"Host"
, HeaderName
"Accept-Encoding"
]
keepWhitelistedHeaders :: [Header] -> [Header]
= (Header -> Bool) -> RequestHeaders -> RequestHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter (\(HeaderName
name, ByteString
_) -> HeaderName
name HeaderName -> [HeaderName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HeaderName]
whitelistHeaders)
verifySignatureHmac
:: (SecretKey -> ByteString -> Signature)
-> SecretKey
-> RequestPayload
-> Maybe LBS.ByteString
verifySignatureHmac :: (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Maybe ByteString
verifySignatureHmac SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
signedPayload = case Either ByteString (RequestPayload, Signature)
unsignedPayload of
Left ByteString
err -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
err
Right (RequestPayload
pay, Signature
sig) -> if Signature
sig Signature -> Signature -> Bool
forall a. Eq a => a -> a -> Bool
== (SecretKey -> ByteString -> Signature)
-> SecretKey -> RequestPayload -> Signature
requestSignature SecretKey -> ByteString -> Signature
signer SecretKey
sk RequestPayload
pay
then Maybe ByteString
forall a. Maybe a
Nothing
else ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
"Signatures don't match"
where
unsignedPayload :: Either LBS.ByteString (RequestPayload, Signature)
unsignedPayload :: Either ByteString (RequestPayload, Signature)
unsignedPayload = case (Header -> Bool)
-> RequestHeaders -> (Maybe Header, RequestHeaders)
forall a. (a -> Bool) -> [a] -> (Maybe a, [a])
extractOn Header -> Bool
isAuthHeader (RequestHeaders -> (Maybe Header, RequestHeaders))
-> RequestHeaders -> (Maybe Header, RequestHeaders)
forall a b. (a -> b) -> a -> b
$ RequestPayload -> RequestHeaders
rpHeaders RequestPayload
signedPayload of
(Maybe Header
Nothing, RequestHeaders
_) -> ByteString -> Either ByteString (RequestPayload, Signature)
forall a b. a -> Either a b
Left ByteString
"No 'Authentication' header"
(Just (HeaderName
_, ByteString
val), RequestHeaders
headers) -> case ByteString -> ByteString -> Maybe ByteString
BS.stripPrefix ByteString
"HMAC " ByteString
val of
Just ByteString
sig -> (RequestPayload, Signature)
-> Either ByteString (RequestPayload, Signature)
forall a b. b -> Either a b
Right
( RequestPayload
signedPayload { rpHeaders :: RequestHeaders
rpHeaders = RequestHeaders
headers }
, ByteString -> Signature
Signature ByteString
sig
)
Maybe ByteString
Nothing -> ByteString -> Either ByteString (RequestPayload, Signature)
forall a b. a -> Either a b
Left ByteString
"Can not strip 'HMAC' prefix in header"
authHeaderName :: HeaderName
= HeaderName
"Authentication"
isAuthHeader :: Header -> Bool
= (HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
authHeaderName) (HeaderName -> Bool) -> (Header -> HeaderName) -> Header -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> HeaderName
forall a b. (a, b) -> a
fst
hashMD5 :: ByteString -> ByteString
hashMD5 :: ByteString -> ByteString
hashMD5 = Digest MD5 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (Digest MD5 -> ByteString)
-> (ByteString -> Digest MD5) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteArrayAccess ByteString, HashAlgorithm MD5) =>
ByteString -> Digest MD5
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
ba -> Digest a
hash @_ @MD5
extractOn :: (a -> Bool) -> [a] -> (Maybe a, [a])
a -> Bool
p [a]
l =
let ([a]
before, [a]
after) = (a -> Bool) -> [a] -> ([a], [a])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break a -> Bool
p [a]
l
in case [a] -> Maybe (a, [a])
forall a. [a] -> Maybe (a, [a])
uncons [a]
after of
Maybe (a, [a])
Nothing -> (Maybe a
forall a. Maybe a
Nothing, [a]
l)
Just (a
x, [a]
xs) -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, [a]
before [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs)