{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MonoLocalBinds #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Crypto.JWT
  (
  
    signClaims
  , SignedJWT
  
  , defaultJWTValidationSettings
  , verifyClaims
  , verifyClaimsAt
  , HasAllowedSkew(..)
  , HasAudiencePredicate(..)
  , HasIssuerPredicate(..)
  , HasCheckIssuedAt(..)
  , JWTValidationSettings
  , HasJWTValidationSettings(..)
  
  , ClaimsSet
  , claimAud
  , claimExp
  , claimIat
  , claimIss
  , claimJti
  , claimNbf
  , claimSub
  , unregisteredClaims
  , addClaim
  , emptyClaimsSet
  , validateClaimsSet
  
  , JWTError(..)
  , AsJWTError(..)
  
  , Audience(..)
  , StringOrURI
  , stringOrUri
  , string
  , uri
  , NumericDate(..)
  , module Crypto.JOSE
  ) where
import Control.Applicative
import Control.Monad
import Control.Monad.Time (MonadTime(..))
#if ! MIN_VERSION_monad_time(0,2,0)
import Control.Monad.Time.Instances ()
#endif
import Data.Foldable (traverse_)
import Data.Functor.Identity
import Data.Maybe
import qualified Data.String
import Control.Lens (
  makeClassy, makeClassyPrisms, makePrisms,
  Lens', _Just, over, preview, view,
  Prism', prism', Cons, iso, AsEmpty)
import Control.Lens.Cons.Extras (recons)
import Control.Monad.Error.Lens (throwing, throwing_)
import Control.Monad.Except (MonadError)
import Control.Monad.Reader (ReaderT, ask, runReaderT)
import Data.Aeson
import qualified Data.HashMap.Strict as M
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime)
import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds)
import Network.URI (parseURI)
import Crypto.JOSE
import Crypto.JOSE.Types
data JWTError
  = JWSError Error
  
  | JWTClaimsSetDecodeError String
  
  | JWTExpired
  | JWTNotYetValid
  | JWTNotInIssuer
  | JWTNotInAudience
  | JWTIssuedAtFuture
  deriving (Eq, Show)
makeClassyPrisms ''JWTError
instance AsError JWTError where
  _Error = _JWSError
data StringOrURI = Arbitrary T.Text | OrURI URI deriving (Eq, Show)
instance Data.String.IsString StringOrURI where
  fromString = fromJust . preview stringOrUri
stringOrUri :: (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
stringOrUri = iso (view recons) (view recons) . prism' rev fwd
  where
  rev (Arbitrary s) = s
  rev (OrURI x) = T.pack (show x)
  fwd s
    | T.any (== ':') s = OrURI <$> parseURI (T.unpack s)
    | otherwise = pure (Arbitrary s)
string :: Prism' StringOrURI T.Text
string = prism' Arbitrary f where
  f (Arbitrary s) = Just s
  f _ = Nothing
uri :: Prism' StringOrURI URI
uri = prism' OrURI f where
  f (OrURI s) = Just s
  f _ = Nothing
instance FromJSON StringOrURI where
  parseJSON = withText "StringOrURI"
    (maybe (fail "failed to parse StringOrURI") pure . preview stringOrUri)
instance ToJSON StringOrURI where
  toJSON (Arbitrary s)  = toJSON s
  toJSON (OrURI x)      = toJSON $ show x
newtype NumericDate = NumericDate UTCTime deriving (Eq, Ord, Show)
makePrisms ''NumericDate
instance FromJSON NumericDate where
  parseJSON = withScientific "NumericDate" $
    pure . NumericDate . posixSecondsToUTCTime . fromRational . toRational
instance ToJSON NumericDate where
  toJSON (NumericDate t)
    = Number $ fromRational $ toRational $ utcTimeToPOSIXSeconds t
newtype Audience = Audience [StringOrURI] deriving (Eq, Show)
makePrisms ''Audience
instance FromJSON Audience where
  parseJSON v = Audience <$> (parseJSON v <|> fmap (:[]) (parseJSON v))
instance ToJSON Audience where
  toJSON (Audience [aud]) = toJSON aud
  toJSON (Audience auds) = toJSON auds
data ClaimsSet = ClaimsSet
  { _claimIss :: Maybe StringOrURI
  , _claimSub :: Maybe StringOrURI
  , _claimAud :: Maybe Audience
  , _claimExp :: Maybe NumericDate
  , _claimNbf :: Maybe NumericDate
  , _claimIat :: Maybe NumericDate
  , _claimJti :: Maybe T.Text
  , _unregisteredClaims :: M.HashMap T.Text Value
  }
  deriving (Eq, Show)
claimIss :: Lens' ClaimsSet (Maybe StringOrURI)
claimIss f h@ClaimsSet{ _claimIss = a} =
  fmap (\a' -> h { _claimIss = a' }) (f a)
