{-# LANGUAGE RecordWildCards, OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Web.WebPush.Internal where
import           Control.Monad.IO.Class
import           Crypto.Cipher.AES               (AES128)
import qualified Crypto.Cipher.Types             as Cipher
import qualified Crypto.ECC
import           Crypto.Error                    (CryptoError,
                                                  eitherCryptoError)
import           Crypto.Hash.Algorithms          (SHA256 (..))
import qualified Crypto.MAC.HMAC                 as HMAC
import qualified Crypto.PubKey.ECC.DH            as ECDH
import qualified Crypto.PubKey.ECC.ECDSA         as ECDSA
import qualified Crypto.PubKey.ECC.P256          as P256
import qualified Crypto.PubKey.ECC.Types         as ECC
import qualified Crypto.PubKey.ECC.Types         as ECCTypes
import           Crypto.Random
import           Data.Aeson                      ((.=))
import qualified Data.Aeson                      as A
import           Data.Bifunctor
import qualified Data.Binary                     as Binary
import qualified Data.Bits                       as Bits
import qualified Data.ByteArray                  as ByteArray
import           Data.ByteString                 (ByteString)
import qualified Data.ByteString                 as BS
import qualified Data.ByteString.Lazy            as LB
import qualified Data.ByteString.Lazy.Base64.URL as B64.URL
import           Data.Data
import           Data.Text                       (Text)
import qualified Data.Text                       as T
import           Data.Word                       (Word16, Word64, Word8)
import           GHC.Int                         (Int64)
import           Network.HTTP.Types
import           Network.URI
data ServerIdentification = ServerIdentification {
  ServerIdentification -> Text
serverIdentificationAudience :: Text
, ServerIdentification -> Int
serverIdentificationExpiration :: Int
, ServerIdentification -> Text
serverIdentificationSubject :: Text 
} deriving (Int -> ServerIdentification -> ShowS
[ServerIdentification] -> ShowS
ServerIdentification -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ServerIdentification] -> ShowS
$cshowList :: [ServerIdentification] -> ShowS
show :: ServerIdentification -> String
$cshow :: ServerIdentification -> String
showsPrec :: Int -> ServerIdentification -> ShowS
$cshowsPrec :: Int -> ServerIdentification -> ShowS
Show)
 
instance A.ToJSON ServerIdentification where
  toJSON :: ServerIdentification -> Value
toJSON ServerIdentification
p = [Pair] -> Value
A.object [
      Key
"aud" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ServerIdentification -> Text
serverIdentificationAudience ServerIdentification
p
    , Key
"exp" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ServerIdentification -> Int
serverIdentificationExpiration ServerIdentification
p
    , Key
"sub" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= ServerIdentification -> Text
serverIdentificationSubject ServerIdentification
p
    ]
webPushJWT :: MonadRandom m => ECDSA.PrivateKey -> ServerIdentification -> m BS.ByteString
webPushJWT :: forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> ServerIdentification -> m ByteString
webPushJWT PrivateKey
privateKey ServerIdentification
payload = do
  
  
  Signature
signature <- forall msg hash (m :: * -> *).
(ByteArrayAccess msg, HashAlgorithm hash, MonadRandom m) =>
PrivateKey -> hash -> msg -> m Signature
ECDSA.sign PrivateKey
privateKey SHA256
SHA256 ByteString
jwtMessage
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ByteString
jwtMessage forall a. Semigroup a => a -> a -> a
<> ByteString
"." forall a. Semigroup a => a -> a -> a
<> Signature -> ByteString
jwtSignature Signature
signature
  where
    jwtSignature :: Signature -> ByteString
jwtSignature (ECDSA.Signature Integer
signR Integer
signS) =
      
      ByteString -> ByteString
BS.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
B64.URL.encodeBase64Unpadded' forall a b. (a -> b) -> a -> b
$ forall a. Binary a => a -> ByteString
Binary.encode (Integer -> Bytes32
int32Bytes Integer
signR, Integer -> Bytes32
int32Bytes Integer
signS)
    jwtPayload :: ByteString
jwtPayload = ByteString -> ByteString
B64.URL.encodeBase64Unpadded' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> ByteString
A.encode forall a b. (a -> b) -> a -> b
$ ServerIdentification
payload
    jwtMessage :: ByteString
jwtMessage = ByteString -> ByteString
BS.toStrict forall a b. (a -> b) -> a -> b
$ ByteString
jwtHeader forall a. Semigroup a => a -> a -> a
<> ByteString
"." forall a. Semigroup a => a -> a -> a
<> ByteString
jwtPayload
    jwtHeader :: ByteString
