{-# LANGUAGE NoImplicitPrelude   #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |Contains JWT authentication settings
module Chakra.JWT where

import           Control.Lens         (preview)
import           Control.Monad.Except (Monad (return))
import           Crypto.JOSE          (JWK, JWKSet (..))
import           Crypto.JWT           (StringOrURI, string, uri)
import qualified Data.Aeson           as Aeson
import           Data.ByteString      (readFile)
import qualified Data.Text            as T
import           Network.URI          (parseURI)
import           RIO                  (Eq ((==)), IO, Maybe (Just, Nothing),
                                       Semigroup ((<>)), const, either, error,
                                       fromMaybe, id, maybe, ($), (++))
import           Servant.Auth.Server  (IsMatch (..), JWTSettings (..),
                                       generateKey)
import           System.Environment   (lookupEnv)

-- |Build JWT settings to be used in Servant Auth context
-- Looks for `JWK_AUDIENCES` and `JWK_PATH` in environment values
-- to load the sig file and value to verify the incoming jwt audience claim
getJWTAuthSettings :: IO JWTSettings
getJWTAuthSettings :: IO JWTSettings
getJWTAuthSettings = do
  JWKSet
jwkSet <- IO JWKSet
acquireJwks
  JWK
signKey <- IO JWK
generateKey
  Maybe String
audienceCfg <- String -> IO (Maybe String)
lookupEnv String
"JWK_AUDIENCES"
  let audMatches :: StringOrURI -> IsMatch
audMatches = (StringOrURI -> IsMatch)
-> (String -> StringOrURI -> IsMatch)
-> Maybe String
-> StringOrURI
-> IsMatch
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (IsMatch -> StringOrURI -> IsMatch
forall a b. a -> b -> a
const IsMatch
Matches) String -> StringOrURI -> IsMatch
checkAud Maybe String
audienceCfg
      checkAud :: String -> StringOrURI -> IsMatch
checkAud String
audConfig = \StringOrURI
tokenAud ->
        if Getting (First URI) StringOrURI URI -> StringOrURI -> Maybe URI
forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview Getting (First URI) StringOrURI URI
Prism' StringOrURI URI
uri StringOrURI
tokenAud Maybe URI -> Maybe URI -> Bool
forall a. Eq a => a -> a -> Bool
== String -> Maybe URI
parseURI String
audConfig then
          IsMatch
Matches
        else if Getting (First Text) StringOrURI Text -> StringOrURI -> Maybe Text
forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
preview Getting (First Text) StringOrURI Text
Prism' StringOrURI Text
string StringOrURI
tokenAud Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> Maybe Text
forall a. a -> Maybe a
Just (String -> Text
T.pack String
audConfig) then
          IsMatch
Matches
        else
          IsMatch
DoesNotMatch
  JWTSettings -> IO JWTSettings
forall (m :: * -> *) a. Monad m => a -> m a
return (JWTSettings -> IO JWTSettings) -> JWTSettings -> IO JWTSettings
forall a b. (a -> b) -> a -> b
$ JWK -> JWKSet -> (StringOrURI -> IsMatch) -> JWTSettings
buildJWTSettings JWK
signKey JWKSet
jwkSet StringOrURI -> IsMatch
audMatches

buildJWTSettings :: JWK -> JWKSet -> (StringOrURI -> IsMatch) -> JWTSettings
buildJWTSettings :: JWK -> JWKSet -> (StringOrURI -> IsMatch) -> JWTSettings
buildJWTSettings JWK
signKey JWKSet
jwkSet StringOrURI -> IsMatch
audMatches =
  JWTSettings :: JWK
-> Maybe Alg -> JWKSet -> (StringOrURI -> IsMatch) -> JWTSettings
JWTSettings
    { signingKey :: JWK
signingKey = JWK
signKey,
      jwtAlg :: Maybe Alg
jwtAlg = Maybe Alg
forall a. Maybe a
Nothing,
      validationKeys :: JWKSet
validationKeys = JWK -> JWKSet -> JWKSet
vkeys JWK
signKey JWKSet
jwkSet,
      audienceMatches :: StringOrURI -> IsMatch
audienceMatches = StringOrURI -> IsMatch
audMatches
    }
  where
    vkeys :: JWK -> JWKSet -> JWKSet
vkeys JWK
k (JWKSet [JWK]
x) = [JWK] -> JWKSet
JWKSet ([JWK]
x [JWK] -> [JWK] -> [JWK]
forall a. [a] -> [a] -> [a]
++ [JWK
k])

acquireJwks :: IO JWKSet
acquireJwks :: IO JWKSet
acquireJwks = do
  Maybe String
envUrl <- String -> IO (Maybe String)
lookupEnv String
"JWK_PATH"
  let jwkPath :: String
jwkPath = String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"secrets/jwk.sig" Maybe String
envUrl
  ByteString
fileContent <- String -> IO ByteString
readFile String
jwkPath
  let parsed :: Either String JWKSet
parsed = ByteString -> Either String JWKSet
forall a. FromJSON a => ByteString -> Either String a
Aeson.eitherDecodeStrict ByteString
fileContent
  JWKSet -> IO JWKSet
forall (m :: * -> *) a. Monad m => a -> m a
return (JWKSet -> IO JWKSet) -> JWKSet -> IO JWKSet
forall a b. (a -> b) -> a -> b
$ (String -> JWKSet)
-> (JWKSet -> JWKSet) -> Either String JWKSet -> JWKSet
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (\String
e -> String -> JWKSet
forall a. HasCallStack => String -> a
error (String -> JWKSet) -> String -> JWKSet
forall a b. (a -> b) -> a -> b
$ String
"Invalid JWK file: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
e) JWKSet -> JWKSet
forall a. a -> a
id Either String JWKSet
parsed