module OpenID.Connect.Client.Authentication
( applyRequestAuthentication
) where
import Control.Lens ((&), (?~), (.~), (^?), (#))
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JOSE.JWK (JWK)
import qualified Crypto.JOSE.JWK as JWK
import Crypto.JWT (ClaimsSet)
import qualified Crypto.JWT as JWT
import Crypto.Random (MonadRandom(..))
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import qualified Network.HTTP.Client as HTTP
import OpenID.Connect.Authentication
import OpenID.Connect.JSON
applyRequestAuthentication
:: forall m. MonadRandom m
=> Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> HTTP.Request
-> m (Maybe HTTP.Request)
applyRequestAuthentication :: forall (m :: * -> *).
MonadRandom m =>
Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> Request
-> m (Maybe Request)
applyRequestAuthentication Credentials
creds [ClientAuthentication]
methods URI
uri UTCTime
now [(ByteString, ByteString)]
body =
case Credentials -> ClientSecret
clientSecret Credentials
creds of
AssignedSecretText ClientID
secret
| ClientAuthentication
ClientSecretBasic forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientID -> Request -> Request
useBasic ClientID
secret
| ClientAuthentication
ClientSecretPost forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientID -> Request -> Request
useBody ClientID
secret
| ClientAuthentication
None forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
| Bool
otherwise -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall a. Maybe a
Nothing
AssignedAssertionText ClientID
key
| ClientAuthentication
ClientSecretJwt forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> ClientID -> Request -> m (Maybe Request)
hmacWithKey ClientID
key
| ClientAuthentication
None forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
| Bool
otherwise -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall a. Maybe a
Nothing
AssertionPrivateKey JWK
key
| ClientAuthentication
PrivateKeyJwt forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> JWK -> Request -> m (Maybe Request)
signWithKey JWK
key
| ClientAuthentication
None forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
| Bool
otherwise -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const forall a. Maybe a
Nothing
where
pass :: [(ByteString, ByteString)] -> HTTP.Request -> HTTP.Request
pass :: [(ByteString, ByteString)] -> Request -> Request
pass = [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
useBody :: Text -> HTTP.Request -> HTTP.Request
useBody :: ClientID -> Request -> Request
useBody ClientID
secret = [(ByteString, ByteString)] -> Request -> Request
pass
([(ByteString, ByteString)]
body forall a. Semigroup a => a -> a -> a
<> [ (ByteString
"client_secret", ClientID -> ByteString
Text.encodeUtf8 ClientID
secret)
])
useBasic :: Text -> HTTP.Request -> HTTP.Request
useBasic :: ClientID -> Request -> Request
useBasic ClientID
secret =
ByteString -> ByteString -> Request -> Request
HTTP.applyBasicAuth
(ClientID -> ByteString
Text.encodeUtf8 (Credentials -> ClientID
assignedClientId Credentials
creds))
(ClientID -> ByteString
Text.encodeUtf8 ClientID
secret) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
hmacWithKey :: Text -> HTTP.Request -> m (Maybe HTTP.Request)
hmacWithKey :: ClientID -> Request -> m (Maybe Request)
hmacWithKey ClientID
keyBytes =
JWK -> Request -> m (Maybe Request)
signWithKey (forall s. Cons s s Word8 Word8 => s -> JWK
JWK.fromOctets (ClientID -> ByteString
Text.encodeUtf8 ClientID
keyBytes))
signWithKey :: JWK -> HTTP.Request -> m (Maybe HTTP.Request)
signWithKey :: JWK -> Request -> m (Maybe Request)
signWithKey JWK
key Request
req = do
ClaimsSet
claims <- ClientID -> ClaimsSet
makeClaims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m ClientID
makeJti
Either Error SignedJWT
res <- forall e (m :: * -> *) a. JOSE e m a -> m (Either e a)
JWT.runJOSE forall a b. (a -> b) -> a -> b
$ do
Alg
alg <- forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
JWK.bestJWSAlg JWK
key
forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
JWT.signClaims JWK
key (forall p. (p, Alg) -> JWSHeader p
JWT.newJWSHeader ((), Alg
alg)) ClaimsSet
claims
case Either Error SignedJWT
res of
Left (Error
_ :: JOSE.Error) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
Right SignedJWT
jwt -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
([(ByteString, ByteString)]
body forall a. Semigroup a => a -> a -> a
<> [ ( ByteString
"client_assertion_type"
, ByteString
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)
, ( ByteString
"client_assertion"
, ByteString -> ByteString
LChar8.toStrict (forall a. ToCompact a => a -> ByteString
JOSE.encodeCompact SignedJWT
jwt)
)
]) Request
req
makeClaims :: Text -> ClaimsSet
makeClaims :: ClientID -> ClaimsSet
makeClaims ClientID
jti
= ClaimsSet
JWT.emptyClaimsSet
forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
JWT.claimIss forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> ClientID
assignedClientId Credentials
creds forall s a. s -> Getting (First a) s a -> Maybe a
^? forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe StringOrURI)
JWT.claimSub forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> ClientID
assignedClientId Credentials
creds forall s a. s -> Getting (First a) s a -> Maybe a
^? forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe Audience)
JWT.claimAud forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ [StringOrURI] -> Audience
JWT.Audience (forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prism' StringOrURI URI
JWT.uri forall t b. AReview t b -> b -> t
# URI -> URI
getURI URI
uri))
forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe ClientID)
JWT.claimJti forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ ClientID
jti
forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
JWT.claimExp forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate (NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
60 UTCTime
now)
forall a b. a -> (a -> b) -> b
& forall a. HasClaimsSet a => Lens' a (Maybe NumericDate)
JWT.claimIat forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate UTCTime
now
makeJti :: m Text
makeJti :: m ClientID
makeJti = (forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64 :: m ByteString)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (forall a. Semigroup a => a -> a -> a
<> String -> ByteString
Char8.pack (forall a. Show a => a -> String
show UTCTime
now))
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase Base
Base64URLUnpadded
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ByteString -> ClientID
Text.decodeUtf8