{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-orphans #-}

-- | Server implementation of the 'JWTAuth'' trait.
module WebGear.Server.Trait.Auth.JWT where

import Control.Arrow (arr, returnA, (>>>))
import Control.Monad.Except (MonadError (throwError), runExceptT, withExceptT)
import Control.Monad.Time (MonadTime)
import Control.Monad.Trans (lift)
import qualified Crypto.JWT as JWT
import Data.ByteString.Lazy (fromStrict)
import Data.Void (Void)
import WebGear.Core.Handler (arrM)
import WebGear.Core.Modifiers
import WebGear.Core.Request (Request)
import WebGear.Core.Trait (Get (..), HasTrait, With, from, pick)
import WebGear.Core.Trait.Auth.Common (
  AuthToken (..),
  AuthorizationHeader,
 )
import WebGear.Core.Trait.Auth.JWT (JWTAuth' (..), JWTAuthError (..))
import WebGear.Server.Handler (ServerHandler)

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Required scheme m e a) Request where
  {-# INLINE getTrait #-}
  getTrait ::
    (HasTrait (AuthorizationHeader scheme) ts) =>
    JWTAuth' Required scheme m e a ->
    ServerHandler m (Request `With` ts) (Either (JWTAuthError e) a)
  getTrait :: forall (ts :: [*]).
HasTrait (AuthorizationHeader scheme) ts =>
JWTAuth' 'Required scheme m e a
-> ServerHandler m (With Request ts) (Either (JWTAuthError e) a)
getTrait JWTAuth'{JWKSet
JWTValidationSettings
ClaimsSet -> m (Either e a)
jwtValidationSettings :: JWTValidationSettings
jwkSet :: JWKSet
toJWTAttribute :: ClaimsSet -> m (Either e a)
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
..} = proc With Request ts
request -> do
    let result :: Maybe (Either Text (AuthToken scheme))
result = forall {k} (t :: k) a. Tagged t a -> a
forall t a. Tagged t a -> a
pick @(AuthorizationHeader scheme) (Tagged
   (AuthorizationHeader scheme)
   (Maybe (Either Text (AuthToken scheme)))
 -> Maybe (Either Text (AuthToken scheme)))
-> Tagged
     (AuthorizationHeader scheme)
     (Maybe (Either Text (AuthToken scheme)))
-> Maybe (Either Text (AuthToken scheme))
forall a b. (a -> b) -> a -> b
$ With Request ts
-> Tagged
     (AuthorizationHeader scheme)
     (Attribute (AuthorizationHeader scheme) Request)
forall a.
With a ts
-> Tagged
     (AuthorizationHeader scheme)
     (Attribute (AuthorizationHeader scheme) a)
forall t (ts :: [*]) a.
HasTrait t ts =>
With a ts -> Tagged t (Attribute t a)
from With Request ts
request
    case Maybe (Either Text (AuthToken scheme))
result of
      Maybe (Either Text (AuthToken scheme))
Nothing -> ServerHandler
  m (Either (JWTAuthError e) a) (Either (JWTAuthError e) a)
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< JWTAuthError e -> Either (JWTAuthError e) a
forall a b. a -> Either a b
Left JWTAuthError e
forall e. JWTAuthError e
JWTAuthHeaderMissing
      (Just (Left Text
_)) -> ServerHandler
  m (Either (JWTAuthError e) a) (Either (JWTAuthError e) a)
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< JWTAuthError e -> Either (JWTAuthError e) a
forall a b. a -> Either a b
Left JWTAuthError e
forall e. JWTAuthError e
JWTAuthSchemeMismatch
      (Just (Right AuthToken scheme
token)) ->
        case AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken scheme
token of
          Left JWTError
e -> ServerHandler
  m (Either (JWTAuthError e) a) (Either (JWTAuthError e) a)
forall (a :: * -> * -> *) b. Arrow a => a b b
returnA -< JWTAuthError e -> Either (JWTAuthError e) a
forall a b. a -> Either a b
Left (JWTError -> JWTAuthError e
forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat JWTError
e)
          Right SignedJWT
jwt -> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT -< SignedJWT
jwt
    where
      parseJWT :: AuthToken scheme -> Either JWT.JWTError JWT.SignedJWT
      parseJWT :: AuthToken scheme -> Either JWTError SignedJWT
parseJWT AuthToken{ByteString
CI ByteString
authScheme :: CI ByteString
authToken :: ByteString
authScheme :: forall (scheme :: Symbol). AuthToken scheme -> CI ByteString
authToken :: forall (scheme :: Symbol). AuthToken scheme -> ByteString
..} = ByteString -> Either JWTError SignedJWT
forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
JWT.decodeCompact (ByteString -> Either JWTError SignedJWT)
-> ByteString -> Either JWTError SignedJWT
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
fromStrict ByteString
authToken

      validateJWT :: ServerHandler m JWT.SignedJWT (Either (JWTAuthError e) a)
      validateJWT :: ServerHandler m SignedJWT (Either (JWTAuthError e) a)
validateJWT = (SignedJWT -> m (Either (JWTAuthError e) a))
-> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
forall a b. (a -> m b) -> ServerHandler m a b
forall (h :: * -> * -> *) (m :: * -> *) a b.
Handler h m =>
(a -> m b) -> h a b
arrM ((SignedJWT -> m (Either (JWTAuthError e) a))
 -> ServerHandler m SignedJWT (Either (JWTAuthError e) a))
-> (SignedJWT -> m (Either (JWTAuthError e) a))
-> ServerHandler m SignedJWT (Either (JWTAuthError e) a)
forall a b. (a -> b) -> a -> b
$ \SignedJWT
jwt -> ExceptT (JWTAuthError e) m a -> m (Either (JWTAuthError e) a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT (JWTAuthError e) m a -> m (Either (JWTAuthError e) a))
-> ExceptT (JWTAuthError e) m a -> m (Either (JWTAuthError e) a)
forall a b. (a -> b) -> a -> b
$ do
        ClaimsSet
claims <- (JWTError -> JWTAuthError e)
-> ExceptT JWTError m ClaimsSet
-> ExceptT (JWTAuthError e) m ClaimsSet
forall (m :: * -> *) e e' a.
Functor m =>
(e -> e') -> ExceptT e m a -> ExceptT e' m a
withExceptT JWTError -> JWTAuthError e
forall e. JWTError -> JWTAuthError e
JWTAuthTokenBadFormat (ExceptT JWTError m ClaimsSet
 -> ExceptT (JWTAuthError e) m ClaimsSet)
-> ExceptT JWTError m ClaimsSet
-> ExceptT (JWTAuthError e) m ClaimsSet
forall a b. (a -> b) -> a -> b
$ JWTValidationSettings
-> JWKSet -> SignedJWT -> ExceptT JWTError m ClaimsSet
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
JWT.verifyClaims JWTValidationSettings
jwtValidationSettings JWKSet
jwkSet SignedJWT
jwt
        m (Either e a) -> ExceptT (JWTAuthError e) m (Either e a)
forall (m :: * -> *) a.
Monad m =>
m a -> ExceptT (JWTAuthError e) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ClaimsSet -> m (Either e a)
toJWTAttribute ClaimsSet
claims) ExceptT (JWTAuthError e) m (Either e a)
-> (Either e a -> ExceptT (JWTAuthError e) m a)
-> ExceptT (JWTAuthError e) m a
forall a b.
ExceptT (JWTAuthError e) m a
-> (a -> ExceptT (JWTAuthError e) m b)
-> ExceptT (JWTAuthError e) m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (e -> ExceptT (JWTAuthError e) m a)
-> (a -> ExceptT (JWTAuthError e) m a)
-> Either e a
-> ExceptT (JWTAuthError e) m a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (JWTAuthError e -> ExceptT (JWTAuthError e) m a
forall a. JWTAuthError e -> ExceptT (JWTAuthError e) m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (JWTAuthError e -> ExceptT (JWTAuthError e) m a)
-> (e -> JWTAuthError e) -> e -> ExceptT (JWTAuthError e) m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> JWTAuthError e
forall e. e -> JWTAuthError e
JWTAuthAttributeError) a -> ExceptT (JWTAuthError e) m a
forall a. a -> ExceptT (JWTAuthError e) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance (MonadTime m, Get (ServerHandler m) (AuthorizationHeader scheme) Request) => Get (ServerHandler m) (JWTAuth' Optional scheme m e a) Request where
  {-# INLINE getTrait #-}
  getTrait ::
    (HasTrait (AuthorizationHeader scheme) ts) =>
    JWTAuth' Optional scheme m e a ->
    ServerHandler m (Request `With` ts) (Either Void (Either (JWTAuthError e) a))
  getTrait :: forall (ts :: [*]).
HasTrait (AuthorizationHeader scheme) ts =>
JWTAuth' 'Optional scheme m e a
-> ServerHandler
     m (With Request ts) (Either Void (Either (JWTAuthError e) a))
getTrait JWTAuth'{JWKSet
JWTValidationSettings
ClaimsSet -> m (Either e a)
$sel:jwtValidationSettings:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWTValidationSettings
$sel:jwkSet:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> JWKSet
$sel:toJWTAttribute:JWTAuth' :: forall (x :: Existence) (scheme :: Symbol) (m :: * -> *) e a.
JWTAuth' x scheme m e a -> ClaimsSet -> m (Either e a)
jwtValidationSettings :: JWTValidationSettings
jwkSet :: JWKSet
toJWTAttribute :: ClaimsSet -> m (Either e a)
..} = JWTAuth' 'Required scheme m e a
-> ServerHandler
     m
     (With Request ts)
     (Either
        (Absence (JWTAuth' 'Required scheme m e a) Request)
        (Attribute (JWTAuth' 'Required scheme m e a) Request))
forall (ts :: [*]).
Prerequisite (JWTAuth' 'Required scheme m e a) ts Request =>
JWTAuth' 'Required scheme m e a
-> ServerHandler
     m
     (With Request ts)
     (Either
        (Absence (JWTAuth' 'Required scheme m e a) Request)
        (Attribute (JWTAuth' 'Required scheme m e a) Request))
forall (h :: * -> * -> *) t a (ts :: [*]).
(Get h t a, Prerequisite t ts a) =>
t -> h (With a ts) (Either (Absence t a) (Attribute t a))
getTrait (JWTAuth'{JWKSet
JWTValidationSettings
ClaimsSet -> m (Either e a)
$sel:jwtValidationSettings:JWTAuth' :: JWTValidationSettings
$sel:jwkSet:JWTAuth' :: JWKSet
$sel:toJWTAttribute:JWTAuth' :: ClaimsSet -> m (Either e a)
jwtValidationSettings :: JWTValidationSettings
jwkSet :: JWKSet
toJWTAttribute :: ClaimsSet -> m (Either e a)
..} :: JWTAuth' Required scheme m e a) ServerHandler m (With Request ts) (Either (JWTAuthError e) a)
-> ServerHandler
     m
     (Either (JWTAuthError e) a)
     (Either Void (Either (JWTAuthError e) a))
-> ServerHandler
     m (With Request ts) (Either Void (Either (JWTAuthError e) a))
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (Either (JWTAuthError e) a
 -> Either Void (Either (JWTAuthError e) a))
-> ServerHandler
     m
     (Either (JWTAuthError e) a)
     (Either Void (Either (JWTAuthError e) a))
forall b c. (b -> c) -> ServerHandler m b c
forall (a :: * -> * -> *) b c. Arrow a => (b -> c) -> a b c
arr Either (JWTAuthError e) a
-> Either Void (Either (JWTAuthError e) a)
forall a b. b -> Either a b
Right