{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_HADDOCK prune #-}
module Jose.Internal.Parser
( parseJwt
, DecodableJwt (..)
, EncryptedCEK (..)
, Payload (..)
, IV (..)
, Tag (..)
, AAD (..)
, Sig (..)
, SigTarget (..)
)
where
import Data.Bifunctor (first)
import Data.Aeson (eitherDecodeStrict')
import Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as P
import qualified Data.Attoparsec.ByteString.Char8 as PC
import Data.ByteArray.Encoding (convertFromBase, Base(..))
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Jose.Jwa
import Jose.Types (JwtError(..), JwtHeader(..), JwsHeader(..), JweHeader(..))
data DecodableJwt
= Unsecured ByteString
| DecodableJws JwsHeader Payload Sig SigTarget
| DecodableJwe JweHeader EncryptedCEK IV Payload Tag AAD
data Tag
= Tag16 ByteString
| Tag24 ByteString
| Tag32 ByteString
data IV
= IV12 ByteString
| IV16 ByteString
newtype Sig = Sig ByteString
newtype SigTarget = SigTarget ByteString
newtype AAD = AAD ByteString
newtype Payload = Payload ByteString
newtype EncryptedCEK = EncryptedCEK ByteString
parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt :: ByteString -> Either JwtError DecodableJwt
parseJwt ByteString
bs = (String -> JwtError)
-> Either String DecodableJwt -> Either JwtError DecodableJwt
forall a b c. (a -> b) -> Either a c -> Either b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (JwtError -> String -> JwtError
forall a b. a -> b -> a
const JwtError
BadCrypto) (Either String DecodableJwt -> Either JwtError DecodableJwt)
-> Either String DecodableJwt -> Either JwtError DecodableJwt
forall a b. (a -> b) -> a -> b
$ Parser DecodableJwt -> ByteString -> Either String DecodableJwt
forall a. Parser a -> ByteString -> Either String a
P.parseOnly Parser DecodableJwt
jwt ByteString
bs
jwt :: Parser DecodableJwt
jwt :: Parser DecodableJwt
jwt = do
(JwtHeader
hdr, ByteString
raw) <- Parser (JwtHeader, ByteString)
jwtHeader
case JwtHeader
hdr of
JwtHeader
UnsecuredH -> ByteString -> DecodableJwt
Unsecured (ByteString -> DecodableJwt)
-> Parser ByteString ByteString -> Parser DecodableJwt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
base64Chunk
JwsH JwsHeader
h -> do
ByteString
payloadB64 <- (Char -> Bool) -> Parser ByteString ByteString
PC.takeWhile (Char
'.' Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/=) Parser ByteString ByteString
-> Parser ByteString Char -> Parser ByteString ByteString
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
PC.char Char
'.'
ByteString
payload <- ByteString -> Parser ByteString ByteString
b64Decode ByteString
payloadB64
Sig
s <- JwsAlg -> Parser Sig
sig (JwsHeader -> JwsAlg
jwsAlg JwsHeader
h)
DecodableJwt -> Parser DecodableJwt
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DecodableJwt -> Parser DecodableJwt)
-> DecodableJwt -> Parser DecodableJwt
forall a b. (a -> b) -> a -> b
$ JwsHeader -> Payload -> Sig -> SigTarget -> DecodableJwt
DecodableJws JwsHeader
h (ByteString -> Payload
Payload ByteString
payload) Sig
s (ByteString -> SigTarget
SigTarget ([ByteString] -> ByteString
B.concat [ByteString
raw, ByteString
".", ByteString
payloadB64]))
JweH JweHeader
h ->
JweHeader
-> EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt
DecodableJwe
(JweHeader
-> EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString JweHeader
-> Parser
ByteString
(EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> JweHeader -> Parser ByteString JweHeader
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure JweHeader
h
Parser
ByteString
(EncryptedCEK -> IV -> Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString EncryptedCEK
-> Parser ByteString (IV -> Payload -> Tag -> AAD -> DecodableJwt)
forall a b.
Parser ByteString (a -> b)
-> Parser ByteString a -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString EncryptedCEK
encryptedCEK
Parser ByteString (IV -> Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString IV
-> Parser ByteString (Payload -> Tag -> AAD -> DecodableJwt)
forall a b.
Parser ByteString (a -> b)
-> Parser ByteString a -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser ByteString IV
iv (JweHeader -> Enc
jweEnc JweHeader
h)
Parser ByteString (Payload -> Tag -> AAD -> DecodableJwt)
-> Parser ByteString Payload
-> Parser ByteString (Tag -> AAD -> DecodableJwt)
forall a b.
Parser ByteString (a -> b)
-> Parser ByteString a -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Payload
encryptedPayload
Parser ByteString (Tag -> AAD -> DecodableJwt)
-> Parser ByteString Tag -> Parser ByteString (AAD -> DecodableJwt)
forall a b.
Parser ByteString (a -> b)
-> Parser ByteString a -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Enc -> Parser ByteString Tag
authTag (JweHeader -> Enc
jweEnc JweHeader
h)
Parser ByteString (AAD -> DecodableJwt)
-> Parser ByteString AAD -> Parser DecodableJwt
forall a b.
Parser ByteString (a -> b)
-> Parser ByteString a -> Parser ByteString b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> AAD -> Parser ByteString AAD
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> AAD
AAD ByteString
raw)
sig :: JwsAlg -> Parser Sig
sig :: JwsAlg -> Parser Sig
sig JwsAlg
_ = do
ByteString
t <- Parser ByteString ByteString
P.takeByteString Parser ByteString ByteString
-> (ByteString -> Parser ByteString ByteString)
-> Parser ByteString ByteString
forall a b.
Parser ByteString a
-> (a -> Parser ByteString b) -> Parser ByteString b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString ByteString
b64Decode
Sig -> Parser Sig
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Sig
Sig ByteString
t)
authTag :: Enc -> Parser Tag
authTag :: Enc -> Parser ByteString Tag
authTag Enc
e = do
ByteString
t <- Parser ByteString ByteString
P.takeByteString Parser ByteString ByteString
-> (ByteString -> Parser ByteString ByteString)
-> Parser ByteString ByteString
forall a b.
Parser ByteString a
-> (a -> Parser ByteString b) -> Parser ByteString b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Parser ByteString ByteString
b64Decode
case Enc
e of
Enc
A128GCM -> ByteString -> Parser ByteString Tag
forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A192GCM -> ByteString -> Parser ByteString Tag
forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A256GCM -> ByteString -> Parser ByteString Tag
forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A128CBC_HS256 -> ByteString -> Parser ByteString Tag
forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag16 ByteString
t
Enc
A192CBC_HS384 -> ByteString -> Parser ByteString Tag
forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag24 ByteString
t
Enc
A256CBC_HS512 -> ByteString -> Parser ByteString Tag
forall {m :: * -> *}. MonadFail m => ByteString -> m Tag
tag32 ByteString
t
where
badTag :: String
badTag = String
"invalid auth tag"
tag16 :: ByteString -> m Tag
tag16 ByteString
t = if ByteString -> Int
B.length ByteString
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16 then String -> m Tag
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else Tag -> m Tag
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag16 ByteString
t)
tag24 :: ByteString -> m Tag
tag24 ByteString
t = if ByteString -> Int
B.length ByteString
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
24 then String -> m Tag
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else Tag -> m Tag
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag24 ByteString
t)
tag32 :: ByteString -> m Tag
tag32 ByteString
t = if ByteString -> Int
B.length ByteString
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 then String -> m Tag
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
badTag else Tag -> m Tag
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Tag
Tag32 ByteString
t)
iv :: Enc -> Parser IV
iv :: Enc -> Parser ByteString IV
iv Enc
e = do
ByteString
bs <- Parser ByteString ByteString
base64Chunk
case Enc
e of
Enc
A128GCM -> ByteString -> Parser ByteString IV
forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
Enc
A192GCM -> ByteString -> Parser ByteString IV
forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
Enc
A256GCM -> ByteString -> Parser ByteString IV
forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv12 ByteString
bs
Enc
_ -> ByteString -> Parser ByteString IV
forall {m :: * -> *}. MonadFail m => ByteString -> m IV
iv16 ByteString
bs
where
iv12 :: ByteString -> m IV
iv12 ByteString
bs = if ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 then String -> m IV
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else IV -> m IV
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV12 ByteString
bs)
iv16 :: ByteString -> m IV
iv16 ByteString
bs = if ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16 then String -> m IV
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"invalid iv" else IV -> m IV
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IV
IV16 ByteString
bs)
encryptedCEK :: Parser EncryptedCEK
encryptedCEK :: Parser ByteString EncryptedCEK
encryptedCEK = ByteString -> EncryptedCEK
EncryptedCEK (ByteString -> EncryptedCEK)
-> Parser ByteString ByteString -> Parser ByteString EncryptedCEK
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
base64Chunk
encryptedPayload :: Parser Payload
encryptedPayload :: Parser ByteString Payload
encryptedPayload = ByteString -> Payload
Payload (ByteString -> Payload)
-> Parser ByteString ByteString -> Parser ByteString Payload
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString ByteString
base64Chunk
jwtHeader :: P.Parser (JwtHeader, ByteString)
= do
ByteString
hdrB64 <- (Char -> Bool) -> Parser ByteString ByteString
PC.takeWhile (Char
'.' Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/=) Parser ByteString ByteString
-> Parser ByteString Char -> Parser ByteString ByteString
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
PC.char Char
'.'
ByteString
hdrBytes <- ByteString -> Parser ByteString ByteString
b64Decode ByteString
hdrB64 :: P.Parser ByteString
JwtHeader
hdr <- ByteString -> Parser ByteString JwtHeader
forall {m :: * -> *} {a}.
(MonadFail m, FromJSON a) =>
ByteString -> m a
parseHdr ByteString
hdrBytes
(JwtHeader, ByteString) -> Parser (JwtHeader, ByteString)
forall a. a -> Parser ByteString a
forall (m :: * -> *) a. Monad m => a -> m a
return (JwtHeader
hdr, ByteString
hdrB64)
where
parseHdr :: ByteString -> m a
parseHdr ByteString
bs = (String -> m a) -> (a -> m a) -> Either String a -> m a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> Either String a
forall a. FromJSON a => ByteString -> Either String a
eitherDecodeStrict' ByteString
bs)
base64Chunk :: P.Parser ByteString
base64Chunk :: Parser ByteString ByteString
base64Chunk = do
ByteString
bs <- (Char -> Bool) -> Parser ByteString ByteString
PC.takeWhile (Char
'.' Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/=) Parser ByteString ByteString
-> Parser ByteString Char -> Parser ByteString ByteString
forall a b.
Parser ByteString a -> Parser ByteString b -> Parser ByteString a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
PC.char Char
'.'
ByteString -> Parser ByteString ByteString
b64Decode ByteString
bs
b64Decode :: ByteString -> P.Parser ByteString
b64Decode :: ByteString -> Parser ByteString ByteString
b64Decode ByteString
bs = (String -> Parser ByteString ByteString)
-> (ByteString -> Parser ByteString ByteString)
-> Either String ByteString
-> Parser ByteString ByteString
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Parser ByteString ByteString
-> String -> Parser ByteString ByteString
forall a b. a -> b -> a
const (String -> Parser ByteString ByteString
forall a. String -> Parser ByteString a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Invalid Base64")) ByteString -> Parser ByteString ByteString
forall a. a -> Parser ByteString a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either String ByteString -> Parser ByteString ByteString)
-> Either String ByteString -> Parser ByteString ByteString
forall a b. (a -> b) -> a -> b
$ Base -> ByteString -> Either String ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> Either String bout
convertFromBase Base
Base64URLUnpadded ByteString
bs