{-# LANGUAGE DeriveAnyClass      #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

module Web.WebPush.Keys where

import           Web.WebPush.Internal

import           Control.Exception
import qualified Crypto.ECC
import qualified Crypto.Number.Serialize    as Serialize
import qualified Crypto.PubKey.ECC.ECDSA    as ECDSA
import qualified Crypto.PubKey.ECC.Generate as ECC
import qualified Crypto.PubKey.ECC.Types    as ECC
import           Crypto.Random              (MonadRandom)
import qualified Data.ASN1.BinaryEncoding   as ASN1
import qualified Data.ASN1.Encoding         as ASN1
import           Data.ASN1.Error
import qualified Data.ASN1.Types            as ASN1
import           Data.Bifunctor
import qualified Data.ByteString            as BS
import           Data.PEM
import           Data.Proxy
import           Data.Word                  (Word8)
import           Data.X509
import           Data.X509.EC
import           Data.X509.File

-- | VAPIDKeys are the public and private keys used to sign the JWT
-- authentication token sent for the push sendPushNotification
-- 
-- The key is an ECDSA key pair with the p256 curve
newtype VAPIDKeys = VAPIDKeys {
  VAPIDKeys -> KeyPair
unVAPIDKeys :: ECDSA.KeyPair
} deriving (Int -> VAPIDKeys -> ShowS
[VAPIDKeys] -> ShowS
VAPIDKeys -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VAPIDKeys] -> ShowS
$cshowList :: [VAPIDKeys] -> ShowS
show :: VAPIDKeys -> String
$cshow :: VAPIDKeys -> String
showsPrec :: Int -> VAPIDKeys -> ShowS
$cshowsPrec :: Int -> VAPIDKeys -> ShowS
Show)

-- | Get the public key from the VAPID keys
vapidPublicKey :: VAPIDKeys -> ECDSA.PublicKey
vapidPublicKey :: VAPIDKeys -> PublicKey
vapidPublicKey = KeyPair -> PublicKey
ECDSA.toPublicKey forall b c a. (b -> c) -> (a -> b) -> a -> c
. VAPIDKeys -> KeyPair
unVAPIDKeys

-- | Errors from reading the VAPID keys from files
data VAPIDKeysError =
    VAPIDKeysPublicKeyError PublicKeyError -- ^ Error reading the public key
  | VAPIDKeysPrivateKeyError PrivateKeyError -- ^ Error reading the private key
  | VAPIDKeysCurveMismatch -- ^ The public and private keys are not on the same curve
  deriving (Int -> VAPIDKeysError -> ShowS
[VAPIDKeysError] -> ShowS
VAPIDKeysError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VAPIDKeysError] -> ShowS
$cshowList :: [VAPIDKeysError] -> ShowS
show :: VAPIDKeysError -> String
$cshow :: VAPIDKeysError -> String
showsPrec :: Int -> VAPIDKeysError -> ShowS
$cshowsPrec :: Int -> VAPIDKeysError -> ShowS
Show)

-- | Read the public and private keys from files
readVapidKeys :: FilePath -- ^ Path to the public key file
              -> FilePath -- ^ Path to the private key file
              -> IO (Either VAPIDKeysError VAPIDKeys)
readVapidKeys :: String -> String -> IO (Either VAPIDKeysError VAPIDKeys)
readVapidKeys String
pubKeyPath String
privKeyPath = do
  Either PublicKeyError PublicKey
pubKey <- String -> IO (Either PublicKeyError PublicKey)
readWebPushPublicKey String
pubKeyPath
  Either PrivateKeyError PrivateKey
privKey <- String -> IO (Either PrivateKeyError PrivateKey)
readWebPushPrivateKey String
privKeyPath
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
    PublicKey
pub <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first PublicKeyError -> VAPIDKeysError
VAPIDKeysPublicKeyError Either PublicKeyError PublicKey
pubKey
    PrivateKey
priv <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first PrivateKeyError -> VAPIDKeysError
VAPIDKeysPrivateKeyError Either PrivateKeyError PrivateKey
privKey
    if PublicKey -> Curve
ECDSA.public_curve PublicKey
pub forall a. Eq a => a -> a -> Bool
/= PrivateKey -> Curve
ECDSA.private_curve PrivateKey
priv
        then  forall a b. a -> Either a b
Left VAPIDKeysError
VAPIDKeysCurveMismatch
        else forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ KeyPair -> VAPIDKeys
VAPIDKeys forall a b. (a -> b) -> a -> b
$ PublicKey -> PrivateKey -> KeyPair
toKeyPair PublicKey
pub PrivateKey
priv

