-- Copyright (C) 2013, 2014, 2015, 2016, 2017  Fraser Tweedale
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

{-|

JSON Web Token implementation (RFC 7519). A JWT is a JWS
with a payload of /claims/ to be transferred between two
parties.

JWTs use the JWS /compact serialisation/.
See "Crypto.JOSE.Compact" for details.

@
import Crypto.JWT

mkClaims :: IO 'ClaimsSet'
mkClaims = do
  t <- 'currentTime'
  pure $ 'emptyClaimsSet'
    & 'claimIss' ?~ "alice"
    & 'claimAud' ?~ 'Audience' ["bob"]
    & 'claimIat' ?~ 'NumericDate' t

doJwtSign :: 'JWK' -> 'ClaimsSet' -> IO (Either 'JWTError' 'SignedJWT')
doJwtSign jwk claims = 'runJOSE' $ do
  alg \<- 'bestJWSAlg' jwk
  'signClaims' jwk ('newJWSHeader' ((), alg)) claims

doJwtVerify :: 'JWK' -> 'SignedJWT' -> IO (Either 'JWTError' 'ClaimsSet')
doJwtVerify jwk jwt = 'runJOSE' $ do
  let config = 'defaultJWTValidationSettings' (== "bob")
  'verifyClaims' config jwk jwt
@

Some JWT libraries have a function that takes two strings: the
"secret" (a symmetric key) and the raw JWT.  The following function
achieves the same:

@
verify :: L.ByteString -> L.ByteString -> IO (Either 'JWTError' 'ClaimsSet')
verify k s = 'runJOSE' $ do
  let
    k' = 'fromOctets' k      -- turn raw secret into symmetric JWK
    audCheck = const True  -- should be a proper audience check
  jwt <- 'decodeCompact' s    -- decode JWT
  'verifyClaims' ('defaultJWTValidationSettings' audCheck) k' jwt
@

For applications that use __additional claims__, define a data type that wraps
'ClaimsSet' and includes fields for the additional claims.  You will also need
to define 'FromJSON' if verifying JWTs, and 'ToJSON' if producing JWTs.  The
following example is taken from
<https://datatracker.ietf.org/doc/html/rfc7519#section-3.1 RFC 7519 §3.1>.

@
import qualified Data.Aeson.KeyMap as M

data Super = Super { jwtClaims :: 'ClaimsSet', isRoot :: Bool }

instance 'HasClaimsSet' Super where
  'claimsSet' f s = fmap (\\a' -> s { jwtClaims = a' }) (f (jwtClaims s))

instance FromJSON Super where
  parseJSON = withObject "Super" $ \\o -> Super
    \<$\> parseJSON (Object o)
    \<*\> o .: "http://example.com/is_root"

instance ToJSON Super where
  toJSON s =
    ins "http://example.com/is_root" (isRoot s) (toJSON (jwtClaims s))
    where
      ins k v (Object o) = Object $ M.insert k (toJSON v) o
      ins _ _ a = a
@

__Use 'signJWT' and 'verifyJWT' when using custom payload types__ (instead of
'signClaims' and 'verifyClaims' which are specialised to 'ClaimsSet').

-}
module Crypto.JWT
  (
  -- * Creating a JWT
    SignedJWT
  , signClaims
  , signJWT

  -- * Validating a JWT and extracting claims
  , defaultJWTValidationSettings
  , verifyClaims
  , verifyJWT
  , HasAllowedSkew(..)
  , HasAudiencePredicate(..)
  , HasIssuerPredicate(..)
  , HasCheckIssuedAt(..)
  , JWTValidationSettings
  , HasJWTValidationSettings(..)

  -- ** Specifying the verification time
  , WrappedUTCTime(..)
  , verifyClaimsAt
  , verifyJWTAt

  -- * Claims Set
  , HasClaimsSet(..)
  , ClaimsSet
  , emptyClaimsSet
  , addClaim
  , unregisteredClaims
  , validateClaimsSet

  -- * JWT errors
  , JWTError(..)
  , AsJWTError(..)

  -- * Miscellaneous
  , Audience(..)
  , StringOrURI
  , stringOrUri
  , string
  , uri
  , NumericDate(..)

  , module Crypto.JOSE

  ) where

import Control.Applicative
import Control.Monad
import Control.Monad.Time (MonadTime(..))
import Data.Foldable (traverse_)
import Data.Functor.Identity
import Data.Maybe
import qualified Data.String
import Data.Semigroup ((<>))

import Control.Lens (
  makeClassy, makeClassyPrisms, makePrisms,
  Lens', _Just, over, preview, view,
  Prism', prism', Cons, iso, AsEmpty)
import Control.Lens.Cons.Extras (recons)
import Control.Monad.Error.Lens (throwing, throwing_)
import Control.Monad.Except (MonadError)
import Control.Monad.Reader (ReaderT, asks, runReaderT)
import Data.Aeson
import qualified Data.Aeson.Key as Key
import qualified Data.Aeson.KeyMap as KeyMap
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds)
import Network.URI (parseURI)

import Crypto.JOSE
import Crypto.JOSE.Types


data JWTError
  = JWSError Error
  -- ^ A JOSE error occurred while processing the JWT
  | JWTClaimsSetDecodeError String
  -- ^ The JWT payload is not a JWT Claims Set
  | JWTExpired
  | JWTNotYetValid
  | JWTNotInIssuer
  | JWTNotInAudience
  | JWTIssuedAtFuture
  deriving (JWTError -> JWTError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: JWTError -> JWTError -> Bool
$c/= :: JWTError -> JWTError -> Bool
== :: JWTError -> JWTError -> Bool
$c== :: JWTError -> JWTError -> Bool
Eq, Int -> JWTError -> ShowS
[JWTError] -> ShowS
JWTError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [JWTError] -> ShowS
$cshowList :: [JWTError] -> ShowS
show :: JWTError -> String
$cshow :: JWTError -> String
showsPrec :: Int -> JWTError -> ShowS
$cshowsPrec :: Int -> JWTError -> ShowS
Show)
makeClassyPrisms ''JWTError

