module Web.JWT
(
decode
, decodeAndVerifySignature
, encodeSigned
, encodeUnsigned
, tokenIssuer
, secret
, claims
, header
, signature
, module Data.Default
, UnverifiedJWT
, VerifiedJWT
, Signature
, Secret
, JWT
, JSON
, Algorithm(..)
, JWTClaimsSet(..)
#ifdef TEST
, IntDate(..)
, JWTHeader(..)
, base64Encode
, base64Decode
#endif
) where
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as BL (fromStrict, toStrict)
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Control.Applicative
import Control.Monad
import qualified Crypto.Hash.SHA256 as SHA
import qualified Crypto.MAC.HMAC as HMAC
import Data.Aeson hiding (decode, encode)
import qualified Data.Aeson as JSON
import qualified Data.ByteString.Base64.URL as BASE64
import Data.Default
import qualified Data.HashMap.Strict as StrictMap
import qualified Data.Map as Map
import Data.Maybe
import Data.Scientific
import Prelude hiding (exp)
type JSON = T.Text
newtype Secret = Secret T.Text deriving (Eq, Show)
newtype Signature = Signature T.Text deriving (Eq, Show)
data UnverifiedJWT
data VerifiedJWT
data JWT r where
Unverified :: JWTHeader -> JWTClaimsSet -> JWT UnverifiedJWT
Verified :: JWTHeader -> JWTClaimsSet -> Signature -> JWT VerifiedJWT
deriving instance Show (JWT r)
claims :: JWT r -> JWTClaimsSet
claims (Unverified _ c) = c
claims (Verified _ c _) = c
header :: JWT r -> JWTHeader
header (Unverified h _) = h
header (Verified h _ _) = h
signature :: JWT r -> Maybe Signature
signature (Unverified _ _) = Nothing
signature (Verified _ _ s) = Just s
newtype IntDate = IntDate Integer deriving (Eq, Show)
data Algorithm = HS256
deriving (Eq, Show)
data JWTHeader = JWTHeader {
typ :: Maybe T.Text
, cty :: Maybe T.Text
, alg :: Maybe Algorithm
} deriving (Eq, Show)
instance Default JWTHeader where
def = JWTHeader Nothing Nothing Nothing
data JWTClaimsSet = JWTClaimsSet {
iss :: Maybe T.Text
, sub :: Maybe T.Text
, aud :: Maybe T.Text
, exp :: Maybe IntDate
, nbf :: Maybe IntDate
, iat :: Maybe IntDate
, jti :: Maybe T.Text
, unregisteredClaims :: ClaimsMap
} deriving (Eq, Show)
instance Default JWTClaimsSet where
def = JWTClaimsSet Nothing Nothing Nothing Nothing Nothing Nothing Nothing Map.empty
encodeSigned :: Algorithm -> Secret -> JWTClaimsSet -> JSON
encodeSigned algo secret claims = dotted [header, claim, signature]
where claim = encodeJWT claims
header = encodeJWT def {
typ = Just "JWT"
, alg = Just algo
}
signature = calculateDigest algo secret (dotted [header, claim])
encodeUnsigned :: JWTClaimsSet -> JSON
encodeUnsigned claims = dotted [header, claim, ""]
where claim = encodeJWT claims
header = encodeJWT def {
typ = Just "JWT"
, alg = Just HS256
}
decode :: JSON -> Maybe (JWT UnverifiedJWT)
decode input = do
(h,c) <- extractElems $ T.splitOn "." input
let header' = parseJWT h
claims' = parseJWT c
Unverified <$> header' <*> claims'
where
extractElems (h:c:_) = Just (h,c)
extractElems _ = Nothing
decodeAndVerifySignature :: Secret -> T.Text -> Maybe (JWT VerifiedJWT)
decodeAndVerifySignature secret' input = do
(h,c,s) <- extractElems $ T.splitOn "." input
header' <- parseJWT h
claims' <- parseJWT c
algo <- fmap alg header'
let sign = if Just s == calculateMessageDigest h c algo then pure $ Signature s else mzero
Verified <$> header' <*> claims' <*> sign
where
calculateMessageDigest header' claims' (Just algo') = Just $ calculateDigest algo' secret' (dotted [header', claims'])
calculateMessageDigest _ _ Nothing = Nothing
extractElems (h:c:s:_) = Just (h,c,s)
extractElems _ = Nothing
tokenIssuer :: JSON -> Maybe T.Text
tokenIssuer = decode >=> fmap pure claims >=> iss
secret :: T.Text -> Secret
secret = Secret
encodeJWT :: ToJSON a => a -> T.Text
encodeJWT = base64Encode . TE.decodeUtf8 . BL.toStrict . JSON.encode
parseJWT :: FromJSON a => T.Text -> Maybe a
parseJWT = JSON.decode . BL.fromStrict . TE.encodeUtf8 . base64Decode
dotted :: [T.Text] -> T.Text
dotted = T.intercalate "."
base64Decode :: T.Text -> T.Text
base64Decode = operateOnText BASE64.decodeLenient
base64Encode :: T.Text -> T.Text
base64Encode = removePaddingBase64Encoding . operateOnText BASE64.encode
operateOnText :: (B.ByteString -> B.ByteString) -> T.Text -> T.Text
operateOnText f = TE.decodeUtf8 . f . TE.encodeUtf8
removePaddingBase64Encoding :: T.Text -> T.Text
removePaddingBase64Encoding = T.dropWhileEnd (=='=')
calculateDigest :: Algorithm -> Secret -> T.Text -> T.Text
calculateDigest _ (Secret key) msg = base64Encode' $ HMAC.hmac SHA.hash 64 (bs key) (bs msg)
where bs = TE.encodeUtf8
base64Encode' = removePaddingBase64Encoding . TE.decodeUtf8 . BASE64.encode
type ClaimsMap = Map.Map T.Text Value
fromHashMap :: Object -> ClaimsMap
fromHashMap = Map.fromList . StrictMap.toList
removeRegisteredClaims :: ClaimsMap -> ClaimsMap
removeRegisteredClaims input = Map.differenceWithKey (\_ _ _ -> Nothing) input registeredClaims
where registeredClaims = Map.fromList $ map (\e -> (e, Null)) ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"]
instance ToJSON JWTClaimsSet where
toJSON JWTClaimsSet{..} = object $ catMaybes [
fmap ("iss" .=) iss
, fmap ("sub" .=) sub
, fmap ("aud" .=) aud
, fmap ("exp" .=) exp
, fmap ("nbf" .=) nbf
, fmap ("iat" .=) iat
, fmap ("jti" .=) jti
] ++ Map.toList (removeRegisteredClaims unregisteredClaims)
instance FromJSON JWTClaimsSet where
parseJSON = withObject "JWTClaimsSet"
(\o -> JWTClaimsSet
<$> o .:? "iss"
<*> o .:? "sub"
<*> o .:? "aud"
<*> o .:? "exp"
<*> o .:? "nbf"
<*> o .:? "iat"
<*> o .:? "jti"
<*> pure (removeRegisteredClaims $ fromHashMap o))
instance FromJSON JWTHeader where
parseJSON = withObject "JWTHeader"
(\o -> JWTHeader
<$> o .:? "typ"
<*> o .:? "cty"
<*> o .:? "alg")
instance ToJSON JWTHeader where
toJSON JWTHeader{..} = object $ catMaybes [
fmap ("typ" .=) typ
, fmap ("cty" .=) cty
, fmap ("alg" .=) alg
]
instance ToJSON IntDate where
toJSON (IntDate ts) = Number $ scientific (fromIntegral ts) 0
instance FromJSON IntDate where
parseJSON (Number x) = return $ IntDate $ coefficient x
parseJSON _ = mzero
instance ToJSON Algorithm where
toJSON HS256 = String ("HS256"::T.Text)
instance FromJSON Algorithm where
parseJSON (String "HS256") = return HS256
parseJSON _ = mzero