module Servant.Auth.Server.Internal.JWT where

import           Control.Lens
import           Control.Monad (MonadPlus(..), guard)
import           Control.Monad.Reader
import qualified Crypto.JOSE          as Jose
import qualified Crypto.JWT           as Jose
import           Data.ByteArray       (constEq)
import qualified Data.ByteString      as BS
import qualified Data.ByteString.Lazy as BSL
import           Data.Maybe           (fromMaybe)
import           Data.Time            (UTCTime)
import           Network.Wai          (requestHeaders)

import Servant.Auth.JWT               (FromJWT(..), ToJWT(..))
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.Types


-- | A JWT @AuthCheck@. You likely won't need to use this directly unless you
-- are protecting a @Raw@ endpoint.
jwtAuthCheck :: FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck :: forall usr. FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck JWTSettings
jwtSettings = do
  Request
req <- forall r (m :: * -> *). MonadReader r m => m r
ask
  ByteString
token <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ do
    ByteString
authHdr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req
    let bearer :: ByteString
bearer = ByteString
"Bearer "
        (ByteString
mbearer, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
bearer) ByteString
authHdr
    forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
mbearer forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
bearer)
    forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
  Maybe usr
verifiedJWT <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtSettings ByteString
token
  case Maybe usr
verifiedJWT of
    Maybe usr
Nothing -> forall (m :: * -> *) a. MonadPlus m => m a
mzero
    Just usr
v -> forall (m :: * -> *) a. Monad m => a -> m a
return usr
v

-- | Creates a JWT containing the specified data. The data is stored in the
-- @dat@ claim. The 'Maybe UTCTime' argument indicates the time at which the
-- token expires.
makeJWT :: ToJWT a
  => a -> JWTSettings -> Maybe UTCTime -> IO (Either Jose.Error BSL.ByteString)
makeJWT :: forall a.
ToJWT a =>
a -> JWTSettings -> Maybe UTCTime -> IO (Either Error ByteString)
makeJWT a
v JWTSettings
cfg Maybe UTCTime
expiry = forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE forall a b. (a -> b) -> a -> b
$ do
  Alg
bestAlg <- forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
Jose.bestJWSAlg forall a b. (a -> b) -> a -> b
$ JWTSettings -> JWK
signingKey JWTSettings
cfg
  let alg :: Alg
alg = forall a. a -> Maybe a -> a
fromMaybe Alg
bestAlg forall a b. (a -> b) -> a -> b
$ JWTSettings -> Maybe Alg
jwtAlg JWTSettings
cfg
  SignedJWT
ejwt <- forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
Jose.signClaims (JWTSettings -> JWK
signingKey JWTSettings
cfg)
                          (forall p. (p, Alg) -> JWSHeader p
Jose.newJWSHeader ((), Alg
alg))
                          (ClaimsSet -> ClaimsSet
addExp forall a b. (a -> b) -> a -> b
$ forall a. ToJWT a => a -> ClaimsSet
encodeJWT a
v)

  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ToCompact a => a -> ByteString
Jose.encodeCompact SignedJWT
ejwt
  where
   addExp :: ClaimsSet -> ClaimsSet
addExp ClaimsSet
claims = case Maybe UTCTime
expiry of
     Maybe UTCTime
Nothing -> ClaimsSet
claims
     Just UTCTime
e  -> ClaimsSet
claims forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
Jose.claimExp forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
Jose.NumericDate UTCTime
e


verifyJWT :: FromJWT a => JWTSettings -> BS.ByteString -> IO (Maybe a)
verifyJWT :: forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtCfg ByteString
input = do
  JWKSet
keys <- JWTSettings -> IO JWKSet
validationKeys JWTSettings
jwtCfg
  Either JWTError ClaimsSet
verifiedJWT <- forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE forall a b. (a -> b) -> a -> b
$ do
    SignedJWT
unverifiedJWT <- forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> ByteString
BSL.fromStrict ByteString
input)
    forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
 HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
 AsError e, AsJWTError e, MonadError e m,
 VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
Jose.verifyClaims
      (JWTSettings -> JWTValidationSettings
jwtSettingsToJwtValidationSettings JWTSettings
jwtCfg)
      JWKSet
keys
      SignedJWT
unverifiedJWT
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
verifiedJWT of
    Left (JWTError
_ :: Jose.JWTError) -> forall a. Maybe a
Nothing
    Right ClaimsSet
v -> case forall a. FromJWT a => ClaimsSet -> Either Text a
decodeJWT ClaimsSet
v of
      Left Text
_ -> forall a. Maybe a
Nothing
      Right a
v' -> forall a. a -> Maybe a
Just a
v'