module Jose.Jwt
( module Jose.Types
, encode
, decode
, decodeClaims
)
where
import Control.Monad.State.Strict
import Control.Monad.Trans.Either
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..))
import Crypto.Random (CPRG)
import Data.Aeson (decodeStrict')
import Data.ByteString (ByteString)
import Data.List (find)
import Data.Maybe (fromJust, isJust, isNothing)
import qualified Data.ByteString.Char8 as BC
import qualified Jose.Internal.Base64 as B64
import Jose.Types
import Jose.Jwk
import Jose.Jwa
import qualified Jose.Jws as Jws
import qualified Jose.Jwe as Jwe
encode :: (CPRG g)
=> g
-> [Jwk]
-> JwtEncoding
-> Payload
-> (Either JwtError Jwt, g)
encode rng jwks encoding msg = flip runState rng $ runEitherT $ case encoding of
JwsEncoding None -> case msg of
Claims p -> return $ Jwt $ BC.intercalate "." [unsecuredHdr, B64.encode p]
Nested _ -> left BadClaims
JwsEncoding a -> case filter (canEncodeJws a) jwks of
[] -> left (KeyError "No matching key found for JWS algorithm")
(k:_) -> hoistEither =<< state (\g -> Jws.jwkEncode g a k msg)
JweEncoding a e -> case filter (canEncodeJwe a) jwks of
[] -> left (KeyError "No matching key found for JWE algorithm")
(k:_) -> hoistEither =<< state (\g -> Jwe.jwkEncode g a e k msg)
where
unsecuredHdr = B64.encode "{\"alg\":\"none\"}"
decode :: CPRG g
=> g
-> [Jwk]
-> Maybe JwtEncoding
-> ByteString
-> (Either JwtError JwtContent, g)
decode rng keySet encoding jwt = flip runState rng $ runEitherT $ do
let components = BC.split '.' jwt
when (length components < 3) $ left $ BadDots 2
hdr <- B64.decode (head components) >>= hoistEither . parseHeader
ks <- findDecodingKeys hdr keySet
decodings <- case hdr of
UnsecuredH -> do
unless (encoding == Just (JwsEncoding None)) $ left (BadAlgorithm "JWT is unsecured but expected 'alg' was not 'none'")
B64.decode (components !! 1) >>= \p -> return [Just (Unsecured p)]
JwsH h -> do
unless (isNothing encoding || encoding == Just (JwsEncoding (jwsAlg h))) $ left (BadAlgorithm "Expected 'alg' doesn't match JWS header")
mapM decodeWithJws ks
JweH h -> do
unless (isNothing encoding || encoding == Just (JweEncoding (jweAlg h) (jweEnc h))) $ left (BadAlgorithm "Expected encoding doesn't match JWE header")
mapM decodeWithJwe ks
maybe (left $ KeyError "None of the keys was able to decode the JWT") (return . fromJust) $ find isJust decodings
where
decodeWithJws :: CPRG g => Jwk -> EitherT JwtError (State g) (Maybe JwtContent)
decodeWithJws k = either (const $ return Nothing) (return . Just . Jws) $ case k of
RsaPublicJwk kPub _ _ _ -> Jws.rsaDecode kPub jwt
RsaPrivateJwk kPr _ _ _ -> Jws.rsaDecode (private_pub kPr) jwt
EcPublicJwk kPub _ _ _ _ -> Jws.ecDecode kPub jwt
EcPrivateJwk kPr _ _ _ _ -> Jws.ecDecode (ECDSA.toPublicKey kPr) jwt
SymmetricJwk kb _ _ _ -> Jws.hmacDecode kb jwt
decodeWithJwe :: CPRG g => Jwk -> EitherT JwtError (State g) (Maybe JwtContent)
decodeWithJwe k = case k of
RsaPrivateJwk kPr _ _ _ -> do
e <- state (\g -> Jwe.rsaDecode g kPr jwt)
either (const $ return Nothing) (return . Just . Jwe) e
_ -> left $ KeyError "Not a JWE key (shouldn't happen)"
decodeClaims :: ByteString
-> Either JwtError (JwtHeader, JwtClaims)
decodeClaims jwt = do
let components = BC.split '.' jwt
when (length components /= 3) $ Left $ BadDots 2
hdr <- B64.decode (head components) >>= parseHeader
claims <- B64.decode ((head . tail) components) >>= parseClaims
return (hdr, claims)
where
parseClaims bs = maybe (Left BadClaims) Right $ decodeStrict' bs
findDecodingKeys :: Monad m => JwtHeader -> [Jwk] -> EitherT JwtError m [Jwk]
findDecodingKeys hdr jwks = case hdr of
JweH h -> checkKeys $ filter (canDecodeJwe h) jwks
JwsH h -> checkKeys $ filter (canDecodeJws h) jwks
UnsecuredH -> return []
where
checkKeys [] = left $ KeyError "No suitable key was found to decode the JWT"
checkKeys ks = return ks