claimSub :: Lens' ClaimsSet (Maybe StringOrURI)
claimSub f h@ClaimsSet{ _claimSub = a} =
  fmap (\a' -> h { _claimSub = a' }) (f a)
claimAud :: Lens' ClaimsSet (Maybe Audience)
claimAud f h@ClaimsSet{ _claimAud = a} =
  fmap (\a' -> h { _claimAud = a' }) (f a)
claimExp :: Lens' ClaimsSet (Maybe NumericDate)
claimExp f h@ClaimsSet{ _claimExp = a} =
  fmap (\a' -> h { _claimExp = a' }) (f a)
claimNbf :: Lens' ClaimsSet (Maybe NumericDate)
claimNbf f h@ClaimsSet{ _claimNbf = a} =
  fmap (\a' -> h { _claimNbf = a' }) (f a)
claimIat :: Lens' ClaimsSet (Maybe NumericDate)
claimIat f h@ClaimsSet{ _claimIat = a} =
  fmap (\a' -> h { _claimIat = a' }) (f a)
claimJti :: Lens' ClaimsSet (Maybe T.Text)
claimJti f h@ClaimsSet{ _claimJti = a} =
  fmap (\a' -> h { _claimJti = a' }) (f a)
unregisteredClaims :: Lens' ClaimsSet (M.HashMap T.Text Value)
unregisteredClaims f h@ClaimsSet{ _unregisteredClaims = a} =
  fmap (\a' -> h { _unregisteredClaims = a' }) (f a)
emptyClaimsSet :: ClaimsSet
emptyClaimsSet = ClaimsSet n n n n n n n M.empty where n = Nothing
addClaim :: T.Text -> Value -> ClaimsSet -> ClaimsSet
addClaim k v = over unregisteredClaims (M.insert k v)
filterUnregistered :: M.HashMap T.Text Value -> M.HashMap T.Text Value
filterUnregistered = M.filterWithKey (\k _ -> k `notElem` registered) where
  registered = ["iss", "sub", "aud", "exp", "nbf", "iat", "jti"]
instance FromJSON ClaimsSet where
  parseJSON = withObject "JWT Claims Set" (\o -> ClaimsSet
    <$> o .:? "iss"
    <*> o .:? "sub"
    <*> o .:? "aud"
    <*> o .:? "exp"
    <*> o .:? "nbf"
    <*> o .:? "iat"
    <*> o .:? "jti"
    <*> pure (filterUnregistered o))
instance ToJSON ClaimsSet where
  toJSON (ClaimsSet iss sub aud exp' nbf iat jti o) = object $ catMaybes [
    fmap ("iss" .=) iss
    , fmap ("sub" .=) sub
    , fmap ("aud" .=) aud
    , fmap ("exp" .=) exp'
    , fmap ("nbf" .=) nbf
    , fmap ("iat" .=) iat
    , fmap ("jti" .=) jti
    ] ++ M.toList (filterUnregistered o)
data JWTValidationSettings = JWTValidationSettings
  { _jwtValidationSettingsValidationSettings :: ValidationSettings
  , _jwtValidationSettingsAllowedSkew :: NominalDiffTime
  , _jwtValidationSettingsCheckIssuedAt :: Bool
  
  
  , _jwtValidationSettingsAudiencePredicate :: StringOrURI -> Bool
  , _jwtValidationSettingsIssuerPredicate :: StringOrURI -> Bool
  }
makeClassy ''JWTValidationSettings
instance HasJWTValidationSettings a => HasValidationSettings a where
  validationSettings = jwtValidationSettingsValidationSettings
class HasAllowedSkew s where
  allowedSkew :: Lens' s NominalDiffTime
class HasAudiencePredicate s where
  audiencePredicate :: Lens' s (StringOrURI -> Bool)
class HasIssuerPredicate s where
  issuerPredicate :: Lens' s (StringOrURI -> Bool)
class HasCheckIssuedAt s where
  checkIssuedAt :: Lens' s Bool
instance HasJWTValidationSettings a => HasAllowedSkew a where
  allowedSkew = jwtValidationSettingsAllowedSkew
instance HasJWTValidationSettings a => HasAudiencePredicate a where
  audiencePredicate = jwtValidationSettingsAudiencePredicate
