module Jose.Jws
( jwkEncode
, hmacEncode
, hmacDecode
, rsaEncode
, rsaDecode
, ecDecode
)
where
import Control.Applicative
import Control.Monad (unless)
import qualified Crypto.PubKey.ECC.ECDSA as ECDSA
import Crypto.PubKey.RSA (PrivateKey(..), PublicKey(..), generateBlinder)
import Crypto.Random (CPRG)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import Jose.Types
import qualified Jose.Internal.Base64 as B64
import Jose.Internal.Crypto
import Jose.Jwa
import Jose.Jwk (Jwk (..))
jwkEncode :: (CPRG g)
=> g
-> JwsAlg
-> Jwk
-> ByteString
-> (Either JwtError ByteString, g)
jwkEncode rng a key payload = case key of
RsaPrivateJwk kPr kid _ _ -> rsaEncodeInternal rng a kPr (sigTarget a kid payload)
SymmetricJwk k kid _ _ -> (hmacEncodeInternal a k (sigTarget a kid payload), rng)
_ -> (Left $ BadAlgorithm "EC signing is not supported", rng)
hmacEncode :: JwsAlg
-> ByteString
-> ByteString
-> Either JwtError ByteString
hmacEncode a key payload = hmacEncodeInternal a key (sigTarget a Nothing payload)
hmacEncodeInternal :: JwsAlg
-> ByteString
-> ByteString
-> Either JwtError ByteString
hmacEncodeInternal a key st = (\mac -> B.concat [st, ".", B64.encode mac]) <$> hmacSign a key st
hmacDecode :: ByteString
-> ByteString
-> Either JwtError Jws
hmacDecode key = decode (`hmacVerify` key)
rsaEncode :: CPRG g
=> g
-> JwsAlg
-> PrivateKey
-> ByteString
-> (Either JwtError ByteString, g)
rsaEncode rng a pk payload = rsaEncodeInternal rng a pk (sigTarget a Nothing payload)
rsaEncodeInternal :: CPRG g
=> g
-> JwsAlg
-> PrivateKey
-> ByteString
-> (Either JwtError ByteString, g)
rsaEncodeInternal rng a pk st = (sign blinder, rng')
where
(blinder, rng') = generateBlinder rng (public_n $ private_pub pk)
sign b = case rsaSign (Just b) a pk st of
Right sig -> Right $ B.concat [st, ".", B64.encode sig]
err -> err
rsaDecode :: PublicKey
-> ByteString
-> Either JwtError Jws
rsaDecode key = decode (`rsaVerify` key)
ecDecode :: ECDSA.PublicKey
-> ByteString
-> Either JwtError Jws
ecDecode key = decode (`ecVerify` key)
sigTarget :: JwsAlg -> Maybe KeyId -> ByteString -> ByteString
sigTarget a kid payload = B.intercalate "." $ map B64.encode [encodeHeader $ defJwsHdr {jwsAlg = a, jwsKid = kid}, payload]
type JwsVerifier = JwsAlg -> ByteString -> ByteString -> Bool
decode :: JwsVerifier -> ByteString -> Either JwtError Jws
decode verify jwt = do
unless (BC.count '.' jwt == 2) $ Left $ BadDots 2
let (hdrPayload, sig) = spanEndDot jwt
sigBytes <- B64.decode sig
[h, payload] <- mapM B64.decode $ BC.split '.' hdrPayload
hdr <- case parseHeader h of
Right (JwsH jwsHdr) -> return jwsHdr
Right (JweH _) -> Left (BadHeader "Header is for a JWE")
Left e -> Left e
if verify (jwsAlg hdr) hdrPayload sigBytes
then Right (hdr, payload)
else Left BadSignature
where
spanEndDot bs = let (toDot, end) = BC.spanEnd (/= '.') bs
in (B.init toDot, end)