{-|

Copyright:

  This file is part of the package openid-connect.  It is subject to
  the license terms in the LICENSE file found in the top-level
  directory of this distribution and at:

    https://code.devalot.com/open/openid-connect

  No part of this package, including this file, may be copied,
  modified, propagated, or distributed except according to the terms
  contained in the LICENSE file.

License: BSD-2-Clause

Client authentication.

-}
module OpenID.Connect.Client.Authentication
  ( applyRequestAuthentication
  ) where

--------------------------------------------------------------------------------
-- Imports:
import Control.Lens ((&), (?~), (.~), (^?), (#))
import Control.Monad.Except
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JOSE.JWK (JWK)
import qualified Crypto.JOSE.JWK as JWK
import Crypto.JWT (ClaimsSet)
import qualified Crypto.JWT as JWT
import Crypto.Random (MonadRandom(..))
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import qualified Network.HTTP.Client as HTTP
import OpenID.Connect.Authentication
import OpenID.Connect.JSON

--------------------------------------------------------------------------------
-- | Modify a request so that it uses the proper authentication method.
applyRequestAuthentication
  :: forall m. MonadRandom m
  => Credentials                -- ^ Client credentials.
  -> [ClientAuthentication]     -- ^ Available authentication methods.
  -> URI                        -- ^ Token Endpoint URI
  -> UTCTime                    -- ^ The current time.
  -> [(ByteString, ByteString)] -- ^ Headers to include in the post.
  -> HTTP.Request               -- ^ The request to modify.
  -> m (Maybe HTTP.Request)     -- ^ The final request.
applyRequestAuthentication :: Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> Request
-> m (Maybe Request)
applyRequestAuthentication Credentials
creds [ClientAuthentication]
methods URI
uri UTCTime
now [(ByteString, ByteString)]
body =
  case Credentials -> ClientSecret
clientSecret Credentials
creds of
    AssignedSecretText Text
secret
      | ClientAuthentication
ClientSecretBasic ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Request -> Request
useBasic Text
secret
      | ClientAuthentication
ClientSecretPost  ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Request -> Request
useBody Text
secret
      | ClientAuthentication
None              ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
      | Bool
otherwise                        -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Request -> Request -> Maybe Request
forall a b. a -> b -> a
const Maybe Request
forall a. Maybe a
Nothing
    AssignedAssertionText Text
key
      | ClientAuthentication
ClientSecretJwt ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods   -> Text -> Request -> m (Maybe Request)
hmacWithKey Text
key
      | ClientAuthentication
None            ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods   -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
      | Bool
otherwise                        -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Request -> Request -> Maybe Request
forall a b. a -> b -> a
const Maybe Request
forall a. Maybe a
Nothing
    AssertionPrivateKey JWK
key
      | ClientAuthentication
PrivateKeyJwt ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods     -> JWK -> Request -> m (Maybe Request)
signWithKey JWK
key
      | ClientAuthentication
None          ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods     -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
      | Bool
otherwise                        -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Request -> Request -> Maybe Request
forall a b. a -> b -> a
const Maybe Request
forall a. Maybe a
Nothing

  where
    pass :: [(ByteString, ByteString)] -> HTTP.Request -> HTTP.Request
    pass :: [(ByteString, ByteString)] -> Request -> Request
pass = [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody

    useBody :: Text -> HTTP.Request -> HTTP.Request
    useBody :: Text -> Request -> Request
useBody Text
secret = [(ByteString, ByteString)] -> Request -> Request
pass
      ([(ByteString, ByteString)]
body [(ByteString, ByteString)]
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. Semigroup a => a -> a -> a
<> [ (ByteString
"client_secret", Text -> ByteString
Text.encodeUtf8 Text
secret)
               ])

    useBasic :: Text -> HTTP.Request -> HTTP.Request
    useBasic :: Text -> Request -> Request
useBasic Text
secret =
      ByteString -> ByteString -> Request -> Request
HTTP.applyBasicAuth
        (Text -> ByteString
Text.encodeUtf8 (Credentials -> Text
assignedClientId Credentials
creds))
        (Text -> ByteString
Text.encodeUtf8 Text
secret) (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body

    -- Use the @client_secret@ as a /key/ to sign a JWT.
    hmacWithKey :: Text -> HTTP.Request -> m (Maybe HTTP.Request)
    hmacWithKey :: Text -> Request -> m (Maybe Request)
hmacWithKey Text
keyBytes =
      JWK -> Request -> m (Maybe Request)
signWithKey (ByteString -> JWK
forall s. Cons s s Word8 Word8 => s -> JWK
JWK.fromOctets (Text -> ByteString
Text.encodeUtf8 Text
keyBytes))

    -- Use the given key to /sign/ a JWT.  May create an actual
    -- digital signature or in the case of 'hmacWithKey', create an
    -- HMAC for the header.
    signWithKey :: JWK -> HTTP.Request -> m (Maybe HTTP.Request)
    signWithKey :: JWK -> Request -> m (Maybe Request)
signWithKey JWK
key Request
req = do
      ClaimsSet
claims <- Text -> ClaimsSet
makeClaims (Text -> ClaimsSet) -> m Text -> m ClaimsSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Text
makeJti
      Either Error SignedJWT
res <- ExceptT Error m SignedJWT -> m (Either Error SignedJWT)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Error m SignedJWT -> m (Either Error SignedJWT))
-> ExceptT Error m SignedJWT -> m (Either Error SignedJWT)
forall a b. (a -> b) -> a -> b
$ do
        Alg
alg <- JWK -> ExceptT Error m Alg
forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
JWK.bestJWSAlg JWK
key
        JWK -> JWSHeader () -> ClaimsSet -> ExceptT Error m SignedJWT
forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
JWT.signClaims JWK
key (((), Alg) -> JWSHeader ()
forall p. (p, Alg) -> JWSHeader p
JWT.newJWSHeader ((), Alg
alg)) ClaimsSet
claims
      case Either Error SignedJWT
res of
        Left (Error
_ :: JOSE.Error) -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Request
forall a. Maybe a
Nothing
        Right SignedJWT
jwt -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> m (Maybe Request)) -> Request -> m (Maybe Request)
forall a b. (a -> b) -> a -> b
$ [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
          ([(ByteString, ByteString)]
body [(ByteString, ByteString)]
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. Semigroup a => a -> a -> a
<> [ ( ByteString
"client_assertion_type"
                     , ByteString
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
                     )
                   , ( ByteString
"client_assertion"
                     , ByteString -> ByteString
LChar8.toStrict (SignedJWT -> ByteString
forall a. ToCompact a => a -> ByteString
JOSE.encodeCompact SignedJWT
jwt)
                     )
                   ]) Request
req

    -- Claims required by OpenID Connect Core §9.
    makeClaims :: Text -> ClaimsSet
    makeClaims :: Text -> ClaimsSet
makeClaims Text
jti
      = ClaimsSet
JWT.emptyClaimsSet
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe StringOrURI -> Identity (Maybe StringOrURI))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe StringOrURI)
JWT.claimIss ((Maybe StringOrURI -> Identity (Maybe StringOrURI))
 -> ClaimsSet -> Identity ClaimsSet)
-> Maybe StringOrURI -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> Text
assignedClientId Credentials
creds Text
-> Getting (First StringOrURI) Text StringOrURI
-> Maybe StringOrURI
forall s a. s -> Getting (First a) s a -> Maybe a
^? Getting (First StringOrURI) Text StringOrURI
forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe StringOrURI -> Identity (Maybe StringOrURI))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe StringOrURI)
JWT.claimSub ((Maybe StringOrURI -> Identity (Maybe StringOrURI))
 -> ClaimsSet -> Identity ClaimsSet)
