module OpenID.Connect.Client.Authentication
( applyRequestAuthentication
) where
import Control.Lens ((&), (?~), (.~), (^?), (#))
import Control.Monad.Except
import qualified Crypto.JOSE.Compact as JOSE
import qualified Crypto.JOSE.Error as JOSE
import Crypto.JOSE.JWK (JWK)
import qualified Crypto.JOSE.JWK as JWK
import Crypto.JWT (ClaimsSet)
import qualified Crypto.JWT as JWT
import Crypto.Random (MonadRandom(..))
import Data.ByteArray.Encoding
import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy.Char8 as LChar8
import Data.Functor ((<&>))
import Data.Text (Text)
import qualified Data.Text.Encoding as Text
import Data.Time.Clock (UTCTime, addUTCTime)
import qualified Network.HTTP.Client as HTTP
import OpenID.Connect.Authentication
import OpenID.Connect.JSON
applyRequestAuthentication
:: forall m. MonadRandom m
=> Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> HTTP.Request
-> m (Maybe HTTP.Request)
applyRequestAuthentication :: Credentials
-> [ClientAuthentication]
-> URI
-> UTCTime
-> [(ByteString, ByteString)]
-> Request
-> m (Maybe Request)
applyRequestAuthentication Credentials
creds [ClientAuthentication]
methods URI
uri UTCTime
now [(ByteString, ByteString)]
body =
case Credentials -> ClientSecret
clientSecret Credentials
creds of
AssignedSecretText Text
secret
| ClientAuthentication
ClientSecretBasic ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Request -> Request
useBasic Text
secret
| ClientAuthentication
ClientSecretPost ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Request -> Request
useBody Text
secret
| ClientAuthentication
None ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
| Bool
otherwise -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Request -> Request -> Maybe Request
forall a b. a -> b -> a
const Maybe Request
forall a. Maybe a
Nothing
AssignedAssertionText Text
key
| ClientAuthentication
ClientSecretJwt ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Text -> Request -> m (Maybe Request)
hmacWithKey Text
key
| ClientAuthentication
None ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
| Bool
otherwise -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Request -> Request -> Maybe Request
forall a b. a -> b -> a
const Maybe Request
forall a. Maybe a
Nothing
AssertionPrivateKey JWK
key
| ClientAuthentication
PrivateKeyJwt ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> JWK -> Request -> m (Maybe Request)
signWithKey JWK
key
| ClientAuthentication
None ClientAuthentication -> [ClientAuthentication] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ClientAuthentication]
methods -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> Maybe Request)
-> (Request -> Request) -> Request -> Maybe Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
| Bool
otherwise -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Request -> Request -> Maybe Request
forall a b. a -> b -> a
const Maybe Request
forall a. Maybe a
Nothing
where
pass :: [(ByteString, ByteString)] -> HTTP.Request -> HTTP.Request
pass :: [(ByteString, ByteString)] -> Request -> Request
pass = [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
useBody :: Text -> HTTP.Request -> HTTP.Request
useBody :: Text -> Request -> Request
useBody Text
secret = [(ByteString, ByteString)] -> Request -> Request
pass
([(ByteString, ByteString)]
body [(ByteString, ByteString)]
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. Semigroup a => a -> a -> a
<> [ (ByteString
"client_secret", Text -> ByteString
Text.encodeUtf8 Text
secret)
])
useBasic :: Text -> HTTP.Request -> HTTP.Request
useBasic :: Text -> Request -> Request
useBasic Text
secret =
ByteString -> ByteString -> Request -> Request
HTTP.applyBasicAuth
(Text -> ByteString
Text.encodeUtf8 (Credentials -> Text
assignedClientId Credentials
creds))
(Text -> ByteString
Text.encodeUtf8 Text
secret) (Request -> Request) -> (Request -> Request) -> Request -> Request
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ByteString, ByteString)] -> Request -> Request
pass [(ByteString, ByteString)]
body
hmacWithKey :: Text -> HTTP.Request -> m (Maybe HTTP.Request)
hmacWithKey :: Text -> Request -> m (Maybe Request)
hmacWithKey Text
keyBytes =
JWK -> Request -> m (Maybe Request)
signWithKey (ByteString -> JWK
forall s. Cons s s Word8 Word8 => s -> JWK
JWK.fromOctets (Text -> ByteString
Text.encodeUtf8 Text
keyBytes))
signWithKey :: JWK -> HTTP.Request -> m (Maybe HTTP.Request)
signWithKey :: JWK -> Request -> m (Maybe Request)
signWithKey JWK
key Request
req = do
ClaimsSet
claims <- Text -> ClaimsSet
makeClaims (Text -> ClaimsSet) -> m Text -> m ClaimsSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Text
makeJti
Either Error SignedJWT
res <- ExceptT Error m SignedJWT -> m (Either Error SignedJWT)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Error m SignedJWT -> m (Either Error SignedJWT))
-> ExceptT Error m SignedJWT -> m (Either Error SignedJWT)
forall a b. (a -> b) -> a -> b
$ do
Alg
alg <- JWK -> ExceptT Error m Alg
forall e (m :: * -> *). (MonadError e m, AsError e) => JWK -> m Alg
JWK.bestJWSAlg JWK
key
JWK -> JWSHeader () -> ClaimsSet -> ExceptT Error m SignedJWT
forall (m :: * -> *) e.
(MonadRandom m, MonadError e m, AsError e) =>
JWK -> JWSHeader () -> ClaimsSet -> m SignedJWT
JWT.signClaims JWK
key (((), Alg) -> JWSHeader ()
forall p. (p, Alg) -> JWSHeader p
JWT.newJWSHeader ((), Alg
alg)) ClaimsSet
claims
case Either Error SignedJWT
res of
Left (Error
_ :: JOSE.Error) -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Request
forall a. Maybe a
Nothing
Right SignedJWT
jwt -> Maybe Request -> m (Maybe Request)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Request -> m (Maybe Request))
-> (Request -> Maybe Request) -> Request -> m (Maybe Request)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> Maybe Request
forall a. a -> Maybe a
Just (Request -> m (Maybe Request)) -> Request -> m (Maybe Request)
forall a b. (a -> b) -> a -> b
$ [(ByteString, ByteString)] -> Request -> Request
HTTP.urlEncodedBody
([(ByteString, ByteString)]
body [(ByteString, ByteString)]
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. Semigroup a => a -> a -> a
<> [ ( ByteString
"client_assertion_type"
, ByteString
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)
, ( ByteString
"client_assertion"
, ByteString -> ByteString
LChar8.toStrict (SignedJWT -> ByteString
forall a. ToCompact a => a -> ByteString
JOSE.encodeCompact SignedJWT
jwt)
)
]) Request
req
makeClaims :: Text -> ClaimsSet
makeClaims :: Text -> ClaimsSet
makeClaims Text
jti
= ClaimsSet
JWT.emptyClaimsSet
ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe StringOrURI -> Identity (Maybe StringOrURI))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe StringOrURI)
JWT.claimIss ((Maybe StringOrURI -> Identity (Maybe StringOrURI))
-> ClaimsSet -> Identity ClaimsSet)
-> Maybe StringOrURI -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> Text
assignedClientId Credentials
creds Text
-> Getting (First StringOrURI) Text StringOrURI
-> Maybe StringOrURI
forall s a. s -> Getting (First a) s a -> Maybe a
^? Getting (First StringOrURI) Text StringOrURI
forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe StringOrURI -> Identity (Maybe StringOrURI))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe StringOrURI)
JWT.claimSub ((Maybe StringOrURI -> Identity (Maybe StringOrURI))
-> ClaimsSet -> Identity ClaimsSet)
-> Maybe StringOrURI -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a b -> b -> s -> t
.~ Credentials -> Text
assignedClientId Credentials
creds Text
-> Getting (First StringOrURI) Text StringOrURI
-> Maybe StringOrURI
forall s a. s -> Getting (First a) s a -> Maybe a
^? Getting (First StringOrURI) Text StringOrURI
forall s. (Cons s s Char Char, AsEmpty s) => Prism' s StringOrURI
JWT.stringOrUri
ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe Audience -> Identity (Maybe Audience))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe Audience)
JWT.claimAud ((Maybe Audience -> Identity (Maybe Audience))
-> ClaimsSet -> Identity ClaimsSet)
-> Audience -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ [StringOrURI] -> Audience
JWT.Audience (StringOrURI -> [StringOrURI]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tagged URI (Identity URI)
-> Tagged StringOrURI (Identity StringOrURI)
Prism' StringOrURI URI
JWT.uri (Tagged URI (Identity URI)
-> Tagged StringOrURI (Identity StringOrURI))
-> URI -> StringOrURI
forall t b. AReview t b -> b -> t
# URI -> URI
getURI URI
uri))
ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe Text -> Identity (Maybe Text))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe Text)
JWT.claimJti ((Maybe Text -> Identity (Maybe Text))
-> ClaimsSet -> Identity ClaimsSet)
-> Text -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ Text
jti
ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe NumericDate)
JWT.claimExp ((Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate (NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
60 UTCTime
now)
ClaimsSet -> (ClaimsSet -> ClaimsSet) -> ClaimsSet
forall a b. a -> (a -> b) -> b
& (Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet
Lens' ClaimsSet (Maybe NumericDate)
JWT.claimIat ((Maybe NumericDate -> Identity (Maybe NumericDate))
-> ClaimsSet -> Identity ClaimsSet)
-> NumericDate -> ClaimsSet -> ClaimsSet
forall s t a b. ASetter s t a (Maybe b) -> b -> s -> t
?~ UTCTime -> NumericDate
JWT.NumericDate UTCTime
now
makeJti :: m Text
makeJti :: m Text
makeJti = (Int -> m ByteString
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64 :: m ByteString)
m ByteString -> (ByteString -> ByteString) -> m ByteString
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> String -> ByteString
Char8.pack (UTCTime -> String
forall a. Show a => a -> String
show UTCTime
now))
m ByteString -> (ByteString -> ByteString) -> m ByteString
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Base -> ByteString -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
convertToBase Base
Base64URLUnpadded
m ByteString -> (ByteString -> Text) -> m Text
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ByteString -> Text
Text.decodeUtf8