instance AsError JWTError where
  _Error :: Prism' JWTError Error
_Error = forall r. AsJWTError r => Prism' r Error
_JWSError


-- RFC 7519 §2.  Terminology

-- | A JSON string value, with the additional requirement that while
--   arbitrary string values MAY be used, any value containing a @:@
--   character MUST be a URI.
--
-- __Note__: the 'IsString' instance will fail if the string
-- contains a @:@ but does not parse as a 'URI'.  Use 'stringOrUri'
-- directly in this situation.
--
data StringOrURI = Arbitrary T.Text | OrURI URI deriving (StringOrURI -> StringOrURI -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StringOrURI -> StringOrURI -> Bool
$c/= :: StringOrURI -> StringOrURI -> Bool
== :: StringOrURI -> StringOrURI -> Bool
$c== :: StringOrURI -> StringOrURI -> Bool
Eq, Int -> StringOrURI -> ShowS
[StringOrURI] -> ShowS
StringOrURI -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StringOrURI] -> ShowS
$cshowList :: [StringOrURI] -> ShowS
show :: StringOrURI -> String
$cshow :: StringOrURI -> String
showsPrec :: Int -> StringOrURI -> ShowS
$cshowsPrec :: Int -> StringOrURI -> ShowS
Show)

-- | Non-total.  A string with a @':'@ in it MUST parse as a URI
instance Data.String.IsString StringOrURI where
  fromString :: String -> StringOrURI
fromString = forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
stringOrUri

stringOrUri :: (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
stringOrUri :: forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
stringOrUri = forall s a b t. (s -> a) -> (b -> t) -> Iso s t a b
iso (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s1 a s2.
(Cons s1 s1 a a, Cons s2 s2 a a, AsEmpty s2) =>
Getter s1 s2
recons) (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s1 a s2.
(Cons s1 s1 a a, Cons s2 s2 a a, AsEmpty s2) =>
Getter s1 s2
recons) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b s a. (b -> s) -> (s -> Maybe a) -> Prism s s a b
prism' StringOrURI -> Text
rev Text -> Maybe StringOrURI
fwd
  where
  rev :: StringOrURI -> Text
rev (Arbitrary Text
s) = Text
s
  rev (OrURI URI
x) = String -> Text
T.pack (forall a. Show a => a -> String
show URI
x)
  fwd :: Text -> Maybe StringOrURI
fwd Text
s
    | (Char -> Bool) -> Text -> Bool
T.any (forall a. Eq a => a -> a -> Bool
== Char
':') Text
s = URI -> StringOrURI
OrURI forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Maybe URI
parseURI (Text -> String
T.unpack Text
s)
    | Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> StringOrURI
Arbitrary Text
s)
{-# INLINE stringOrUri #-}

string :: Prism' StringOrURI T.Text
string :: Prism' StringOrURI Text
string = forall b s a. (b -> s) -> (s -> Maybe a) -> Prism s s a b
prism' Text -> StringOrURI
Arbitrary StringOrURI -> Maybe Text
f where
  f :: StringOrURI -> Maybe Text
f (Arbitrary Text
s) = forall a. a -> Maybe a
Just Text
s
  f StringOrURI
_ = forall a. Maybe a
Nothing

uri :: Prism' StringOrURI URI
uri :: Prism' StringOrURI URI
uri = forall b s a. (b -> s) -> (s -> Maybe a) -> Prism s s a b
prism' URI -> StringOrURI
OrURI StringOrURI -> Maybe URI
f where
  f :: StringOrURI -> Maybe URI
f (OrURI URI
s) = forall a. a -> Maybe a
Just URI
s
  f StringOrURI
_ = forall a. Maybe a
Nothing

instance FromJSON StringOrURI where
  parseJSON :: Value -> Parser StringOrURI
parseJSON = forall a. String -> (Text -> Parser a) -> Value -> Parser a
withText String
"StringOrURI"
    (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"failed to parse StringOrURI") forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
stringOrUri)

instance ToJSON StringOrURI where
  toJSON :: StringOrURI -> Value
toJSON (Arbitrary Text
s)  = forall a. ToJSON a => a -> Value
toJSON Text
s
  toJSON (OrURI URI
x)      = forall a. ToJSON a => a -> Value
toJSON forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show URI
x


-- | A JSON numeric value representing the number of seconds from
--   1970-01-01T0:0:0Z UTC until the specified UTC date\/time.
--
newtype NumericDate = NumericDate UTCTime deriving (NumericDate -> NumericDate -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NumericDate -> NumericDate -> Bool
$c/= :: NumericDate -> NumericDate -> Bool
== :: NumericDate -> NumericDate -> Bool
$c== :: NumericDate -> NumericDate -> Bool
Eq, Eq NumericDate
NumericDate -> NumericDate -> Bool
NumericDate -> NumericDate -> Ordering
NumericDate -> NumericDate -> NumericDate
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: NumericDate -> NumericDate -> NumericDate
$cmin :: NumericDate -> NumericDate -> NumericDate
max :: NumericDate -> NumericDate -> NumericDate
$cmax :: NumericDate -> NumericDate -> NumericDate
>= :: NumericDate -> NumericDate -> Bool
$c>= :: NumericDate -> NumericDate -> Bool
> :: NumericDate -> NumericDate -> Bool
$c> :: NumericDate -> NumericDate -> Bool
<= :: NumericDate -> NumericDate -> Bool
$c<= :: NumericDate -> NumericDate -> Bool
< :: NumericDate -> NumericDate -> Bool
$c< :: NumericDate -> NumericDate -> Bool
compare :: NumericDate -> NumericDate -> Ordering
$ccompare :: NumericDate -> NumericDate -> Ordering
Ord, Int -> NumericDate -> ShowS
[NumericDate] -> ShowS
NumericDate -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NumericDate] -> ShowS
$cshowList :: [NumericDate] -> ShowS
show :: NumericDate -> String
$cshow :: NumericDate -> String
showsPrec :: Int -> NumericDate -> ShowS
$cshowsPrec :: Int -> NumericDate -> ShowS
Show)
makePrisms ''NumericDate

