{-# LANGUAGE RecordWildCards, OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Web.WebPush.Internal where
import Control.Monad.IO.Class
import Crypto.Cipher.AES (AES128)
import qualified Crypto.Cipher.Types as Cipher
import qualified Crypto.ECC
import Crypto.Error (CryptoError,
eitherCryptoError)
import Crypto.Hash.Algorithms (SHA256 (..))
import qualified Crypto.MAC.HMAC as HMAC
import qualified Crypto.PubKey.ECC.DH as ECDH
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.PubKey.ECC.P256 as P256
import qualified Crypto.PubKey.ECC.Types as ECC
import qualified Crypto.PubKey.ECC.Types as ECCTypes
import Crypto.Random
import Data.Aeson ((.=))
import qualified Data.Aeson as A
import Data.Bifunctor
import qualified Data.Binary as Binary
import qualified Data.Bits as Bits
import qualified Data.ByteArray as ByteArray
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as LB
import qualified Data.ByteString.Lazy.Base64.URL as B64.URL
import Data.Data
import Data.Text (Text)
import qualified Data.Text as T
import Data.Word (Word16, Word64, Word8)
import GHC.Int (Int64)
import Network.HTTP.Types
import Network.URI
data ServerIdentification = ServerIdentification {
ServerIdentification -> Text
serverIdentificationAudience :: Text
, ServerIdentification -> Int
serverIdentificationExpiration :: Int
, ServerIdentification -> Text
serverIdentificationSubject :: Text
} deriving (Int -> ServerIdentification -> ShowS
[ServerIdentification] -> ShowS
ServerIdentification -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ServerIdentification] -> ShowS
$cshowList :: [ServerIdentification] -> ShowS
show :: ServerIdentification -> String
$cshow :: ServerIdentification -> String
showsPrec :: Int -> ServerIdentification -> ShowS
$cshowsPrec :: Int -> ServerIdentification -> ShowS
Show)
instance A.ToJSON ServerIdentification where
toJSON :: ServerIdentification -> Value
toJSON ServerIdentification
p = [Pair] -> Value
A.object [
Key
"aud" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ServerIdentification -> Text
serverIdentificationAudience ServerIdentification
p
, Key
"exp" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ServerIdentification -> Int
serverIdentificationExpiration ServerIdentification
p
, Key
"sub" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ServerIdentification -> Text
serverIdentificationSubject ServerIdentification
p
]
webPushJWT :: MonadRandom m => ECDSA.PrivateKey -> ServerIdentification -> m BS.ByteString
webPushJWT :: forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> ServerIdentification -> m ByteString
webPushJWT PrivateKey
privateKey ServerIdentification
payload = do
Signature
signature <- forall msg hash (m :: * -> *).
(ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m) =>
PrivateKey -> hash -> msg -> m Signature
ECDSA.sign PrivateKey
privateKey SHA256
SHA256 ByteString
jwtMessage
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ByteString
jwtMessage forall a. Semigroup a => a -> a -> a
<> ByteString
"." forall a. Semigroup a => a -> a -> a
<> Signature -> ByteString
jwtSignature Signature
signature
where
jwtSignature :: Signature -> ByteString
jwtSignature (ECDSA.Signature Integer
signR Integer
signS) =
ByteString -> ByteString
BS.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B64.URL.encodeBase64Unpadded' forall a b. (a -> b) -> a -> b
$ forall a. Binary a => a -> ByteString
Binary.encode (Integer -> Bytes32
int32Bytes Integer
signR, Integer -> Bytes32
int32Bytes Integer
signS)
jwtPayload :: ByteString
jwtPayload = ByteString -> ByteString
B64.URL.encodeBase64Unpadded' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> ByteString
A.encode forall a b. (a -> b) -> a -> b
$ ServerIdentification
payload
jwtMessage :: ByteString
jwtMessage = ByteString -> ByteString
BS.toStrict forall a b. (a -> b) -> a -> b
$ ByteString
jwtHeader forall a. Semigroup a => a -> a -> a
<> ByteString
"." forall a. Semigroup a => a -> a -> a
<> ByteString
jwtPayload
jwtHeader :: ByteString
jwtHeader = ByteString -> ByteString
B64.URL.encodeBase64Unpadded' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> ByteString
A.encode forall a b. (a -> b) -> a -> b
$ [Pair] -> Value
A.object [
Key
"typ" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= (Text
"JWT" :: Text)
, Key
"alg" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= (Text
"ES256" :: Text)
]
data WebPushEncryptionInput = EncryptionInput {
WebPushEncryptionInput -> Integer
applicationServerPrivateKey :: ECDH.PrivateNumber
, WebPushEncryptionInput -> ByteString
userAgentPublicKeyBytes :: ByteString
, WebPushEncryptionInput -> ByteString
authenticationSecret :: ByteString
, WebPushEncryptionInput -> ByteString
salt :: ByteString
, WebPushEncryptionInput -> ByteString
plainText :: LB.ByteString
, WebPushEncryptionInput -> Int64
paddingLength :: Int64
} deriving (Int -> WebPushEncryptionInput -> ShowS
[WebPushEncryptionInput] -> ShowS
WebPushEncryptionInput -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WebPushEncryptionInput] -> ShowS
$cshowList :: [WebPushEncryptionInput] -> ShowS
show :: WebPushEncryptionInput -> String
$cshow :: WebPushEncryptionInput -> String
showsPrec :: Int -> WebPushEncryptionInput -> ShowS
$cshowsPrec :: Int -> WebPushEncryptionInput -> ShowS
Show)
data WebPushEncryptionOutput = EncryptionOutput {
WebPushEncryptionOutput -> ByteString
sharedECDHSecretBytes :: ByteString
, WebPushEncryptionOutput -> ByteString
inputKeyingMaterialBytes :: ByteString
, WebPushEncryptionOutput -> ByteString
contentEncryptionKeyContext :: ByteString
, WebPushEncryptionOutput -> ByteString
contentEncryptionKey :: ByteString
, WebPushEncryptionOutput -> ByteString
nonceContext :: ByteString
, WebPushEncryptionOutput -> ByteString
nonce :: ByteString
, WebPushEncryptionOutput -> ByteString
paddedPlainText :: ByteString
, WebPushEncryptionOutput -> ByteString
encryptedMessage :: ByteString
}
data EncryptError =
EncodeApplicationPublicKeyError String
| EncryptCipherInitError CryptoError
| EncryptAeadInitError CryptoError
| EncryptInputPublicKeyError CryptoError
| EncryptInputApplicationPublicKeyError String
deriving (EncryptError -> EncryptError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncryptError -> EncryptError -> Bool
$c/= :: EncryptError -> EncryptError -> Bool
== :: EncryptError -> EncryptError -> Bool
$c== :: EncryptError -> EncryptError -> Bool
Eq, Int -> EncryptError -> ShowS
[EncryptError] -> ShowS
EncryptError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncryptError] -> ShowS
$cshowList :: [EncryptError] -> ShowS
show :: EncryptError -> String
$cshow :: EncryptError -> String
showsPrec :: Int -> EncryptError -> ShowS
$cshowsPrec :: Int -> EncryptError -> ShowS
Show)
webPushEncrypt :: WebPushEncryptionInput -> Either EncryptError WebPushEncryptionOutput
webPushEncrypt :: WebPushEncryptionInput
-> Either EncryptError WebPushEncryptionOutput
webPushEncrypt EncryptionInput{Int64
Integer
ByteString
ByteString
paddingLength :: Int64
plainText :: ByteString
salt :: ByteString
authenticationSecret :: ByteString
userAgentPublicKeyBytes :: ByteString
applicationServerPrivateKey :: Integer
paddingLength :: WebPushEncryptionInput -> Int64
plainText :: WebPushEncryptionInput -> ByteString
salt :: WebPushEncryptionInput -> ByteString
authenticationSecret :: WebPushEncryptionInput -> ByteString
userAgentPublicKeyBytes :: WebPushEncryptionInput -> ByteString
applicationServerPrivateKey :: WebPushEncryptionInput -> Integer
..} = do
Point
userAgentPublicKey <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> EncryptError
EncryptInputPublicKeyError forall a b. (a -> b) -> a -> b
$ ByteString -> Either CryptoError Point
ecBytesToPublicKey ByteString
userAgentPublicKeyBytes
ByteString
applicationServerPublicKey <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> EncryptError
EncryptInputApplicationPublicKeyError forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point -> Either String ByteString
ecPublicKeyToBytes forall a b. (a -> b) -> a -> b
$ Curve -> Integer -> Point
ECDH.calculatePublic Curve
curveP256 Integer
applicationServerPrivateKey
let
sharedECDHSecret :: SharedKey
sharedECDHSecret = Curve -> Integer -> Point -> SharedKey
ECDH.getShared Curve
curveP256 Integer
applicationServerPrivateKey Point
userAgentPublicKey
pseudoRandomKeyCombine :: HMAC SHA256
pseudoRandomKeyCombine = forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac ByteString
authenticationSecret SharedKey
sharedECDHSecret :: HMAC.HMAC SHA256
inputKeyingMaterial :: HMAC SHA256
inputKeyingMaterial = forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac HMAC SHA256
pseudoRandomKeyCombine (ByteString
authInfo forall a. Semigroup a => a -> a -> a
<> ByteString
"\x01") :: HMAC.HMAC SHA256
pseudoRandomKeyEncryption :: HMAC SHA256
pseudoRandomKeyEncryption = forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac ByteString
salt HMAC SHA256
inputKeyingMaterial :: HMAC.HMAC SHA256
nonce :: ByteString
nonce = [Word8] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
12 forall a b. (a -> b) -> a -> b
$ forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack (forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac HMAC SHA256
pseudoRandomKeyEncryption (ByteString
nonceContext forall a. Semigroup a => a -> a -> a
<> ByteString
"\x01") :: HMAC.HMAC SHA256)
contentEncryptionKey :: ByteString
contentEncryptionKey = Int -> ByteString -> ByteString
BS.take Int
16 forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert (forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac HMAC SHA256
pseudoRandomKeyEncryption (ByteString
contentEncryptionKeyContext forall a. Semigroup a => a -> a -> a
<> ByteString
"\x01") :: HMAC.HMAC SHA256)
context :: ByteString
context = ByteString
"P-256" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x41" forall a. Semigroup a => a -> a -> a
<> ByteString
userAgentPublicKeyBytes forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x41" forall a. Semigroup a => a -> a -> a
<> ByteString
applicationServerPublicKey
contentEncryptionKeyContext :: ByteString
contentEncryptionKeyContext = ByteString
"Content-Encoding: aesgcm" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
context
nonceContext :: ByteString
nonceContext = ByteString
"Content-Encoding: nonce" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
context
AES128
aesCipher :: AES128 <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> EncryptError
EncryptCipherInitError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
Cipher.cipherInit ByteString
contentEncryptionKey
AEAD AES128
aeadGcmCipher <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> EncryptError
EncryptAeadInitError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
Cipher.aeadInit AEADMode
Cipher.AEAD_GCM AES128
aesCipher ByteString
nonce
let
authTag :: ByteString
authTag = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert forall a b. (a -> b) -> a -> b
$ AuthTag -> Bytes
Cipher.unAuthTag AuthTag
authTagBytes
(AuthTag
authTagBytes, ByteString
cipherText) = forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
Cipher.aeadSimpleEncrypt AEAD AES128
aeadGcmCipher ByteString
BS.empty ByteString
paddedPlainText Int
16
encryptedMessage :: ByteString
encryptedMessage = ByteString
cipherText forall a. Semigroup a => a -> a -> a
<> ByteString
authTag
inputKeyingMaterialBytes :: ByteString
inputKeyingMaterialBytes = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert HMAC SHA256
inputKeyingMaterial
sharedECDHSecretBytes :: ByteString
sharedECDHSecretBytes = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert SharedKey
sharedECDHSecret
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ EncryptionOutput {ByteString
sharedECDHSecretBytes :: ByteString
inputKeyingMaterialBytes :: ByteString
encryptedMessage :: ByteString
paddedPlainText :: ByteString
contentEncryptionKeyContext :: ByteString
contentEncryptionKey :: ByteString
nonceContext :: ByteString
nonce :: ByteString
encryptedMessage :: ByteString
paddedPlainText :: ByteString
nonce :: ByteString
nonceContext :: ByteString
contentEncryptionKey :: ByteString
contentEncryptionKeyContext :: ByteString
inputKeyingMaterialBytes :: ByteString
sharedECDHSecretBytes :: ByteString
..}
where
paddedPlainText :: ByteString
paddedPlainText = ByteString -> ByteString
LB.toStrict forall a b. (a -> b) -> a -> b
$
(forall a. Binary a => a -> ByteString
Binary.encode (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
paddingLength :: Word16)) forall a. Semigroup a => a -> a -> a
<>
(Int64 -> Word8 -> ByteString
LB.replicate Int64
paddingLength (Word8
0 :: Word8)) forall a. Semigroup a => a -> a -> a
<>
ByteString
plainText
authInfo :: ByteString
authInfo = ByteString
"Content-Encoding: auth" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" :: ByteString
curveP256 :: Curve
curveP256 = CurveName -> Curve
ECCTypes.getCurveByName CurveName
ECCTypes.SEC_p256r1
hostHeaders :: (MonadIO m, MonadRandom m)
=> ECDSA.PrivateKey
-> ServerIdentification
-> m [Header]
PrivateKey
privateKey ServerIdentification
serverIdentification = do
ByteString
jwt <- forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> ServerIdentification -> m ByteString
webPushJWT PrivateKey
privateKey ServerIdentification
serverIdentification
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(HeaderName
hAuthorization, ByteString
"WebPush " forall a. Semigroup a => a -> a -> a
<> ByteString
jwt)]
uriHost :: URI -> Maybe T.Text
uriHost :: URI -> Maybe Text
uriHost URI
uri = do
String
regName <- URIAuth -> String
uriRegName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> URI -> Maybe URIAuth
uriAuthority URI
uri
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ URI -> String
uriScheme URI
uri forall a. Semigroup a => a -> a -> a
<> String
"//" forall a. Semigroup a => a -> a -> a
<> String
regName
ecPublicKeyToBytes :: ECC.Point -> Either String ByteString
ecPublicKeyToBytes :: Point -> Either String ByteString
ecPublicKeyToBytes Point
p = (Integer, Integer) -> ByteString
ecPublicKeyToBytes' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {a}. IsString a => Point -> Either a (Integer, Integer)
fromECCPoint Point
p
where
fromECCPoint :: Point -> Either a (Integer, Integer)
fromECCPoint Point
ECC.PointO = forall a b. a -> Either a b
Left a
"Invalid public key infinity point"
fromECCPoint (ECC.Point Integer
x Integer
y) = forall a b. b -> Either a b
Right (Integer
x,Integer
y)
ecPublicKeyToBytes' :: (Integer, Integer) -> ByteString
ecPublicKeyToBytes' :: (Integer, Integer) -> ByteString
ecPublicKeyToBytes' = forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
Crypto.ECC.encodePoint (forall {k} (t :: k). Proxy t
Proxy :: Proxy Crypto.ECC.Curve_P256R1) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, Integer) -> Point
P256.pointFromIntegers
ecBytesToPublicKey :: ByteString -> Either CryptoError ECC.Point
ecBytesToPublicKey :: ByteString -> Either CryptoError Point
ecBytesToPublicKey =
forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Point -> Point
toECCPoint forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> bs -> CryptoFailable (Point curve)
Crypto.ECC.decodePoint (forall {k} (t :: k). Proxy t
Proxy :: Proxy Crypto.ECC.Curve_P256R1)
where toECCPoint :: Point -> Point
toECCPoint = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Integer -> Integer -> Point
ECC.Point forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point -> (Integer, Integer)
P256.pointToIntegers
type Bytes32 = (Word64, Word64, Word64, Word64)
int32Bytes :: Integer -> Bytes32
int32Bytes :: Integer -> Bytes32
int32Bytes Integer
number = let shift1 :: Integer
shift1 = forall a. Bits a => a -> Int -> a
Bits.shiftR Integer
number Int
64
shift2 :: Integer
shift2 = forall a. Bits a => a -> Int -> a
Bits.shiftR Integer
shift1 Int
64
shift3 :: Integer
shift3 = forall a. Bits a => a -> Int -> a
Bits.shiftR Integer
shift2 Int
64
in ( forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
shift3
, forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
shift2
, forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
shift1
, forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
number
)
bytes32Int :: Bytes32 -> Integer
bytes32Int :: Bytes32 -> Integer
bytes32Int (Word64
d,Word64
c,Word64
b,Word64
a) = (forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
d) (Int
64forall a. Num a => a -> a -> a
*Int
3)) forall a. Num a => a -> a -> a
+
(forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
c) (Int
64forall a. Num a => a -> a -> a
*Int
2)) forall a. Num a => a -> a -> a
+
(forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
b) (Int
64 )) forall a. Num a => a -> a -> a
+
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
a)