-> Maybe StringOrURI -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> Text
assignedClientId Credentials
creds Text
-> Getting (First StringOrURI) Text StringOrURI
-> Maybe StringOrURI
forall s a. s -> Getting (First a) s a -> Maybe a
^? Getting (First StringOrURI) Text StringOrURI
forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe Audience -> Identity (Maybe Audience))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe Audience)
JWT.claimAud ((Maybe Audience -> Identity (Maybe Audience))
 -> ClaimsSet -> Identity ClaimsSet)
-> Audience -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ [StringOrURI] -> Audience
JWT.Audience (StringOrURI -> [StringOrURI]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tagged URI (Identity URI)
-> Tagged StringOrURI (Identity StringOrURI)
Prism' StringOrURI URI
JWT.uri (Tagged URI (Identity URI)
 -> Tagged StringOrURI (Identity StringOrURI))
-> URI -> StringOrURI
forall t b. AReview t b -> b -> t
# URI -> URI
getURI URI
uri))
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe Text -> Identity (Maybe Text))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe Text)
JWT.claimJti ((Maybe Text -> Identity (Maybe Text))
 -> ClaimsSet -> Identity ClaimsSet)
-> Text -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Text
jti
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe NumericDate)
JWT.claimExp ((Maybe NumericDate -> Identity (Maybe NumericDate))
 -> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate (NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
60 UTCTime
now)
      ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe NumericDate)
JWT.claimIat ((Maybe NumericDate -> Identity (Maybe NumericDate))
 -> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate UTCTime
now

    -- JWT ID.  From the standard: A unique identifier for the token,
    -- which can be used to prevent reuse of the token.
    makeJti :: m Text
    makeJti :: m Text
makeJti = (Int -> m ByteString
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64 :: m ByteString)
                m ByteString -> (ByteString -> ByteString) -> m ByteString
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
Char8.pack (UTCTime -> String
forall a. Show a => a -> String
show UTCTime
now))
                m ByteString -> (ByteString -> ByteString) -> m ByteString
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Base -> ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase Base
Base64URLUnpadded
                m ByteString -> (ByteString -> Text) -> m Text
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ByteString -> Text
Text.decodeUtf8