instance FromJSON NumericDate where
  parseJSON :: Value -> Parser NumericDate
parseJSON = forall a. String -> (Scientific -> Parser a) -> Value -> Parser a
withScientific String
"NumericDate" forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. UTCTime -> NumericDate
NumericDate forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> UTCTime
posixSecondsToUTCTime forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Real a => a -> Rational
toRational

instance ToJSON NumericDate where
  toJSON :: NumericDate -> Value
toJSON (NumericDate UTCTime
t)
    = Scientific -> Value
Number forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => Rational -> a
fromRational forall a b. (a -> b) -> a -> b
$ forall a. Real a => a -> Rational
toRational forall a b. (a -> b) -> a -> b
$ UTCTime -> NominalDiffTime
utcTimeToPOSIXSeconds UTCTime
t


-- | Audience data.  In the general case, the /aud/ value is an
-- array of case-sensitive strings, each containing a 'StringOrURI'
-- value.  In the special case when the JWT has one audience, the
-- /aud/ value MAY be a single case-sensitive string containing a
-- 'StringOrURI' value.
--
-- The 'ToJSON' instance formats an 'Audience' with one value as a
-- string (some non-compliant implementations require this.)
--
newtype Audience = Audience [StringOrURI] deriving (Audience -> Audience -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Audience -> Audience -> Bool
$c/= :: Audience -> Audience -> Bool
== :: Audience -> Audience -> Bool
$c== :: Audience -> Audience -> Bool
Eq, Int -> Audience -> ShowS
[Audience] -> ShowS
Audience -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Audience] -> ShowS
$cshowList :: [Audience] -> ShowS
show :: Audience -> String
$cshow :: Audience -> String
showsPrec :: Int -> Audience -> ShowS
$cshowsPrec :: Int -> Audience -> ShowS
Show)
makePrisms ''Audience

instance FromJSON Audience where
  parseJSON :: Value -> Parser Audience
parseJSON Value
v = [StringOrURI] -> Audience
Audience forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. FromJSON a => Value -> Parser a
parseJSON Value
v forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. a -> [a] -> [a]
:[]) (forall a. FromJSON a => Value -> Parser a
parseJSON Value
v))

instance ToJSON Audience where
  toJSON :: Audience -> Value
toJSON (Audience [StringOrURI
aud]) = forall a. ToJSON a => a -> Value
toJSON StringOrURI
aud
  toJSON (Audience [StringOrURI]
auds) = forall a. ToJSON a => a -> Value
toJSON [StringOrURI]
auds


-- | The JWT Claims Set represents a JSON object whose members are
-- the registered claims defined by RFC 7519.  To construct a
-- @ClaimsSet@ use 'emptyClaimsSet' then use the lenses from this
-- class to set relevant claims.
--
-- For applications that use additional claims beyond those defined
-- by RFC 7519, define a new data type and instance 'HasClaimsSet'.
-- See the module synopsis for more details and an example.
--
data ClaimsSet = ClaimsSet
  { ClaimsSet -> Maybe StringOrURI
_claimIss :: Maybe StringOrURI
  , ClaimsSet -> Maybe StringOrURI
_claimSub :: Maybe StringOrURI
  , ClaimsSet -> Maybe Audience
_claimAud :: Maybe Audience
  , ClaimsSet -> Maybe NumericDate
_claimExp :: Maybe NumericDate
  , ClaimsSet -> Maybe NumericDate
_claimNbf :: Maybe NumericDate
  , ClaimsSet -> Maybe NumericDate
_claimIat :: Maybe NumericDate
  , ClaimsSet -> Maybe Text
_claimJti :: Maybe T.Text
  , ClaimsSet -> Map Text Value
_unregisteredClaims :: M.Map T.Text Value
  }
  deriving (ClaimsSet -> ClaimsSet -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ClaimsSet -> ClaimsSet -> Bool
$c/= :: ClaimsSet -> ClaimsSet -> Bool
== :: ClaimsSet -> ClaimsSet -> Bool
$c== :: ClaimsSet -> ClaimsSet -> Bool
Eq, Int -> ClaimsSet -> ShowS
[ClaimsSet] -> ShowS
ClaimsSet -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ClaimsSet] -> ShowS
$cshowList :: [ClaimsSet] -> ShowS
show :: ClaimsSet -> String
$cshow :: ClaimsSet -> String
showsPrec :: Int -> ClaimsSet -> ShowS
$cshowsPrec :: Int -> ClaimsSet -> ShowS
Show)