jwtHeader = ByteString -> ByteString
B64.URL.encodeBase64Unpadded' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> ByteString
A.encode forall a b. (a -> b) -> a -> b
$ [Pair] -> Value
A.object [
        Key
"typ" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= (Text
"JWT" :: Text)
      , Key
"alg" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.= (Text
"ES256" :: Text)
      ]
data WebPushEncryptionInput = EncryptionInput {
  WebPushEncryptionInput -> Integer
applicationServerPrivateKey :: ECDH.PrivateNumber
, WebPushEncryptionInput -> ByteString
userAgentPublicKeyBytes :: ByteString
, WebPushEncryptionInput -> ByteString
authenticationSecret :: ByteString
, WebPushEncryptionInput -> ByteString
salt :: ByteString
, WebPushEncryptionInput -> ByteString
plainText :: LB.ByteString
, WebPushEncryptionInput -> Int64
paddingLength :: Int64
} deriving (Int -> WebPushEncryptionInput -> ShowS
[WebPushEncryptionInput] -> ShowS
WebPushEncryptionInput -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WebPushEncryptionInput] -> ShowS
$cshowList :: [WebPushEncryptionInput] -> ShowS
show :: WebPushEncryptionInput -> String
$cshow :: WebPushEncryptionInput -> String
showsPrec :: Int -> WebPushEncryptionInput -> ShowS
$cshowsPrec :: Int -> WebPushEncryptionInput -> ShowS
Show)
data WebPushEncryptionOutput = EncryptionOutput {
  WebPushEncryptionOutput -> ByteString
sharedECDHSecretBytes :: ByteString
, WebPushEncryptionOutput -> ByteString
inputKeyingMaterialBytes :: ByteString
, WebPushEncryptionOutput -> ByteString
contentEncryptionKeyContext :: ByteString
, WebPushEncryptionOutput -> ByteString
contentEncryptionKey :: ByteString
, WebPushEncryptionOutput -> ByteString
nonceContext :: ByteString
, WebPushEncryptionOutput -> ByteString
nonce :: ByteString
, WebPushEncryptionOutput -> ByteString
paddedPlainText :: ByteString
, WebPushEncryptionOutput -> ByteString
encryptedMessage :: ByteString
}
data EncryptError =
    EncodeApplicationPublicKeyError String
  | EncryptCipherInitError CryptoError
  | EncryptAeadInitError CryptoError
  | EncryptInputPublicKeyError CryptoError
  | EncryptInputApplicationPublicKeyError String
  deriving (EncryptError -> EncryptError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: EncryptError -> EncryptError -> Bool
$c/= :: EncryptError -> EncryptError -> Bool
== :: EncryptError -> EncryptError -> Bool
$c== :: EncryptError -> EncryptError -> Bool
Eq, Int -> EncryptError -> ShowS
[EncryptError] -> ShowS
EncryptError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [EncryptError] -> ShowS
$cshowList :: [EncryptError] -> ShowS
show :: EncryptError -> String
$cshow :: EncryptError -> String
showsPrec :: Int -> EncryptError -> ShowS
$cshowsPrec :: Int -> EncryptError -> ShowS
Show)
webPushEncrypt :: WebPushEncryptionInput -> Either EncryptError WebPushEncryptionOutput
webPushEncrypt :: WebPushEncryptionInput
-> Either EncryptError WebPushEncryptionOutput
webPushEncrypt EncryptionInput{Int64
Integer
ByteString
ByteString
paddingLength :: Int64
plainText :: ByteString
salt :: ByteString
authenticationSecret :: ByteString
userAgentPublicKeyBytes :: ByteString
applicationServerPrivateKey :: Integer
paddingLength :: WebPushEncryptionInput -> Int64
plainText :: WebPushEncryptionInput -> ByteString
salt :: WebPushEncryptionInput -> ByteString
authenticationSecret :: WebPushEncryptionInput -> ByteString
userAgentPublicKeyBytes :: WebPushEncryptionInput -> ByteString
applicationServerPrivateKey :: WebPushEncryptionInput -> Integer
..} = do 
  Point
userAgentPublicKey <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> EncryptError
EncryptInputPublicKeyError forall a b. (a -> b) -> a -> b
$ ByteString -> Either CryptoError Point
ecBytesToPublicKey ByteString
userAgentPublicKeyBytes
  ByteString
applicationServerPublicKey <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> EncryptError
EncryptInputApplicationPublicKeyError forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point -> Either String ByteString
ecPublicKeyToBytes forall a b. (a -> b) -> a -> b
$ Curve -> Integer -> Point
ECDH.calculatePublic Curve
curveP256 Integer
applicationServerPrivateKey
  let
    sharedECDHSecret :: SharedKey
sharedECDHSecret = Curve -> Integer -> Point -> SharedKey
ECDH.getShared Curve
curveP256 Integer
applicationServerPrivateKey Point
userAgentPublicKey
    pseudoRandomKeyCombine :: HMAC SHA256
