{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}

-- | Server implementation of the 'JWTAuth'' trait.
module WebGear.Server.Trait.Auth.JWT where

import Control.Arrow (arr, returnA, (>>>))
import Control.Monad.Except (MonadError (throwError), runExceptT, withExceptT)
import Control.Monad.Time (MonadTime)
import Control.Monad.Trans (lift)
import qualified Crypto.JWT as JWT
import Data.ByteString.Lazy (fromStrict)
import Data.Void (Void)
import WebGear.Core.Handler (arrM)
import WebGear.Core.Modifiers
import WebGear.Core.Request (Request)
import WebGear.Core.Trait (Get (..), With)
import WebGear.Core.Trait.Auth.Common (
  AuthToken (..),
  AuthorizationHeader,
  getAuthorizationHeaderTrait,
 )
import WebGear.Core.Trait.Auth.JWT (JWTAuth' (..), JWTAuthError (..))
import WebGear.Server.Handler (ServerHandler)

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Required scheme m e a) Request where
  {-# INLINE getTrait #-}
  getTrait ::
    JWTAuth' Required scheme m e a ->
    ServerHandler m (Request `With` ts) (Either (JWTAuthError e) a)
  getTrait :: forall (ts :: [*]).
JWTAuth' 'Required scheme m e a
-> ServerHandler m (With Request ts) (Either (JWTAuthError e) a)
getTrait JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
..} = proc With Request ts
request -> do
    Maybe (Either Text (AuthToken scheme))
result <- forall (scheme :: Symbol) (h :: * -> * -> *) (ts :: [*]).
Get h (AuthorizationHeader scheme) Request =>
h (With Request ts) (Maybe (Either Text (AuthToken scheme)))
getAuthorizationHeaderTrait @scheme -< With Request ts
request
    case Maybe (Either Text (AuthToken scheme))
result of
      Maybe (Either Text (AuthToken scheme))
Nothing -> forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< forall a b. a -> Either a b
Left forall e. JWTAuthError e
JWTAuthHeaderMissing
      (Just (Left Text
_)) -> forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< forall a b. a -> Either a b
Left forall e. JWTAuthError e
JWTAuthSchemeMismatch
      (Just (Right AuthToken scheme
token)) ->
        case AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken scheme
token of
          Left JWTError
e -> forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< forall a b. a -> Either a b
Left (forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat JWTError
e)
          Right SignedJWT
jwt -> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT -< SignedJWT
jwt
    where
      parseJWT :: AuthToken scheme -> Either JWT.JWTError JWT.SignedJWT
      parseJWT :: AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken{ByteString
CI ByteString
authScheme :: forall (scheme :: Symbol). AuthToken scheme -> CI ByteString
authToken :: forall (scheme :: Symbol). AuthToken scheme -> ByteString
authToken :: ByteString
authScheme :: CI ByteString
..} = forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JWT.decodeCompact forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
fromStrict ByteString
authToken

      validateJWT :: ServerHandler m JWT.SignedJWT (Either (JWTAuthError e) a)
      validateJWT :: ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT = forall (h :: * -> * -> *) (m :: * -> *) a b.
Handler h m =>
(a -> m b) -> h a b
arrM forall a b. (a -> b) -> a -> b
$ \SignedJWT
jwt -> forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
        ClaimsSet
claims <- forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
JWT.verifyClaims JWTValidationSettings
jwtValidationSettings JWKSet
jwkSet SignedJWT
jwt
        forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ClaimsSet -> m (Either e a)
toJWTAttribute ClaimsSet
claims) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. e -> JWTAuthError e
JWTAuthAttributeError) forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Optional scheme m e a) Request where
  {-# INLINE getTrait #-}
  getTrait ::
    JWTAuth' Optional scheme m e a ->
    ServerHandler m (Request `With` ts) (Either Void (Either (JWTAuthError e) a))
  getTrait :: forall (ts :: [*]).
JWTAuth' 'Optional scheme m e a
-> ServerHandler
     m (With Request ts) (Either Void (Either (JWTAuthError e) a))
getTrait JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
..} = forall (h :: * -> * -> *) t a (ts :: [*]).
Get h t a =>
t -> h (With a ts) (Either (Absence t a) (Attribute t a))
getTrait (JWTAuth'{JWTValidationSettings
JWKSet
ClaimsSet -> m (Either e a)
toJWTAttribute :: ClaimsSet -> m (Either e a)
jwkSet :: JWKSet
jwtValidationSettings :: JWTValidationSettings
$sel:jwtValidationSettings:JWTAuth' :: JWTValidationSettings
$sel:jwkSet:JWTAuth' :: JWKSet
$sel:toJWTAttribute:JWTAuth' :: ClaimsSet -> m (Either e a)
..} :: JWTAuth' Required scheme m e a) forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr forall a b. b -> Either a b
Right