class HasClaimsSet a where
  claimsSet :: Lens' a ClaimsSet

  -- | The issuer claim identifies the principal that issued the
  -- JWT.  The processing of this claim is generally application
  -- specific.
  claimIss :: Lens' a (Maybe StringOrURI)
  {-# INLINE claimIss #-}

  -- | The subject claim identifies the principal that is the
  -- subject of the JWT.  The Claims in a JWT are normally
  -- statements about the subject.  The subject value MAY be scoped
  -- to be locally unique in the context of the issuer or MAY be
  -- globally unique.  The processing of this claim is generally
  -- application specific.
  claimSub :: Lens' a (Maybe StringOrURI)
  {-# INLINE claimSub #-}

  -- | The audience claim identifies the recipients that the JWT is
  -- intended for.  Each principal intended to process the JWT MUST
  -- identify itself with a value in the audience claim.  If the
  -- principal processing the claim does not identify itself with a
  -- value in the /aud/ claim when this claim is present, then the
  -- JWT MUST be rejected.
  claimAud :: Lens' a (Maybe Audience)
  {-# INLINE claimAud #-}

  -- | The expiration time claim identifies the expiration time on
  -- or after which the JWT MUST NOT be accepted for processing.
  -- The processing of /exp/ claim requires that the current
  -- date\/time MUST be before expiration date\/time listed in the
  -- /exp/ claim.  Implementers MAY provide for some small leeway,
  -- usually no more than a few minutes, to account for clock skew.
  claimExp :: Lens' a (Maybe NumericDate)
  {-# INLINE claimExp #-}

  -- | The not before claim identifies the time before which the JWT
  -- MUST NOT be accepted for processing.  The processing of the
  -- /nbf/ claim requires that the current date\/time MUST be after
  -- or equal to the not-before date\/time listed in the /nbf/
  -- claim.  Implementers MAY provide for some small leeway, usually
  -- no more than a few minutes, to account for clock skew.
  claimNbf :: Lens' a (Maybe NumericDate)
  {-# INLINE claimNbf #-}

  -- | The issued at claim identifies the time at which the JWT was
  -- issued.  This claim can be used to determine the age of the
  -- JWT.
  claimIat :: Lens' a (Maybe NumericDate)
  {-# INLINE claimIat #-}

  -- | The JWT ID claim provides a unique identifier for the JWT.
  -- The identifier value MUST be assigned in a manner that ensures
  -- that there is a negligible probability that the same value will
  -- be accidentally assigned to a different data object.  The /jti/
  -- claim can be used to prevent the JWT from being replayed.  The
  -- /jti/ value is a case-sensitive string.
  claimJti :: Lens' a (Maybe T.Text)
  {-# INLINE claimJti #-}

  claimAud = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe Audience)
claimAud
  claimExp = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimExp
  claimIat = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimIat
  claimIss = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
claimIss
  claimJti = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe Text)
claimJti
  claimNbf = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimNbf
  claimSub = (forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet) forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
claimSub

instance HasClaimsSet ClaimsSet where
  claimsSet :: Lens' ClaimsSet ClaimsSet
claimsSet = forall a. a -> a
id

  claimIss :: Lens' ClaimsSet (Maybe StringOrURI)
claimIss Maybe StringOrURI -> f (Maybe StringOrURI)
f h :: ClaimsSet
h@ClaimsSet{ _claimIss :: ClaimsSet -> Maybe StringOrURI
_claimIss = Maybe StringOrURI
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe StringOrURI
a' -> ClaimsSet
h { _claimIss :: Maybe StringOrURI
_claimIss = Maybe StringOrURI
a' }) (Maybe StringOrURI -> f (Maybe StringOrURI)
f Maybe StringOrURI
a)
  {-# INLINE claimIss #-}

  claimSub :: Lens' ClaimsSet (Maybe StringOrURI)
claimSub Maybe StringOrURI -> f (Maybe StringOrURI)
f h :: ClaimsSet
h@ClaimsSet{ _claimSub :: ClaimsSet -> Maybe StringOrURI
_claimSub = Maybe StringOrURI
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe StringOrURI
a' -> ClaimsSet
h { _claimSub :: Maybe StringOrURI
_claimSub = Maybe StringOrURI
a' }) (Maybe StringOrURI -> f (Maybe StringOrURI)
f Maybe StringOrURI
a)
  {-# INLINE claimSub #-}

  claimAud :: Lens' ClaimsSet (Maybe Audience)
claimAud Maybe Audience -> f (Maybe Audience)
f h :: ClaimsSet
h@ClaimsSet{ _claimAud :: ClaimsSet -> Maybe Audience
_claimAud = Maybe Audience
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe Audience
a' -> ClaimsSet
h { _claimAud :: Maybe Audience
_claimAud = Maybe Audience
a' }) (Maybe Audience -> f (Maybe Audience)
f Maybe Audience
a)
  {-# INLINE claimAud #-}

  claimExp :: Lens' ClaimsSet (Maybe NumericDate)
claimExp Maybe NumericDate -> f (Maybe NumericDate)
f h :: ClaimsSet
h@ClaimsSet{ _claimExp :: ClaimsSet -> Maybe NumericDate
_claimExp = Maybe NumericDate
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe NumericDate
a' -> ClaimsSet
h { _claimExp :: Maybe NumericDate
_claimExp = Maybe NumericDate
a' }) (Maybe NumericDate -> f (Maybe NumericDate)
f Maybe NumericDate
a)
  {-# INLINE claimExp #-}

  claimNbf :: Lens' ClaimsSet (Maybe NumericDate)
claimNbf Maybe NumericDate -> f (Maybe NumericDate)
f h :: ClaimsSet
h@ClaimsSet{ _claimNbf :: ClaimsSet -> Maybe NumericDate
_claimNbf = Maybe NumericDate
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe NumericDate
a' -> ClaimsSet
h { _claimNbf :: Maybe NumericDate
_claimNbf = Maybe NumericDate
a' }) (Maybe NumericDate -> f (Maybe NumericDate)
f Maybe NumericDate
a)
  {-# INLINE claimNbf #-}

  claimIat :: Lens' ClaimsSet (Maybe NumericDate)
claimIat Maybe NumericDate -> f (Maybe NumericDate)
f h :: ClaimsSet
h@ClaimsSet{ _claimIat :: ClaimsSet -> Maybe NumericDate
_claimIat = Maybe NumericDate
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe NumericDate
a' -> ClaimsSet
h { _claimIat :: Maybe NumericDate
_claimIat = Maybe NumericDate
a' }) (Maybe NumericDate -> f (Maybe NumericDate)
f Maybe NumericDate
a)
  {-# INLINE claimIat #-}

  claimJti :: Lens' ClaimsSet (Maybe Text)
claimJti Maybe Text -> f (Maybe Text)
f h :: ClaimsSet
h@ClaimsSet{ _claimJti :: ClaimsSet -> Maybe Text
_claimJti = Maybe Text
a} = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Maybe Text
a' -> ClaimsSet
h { _claimJti :: Maybe Text
_claimJti = Maybe Text
a' }) (Maybe Text -> f (Maybe Text)
f Maybe Text
a)
  {-# INLINE claimJti #-}

-- | Claim Names can be defined at will by those using JWTs.
-- Use this lens to access a map non-RFC 7519 claims in the
-- Claims Set object.
unregisteredClaims :: Lens' ClaimsSet (M.Map T.Text Value)
unregisteredClaims :: Lens' ClaimsSet (Map Text Value)
unregisteredClaims Map Text Value -> f (Map Text Value)
f h :: ClaimsSet
h@ClaimsSet{ _unregisteredClaims :: ClaimsSet -> Map Text Value
_unregisteredClaims = Map Text Value
a} =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\Map Text Value
a' -> ClaimsSet
h { _unregisteredClaims :: Map Text Value
_unregisteredClaims = Map Text Value
a' }) (Map Text Value -> f (Map Text Value)
f Map Text Value
a)
{-# INLINE unregisteredClaims #-}
{-# DEPRECATED unregisteredClaims "use a sub-type" #-}

-- | Return an empty claims set.
--
emptyClaimsSet :: ClaimsSet
emptyClaimsSet :: ClaimsSet
emptyClaimsSet = Maybe StringOrURI
-> Maybe StringOrURI
-> Maybe Audience
-> Maybe NumericDate
-> Maybe NumericDate
-> Maybe NumericDate
-> Maybe Text
-> Map Text Value
-> ClaimsSet
ClaimsSet forall a. Maybe a
n forall a. Maybe a
n forall a. Maybe a
n forall a. Maybe a
n forall a. Maybe a
n forall a. Maybe a
n forall a. Maybe a
n forall k a. Map k a
M.empty where n :: Maybe a
n = forall a. Maybe a
Nothing

-- | Add a __non-RFC 7519__ claim.  Use the lenses from the
-- 'HasClaimsSet' class for setting registered claims.
--
addClaim :: T.Text -> Value -> ClaimsSet -> ClaimsSet
addClaim :: Text -> Value -> ClaimsSet -> ClaimsSet
addClaim Text
k Value
v = forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over Lens' ClaimsSet (Map Text Value)
unregisteredClaims (forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Text
k Value
v)
{-# DEPRECATED addClaim "'unregisteredClaims' is deprecated; use a sub-type" #-}

registeredClaims :: S.Set T.Text
registeredClaims :: Set Text
registeredClaims = forall a. [a] -> Set a
S.fromDistinctAscList
  [ Text
"aud"
  , Text
"exp"
  , Text
"iat"
  , Text
"iss"
  , Text
"jti"
  , Text
"nbf"
  , Text
"sub"
  ]

filterUnregistered :: M.Map T.Text Value -> M.Map T.Text Value
filterUnregistered :: Map Text Value -> Map Text Value
filterUnregistered Map Text Value
m =
#if MIN_VERSION_containers(0,5,8)
  Map Text Value
m forall k a. Ord k => Map k a -> Set k -> Map k a
`M.withoutKeys` Set Text
registeredClaims
#else
  m `M.difference` M.fromSet (const ()) registeredClaims
#endif

toKeyMap :: M.Map T.Text Value -> KeyMap.KeyMap Value
toKeyMap :: Map Text Value -> KeyMap Value
toKeyMap = forall v. Map Key v -> KeyMap v
KeyMap.fromMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k1 k2 a. (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeysMonotonic Text -> Key
Key.fromText

fromKeyMap :: KeyMap.KeyMap Value -> M.Map T.Text Value
fromKeyMap :: KeyMap Value -> Map Text Value
fromKeyMap = forall k1 k2 a. (k1 -> k2) -> Map k1 a -> Map k2 a
M.mapKeysMonotonic Key -> Text
Key.toText forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. KeyMap v -> Map Key v
KeyMap.toMap

instance FromJSON ClaimsSet where
  parseJSON :: Value -> Parser ClaimsSet
parseJSON = forall a. String -> (KeyMap Value -> Parser a) -> Value -> Parser a
withObject String
"JWT Claims Set" (\KeyMap Value
o -> Maybe StringOrURI
-> Maybe StringOrURI
-> Maybe Audience
-> Maybe NumericDate
-> Maybe NumericDate
-> Maybe NumericDate
-> Maybe Text
-> Map Text Value
-> ClaimsSet
ClaimsSet
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"iss"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"sub"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"aud"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"exp"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"nbf"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"iat"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> KeyMap Value
o forall a. FromJSON a => KeyMap Value -> Key -> Parser (Maybe a)
.:? Key
"jti"
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map Text Value -> Map Text Value
filterUnregistered forall b c a. (b -> c) -> (a -> b) -> a -> c
. KeyMap Value -> Map Text Value
fromKeyMap forall a b. (a -> b) -> a -> b
$ KeyMap Value
o)
    )

instance ToJSON ClaimsSet where
  toJSON :: ClaimsSet -> Value
toJSON (ClaimsSet Maybe StringOrURI
iss Maybe StringOrURI
sub Maybe Audience
aud Maybe NumericDate
exp' Maybe NumericDate
nbf Maybe NumericDate
iat Maybe Text
jti Map Text Value
o) = KeyMap Value -> Value
Object forall a b. (a -> b) -> a -> b
$
    ( forall v. Map Key v -> KeyMap v
KeyMap.fromMap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. [(k, a)] -> Map k a
M.fromDistinctAscList forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes
      [ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"aud" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe Audience
aud
      , forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"exp" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe NumericDate
exp'
      , forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"iat" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe NumericDate
iat
      , forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"iss" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe StringOrURI
iss
      , forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"jti" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe Text
jti
      , forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"nbf" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe NumericDate
nbf
      , forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Key
"sub" forall kv v. (KeyValue kv, ToJSON v) => Key -> v -> kv
.=) Maybe StringOrURI
sub
      ]
    )
    forall a. Semigroup a => a -> a -> a
<> Map Text Value -> KeyMap Value
toKeyMap (Map Text Value -> Map Text Value
filterUnregistered Map Text Value
o)


data JWTValidationSettings = JWTValidationSettings
  { JWTValidationSettings -> ValidationSettings
_jwtValidationSettingsValidationSettings :: ValidationSettings
  , JWTValidationSettings -> NominalDiffTime
_jwtValidationSettingsAllowedSkew :: NominalDiffTime
  , JWTValidationSettings -> Bool
_jwtValidationSettingsCheckIssuedAt :: Bool
  -- ^ The allowed skew is interpreted in absolute terms;
  --   a nonzero value always expands the validity period.
  , JWTValidationSettings -> StringOrURI -> Bool
_jwtValidationSettingsAudiencePredicate :: StringOrURI -> Bool
  , JWTValidationSettings -> StringOrURI -> Bool
_jwtValidationSettingsIssuerPredicate :: StringOrURI -> Bool
  }
makeClassy ''JWTValidationSettings

instance {-# OVERLAPPABLE #-} HasJWTValidationSettings a => HasValidationSettings a where
  validationSettings :: Lens' a ValidationSettings
validationSettings = forall c. HasJWTValidationSettings c => Lens' c ValidationSettings
jwtValidationSettingsValidationSettings

-- | Maximum allowed skew when validating the /nbf/, /exp/ and /iat/ claims.
class HasAllowedSkew s where
  allowedSkew :: Lens' s NominalDiffTime

-- | Predicate for checking values in the /aud/ claim.
class HasAudiencePredicate s where
  audiencePredicate :: Lens' s (StringOrURI -> Bool)

-- | Predicate for checking the /iss/ claim.
class HasIssuerPredicate s where
  issuerPredicate :: Lens' s (StringOrURI -> Bool)

-- | Whether to check that the /iat/ claim is not in the future.
class HasCheckIssuedAt s where
  checkIssuedAt :: Lens' s Bool

instance HasJWTValidationSettings a => HasAllowedSkew a where
  allowedSkew :: Lens' a NominalDiffTime
allowedSkew = forall c. HasJWTValidationSettings c => Lens' c NominalDiffTime
jwtValidationSettingsAllowedSkew
instance HasJWTValidationSettings a => HasAudiencePredicate a where
  audiencePredicate :: Lens' a (StringOrURI -> Bool)
audiencePredicate = forall c.
HasJWTValidationSettings c =>
Lens' c (StringOrURI -> Bool)
jwtValidationSettingsAudiencePredicate
instance HasJWTValidationSettings a => HasIssuerPredicate a where
  issuerPredicate :: Lens' a (StringOrURI -> Bool)
issuerPredicate = forall c.
HasJWTValidationSettings c =>
Lens' c (StringOrURI -> Bool)
jwtValidationSettingsIssuerPredicate
instance HasJWTValidationSettings a => HasCheckIssuedAt a where
  checkIssuedAt :: Lens' a Bool
checkIssuedAt = forall c. HasJWTValidationSettings c => Lens' c Bool
jwtValidationSettingsCheckIssuedAt

-- | Acquire the default validation settings.
--
-- <https://tools.ietf.org/html/rfc7519#section-4.1.3 RFC 7519 §4.1.3.>
-- states that applications MUST identify itself with a value in the
-- audience claim, therefore a predicate must be supplied.
--
-- The other defaults are:
--
-- - 'defaultValidationSettings' for JWS verification
-- - Zero clock skew tolerance when validating /nbf/, /exp/ and /iat/ claims
-- - /iat/ claim is checked
-- - /issuer/ claim is not checked
--
defaultJWTValidationSettings :: (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings :: (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings StringOrURI -> Bool
p = ValidationSettings
-> NominalDiffTime
-> Bool
-> (StringOrURI -> Bool)
-> (StringOrURI -> Bool)
-> JWTValidationSettings
JWTValidationSettings
  ValidationSettings
defaultValidationSettings
  NominalDiffTime
0
  Bool
True
  StringOrURI -> Bool
p
  (forall a b. a -> b -> a
const Bool
True)

-- | Validate the claims made by a ClaimsSet.
--
-- __You should never need to use this function directly.__
-- These checks are always performed by 'verifyClaims' and 'verifyJWT'.
-- The function is exported mainly for testing purposes.
--
validateClaimsSet
  ::
    ( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , AsJWTError e, MonadError e m
    )
  => a
  -> ClaimsSet
  -> m ClaimsSet
validateClaimsSet :: forall (m :: * -> *) a e.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, AsJWTError e,
 MonadError e m) =>
a -> ClaimsSet -> m ClaimsSet
validateClaimsSet a
conf ClaimsSet
claims =
  ClaimsSet
claims forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ((forall a b. (a -> b) -> a -> b
$ ClaimsSet
claims) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (a -> b) -> a -> b
$ a
conf))
    [ forall (m :: * -> *) a e.
(MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m) =>
a -> ClaimsSet -> m ()
validateExpClaim
    , forall (m :: * -> *) a e.
(MonadTime m, HasCheckIssuedAt a, HasAllowedSkew a, AsJWTError e,
 MonadError e m) =>
a -> ClaimsSet -> m ()
validateIatClaim
    , forall (m :: * -> *) a e.
(MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m) =>
a -> ClaimsSet -> m ()
validateNbfClaim
    , forall s e (m :: * -> *).
(HasIssuerPredicate s, AsJWTError e, MonadError e m) =>
s -> ClaimsSet -> m ()
validateIssClaim
    , forall s e (m :: * -> *).
(HasAudiencePredicate s, AsJWTError e, MonadError e m) =>
s -> ClaimsSet -> m ()
validateAudClaim
    ]

validateExpClaim
  :: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
  => a
  -> ClaimsSet
  -> m ()
validateExpClaim :: forall (m :: * -> *) a e.
(MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m) =>
a -> ClaimsSet -> m ()
validateExpClaim a
conf =
  forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\NumericDate
t -> do
    UTCTime
now <- forall (m :: * -> *). MonadTime m => m UTCTime
currentTime
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (UTCTime
now forall a. Ord a => a -> a -> Bool
< NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (forall a. Num a => a -> a
abs (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s. HasAllowedSkew s => Lens' s NominalDiffTime
allowedSkew a
conf)) (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Iso' NumericDate UTCTime
_NumericDate NumericDate
t)) forall a b. (a -> b) -> a -> b
$
      forall e (m :: * -> *) x. MonadError e m => AReview e () -> m x
throwing_ forall r. AsJWTError r => Prism' r ()
_JWTExpired )
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview (forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Prism (Maybe a) (Maybe b) a b
_Just)

validateIatClaim
  :: (MonadTime m, HasCheckIssuedAt a, HasAllowedSkew a, AsJWTError e, MonadError e m)
  => a
  -> ClaimsSet
  -> m ()
validateIatClaim :: forall (m :: * -> *) a e.
(MonadTime m, HasCheckIssuedAt a, HasAllowedSkew a, AsJWTError e,
 MonadError e m) =>
a -> ClaimsSet -> m ()
validateIatClaim a
conf =
  forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\NumericDate
t -> do
    UTCTime
now <- forall (m :: * -> *). MonadTime m => m UTCTime
currentTime
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s. HasCheckIssuedAt s => Lens' s Bool
checkIssuedAt a
conf) forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Iso' NumericDate UTCTime
_NumericDate NumericDate
t forall a. Ord a => a -> a -> Bool
> NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (forall a. Num a => a -> a
abs (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s. HasAllowedSkew s => Lens' s NominalDiffTime
allowedSkew a
conf)) UTCTime
now) forall a b. (a -> b) -> a -> b
$
        forall e (m :: * -> *) x. MonadError e m => AReview e () -> m x
throwing_ forall r. AsJWTError r => Prism' r ()
_JWTIssuedAtFuture )
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview (forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimIat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Prism (Maybe a) (Maybe b) a b
_Just)

validateNbfClaim
  :: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
  => a
  -> ClaimsSet
  -> m ()
validateNbfClaim :: forall (m :: * -> *) a e.
(MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m) =>
a -> ClaimsSet -> m ()
validateNbfClaim a
conf =
  forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\NumericDate
t -> do
    UTCTime
now <- forall (m :: * -> *). MonadTime m => m UTCTime
currentTime
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (UTCTime
now forall a. Ord a => a -> a -> Bool
>= NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (forall a. Num a => a -> a
negate (forall a. Num a => a -> a
abs (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s. HasAllowedSkew s => Lens' s NominalDiffTime
allowedSkew a
conf))) (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view Iso' NumericDate UTCTime
_NumericDate NumericDate
t)) forall a b. (a -> b) -> a -> b
$
      forall e (m :: * -> *) x. MonadError e m => AReview e () -> m x
throwing_ forall r. AsJWTError r => Prism' r ()
_JWTNotYetValid )
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview (forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
claimNbf forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Prism (Maybe a) (Maybe b) a b
_Just)

validateAudClaim
  :: (HasAudiencePredicate s, AsJWTError e, MonadError e m)
  => s
  -> ClaimsSet
  -> m ()
validateAudClaim :: forall s e (m :: * -> *).
(HasAudiencePredicate s, AsJWTError e, MonadError e m) =>
s -> ClaimsSet -> m ()
validateAudClaim s
conf =
  forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_
    (\[StringOrURI]
auds -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *). Foldable t => t Bool -> Bool
or (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s. HasAudiencePredicate s => Lens' s (StringOrURI -> Bool)
audiencePredicate s
conf forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [StringOrURI]
auds)) forall a b. (a -> b) -> a -> b
$
        forall e (m :: * -> *) x. MonadError e m => AReview e () -> m x
throwing_ forall r. AsJWTError r => Prism' r ()
_JWTNotInAudience )
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview (forall a. HasClaimsSet a => Lens' a (Maybe Audience)
claimAud forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Prism (Maybe a) (Maybe b) a b
_Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. Iso' Audience [StringOrURI]
_Audience)

validateIssClaim
  :: (HasIssuerPredicate s, AsJWTError e, MonadError e m)
  => s
  -> ClaimsSet
  -> m ()
validateIssClaim :: forall s e (m :: * -> *).
(HasIssuerPredicate s, AsJWTError e, MonadError e m) =>
s -> ClaimsSet -> m ()
validateIssClaim s
conf =
  forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (\StringOrURI
iss ->
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall s (m :: * -> *) a. MonadReader s m => Getting a s a -> m a
view forall s. HasIssuerPredicate s => Lens' s (StringOrURI -> Bool)
issuerPredicate s
conf StringOrURI
iss) (forall e (m :: * -> *) x. MonadError e m => AReview e () -> m x
throwing_ forall r. AsJWTError r => Prism' r ()
_JWTNotInIssuer) )
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview (forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
claimIss forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. Prism (Maybe a) (Maybe b) a b
_Just)

-- | A digitally signed or MACed JWT
--
type SignedJWT = CompactJWS JWSHeader


newtype WrappedUTCTime = WrappedUTCTime { WrappedUTCTime -> UTCTime
getUTCTime :: UTCTime }

instance Monad m => MonadTime (ReaderT WrappedUTCTime m) where
  currentTime :: ReaderT WrappedUTCTime m UTCTime
currentTime = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks WrappedUTCTime -> UTCTime
getUTCTime
#if MIN_VERSION_monad_time(0,4,0)
  -- | /jose/ doesn't use this, so we fake it.
  -- @monotonicTime = pure 0@
  monotonicTime :: ReaderT WrappedUTCTime m Double
monotonicTime = forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
0
#endif


-- | Cryptographically verify a JWS JWT, then validate the
-- Claims Set, returning it if valid.  The claims are validated
-- at the current system time.
--
-- This is the only way to get at the claims of a JWS JWT,
-- enforcing that the claims are cryptographically and
-- semantically valid before the application can use them.
--
-- This function is abstracted over any payload type with 'HasClaimsSet' and
-- 'FromJSON' instances.  The 'verifyClaims' variant uses 'ClaimsSet' as the
-- payload type.
--
-- See also 'verifyClaimsAt' which allows you to explicitly specify
-- the time of validation (against which time-related claims will be
-- validated).
--
verifyJWT
  ::
    ( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , HasValidationSettings a
    , AsError e, AsJWTError e, MonadError e m
    , VerificationKeyStore m (JWSHeader ()) payload k
    , HasClaimsSet payload, FromJSON payload
    )
  => a
  -> k
  -> SignedJWT
  -> m payload
verifyJWT :: forall (m :: * -> *) a e payload k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) payload k,
 HasClaimsSet payload, FromJSON payload) =>
a -> k -> SignedJWT -> m payload
verifyJWT a
conf k
k SignedJWT
jws =
  -- It is important, for security reasons, that the signature get
  -- verified before the claims.
  forall a e (m :: * -> *) (h :: * -> *) p payload k s (t :: * -> *).
(HasAlgorithms a, HasValidationPolicy a, AsError e, MonadError e m,
 HasJWSHeader h, HasParams h,
 VerificationKeyStore m (h p) payload k, Cons s s Word8 Word8,
 AsEmpty s, Foldable t, ProtectionIndicator p) =>
(s -> m payload) -> a -> k -> JWS t p h -> m payload
verifyJWSWithPayload ByteString -> m payload
f a
conf k
k SignedJWT
jws forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. HasClaimsSet a => Lens' a ClaimsSet
claimsSet (forall (m :: * -> *) a e.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, AsJWTError e,
 MonadError e m) =>
a -> ClaimsSet -> m ClaimsSet
validateClaimsSet a
conf)
  where
    f :: ByteString -> m payload
f = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall e (m :: * -> *) t x.
MonadError e m =>
AReview e t -> t -> m x
throwing forall r. AsJWTError r => Prism' r String
_JWTClaimsSetDecodeError) forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FromJSON a => ByteString -> Either String a
eitherDecode

-- | Variant of 'verifyJWT' that uses 'ClaimsSet' as the payload type.
--
verifyClaims
  ::
    ( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , HasValidationSettings a
    , AsError e, AsJWTError e, MonadError e m
    , VerificationKeyStore m (JWSHeader ()) ClaimsSet k
    )
  => a
  -> k
  -> SignedJWT
  -> m ClaimsSet
verifyClaims :: forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
verifyClaims = forall (m :: * -> *) a e payload k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) payload k,
 HasClaimsSet payload, FromJSON payload) =>
a -> k -> SignedJWT -> m payload
verifyJWT

-- | Variant of 'verifyJWT' where the validation time is provided by
-- caller.  If you process many tokens per second
-- this lets you avoid unnecessary repeat system calls.
--
verifyJWTAt
  ::
    ( HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , HasValidationSettings a
    , AsError e, AsJWTError e, MonadError e m
    , VerificationKeyStore (ReaderT WrappedUTCTime m) (JWSHeader ()) payload k
    , HasClaimsSet payload, FromJSON payload
    )
  => a
  -> k
  -> UTCTime
  -> SignedJWT
  -> m payload
verifyJWTAt :: forall a e (m :: * -> *) payload k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) payload k,
 HasClaimsSet payload, FromJSON payload) =>
a -> k -> UTCTime -> SignedJWT -> m payload
verifyJWTAt a
a k
k UTCTime
t SignedJWT
jwt = forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall (m :: * -> *) a e payload k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) payload k,
 HasClaimsSet payload, FromJSON payload) =>
a -> k -> SignedJWT -> m payload
verifyJWT a
a k
k SignedJWT
jwt) (UTCTime -> WrappedUTCTime
WrappedUTCTime UTCTime
t)

-- | Variant of 'verifyJWT' that uses 'ClaimsSet' as the payload type and
-- where validation time is provided by caller.
--
verifyClaimsAt
  ::
    ( HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , HasValidationSettings a
    , AsError e, AsJWTError e, MonadError e m
    , VerificationKeyStore (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k
    )
  => a
  -> k
  -> UTCTime
  -> SignedJWT
  -> m ClaimsSet
verifyClaimsAt :: forall a e (m :: * -> *) k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k) =>
a -> k -> UTCTime -> SignedJWT -> m ClaimsSet
verifyClaimsAt = forall a e (m :: * -> *) payload k.
(HasAllowedSkew a, HasAudiencePredicate a, HasIssuerPredicate a,
 HasCheckIssuedAt a, HasValidationSettings a, AsError e,
 AsJWTError e, MonadError e m,
 VerificationKeyStore
   (ReaderT WrappedUTCTime m) (JWSHeader ()) payload k,
 HasClaimsSet payload, FromJSON payload) =>
a -> k -> UTCTime -> SignedJWT -> m payload
verifyJWTAt


-- | Create a JWS JWT.  The payload can be any type with a 'ToJSON'
-- instance.  See also 'signClaims' which uses 'ClaimsSet' as the
-- payload type.
--
-- __Does not set any fields in the Claims Set__, such as @"iat"@
-- ("Issued At") Claim.  The payload is encoded as-is.
--
signJWT
  :: ( MonadRandom m, MonadError e m, AsError e
     , ToJSON payload )
  => JWK
  -> JWSHeader ()
  -> payload
  -> m SignedJWT
signJWT :: forall (m :: * -> *) e payload.
(MonadRandom m, MonadError e m, AsError e, ToJSON payload) =>
JWK -> JWSHeader () -> payload -> m SignedJWT
signJWT JWK
k JWSHeader ()
h payload
c = forall s (a :: * -> *) (m :: * -> *) e (t :: * -> *) p.
(Cons s s Word8 Word8, HasJWSHeader a, HasParams a, MonadRandom m,
 AsError e, MonadError e m, Traversable t, ProtectionIndicator p) =>
s -> t (a p, JWK) -> m (JWS t p a)
signJWS (forall a. ToJSON a => a -> ByteString
encode payload
c) (forall a. a -> Identity a
Identity (JWSHeader ()
h, JWK
k))

-- | Create a JWS JWT.  Specialisation of 'signJWT' with payload type fixed
-- at 'ClaimsSet'.
--
-- __Does not set any fields in the Claims Set__, such as @"iat"@
-- ("Issued At") Claim.  The payload is encoded as-is.
--
signClaims
  :: (MonadRandom m, MonadError e m, AsError e)
  => JWK
  -> JWSHeader ()
  -> ClaimsSet
  -> m SignedJWT
signClaims :: forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
signClaims = forall (m :: * -> *) e payload.
(MonadRandom m, MonadError e m, AsError e, ToJSON payload) =>
JWK -> JWSHeader () -> payload -> m SignedJWT
signJWT