pseudoRandomKeyCombine = forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac ByteString
authenticationSecret SharedKey
sharedECDHSecret :: HMAC.HMAC SHA256
    inputKeyingMaterial :: HMAC SHA256
inputKeyingMaterial = forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac HMAC SHA256
pseudoRandomKeyCombine (ByteString
authInfo forall a. Semigroup a => a -> a -> a
<> ByteString
"\x01") :: HMAC.HMAC SHA256
    pseudoRandomKeyEncryption :: HMAC SHA256
pseudoRandomKeyEncryption = forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac ByteString
salt HMAC SHA256
inputKeyingMaterial :: HMAC.HMAC SHA256
    nonce :: ByteString
nonce = [Word8] -> ByteString
BS.pack forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
12 forall a b. (a -> b) -> a -> b
$ forall a. ByteArrayAccess a => a -> [Word8]
ByteArray.unpack (forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac HMAC SHA256
pseudoRandomKeyEncryption (ByteString
nonceContext forall a. Semigroup a => a -> a -> a
<> ByteString
"\x01") :: HMAC.HMAC SHA256)
    contentEncryptionKey :: ByteString
contentEncryptionKey = Int -> ByteString -> ByteString
BS.take Int
16 forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert (forall key message a.
(ByteArrayAccess key, ByteArrayAccess message, HashAlgorithm a) =>
key -> message -> HMAC a
HMAC.hmac HMAC SHA256
pseudoRandomKeyEncryption (ByteString
contentEncryptionKeyContext forall a. Semigroup a => a -> a -> a
<> ByteString
"\x01") :: HMAC.HMAC SHA256)
    context :: ByteString
context = ByteString
"P-256" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x41" forall a. Semigroup a => a -> a -> a
<> ByteString
userAgentPublicKeyBytes forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x41" forall a. Semigroup a => a -> a -> a
<> ByteString
applicationServerPublicKey
    contentEncryptionKeyContext :: ByteString
contentEncryptionKeyContext = ByteString
"Content-Encoding: aesgcm" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
context
    nonceContext :: ByteString
nonceContext = ByteString
"Content-Encoding: nonce" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" forall a. Semigroup a => a -> a -> a
<> ByteString
context
  
  
  AES128
aesCipher :: AES128 <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> EncryptError
EncryptCipherInitError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
Cipher.cipherInit ByteString
contentEncryptionKey
  AEAD AES128
aeadGcmCipher <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first CryptoError -> EncryptError
EncryptAeadInitError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
Cipher.aeadInit AEADMode
Cipher.AEAD_GCM AES128
aesCipher ByteString
nonce
  let
      
      
      
      authTag :: ByteString
authTag = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert forall a b. (a -> b) -> a -> b
$ AuthTag -> Bytes
Cipher.unAuthTag AuthTag
authTagBytes
      (AuthTag
authTagBytes, ByteString
cipherText) = forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
Cipher.aeadSimpleEncrypt AEAD AES128
aeadGcmCipher ByteString
BS.empty ByteString
paddedPlainText Int
16
      encryptedMessage :: ByteString
encryptedMessage = ByteString
cipherText forall a. Semigroup a => a -> a -> a
<> ByteString
authTag
      
      
      inputKeyingMaterialBytes :: ByteString
inputKeyingMaterialBytes = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert HMAC SHA256
inputKeyingMaterial
      sharedECDHSecretBytes :: ByteString
sharedECDHSecretBytes = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
ByteArray.convert SharedKey
sharedECDHSecret
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ EncryptionOutput {ByteString
sharedECDHSecretBytes :: ByteString
inputKeyingMaterialBytes :: ByteString
encryptedMessage :: ByteString
paddedPlainText :: ByteString
contentEncryptionKeyContext :: ByteString
contentEncryptionKey :: ByteString
nonceContext :: ByteString
nonce :: ByteString
encryptedMessage :: ByteString
paddedPlainText :: ByteString
nonce :: ByteString
nonceContext :: ByteString
contentEncryptionKey :: ByteString
contentEncryptionKeyContext :: ByteString
inputKeyingMaterialBytes :: ByteString
sharedECDHSecretBytes :: ByteString
..}
  where
    
    
    paddedPlainText :: ByteString
paddedPlainText = ByteString -> ByteString
LB.toStrict forall a b. (a -> b) -> a -> b
$
                          (forall a. Binary a => a -> ByteString
Binary.encode (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
paddingLength :: Word16)) forall a. Semigroup a => a -> a -> a
<>
                          (Int64 -> Word8 -> ByteString
LB.replicate Int64
paddingLength (Word8
0 :: Word8)) forall a. Semigroup a => a -> a -> a
<>
                          ByteString