-- | Convert public and private keys to a key pair
toKeyPair :: ECDSA.PublicKey -> ECDSA.PrivateKey -> ECDSA.KeyPair
toKeyPair :: PublicKey -> PrivateKey -> KeyPair
toKeyPair PublicKey
pub PrivateKey
priv = Curve -> PublicPoint -> PrivateNumber -> KeyPair
ECDSA.KeyPair (PublicKey -> Curve
ECDSA.public_curve PublicKey
pub) (PublicKey -> PublicPoint
ECDSA.public_q PublicKey
pub) (PrivateKey -> PrivateNumber
ECDSA.private_d PrivateKey
priv)

-- | Errors from reading the VAPID private key from files
data PrivateKeyError =
    PrivateKeyPEMParseError PEMError -- ^ Error parsing the PEM file
  | PrivateKeyUnknownCurveName -- ^ The curve name is not known
  | PrivateKeyWrongCurve ECC.CurveName -- ^ The curve is not p256
  | PrivateKeyInvalidPEM -- ^ The PEM file is not a single private key
  deriving (Int -> PrivateKeyError -> ShowS
[PrivateKeyError] -> ShowS
PrivateKeyError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PrivateKeyError] -> ShowS
$cshowList :: [PrivateKeyError] -> ShowS
show :: PrivateKeyError -> String
$cshow :: PrivateKeyError -> String
showsPrec :: Int -> PrivateKeyError -> ShowS
$cshowsPrec :: Int -> PrivateKeyError -> ShowS
Show)

-- | Read the private key from a PEM file
--
-- The private key is an ECDSA private number on the p256 curve
readWebPushPrivateKey :: FilePath -> IO (Either PrivateKeyError ECDSA.PrivateKey)
readWebPushPrivateKey :: String -> IO (Either PrivateKeyError PrivateKey)
readWebPushPrivateKey String
fp = do
  Either PrivateKeyError [PrivKey]
keys <- forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO [PrivKey]
readKeyFile String
fp) (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. PEMError -> PrivateKeyError
PrivateKeyPEMParseError)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PrivKeyEC -> Either PrivateKeyError PrivateKey
toECDSAPrivateKey forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [PrivKey] -> Either PrivateKeyError PrivKeyEC
findleSingleKey forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Either PrivateKeyError [PrivKey]
keys
  where
    findleSingleKey :: [PrivKey] -> Either PrivateKeyError PrivKeyEC
findleSingleKey [PrivKeyEC PrivKeyEC
key] = forall a b. b -> Either a b
Right PrivKeyEC
key
    findleSingleKey [PrivKey]
_ = forall a b. a -> Either a b
Left PrivateKeyError
PrivateKeyInvalidPEM
    toECDSAPrivateKey :: PrivKeyEC -> Either PrivateKeyError PrivateKey
toECDSAPrivateKey PrivKeyEC
privKey = do
      CurveName
curveName <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left PrivateKeyError
PrivateKeyUnknownCurveName) forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ PrivKeyEC -> Maybe CurveName
ecPrivKeyCurveName PrivKeyEC
privKey
      case CurveName
curveName of
        CurveName
ECC.SEC_p256r1 -> do
          let curve :: Curve
curve = CurveName -> Curve
ECC.getCurveByName CurveName
curveName
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Curve -> PrivateNumber -> PrivateKey
ECDSA.PrivateKey Curve
curve (PrivKeyEC -> PrivateNumber
privkeyEC_priv PrivKeyEC
privKey)
        CurveName
other -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ CurveName -> PrivateKeyError
PrivateKeyWrongCurve CurveName
other

-- | Errors from reading the VAPID public key from files
data PublicKeyError =
    PublicKeyPEMParseError PEMError -- ^ PEM encoding error
  | PublicKeyASN1Error ASN1Error -- ^ ASN1 decoding error
  | PublicKeyFromASN1Error String -- ^ Error converting ASN1 to ECDSA public key
  | PublicKeyUnsupportedKeyType -- ^ The key type is not supported
  | PublicKeyUnknownCurve -- ^ The curve is not known
  | PublicKeyUnserialiseError -- ^ Error unserialising the EC point
  | PublicKeyInvalidPEM -- ^ The PEM file is not a single public key
  deriving (Int -> PublicKeyError -> ShowS
[PublicKeyError] -> ShowS
PublicKeyError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PublicKeyError] -> ShowS
$cshowList :: [PublicKeyError] -> ShowS
show :: PublicKeyError -> String
$cshow :: PublicKeyError -> String
showsPrec :: Int -> PublicKeyError -> ShowS
$cshowsPrec :: Int -> PublicKeyError -> ShowS
Show)

