module Servant.Auth.Server.Internal.JWT where
import Control.Lens
import Control.Monad (MonadPlus(..), guard)
import Control.Monad.Reader
import qualified Crypto.JOSE as Jose
import qualified Crypto.JWT as Jose
import Data.ByteArray (constEq)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Maybe (fromMaybe)
import Data.Time (UTCTime)
import Network.Wai (requestHeaders)
import Servant.Auth.JWT (FromJWT(..), ToJWT(..))
import Servant.Auth.Server.Internal.ConfigTypes
import Servant.Auth.Server.Internal.Types
jwtAuthCheck :: FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck :: forall usr. FromJWT usr => JWTSettings -> AuthCheck usr
jwtAuthCheck JWTSettings
jwtSettings = do
Request
req <- forall r (m :: * -> *). MonadReader r m => m r
ask
ByteString
token <- forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ do
ByteString
authHdr <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" forall a b. (a -> b) -> a -> b
$ Request -> RequestHeaders
requestHeaders Request
req
let bearer :: ByteString
bearer = ByteString
"Bearer "
(ByteString
mbearer, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
bearer) ByteString
authHdr
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (ByteString
mbearer forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`constEq` ByteString
bearer)
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
Maybe usr
verifiedJWT <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtSettings ByteString
token
case Maybe usr
verifiedJWT of
Maybe usr
Nothing -> forall (m :: * -> *) a. MonadPlus m => m a
mzero
Just usr
v -> forall (m :: * -> *) a. Monad m => a -> m a
return usr
v
makeJWT :: ToJWT a
=> a -> JWTSettings -> Maybe UTCTime -> IO (Either Jose.Error BSL.ByteString)
makeJWT :: forall a.
ToJWT a =>
a -> JWTSettings -> Maybe UTCTime -> IO (Either Error ByteString)
makeJWT a
v JWTSettings
cfg Maybe UTCTime
expiry = forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE forall a b. (a -> b) -> a -> b
$ do
Alg
bestAlg <- forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
Jose.bestJWSAlg forall a b. (a -> b) -> a -> b
$ JWTSettings -> JWK
signingKey JWTSettings
cfg
let alg :: Alg
alg = forall a. a -> Maybe a -> a
fromMaybe Alg
bestAlg forall a b. (a -> b) -> a -> b
$ JWTSettings -> Maybe Alg
jwtAlg JWTSettings
cfg
SignedJWT
ejwt <- forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
Jose.signClaims (JWTSettings -> JWK
signingKey JWTSettings
cfg)
(forall p. (p, Alg) -> JWSHeader p
Jose.newJWSHeader ((), Alg
alg))
(ClaimsSet -> ClaimsSet
addExp forall a b. (a -> b) -> a -> b
$ forall a. ToJWT a => a -> ClaimsSet
encodeJWT a
v)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ToCompact a => a -> ByteString
Jose.encodeCompact SignedJWT
ejwt
where
addExp :: ClaimsSet -> ClaimsSet
addExp ClaimsSet
claims = case Maybe UTCTime
expiry of
Maybe UTCTime
Nothing -> ClaimsSet
claims
Just UTCTime
e -> ClaimsSet
claims forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
Jose.claimExp forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
Jose.NumericDate UTCTime
e
verifyJWT :: FromJWT a => JWTSettings -> BS.ByteString -> IO (Maybe a)
verifyJWT :: forall a. FromJWT a => JWTSettings -> ByteString -> IO (Maybe a)
verifyJWT JWTSettings
jwtCfg ByteString
input = do
JWKSet
keys <- JWTSettings -> IO JWKSet
validationKeys JWTSettings
jwtCfg
Either JWTError ClaimsSet
verifiedJWT <- forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
Jose.runJOSE forall a b. (a -> b) -> a -> b
$ do
SignedJWT
unverifiedJWT <- forall a e (m :: * -> *).
(FromCompact a, AsError e, MonadError e m) =>
ByteString -> m a
Jose.decodeCompact (ByteString -> ByteString
BSL.fromStrict ByteString
input)
forall (m :: * -> *) a e k.
(MonadTime m, HasAllowedSkew a, HasAudiencePredicate a,
HasIssuerPredicate a, HasCheckIssuedAt a, HasValidationSettings a,
AsError e, AsJWTError e, MonadError e m,
VerificationKeyStore m (JWSHeader ()) ClaimsSet k) =>
a -> k -> SignedJWT -> m ClaimsSet
Jose.verifyClaims
(JWTSettings -> JWTValidationSettings
jwtSettingsToJwtValidationSettings JWTSettings
jwtCfg)
JWKSet
keys
SignedJWT
unverifiedJWT
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Either JWTError ClaimsSet
verifiedJWT of
Left (JWTError
_ :: Jose.JWTError) -> forall a. Maybe a
Nothing
Right ClaimsSet
v -> case forall a. FromJWT a => ClaimsSet -> Either Text a
decodeJWT ClaimsSet
v of
Left Text
_ -> forall a. Maybe a
Nothing
Right a
v' -> forall a. a -> Maybe a
Just a
v'