{-# LANGUAGE OverloadedStrings #-}
module JwtAuth where
import Control.Monad ((<=<))
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.Types as Aeson
import Data.Bifunctor (first)
import qualified Data.ByteString as SBS
import qualified Data.Map.Strict as Map
import qualified Data.Text.Encoding as Text
import Data.Time.Clock.POSIX (POSIXTime)
import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT)
import qualified Web.JWT as JWT
import AccessControl
data VerificationError
= TokenUsedTooEarly
| TokenExpired
| TokenInvalid
| TokenNotFound
| TokenSignatureInvalid
deriving (Show, Eq)
verifyToken :: POSIXTime -> JWT.Signer -> SBS.ByteString -> Either VerificationError (JWT VerifiedJWT)
verifyToken now secret = verifyNotBefore now
<=< verifyExpiry now
<=< verifySignature secret
<=< decodeToken
verifyNotBefore :: POSIXTime -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyNotBefore now token =
case JWT.nbf . JWT.claims $ token of
Nothing -> Right token
Just notBefore ->
if now <= JWT.secondsSinceEpoch notBefore
then Left TokenUsedTooEarly
else Right token
verifyExpiry :: POSIXTime -> JWT VerifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifyExpiry now token =
case JWT.exp . JWT.claims $ token of
Nothing -> Right token
Just expiry ->
if now > JWT.secondsSinceEpoch expiry
then Left TokenExpired
else Right token
verifySignature :: JWT.Signer -> JWT UnverifiedJWT -> Either VerificationError (JWT VerifiedJWT)
verifySignature secret token =
case JWT.verify secret token of
Nothing -> Left TokenSignatureInvalid
Just token' -> Right token'
decodeToken :: SBS.ByteString -> Either VerificationError (JWT UnverifiedJWT)
decodeToken bytes =
case JWT.decode (Text.decodeUtf8 bytes) of
Nothing -> Left TokenInvalid
Just token -> Right token
data TokenError
= VerificationError VerificationError
| ClaimError String
deriving (Show, Eq)
extractClaim :: POSIXTime -> JWT.Signer -> SBS.ByteString -> Either TokenError IcepeakClaim
extractClaim now secret tokenBytes = do
jwt <- first VerificationError $ verifyToken now secret tokenBytes
claim <- first ClaimError $ getIcepeakClaim jwt
pure claim
extractClaimUnverified :: SBS.ByteString -> Either TokenError IcepeakClaim
extractClaimUnverified tokenBytes = do
jwt <- first VerificationError $ decodeToken tokenBytes
claim <- first ClaimError $ getIcepeakClaim jwt
pure claim
getIcepeakClaim :: JWT r -> Either String IcepeakClaim
getIcepeakClaim token = do
let (JWT.ClaimsMap claimsMap) = JWT.unregisteredClaims $ JWT.claims token
maybeClaim = Map.lookup "icepeak" claimsMap
claimJson <- maybe (Left "Icepeak claim missing.") Right maybeClaim
Aeson.parseEither Aeson.parseJSON claimJson
addIcepeakClaim :: IcepeakClaim -> JWT.JWTClaimsSet -> JWT.JWTClaimsSet
addIcepeakClaim claim claims = claims
{ JWT.unregisteredClaims = newClaimsMap <> JWT.unregisteredClaims claims }
where
newClaimsMap = JWT.ClaimsMap $ Map.fromList [("icepeak", Aeson.toJSON claim)]