module OpenID.Connect.Client.Authentication
( applyRequestAuthentication
) where
import Control.Lens ((&), (?~), (.~), (^?), (#))
import Control.Monad.Except
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JOSE.JWK (JWK)
import qualified Crypto.JOSE.JWK as JWK
import Crypto.JWT (ClaimsSet)
import qualified Crypto.JWT as JWT
import Crypto.Random (MonadRandom(..))
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import qualified Network.HTTP.Client as HTTP
import OpenID.Connect.Authentication
import OpenID.Connect.JSON
applyRequestAuthentication
:: forall m. MonadRandom m
=> Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> HTTP.Request
-> m (Maybe HTTP.Request)
applyRequestAuthentication creds methods uri now body =
case clientSecret creds of
AssignedSecretText secret
| ClientSecretBasic `elem` methods -> pure . Just . useBasic secret
| ClientSecretPost `elem` methods -> pure . Just . useBody secret
| None `elem` methods -> pure . Just . pass body
| otherwise -> pure . const Nothing
AssignedAssertionText key
| ClientSecretJwt `elem` methods -> hmacWithKey key
| None `elem` methods -> pure . Just . pass body
| otherwise -> pure . const Nothing
AssertionPrivateKey key
| PrivateKeyJwt `elem` methods -> signWithKey key
| None `elem` methods -> pure . Just . pass body
| otherwise -> pure . const Nothing
where
pass :: [(ByteString, ByteString)] -> HTTP.Request -> HTTP.Request
pass = HTTP.urlEncodedBody
useBody :: Text -> HTTP.Request -> HTTP.Request
useBody secret = pass
(body <> [ ("client_secret", Text.encodeUtf8 secret)
])
useBasic :: Text -> HTTP.Request -> HTTP.Request
useBasic secret =
HTTP.applyBasicAuth
(Text.encodeUtf8 (assignedClientId creds))
(Text.encodeUtf8 secret) . pass body
hmacWithKey :: Text -> HTTP.Request -> m (Maybe HTTP.Request)
hmacWithKey keyBytes =
signWithKey (JWK.fromOctets (Text.encodeUtf8 keyBytes))
signWithKey :: JWK -> HTTP.Request -> m (Maybe HTTP.Request)
signWithKey key req = do
claims <- makeClaims <$> makeJti
res <- runExceptT $ do
alg <- JWK.bestJWSAlg key
JWT.signClaims key (JWT.newJWSHeader ((), alg)) claims
case res of
Left (_ :: JOSE.Error) -> pure Nothing
Right jwt -> pure . Just $ HTTP.urlEncodedBody
(body <> [ ( "client_assertion_type"
, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)
, ( "client_assertion"
, LChar8.toStrict (JOSE.encodeCompact jwt)
)
]) req
makeClaims :: Text -> ClaimsSet
makeClaims jti
= JWT.emptyClaimsSet
& JWT.claimIss .~ assignedClientId creds ^? JWT.stringOrUri
& JWT.claimSub .~ assignedClientId creds ^? JWT.stringOrUri
& JWT.claimAud ?~ JWT.Audience (pure (JWT.uri # getURI uri))
& JWT.claimJti ?~ jti
& JWT.claimExp ?~ JWT.NumericDate (addUTCTime 60 now)
& JWT.claimIat ?~ JWT.NumericDate now
makeJti :: m Text
makeJti = (getRandomBytes 64 :: m ByteString)
<&> (<> Char8.pack (show now))
<&> convertToBase Base64URLUnpadded
<&> Text.decodeUtf8