-- | Read the public key from a PEM file
-- 
-- The public key is an ECDSA public point on the p256 curve
readWebPushPublicKey :: FilePath -> IO (Either PublicKeyError ECDSA.PublicKey)
readWebPushPublicKey :: String -> IO (Either PublicKeyError PublicKey)
readWebPushPublicKey String
fp = do
  ByteString
contents <- String -> IO ByteString
BS.readFile String
fp
  Either PublicKeyError PubKeyEC
pubKey <- forall {m :: * -> *}.
MonadFail m =>
ByteString -> m (Either PublicKeyError PubKeyEC)
parsePEMPubKey ByteString
contents
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PubKeyEC -> Either PublicKeyError PublicKey
toECDSAPubKey forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Either PublicKeyError PubKeyEC
pubKey
  where
    ecPubKey :: PubKey -> Either PublicKeyError PubKeyEC
ecPubKey (PubKeyEC PubKeyEC
pubKey) = forall a b. b -> Either a b
Right PubKeyEC
pubKey
    ecPubKey PubKey
_ = forall a b. a -> Either a b
Left PublicKeyError
PublicKeyUnsupportedKeyType
    parsePEMPubKey :: ByteString -> m (Either PublicKeyError PubKeyEC)
parsePEMPubKey ByteString
str =
      case ByteString -> Either String [PEM]
pemParseBS ByteString
str of
        Left String
err -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
err
        Right [PEM
pem] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
          [ASN1]
as <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ASN1Error -> PublicKeyError
PublicKeyASN1Error forall a b. (a -> b) -> a -> b
$ forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
ASN1.decodeASN1' DER
ASN1.DER forall a b. (a -> b) -> a -> b
$ PEM -> ByteString
pemContent PEM
pem
          (PubKey
key, [ASN1]
_) <- forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first String -> PublicKeyError
PublicKeyFromASN1Error forall a b. (a -> b) -> a -> b
$ forall a. ASN1Object a => [ASN1] -> Either String (a, [ASN1])
ASN1.fromASN1 [ASN1]
as
          PubKey -> Either PublicKeyError PubKeyEC
ecPubKey PubKey
key
        Right [PEM]
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left PublicKeyError
PublicKeyInvalidPEM
    toECDSAPubKey :: PubKeyEC -> Either PublicKeyError PublicKey
toECDSAPubKey PubKeyEC
pubKey = do
      Curve
curve <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left PublicKeyError
PublicKeyUnknownCurve) forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ CurveName -> Curve
ECC.getCurveByName forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PubKeyEC -> Maybe CurveName
ecPubKeyCurveName PubKeyEC
pubKey
      PublicPoint
point <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left PublicKeyError
PublicKeyUnserialiseError) forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Curve -> SerializedPoint -> Maybe PublicPoint
unserializePoint Curve
curve forall a b. (a -> b) -> a -> b
$ PubKeyEC -> SerializedPoint
pubkeyEC_pub PubKeyEC
pubKey
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Curve -> PublicPoint -> PublicKey
ECDSA.PublicKey Curve
curve PublicPoint
point

-- | Write the public and private keys to files
-- NOTE: This will overwrite any existing files and it does not
-- store keys in the exact same format as they were read in from
-- if they were created with OpenSSL
writeVAPIDKeys :: FilePath -> FilePath -> VAPIDKeys -> IO ()
writeVAPIDKeys :: String -> String -> VAPIDKeys -> IO ()
writeVAPIDKeys String
pubKeyPath String
privKeyPath (VAPIDKeys KeyPair
keyPair) = do
  forall {a}. ASN1Object a => String -> String -> a -> IO ()
writeKeyPEM String
pubKeyPath String
"PUBLIC KEY" forall a b. (a -> b) -> a -> b
$ PublicKey -> PubKey
toPubKey forall a b. (a -> b) -> a -> b
$ KeyPair -> PublicKey
ECDSA.toPublicKey KeyPair
keyPair
  forall {a}. ASN1Object a => String -> String -> a -> IO ()
writeKeyPEM String
privKeyPath String
"EC PRIVATE KEY" forall a b. (a -> b) -> a -> b
$ PrivateKey -> PrivKey
toPrivKey forall a b. (a -> b) -> a -> b
$ KeyPair -> PrivateKey
ECDSA.toPrivateKey KeyPair
keyPair
  where
    writeKeyPEM :: String -> String -> a -> IO ()
