{-# LANGUAGE OverloadedStrings #-}

module Network.DomainAuth.Pubkey.RSAPub (
    lookupPublicKey,
) where

import Crypto.PubKey.RSA (PublicKey)
import Data.ASN1.BinaryEncoding (DER)
import Data.ASN1.Encoding (decodeASN1')
import Data.ASN1.Types (fromASN1)
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS ()
import Data.X509 (PubKey (PubKeyRSA))
import Network.DNS (Domain)
import qualified Network.DNS as DNS
import Network.DomainAuth.Mail
import qualified Network.DomainAuth.Pubkey.Base64 as B

-- $setup
-- >>> import Network.DNS

-- | Looking up an RSA public key
--
-- >>> rs <- DNS.makeResolvSeed DNS.defaultResolvConf
-- >>> withResolver rs $ \rslv -> lookupPublicKey rslv "dk200510._domainkey.yahoo.co.jp"
-- Just (PublicKey {public_size = 128, public_n = 124495277115430906234131617223399742059624761592171426860362133400468320289284068350453787798555522712914036293436636386707903510390018044090096883314714401752103035965668114514933570840775088208966674120428191313530595210688523478828022953238411688594634270571841869051696953556782155414877029327479844990933, public_e = 65537})
-- >>> withResolver rs $ \rslv -> lookupPublicKey rslv "20230601._domainkey.gmail.com"
-- Just (PublicKey {public_size = 256, public_n = 20054049931062868895890884170436368122145070743595938421415808271536128118589158095389269883866014690926251520949836343482211446965168263353397278625494421205505467588876376305465260221818103647257858226961376710643349248303872103127777544119851941320649869060657585270523355729363214754986381410240666592048188131951162530964876952500210032559004364102337827202989395200573305906145708107347940692172630683838117810759589085094521858867092874903269345174914871903592244831151967447426692922405241398232069182007622735165026000699140578092635934951967194944536539675594791745699200646238889064236642593556016708235359, public_e = 65537})
lookupPublicKey :: DNS.Resolver -> Domain -> IO (Maybe PublicKey)
lookupPublicKey :: Resolver -> ByteString -> IO (Maybe PublicKey)
lookupPublicKey Resolver
resolver ByteString
domain = do
    Maybe ByteString
mpub <- Resolver -> ByteString -> IO (Maybe ByteString)
lookupPublicKey' Resolver
resolver ByteString
domain
    Maybe PublicKey -> IO (Maybe PublicKey)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe PublicKey -> IO (Maybe PublicKey))
-> Maybe PublicKey -> IO (Maybe PublicKey)
forall a b. (a -> b) -> a -> b
$ case Maybe ByteString
mpub of
        Maybe ByteString
Nothing -> Maybe PublicKey
forall a. Maybe a
Nothing
        Just ByteString
pub -> ByteString -> Maybe PublicKey
decodeRSAPublicyKey ByteString
pub

lookupPublicKey' :: DNS.Resolver -> Domain -> IO (Maybe ByteString)
lookupPublicKey' :: Resolver -> ByteString -> IO (Maybe ByteString)
lookupPublicKey' Resolver
resolver ByteString
domain = do
    Either DNSError [ByteString]
ex <- Resolver -> ByteString -> IO (Either DNSError [ByteString])
DNS.lookupTXT Resolver
resolver ByteString
domain
    case Either DNSError [ByteString]
ex of
        Left DNSError
_ -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ByteString
forall a. Maybe a
Nothing
        Right [ByteString]
x -> Maybe ByteString -> IO (Maybe ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Maybe ByteString
extractPub [ByteString]
x

extractPub :: [ByteString] -> Maybe ByteString
extractPub :: [ByteString] -> Maybe ByteString
extractPub [ByteString]
xs = case [ByteString]
xs of
    [] -> Maybe ByteString
forall a. Maybe a
Nothing
    (ByteString
x : [ByteString]
_) -> ByteString -> [(ByteString, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup ByteString
"p" (ByteString -> [(ByteString, ByteString)]
parseTaggedValue ByteString
x)

decodeRSAPublicyKey :: ByteString -> Maybe PublicKey
decodeRSAPublicyKey :: ByteString -> Maybe PublicKey
decodeRSAPublicyKey ByteString
b64 = case DER -> ByteString -> Either ASN1Error [ASN1]
forall a.
ASN1Decoding a =>
a -> ByteString -> Either ASN1Error [ASN1]
decodeASN1' (DER
forall a. HasCallStack => a
undefined :: DER) ByteString
der of
    Left ASN1Error
_ -> Maybe PublicKey
forall a. Maybe a
Nothing
    Right [ASN1]
ans1 -> case [ASN1] -> Either String (PubKey, [ASN1])
forall a. ASN1Object a => [ASN1] -> Either String (a, [ASN1])
fromASN1 [ASN1]
ans1 of
        Right (PubKeyRSA PublicKey
p, []) -> PublicKey -> Maybe PublicKey
forall a. a -> Maybe a
Just PublicKey
p
        Either String (PubKey, [ASN1])
_ -> Maybe PublicKey
forall a. Maybe a
Nothing
  where
    der :: ByteString
der = ByteString -> ByteString
B.decode ByteString
b64