instance HasJWTValidationSettings a => HasIssuerPredicate a where
  issuerPredicate = jwtValidationSettingsIssuerPredicate
instance HasJWTValidationSettings a => HasCheckIssuedAt a where
  checkIssuedAt = jwtValidationSettingsCheckIssuedAt
defaultJWTValidationSettings :: (StringOrURI -> Bool) -> JWTValidationSettings
defaultJWTValidationSettings p = JWTValidationSettings
  defaultValidationSettings
  0
  True
  p
  (const True)
validateClaimsSet
  ::
    ( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , AsJWTError e, MonadError e m
    )
  => a
  -> ClaimsSet
  -> m ClaimsSet
validateClaimsSet conf claims =
  traverse_ (($ claims) . ($ conf))
    [ validateExpClaim
    , validateIatClaim
    , validateNbfClaim
    , validateIssClaim
    , validateAudClaim
    ]
  *> pure claims
validateExpClaim
  :: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
  => a
  -> ClaimsSet
  -> m ()
validateExpClaim conf =
  traverse_ (\t -> do
    now <- currentTime
    unless (now < addUTCTime (abs (view allowedSkew conf)) (view _NumericDate t)) $
      throwing_ _JWTExpired )
  . preview (claimExp . _Just)
validateIatClaim
  :: (MonadTime m, HasCheckIssuedAt a, HasAllowedSkew a, AsJWTError e, MonadError e m)
  => a
  -> ClaimsSet
  -> m ()
validateIatClaim conf =
  traverse_ (\t -> do
    now <- currentTime
    when (view checkIssuedAt conf) $
      when (view _NumericDate t > addUTCTime (abs (view allowedSkew conf)) now) $
        throwing_ _JWTIssuedAtFuture )
    . preview (claimIat . _Just)
validateNbfClaim
  :: (MonadTime m, HasAllowedSkew a, AsJWTError e, MonadError e m)
  => a
  -> ClaimsSet
  -> m ()
validateNbfClaim conf =
  traverse_ (\t -> do
    now <- currentTime
    unless (now >= addUTCTime (negate (abs (view allowedSkew conf))) (view _NumericDate t)) $
      throwing_ _JWTNotYetValid )
  . preview (claimNbf . _Just)
validateAudClaim
  :: (HasAudiencePredicate s, AsJWTError e, MonadError e m)
  => s
  -> ClaimsSet
  -> m ()
validateAudClaim conf =
  traverse_
    (\auds -> unless (or (view audiencePredicate conf <$> auds)) $
        throwing_ _JWTNotInAudience )
  . preview (claimAud . _Just . _Audience)
validateIssClaim
  :: (HasIssuerPredicate s, AsJWTError e, MonadError e m)
  => s
  -> ClaimsSet
  -> m ()
validateIssClaim conf =
  traverse_ (\iss ->
    unless (view issuerPredicate conf iss) (throwing_ _JWTNotInIssuer) )
  . preview (claimIss . _Just)
type SignedJWT = CompactJWS JWSHeader
newtype WrappedUTCTime = WrappedUTCTime { getUTCTime :: UTCTime }
instance Monad m => MonadTime (ReaderT WrappedUTCTime m) where
  currentTime = getUTCTime <$> ask
verifyClaims
  ::
    ( MonadTime m, HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , HasValidationSettings a
    , AsError e, AsJWTError e, MonadError e m
    , VerificationKeyStore m (JWSHeader ()) ClaimsSet k
    )
  => a
  -> k
  -> SignedJWT
  -> m ClaimsSet
verifyClaims conf k jws =
  
  
  verifyJWSWithPayload f conf k jws >>= validateClaimsSet conf
  where
    f = either (throwing _JWTClaimsSetDecodeError) pure . eitherDecode
verifyClaimsAt
  ::
    ( HasAllowedSkew a, HasAudiencePredicate a
    , HasIssuerPredicate a
    , HasCheckIssuedAt a
    , HasValidationSettings a
    , AsError e, AsJWTError e, MonadError e m
    , VerificationKeyStore (ReaderT WrappedUTCTime m) (JWSHeader ()) ClaimsSet k
    )
  => a
  -> k
  -> UTCTime
  -> SignedJWT
  -> m ClaimsSet
verifyClaimsAt a k t jwt = runReaderT (verifyClaims a k jwt) (WrappedUTCTime t)
signClaims
  :: (MonadRandom m, MonadError e m, AsError e)
  => JWK
  -> JWSHeader ()
  -> ClaimsSet
  -> m SignedJWT
signClaims k h c = signJWS (encode c) (Identity (h, k))