writeKeyPEM String
path String
name = String -> ByteString -> IO ()
BS.writeFile String
path forall b c a. (b -> c) -> (a -> b) -> a -> c
. PEM -> ByteString
pemWriteBS forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [(String, ByteString)] -> ByteString -> PEM
PEM String
name [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {a}. ASN1Object a => a -> ByteString
encodeASN1
    encodeASN1 :: a -> ByteString
encodeASN1 a
key = forall a. ASN1Encoding a => a -> [ASN1] -> ByteString
ASN1.encodeASN1' DER
ASN1.DER forall a b. (a -> b) -> a -> b
$ forall a. ASN1Object a => a -> ASN1S
ASN1.toASN1 a
key []

    toPubKey :: ECDSA.PublicKey -> PubKey
    toPubKey :: PublicKey -> PubKey
toPubKey = PubKeyEC -> PubKey
PubKeyEC forall b c a. (b -> c) -> (a -> b) -> a -> c
. CurveName -> SerializedPoint -> PubKeyEC
PubKeyEC_Named CurveName
ECC.SEC_p256r1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicPoint -> SerializedPoint
serializePoint forall b c a. (b -> c) -> (a -> b) -> a -> c
. PublicKey -> PublicPoint
ECDSA.public_q

    toPrivKey :: ECDSA.PrivateKey -> PrivKey
    toPrivKey :: PrivateKey -> PrivKey
toPrivKey = PrivKeyEC -> PrivKey
PrivKeyEC forall b c a. (b -> c) -> (a -> b) -> a -> c
. CurveName -> PrivateNumber -> PrivKeyEC
PrivKeyEC_Named CurveName
ECC.SEC_p256r1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrivateKey -> PrivateNumber
ECDSA.private_d


    serializePoint ::  ECC.Point -> SerializedPoint
    serializePoint :: PublicPoint -> SerializedPoint
serializePoint PublicPoint
ECC.PointO = forall a. HasCallStack => String -> a
error String
"can't serialize EC point at infinity"
    serializePoint (ECC.Point PrivateNumber
x PrivateNumber
y) =
      ByteString -> SerializedPoint
SerializedPoint forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
BS.pack [Word8
4] forall a. Semigroup a => a -> a -> a
<> forall ba. ByteArray ba => Int -> PrivateNumber -> ba
Serialize.i2ospOf_ Int
bytes PrivateNumber
x forall a. Semigroup a => a -> a -> a
<> forall ba. ByteArray ba => Int -> PrivateNumber -> ba
Serialize.i2ospOf_ Int
bytes PrivateNumber
y
      where
        bits :: Int
bits  = forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> Int
Crypto.ECC.curveSizeBits (forall {k} (t :: k). Proxy t
Proxy :: Proxy Crypto.ECC.Curve_P256R1)
        bytes :: Int
bytes = (Int
bits forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8

-- | Generate a new VAPID key pair, this is an ECDSA key pair on the p256 curve
--
-- Store them securely and use them across multiple push notification requests.
generateVAPIDKeys :: MonadRandom m => m (Either String VAPIDKeys)
generateVAPIDKeys :: forall (m :: * -> *). MonadRandom m => m (Either String VAPIDKeys)
generateVAPIDKeys = do
  -- SEC_p256r1 is the NIST P-256
  (PublicKey
pubKey, PrivateKey
privKey) <- forall (m :: * -> *).
MonadRandom m =>
Curve -> m (PublicKey, PrivateKey)
ECC.generate forall a b. (a -> b) -> a -> b
$ CurveName -> Curve
ECC.getCurveByName CurveName
ECC.SEC_p256r1
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ case PublicKey -> PublicPoint
ECDSA.public_q PublicKey
pubKey of
    PublicPoint
ECC.PointO -> forall a b. a -> Either a b
Left String
"Invalid public key generated, public_q is the point at infinity"
    ECC.Point PrivateNumber
_ PrivateNumber
_ -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ KeyPair -> VAPIDKeys
VAPIDKeys forall a b. (a -> b) -> a -> b
$ PublicKey -> PrivateKey -> KeyPair
toKeyPair PublicKey
pubKey PrivateKey
privKey 

-- | Pass the VAPID public key bytes as `applicationServerKey` when calling subscribe
-- on the `PushManager` object on a registered service worker
--
-- > applicationServerKey = new Uint8Array( #{toJSON vapidPublicKeyBytes} )
vapidPublicKeyBytes :: ECDSA.PublicKey -> Either String [Word8]
vapidPublicKeyBytes :: PublicKey -> Either String [Word8]
vapidPublicKeyBytes PublicKey
key =
  case PublicKey -> PublicPoint
ECDSA.public_q PublicKey
key of
    PublicPoint
ECC.PointO -> forall a b. a -> Either a b
Left String
"Invalid public key generated, public_q is the point at infinity"
    ECC.Point PrivateNumber
x PrivateNumber
y -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ ByteString -> [Word8]
BS.unpack forall a b. (a -> b) -> a -> b
$ (PrivateNumber, PrivateNumber) -> ByteString
ecPublicKeyToBytes' (PrivateNumber
x, PrivateNumber
y)