plainText
    authInfo :: ByteString
authInfo = ByteString
"Content-Encoding: auth" forall a. Semigroup a => a -> a -> a
<> ByteString
"\x00" :: ByteString
    curveP256 :: Curve
curveP256 = CurveName -> Curve
ECCTypes.getCurveByName CurveName
ECCTypes.SEC_p256r1
hostHeaders :: (MonadIO m, MonadRandom m)
            => ECDSA.PrivateKey
            -> ServerIdentification
            -> m [Header]
 PrivateKey
privateKey ServerIdentification
serverIdentification = do
  ByteString
jwt <- forall (m :: * -> *).
MonadRandom m =>
PrivateKey -> ServerIdentification -> m ByteString
webPushJWT PrivateKey
privateKey ServerIdentification
serverIdentification
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [(HeaderName
hAuthorization, ByteString
"WebPush " forall a. Semigroup a => a -> a -> a
<> ByteString
jwt)]
uriHost :: URI -> Maybe T.Text
uriHost :: URI -> Maybe Text
uriHost URI
uri = do
  String
regName <- URIAuth -> String
uriRegName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> URI -> Maybe URIAuth
uriAuthority URI
uri
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ String -> Text
T.pack forall a b. (a -> b) -> a -> b
$ URI -> String
uriScheme URI
uri forall a. Semigroup a => a -> a -> a
<> String
"//" forall a. Semigroup a => a -> a -> a
<> String
regName
 
ecPublicKeyToBytes :: ECC.Point -> Either String ByteString
ecPublicKeyToBytes :: Point -> Either String ByteString
ecPublicKeyToBytes Point
p = (Integer, Integer) -> ByteString
ecPublicKeyToBytes' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {a}. IsString a => Point -> Either a (Integer, Integer)
fromECCPoint Point
p
  where
    fromECCPoint :: Point -> Either a (Integer, Integer)
fromECCPoint Point
ECC.PointO = forall a b. a -> Either a b
Left a
"Invalid public key infinity point"
    fromECCPoint (ECC.Point Integer
x Integer
y) = forall a b. b -> Either a b
Right (Integer
x,Integer
y)
ecPublicKeyToBytes' :: (Integer, Integer) -> ByteString
ecPublicKeyToBytes' :: (Integer, Integer) -> ByteString
ecPublicKeyToBytes' = forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
Crypto.ECC.encodePoint (forall {k} (t :: k). Proxy t
Proxy :: Proxy Crypto.ECC.Curve_P256R1) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, Integer) -> Point
P256.pointFromIntegers
ecBytesToPublicKey :: ByteString -> Either CryptoError ECC.Point
ecBytesToPublicKey :: ByteString -> Either CryptoError Point
ecBytesToPublicKey =
  forall a. CryptoFailable a -> Either CryptoError a
eitherCryptoError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Point -> Point
toECCPoint forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> bs -> CryptoFailable (Point curve)
Crypto.ECC.decodePoint (forall {k} (t :: k). Proxy t
Proxy :: Proxy Crypto.ECC.Curve_P256R1)
  where toECCPoint :: Point -> Point
toECCPoint = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Integer -> Integer -> Point
ECC.Point forall b c a. (b -> c) -> (a -> b) -> a -> c
. Point -> (Integer, Integer)
P256.pointToIntegers 
type Bytes32 = (Word64, Word64, Word64, Word64)
 
 
 
int32Bytes :: Integer -> Bytes32
int32Bytes :: Integer -> Bytes32
int32Bytes Integer
number =  let shift1 :: Integer
shift1 = forall a. Bits a => a -> Int -> a
Bits.shiftR Integer
number Int
64
                         shift2 :: Integer
shift2 = forall a. Bits a => a -> Int -> a
Bits.shiftR Integer
shift1 Int
64
                         shift3 :: Integer
shift3 = forall a. Bits a => a -> Int -> a
Bits.shiftR Integer
shift2 Int
64
                     in ( forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
shift3
                        , forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
shift2
                        , forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
shift1
                        , forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
number
                       )
bytes32Int :: Bytes32 -> Integer
bytes32Int :: Bytes32 -> Integer
bytes32Int (Word64
d,Word64
c,Word64
b,Word64
a) = (forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
d) (Int
64forall a. Num a => a -> a -> a
*Int
3)) forall a. Num a => a -> a -> a
+
                       (forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
c) (Int
64forall a. Num a => a -> a -> a
*Int
2)) forall a. Num a => a -> a -> a
+
                       (forall a. Bits a => a -> Int -> a
Bits.shiftL (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
b) (Int
64  )) forall a. Num a => a -> a -> a
+
                                    (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
a)