module Jose.Jwk
( KeyType
, KeyUse (..)
, KeyId
, Jwk (..)
, JwkSet (..)
, validateForJws
, findMatchingJwsKeys
, findMatchingJweKeys
)
where
import Control.Applicative (pure)
import Control.Monad (when)
import qualified Crypto.PubKey.RSA as RSA
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import qualified Crypto.Types.PubKey.ECC as ECC
import Crypto.Number.Serialize
import Data.Aeson (genericToJSON, Value(..), FromJSON(..), ToJSON(..), withText)
import Data.Aeson.Types (Parser, Options (..), defaultOptions)
import Data.ByteString (ByteString)
import Data.Text (Text)
import qualified Data.Text.Encoding as TE
import GHC.Generics (Generic)
import qualified Jose.Internal.Base64 as B64
import Jose.Jwa
import Jose.Types (JwtError(..), KeyId, JwsHeader(..), JweHeader(..))
data KeyType = Rsa
| Ec
| Oct
deriving (Eq, Show)
data EcCurve = P_256
| P_384
| P_521
deriving (Eq,Show)
data KeyUse = Sig
| Enc
deriving (Eq,Show)
data Jwk = RsaPublicJwk RSA.PublicKey (Maybe KeyId) (Maybe KeyUse) (Maybe Alg)
| RsaPrivateJwk RSA.PrivateKey (Maybe KeyId) (Maybe KeyUse) (Maybe Alg)
| EcPublicJwk ECDSA.PublicKey (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) EcCurve
| EcPrivateJwk ECDSA.KeyPair (Maybe KeyId) (Maybe KeyUse) (Maybe Alg) EcCurve
| SymmetricJwk ByteString (Maybe KeyId) (Maybe KeyUse) (Maybe Alg)
deriving (Show, Eq)
data JwkSet = JwkSet
{ keys :: [Jwk]
} deriving (Show, Eq, Generic)
validateForJws :: JwsAlg -> Jwk -> Either JwtError ()
validateForJws a jwk = do
when (jwkUse jwk == Just Enc) $ Left (KeyError "JWK is for encryption only")
either (Left . KeyError) (const $ Right ()) $ case a of
HS256 -> mustBeSymmetric
HS384 -> mustBeSymmetric
HS512 -> mustBeSymmetric
RS256 -> mustBeRsa
RS384 -> mustBeRsa
RS512 -> mustBeRsa
ES256 -> mustBeEc
ES384 -> mustBeEc
ES512 -> mustBeEc
None -> Left "JWS with alg 'None' does not require a key"
where
mustBeRsa = case jwk of
RsaPrivateJwk {} -> Right ()
RsaPublicJwk {} -> Right ()
_ -> Left "JWK must be an RSA key"
mustBeSymmetric = case jwk of
SymmetricJwk {} -> Right ()
_ -> Left "JWK must be symmetric"
mustBeEc = case jwk of
EcPrivateJwk {} -> Right ()
EcPublicJwk {} -> Right ()
_ -> Left "JWK must be an EC key"
canDecodeJws :: JwsAlg -> Jwk -> Bool
canDecodeJws al jwk = either (const False) (const True) $ validateForJws al jwk
canDecodeJwe :: JweAlg -> Jwk -> Bool
canDecodeJwe _ jwk = jwkUse jwk /= Just Sig &&
case jwk of
RsaPrivateJwk {} -> True
_ -> False
curve :: EcCurve -> ECC.Curve
curve c = ECC.getCurveByName $ case c of
P_256 -> ECC.SEC_p256r1
P_384 -> ECC.SEC_p384r1
P_521 -> ECC.SEC_p521r1
jwkId :: Jwk -> Maybe KeyId
jwkId key = case key of
RsaPublicJwk _ keyId _ _ -> keyId
RsaPrivateJwk _ keyId _ _ -> keyId
EcPublicJwk _ keyId _ _ _ -> keyId
EcPrivateJwk _ keyId _ _ _ -> keyId
SymmetricJwk _ keyId _ _ -> keyId
jwkUse :: Jwk -> Maybe KeyUse
jwkUse key = case key of
RsaPublicJwk _ _ u _ -> u
RsaPrivateJwk _ _ u _ -> u
EcPublicJwk _ _ u _ _ -> u
EcPrivateJwk _ _ u _ _ -> u
SymmetricJwk _ _ u _ -> u
findKeyById :: [Jwk] -> KeyId -> Maybe Jwk
findKeyById [] _ = Nothing
findKeyById (key:ks) keyId = case jwkId key of
Nothing -> findKeyById ks keyId
Just v -> if v == keyId
then Just key
else findKeyById ks keyId
findMatchingJwsKeys :: [Jwk] -> JwsHeader -> [Jwk]
findMatchingJwsKeys jwks hdr = filter (canDecodeJws (jwsAlg hdr)) $ filterById (jwsKid hdr) jwks
filterById :: Maybe KeyId -> [Jwk] -> [Jwk]
filterById keyId jwks = case keyId of
Just i -> maybe jwks (:[]) $ findKeyById jwks i
Nothing -> jwks
findMatchingJweKeys :: [Jwk] -> JweHeader -> [Jwk]
findMatchingJweKeys jwks hdr = filter (canDecodeJwe (jweAlg hdr)) $ filterById (jweKid hdr) jwks
newtype JwkBytes = JwkBytes {bytes :: ByteString} deriving (Show)
instance FromJSON KeyType where
parseJSON = withText "KeyType" $ \t ->
case t of
"RSA" -> pure Rsa
"EC" -> pure Ec
"oct" -> pure Oct
_ -> fail "unsupported key type"
instance ToJSON KeyType where
toJSON kt = case kt of
Rsa -> String "RSA"
Ec -> String "EC"
Oct -> String "oct"
instance FromJSON KeyUse where
parseJSON = withText "KeyUse" $ \t ->
case t of
"sig" -> pure Sig
"enc" -> pure Enc
_ -> fail "'use' value must be either 'sig' or 'enc'"
instance ToJSON KeyUse where
toJSON ku = case ku of
Sig -> String "sig"
Enc -> String "enc"
instance FromJSON EcCurve where
parseJSON = withText "EcCurve" $ \t ->
case t of
"P-256" -> pure P_256
"P-384" -> pure P_384
"P-521" -> pure P_521
_ -> fail "unsupported 'crv' value"
instance ToJSON EcCurve where
toJSON c = case c of
P_256 -> String "P-256"
P_384 -> String "P-384"
P_521 -> String "P-521"
instance FromJSON JwkBytes where
parseJSON = withText "JwkBytes" $ \t ->
case B64.decode (TE.encodeUtf8 t) of
Left _ -> fail "could not base64 decode bytes"
Right b -> pure $ JwkBytes b
instance ToJSON JwkBytes where
toJSON (JwkBytes b) = String . TE.decodeUtf8 $ B64.encode b
instance FromJSON Jwk where
parseJSON o@(Object _) = do
jwkData <- parseJSON o :: Parser JwkData
case createJwk jwkData of
Left err -> fail err
Right jwk -> return jwk
parseJSON _ = fail "Jwk must be a JSON object"
instance ToJSON Jwk where
toJSON jwk = toJSON $ case jwk of
RsaPublicJwk pubKey mId mUse mAlg ->
createPubData pubKey mId mUse mAlg
RsaPrivateJwk privKey mId mUse mAlg ->
let pubData = createPubData (RSA.private_pub privKey) mId mUse mAlg
in pubData
{ d = Just . JwkBytes . i2osp $ RSA.private_d privKey
, p = i2b $ RSA.private_p privKey
, q = i2b $ RSA.private_q privKey
, dp = i2b $ RSA.private_dP privKey
, dq = i2b $ RSA.private_dQ privKey
, qi = i2b $ RSA.private_qinv privKey
}
SymmetricJwk bs mId mUse mAlg -> defJwk
{ kty = Oct
, k = Just $ JwkBytes bs
, kid = mId
, use = mUse
, alg = mAlg
}
EcPublicJwk pubKey mId mUse mAlg c -> defJwk
{ kty = Ec
, x = fst (ecPoint pubKey)
, y = snd (ecPoint pubKey)
, kid = mId
, use = mUse
, alg = mAlg
, crv = Just c
}
EcPrivateJwk kp mId mUse mAlg c -> defJwk
{ kty = Ec
, x = fst (ecPoint (ECDSA.toPublicKey kp))
, y = snd (ecPoint (ECDSA.toPublicKey kp))
, d = i2b (ECDSA.private_d (ECDSA.toPrivateKey kp))
, kid = mId
, use = mUse
, alg = mAlg
, crv = Just c
}
where
i2b 0 = Nothing
i2b i = Just . JwkBytes . i2osp $ i
ecPoint pk = case ECDSA.public_q pk of
ECC.Point xi yi -> (i2b xi, i2b yi)
_ -> (Nothing, Nothing)
createPubData pubKey mId mUse mAlg = defJwk
{ n = i2b (RSA.public_n pubKey)
, e = i2b (RSA.public_e pubKey)
, kid = mId
, use = mUse
, alg = mAlg
}
instance ToJSON JwkSet
instance FromJSON JwkSet
aesonOptions :: Options
aesonOptions = defaultOptions { omitNothingFields = True }
data JwkData = J
{ kty :: KeyType
, n :: Maybe JwkBytes
, e :: Maybe JwkBytes
, d :: Maybe JwkBytes
, p :: Maybe JwkBytes
, q :: Maybe JwkBytes
, dp :: Maybe JwkBytes
, dq :: Maybe JwkBytes
, qi :: Maybe JwkBytes
, k :: Maybe JwkBytes
, crv :: Maybe EcCurve
, x :: Maybe JwkBytes
, y :: Maybe JwkBytes
, use :: Maybe KeyUse
, alg :: Maybe Alg
, kid :: Maybe Text
, x5u :: Maybe Text
, x5c :: Maybe [Text]
, x5t :: Maybe Text
} deriving (Show, Generic)
instance FromJSON JwkData
instance ToJSON JwkData where
toJSON = genericToJSON aesonOptions
defJwk :: JwkData
defJwk = J
{ kty = Rsa
, n = Nothing
, e = Nothing
, d = Nothing
, p = Nothing
, q = Nothing
, dp = Nothing
, dq = Nothing
, qi = Nothing
, k = Nothing
, crv = Nothing
, x = Nothing
, y = Nothing
, use = Just Sig
, alg = Nothing
, kid = Nothing
, x5u = Nothing
, x5c = Nothing
, x5t = Nothing
}
createJwk :: JwkData -> Either String Jwk
createJwk kd = case kd of
J Rsa (Just nb) (Just eb) Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing u a i _ _ _ ->
return $ RsaPublicJwk (rsaPub nb eb) i u a
J Rsa (Just nb) (Just eb) (Just db) mp mq mdp mdq mqi Nothing Nothing Nothing Nothing u a i _ _ _ ->
return $ RsaPrivateJwk (RSA.PrivateKey (rsaPub nb eb) (os2ip $ bytes db) (os2mip mp) (os2mip mq) (os2mip mdp) (os2mip mdq) (os2mip mqi)) i u a
J Oct Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing (Just kb) Nothing Nothing Nothing u a i Nothing Nothing Nothing ->
return $ SymmetricJwk (bytes kb) i u a
J Ec Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing Nothing (Just crv') (Just xb) (Just yb) u a i Nothing Nothing Nothing ->
return $ EcPublicJwk (ECDSA.PublicKey (curve crv') (ecPoint xb yb)) i u a crv'
J Ec Nothing Nothing (Just db) Nothing Nothing Nothing Nothing Nothing Nothing (Just crv') (Just xb) (Just yb) u a i Nothing Nothing Nothing ->
return $ EcPrivateJwk (ECDSA.KeyPair (curve crv') (ecPoint xb yb) (os2ip (bytes db))) i u a crv'
_ -> Left $ "Invalid key data. Didn't match any known JWK parameter combinations:" ++ show kd
where
rsaPub nb eb = let m = os2ip $ bytes nb
ex = os2ip $ bytes eb
in RSA.PublicKey (rsaSize m 1) m ex
rsaSize m i = if (2 ^ (i * 8)) > m then i else rsaSize m (i+1)
os2mip = maybe 0 (os2ip . bytes)
ecPoint xb yb = ECC.Point (os2ip (bytes xb)) (os2ip (bytes yb))