{-|

Copyright:

  This file is part of the package openid-connect.  It is subject to
  the license terms in the LICENSE file found in the top-level
  directory of this distribution and at:

    https://code.devalot.com/open/openid-connect

  No part of this package, including this file, may be copied,
  modified, propagated, or distributed except according to the terms
  contained in the LICENSE file.

License: BSD-2-Clause

Client authentication.

-}
module OpenID.Connect.Client.Authentication
  ( applyRequestAuthentication
  ) where

--------------------------------------------------------------------------------
-- Imports:
import Control.Lens ((&), (?~), (.~), (^?), (#))
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JOSE.JWK (JWK)
import qualified Crypto.JOSE.JWK as JWK
import Crypto.JWT (ClaimsSet)
import qualified Crypto.JWT as JWT
import Crypto.Random (MonadRandom(..))
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import qualified Network.HTTP.Client as HTTP
import OpenID.Connect.Authentication
import OpenID.Connect.JSON

--------------------------------------------------------------------------------
-- | Modify a request so that it uses the proper authentication method.
applyRequestAuthentication
  :: forall m. MonadRandom m
  => Credentials                -- ^ Client credentials.
  -> [ClientAuthentication]     -- ^ Available authentication methods.
  -> URI                        -- ^ Token Endpoint URI
  -> UTCTime                    -- ^ The current time.
  -> [(ByteString, ByteString)] -- ^ Headers to include in the post.
  -> HTTP.Request               -- ^ The request to modify.
  -> m (Maybe HTTP.Request)     -- ^ The final request.
applyRequestAuthentication :: forall (m :: * -> *).
MonadRandom m =>
Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> Request
-> m (Maybe Request)
applyRequestAuthentication Credentials
creds [ClientAuthentication]
methods URI
uri UTCTime
now [(ByteString, ByteString)]
body =
  case Credentials -> ClientSecret
clientSecret Credentials
creds of
    AssignedSecretText ClientID
secret
      | ClientAuthentication
ClientSecretBasic forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientID -> Request -> Request
useBasic ClientID
secret
      | ClientAuthentication
ClientSecretPost  forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientID -> Request -> Request
useBody ClientID
secret
      | ClientAuthentication
None              forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
      | Bool
otherwise                        -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall a. Maybe a
Nothing
    AssignedAssertionText ClientID
key
      | ClientAuthentication
ClientSecretJwt forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods   -> ClientID -> Request -> m (Maybe Request)
hmacWithKey ClientID
key
      | ClientAuthentication
None            forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods   -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
      | Bool
otherwise                        -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall a. Maybe a
Nothing
    AssertionPrivateKey JWK
key
      | ClientAuthentication
PrivateKeyJwt forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods     -> JWK -> Request -> m (Maybe Request)
signWithKey JWK
key
      | ClientAuthentication
None          forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods     -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
      | Bool
otherwise                        -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall a. Maybe a
Nothing

  where
    pass :: [(ByteString, ByteString)] -> HTTP.Request -> HTTP.Request
    pass :: [(ByteString, ByteString)] -> Request -> Request
pass = [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody

    useBody :: Text -> HTTP.Request -> HTTP.Request
    useBody :: ClientID -> Request -> Request
useBody ClientID
secret = [(ByteString, ByteString)] -> Request -> Request
pass
      ([(ByteString, ByteString)]
body forall a. Semigroup a => a -> a -> a
<> [ (ByteString
"client_secret", ClientID -> ByteString
Text.encodeUtf8 ClientID
secret)
               ])

    useBasic :: Text -> HTTP.Request -> HTTP.Request
    useBasic :: ClientID -> Request -> Request
useBasic ClientID
secret =
      ByteString -> ByteString -> Request -> Request
HTTP.applyBasicAuth
        (ClientID -> ByteString
Text.encodeUtf8 (Credentials -> ClientID
assignedClientId Credentials
creds))
        (ClientID -> ByteString
Text.encodeUtf8 ClientID
secret) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body

    -- Use the @client_secret@ as a /key/ to sign a JWT.
    hmacWithKey :: Text -> HTTP.Request -> m (Maybe HTTP.Request)
    hmacWithKey :: ClientID -> Request -> m (Maybe Request)
hmacWithKey ClientID
keyBytes =
      JWK -> Request -> m (Maybe Request)
signWithKey (forall s. Cons s s Word8 Word8 => s -> JWK
JWK.fromOctets (ClientID -> ByteString
Text.encodeUtf8 ClientID
keyBytes))

    -- Use the given key to /sign/ a JWT.  May create an actual
    -- digital signature or in the case of 'hmacWithKey', create an
    -- HMAC for the header.
    signWithKey :: JWK -> HTTP.Request -> m (Maybe HTTP.Request)
    signWithKey :: JWK -> Request -> m (Maybe Request)
signWithKey JWK
key Request
req = do
      ClaimsSet
claims <- ClientID -> ClaimsSet
makeClaims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m ClientID
makeJti
      Either Error SignedJWT
res <- forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
JWT.runJOSE forall a b. (a -> b) -> a -> b
$ do
        Alg
alg <- forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
JWK.bestJWSAlg JWK
key
        forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
JWT.signClaims JWK
key (forall p. (p, Alg) -> JWSHeader p
JWT.newJWSHeader ((), Alg
alg)) ClaimsSet
claims
      case Either Error SignedJWT
res of
        Left (Error
_ :: JOSE.Error) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        Right SignedJWT
jwt -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
          ([(ByteString, ByteString)]
body forall a. Semigroup a => a -> a -> a
<> [ ( ByteString
"client_assertion_type"
                     , ByteString
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
                     )
                   , ( ByteString
"client_assertion"
                     , ByteString -> ByteString
LChar8.toStrict (forall a. ToCompact a => a -> ByteString
JOSE.encodeCompact SignedJWT
jwt)
                     )
                   ]) Request
req

    -- Claims required by OpenID Connect Core §9.
    makeClaims :: Text -> ClaimsSet
    makeClaims :: ClientID -> ClaimsSet
makeClaims ClientID
jti
      = ClaimsSet
JWT.emptyClaimsSet
      forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
JWT.claimIss forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> ClientID
assignedClientId Credentials
creds forall s a. s -> Getting (First a) s a -> Maybe a
^? forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
      forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
JWT.claimSub forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> ClientID
assignedClientId Credentials
creds forall s a. s -> Getting (First a) s a -> Maybe a
^? forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
      forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe Audience)
JWT.claimAud forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ [StringOrURI] -> Audience
JWT.Audience (forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prism' StringOrURI URI
JWT.uri forall t b. AReview t b -> b -> t
# URI -> URI
getURI URI
uri))
      forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe ClientID)
JWT.claimJti forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ ClientID
jti
      forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
JWT.claimExp forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate (NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
60 UTCTime
now)
      forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
JWT.claimIat forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate UTCTime
now

    -- JWT ID.  From the standard: A unique identifier for the token,
    -- which can be used to prevent reuse of the token.
    makeJti :: m Text
    makeJti :: m ClientID
makeJti = (forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64 :: m ByteString)
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall a. Semigroup a => a -> a -> a
<> String -> ByteString
Char8.pack (forall a. Show a => a -> String
show UTCTime
now))
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase Base
Base64URLUnpadded
                forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ByteString -> ClientID
Text.decodeUtf8