{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE GADTs #-} {-| Module: Web.OIDC.Client Maintainer: krdlab@gmail.com Stability: experimental -} module Web.OIDC.Client ( -- * Client Obtains ID Token and Access Token OIDC , newOIDC , newOIDC' , setProvider , setCredentials , getAuthenticationRequestUrl , requestTokens -- * Types , Provider , Scope, ScopeValue(..) , Code, State , Parameters , Tokens(..), IdToken(..), IdTokenClaims(..) -- * Exception , OpenIdException(..) -- * Re-exports , module Jose.Jwt ) where import Control.Applicative ((<$>)) import Control.Monad (unless) import Control.Monad.Catch (MonadThrow, throwM, MonadCatch, catch) import Crypto.Random (CPRG) import Data.Aeson (decode) import qualified Data.ByteString.Char8 as B import Data.ByteString (ByteString) import Data.IORef (IORef, atomicModifyIORef') import Data.List (nub) import Data.Maybe (fromMaybe, fromJust) import Data.Text (pack) import Data.Text.Encoding (decodeUtf8) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.Tuple (swap) import qualified Jose.Jwk as Jwk import Jose.Jwt (Jwt) import qualified Jose.Jwt as Jwt import Network.HTTP.Client (parseUrl, getUri, setQueryString, applyBasicAuth, urlEncodedBody, Request(..), Manager, httpLbs, responseBody) import Network.URI (URI) import Prelude hiding (exp) import qualified Web.OIDC.Client.Internal as I import qualified Web.OIDC.Types as OT import Web.OIDC.Types (Provider, Scope, ScopeValue(..), Code, State, Parameters, Tokens(..), IdToken(..), IdTokenClaims(..), OpenIdException(..)) -- | This data type represents information needed in the OpenID flow. data OIDC = OIDC { authorizationSeverUrl :: String , tokenEndpoint :: String , clientId :: ByteString , clientSecret :: ByteString , redirectUri :: ByteString , provider :: Provider , cprgRef :: CPRGRef } data CPRGRef where Ref :: (CPRG g) => IORef g -> CPRGRef NoRef :: CPRGRef def :: OIDC def = OIDC { authorizationSeverUrl = error "You must specify authorizationSeverUrl" , tokenEndpoint = error "You must specify tokenEndpoint" , clientId = error "You must specify clientId" , clientSecret = error "You must specify clientSecret" , redirectUri = error "You must specify redirectUri" , provider = error "You must specify provider" , cprgRef = NoRef } -- | Create OIDC. -- -- First argument is used in a token decoding on ID Token Validation. newOIDC :: CPRG g => IORef g -> OIDC newOIDC ref = def { cprgRef = Ref ref } newOIDC' :: OIDC newOIDC' = def setProvider :: Provider -- ^ OP's information (obtain by 'discover') -> OIDC -> OIDC setProvider p oidc = oidc { authorizationSeverUrl = OT.authorizationEndpoint . OT.configuration $ p , tokenEndpoint = OT.tokenEndpoint . OT.configuration $ p , provider = p } setCredentials :: ByteString -- ^ client ID -> ByteString -- ^ client secret -> ByteString -- ^ redirect URI -> OIDC -> OIDC setCredentials cid secret redirect oidc = oidc { clientId = cid , clientSecret = secret , redirectUri = redirect } getAuthenticationRequestUrl :: (MonadThrow m, MonadCatch m) => OIDC -> Scope -> Maybe State -> Parameters -> m URI getAuthenticationRequestUrl oidc scope state params = do req <- parseUrl endpoint `catch` OT.rethrow return $ getUri $ setQueryString query req where endpoint = authorizationSeverUrl oidc query = requireds ++ state' ++ params requireds = [ ("response_type", Just "code") , ("client_id", Just $ clientId oidc) , ("redirect_uri", Just $ redirectUri oidc) , ("scope", Just $ B.pack . unwords . nub . map show $ OpenId:scope) ] state' = case state of Just _ -> [("state", state)] Nothing -> [] -- TODO: error response -- | Request and obtain valid tokens. -- -- This function requests ID Token and Access Token to a OP's token endpoint, and validates the received ID Token. -- Returned value is a valid tokens. requestTokens :: OIDC -> Code -> Manager -> IO Tokens requestTokens oidc code manager = do json <- getTokensJson `catch` OT.rethrow case decode json of Just ts -> validate oidc ts Nothing -> error "failed to decode tokens json" -- TODO where getTokensJson = do req <- parseUrl endpoint let req' = applyBasicAuth cid sec $ urlEncodedBody body $ req { method = "POST" } res <- httpLbs req' manager return $ responseBody res endpoint = tokenEndpoint oidc cid = clientId oidc sec = clientSecret oidc redirect = redirectUri oidc body = [ ("grant_type", "authorization_code") , ("code", code) , ("redirect_uri", redirect) ] validate :: OIDC -> I.TokensResponse -> IO Tokens validate oidc tres = do let jwt' = I.idToken tres claims' <- validateIdToken oidc jwt' let tokens = Tokens { accessToken = I.accessToken tres , tokenType = I.tokenType tres , idToken = IdToken { claims = OT.toIdTokenClaims claims', jwt = jwt' } , expiresIn = I.expiresIn tres , refreshToken = I.refreshToken tres } return tokens validateIdToken :: OIDC -> Jwt -> IO Jwt.JwtClaims validateIdToken oidc jwt' = do case cprgRef oidc of Ref crpg -> do decoded <- case Jwt.decodeClaims (Jwt.unJwt jwt') of Left cause -> throwM $ JwtExceptoin cause Right (jwth, _) -> case jwth of (Jwt.JwsH jws) -> do let kid = Jwt.jwsKid jws alg = Jwt.jwsAlg jws jwk = getJwk kid (OT.jwkSet . provider $ oidc) atomicModifyIORef' crpg $ \g -> swap (Jwt.decode g [jwk] (Just $ Jwt.JwsEncoding alg) (Jwt.unJwt jwt')) (Jwt.JweH jwe) -> do let kid = Jwt.jweKid jwe alg = Jwt.jweAlg jwe enc = Jwt.jweEnc jwe jwk = getJwk kid (OT.jwkSet . provider $ oidc) atomicModifyIORef' crpg $ \g -> swap (Jwt.decode g [jwk] (Just $ Jwt.JweEncoding alg enc) (Jwt.unJwt jwt')) _ -> error "not supported" case decoded of Left err -> throwM $ JwtExceptoin err Right _ -> return () NoRef -> error "not implemented" -- TODO: request tokeninfo claims' <- getClaims unless (getIss claims' == issuer') $ throwM $ ValidationException "issuer" unless (clientId' `elem` getAud claims') $ throwM $ ValidationException "audience" expire <- getExp claims' now <- getCurrentTime unless (now < expire) $ throwM $ ValidationException "expire" return claims' where getJwk kid jwks = head $ case kid of Just keyId -> filter (eq keyId) jwks Nothing -> jwks where eq e jwk = fromMaybe False ((==) e <$> Jwk.jwkId jwk) getClaims = case Jwt.decodeClaims (Jwt.unJwt jwt') of Right (_, c) -> return c Left cause -> throwM $ JwtExceptoin cause issuer' = pack . OT.issuer . OT.configuration . provider $ oidc clientId' = decodeUtf8 . clientId $ oidc getIss c = fromJust (Jwt.jwtIss c) getAud c = fromJust (Jwt.jwtAud c) getExp c = case Jwt.jwtExp c of Just e -> return e Nothing -> throwM $ ValidationException "exp claim was not found" getCurrentTime = Jwt.IntDate <$> getPOSIXTime