parse and process PKCE requests

pull/1108/head
Justin Richer 2016-07-24 17:45:43 -04:00
parent 5dcda2812e
commit ac0cafe7b3
3 changed files with 61 additions and 7 deletions

View File

@ -19,6 +19,13 @@
*/
package org.mitre.oauth2.service.impl;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_VERIFIER;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.Date;
import java.util.HashSet;
@ -30,6 +37,7 @@ import org.mitre.oauth2.model.AuthenticationHolderEntity;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
import org.mitre.oauth2.model.OAuth2RefreshTokenEntity;
import org.mitre.oauth2.model.PKCEAlgorithm;
import org.mitre.oauth2.model.SystemScope;
import org.mitre.oauth2.repository.AuthenticationHolderRepository;
import org.mitre.oauth2.repository.OAuth2TokenRepository;
@ -44,9 +52,9 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
import org.springframework.security.oauth2.common.exceptions.InvalidScopeException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.ClientAlreadyExistsException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.TokenRequest;
@ -54,6 +62,7 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import org.springframework.stereotype.Service;
import com.google.common.collect.Sets;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
@ -169,14 +178,43 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentication) throws AuthenticationException, InvalidClientException {
if (authentication != null && authentication.getOAuth2Request() != null) {
// look up our client
OAuth2Request clientAuth = authentication.getOAuth2Request();
OAuth2Request request = authentication.getOAuth2Request();
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientAuth.getClientId());
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
if (client == null) {
throw new InvalidClientException("Client not found: " + clientAuth.getClientId());
throw new InvalidClientException("Client not found: " + request.getClientId());
}
// handle the PKCE code challenge if present
if (request.getExtensions().containsKey(CODE_CHALLENGE)) {
String challenge = (String) request.getExtensions().get(CODE_CHALLENGE);
PKCEAlgorithm alg = PKCEAlgorithm.parse((String) request.getExtensions().get(CODE_CHALLENGE_METHOD));
String verifier = request.getRequestParameters().get(CODE_VERIFIER);
if (alg.equals(PKCEAlgorithm.plain)) {
// do a direct string comparison
if (!challenge.equals(verifier)) {
throw new InvalidRequestException("Code challenge and verifier do not match");
}
} else if (alg.equals(PKCEAlgorithm.S256)) {
// hash the verifier
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
String hash = Base64URL.encode(digest.digest(verifier.getBytes(StandardCharsets.US_ASCII))).toString();
if (!challenge.equals(hash)) {
throw new InvalidRequestException("Code challenge and verifier do not match");
}
} catch (NoSuchAlgorithmException e) {
logger.error("Unknown algorithm for PKCE digest", e);
}
}
}
OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();//accessTokenFactory.createNewAccessToken();
// attach the client
@ -185,7 +223,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
// inherit the scope from the auth, but make a new set so it is
//not unmodifiable. Unmodifiables don't play nicely with Eclipselink, which
//wants to use the clone operation.
Set<SystemScope> scopes = scopeService.fromStrings(clientAuth.getScope());
Set<SystemScope> scopes = scopeService.fromStrings(request.getScope());
// remove any of the special system scopes
scopes = scopeService.removeReservedScopes(scopes);

View File

@ -17,7 +17,7 @@
package org.mitre.openid.connect.request;
import static org.mitre.openid.connect.request.ConnectRequestParameters.AUD;
import static org.mitre.openid.connect.request.ConnectRequestParameters.*;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CLAIMS;
import static org.mitre.openid.connect.request.ConnectRequestParameters.CLIENT_ID;
import static org.mitre.openid.connect.request.ConnectRequestParameters.DISPLAY;
@ -41,6 +41,7 @@ import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService;
import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.PKCEAlgorithm;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.oauth2.service.SystemScopeService;
import org.slf4j.Logger;
@ -138,6 +139,16 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
request.getExtensions().put(AUD, inputParams.get(AUD));
}
if (inputParams.containsKey(CODE_CHALLENGE)) {
request.getExtensions().put(CODE_CHALLENGE, inputParams.get(CODE_CHALLENGE));
if (inputParams.containsKey(CODE_CHALLENGE_METHOD)) {
request.getExtensions().put(CODE_CHALLENGE_METHOD, inputParams.get(CODE_CHALLENGE_METHOD));
} else {
// if the client doesn't specify a code challenge transformation method, it's "plain"
request.getExtensions().put(CODE_CHALLENGE_METHOD, PKCEAlgorithm.plain.getName());
}
}
if (inputParams.containsKey(REQUEST)) {
request.getExtensions().put(REQUEST, inputParams.get(REQUEST));

View File

@ -46,5 +46,10 @@ public interface ConnectRequestParameters {
// audience
public String AUD = "aud";
// PKCE
public String CODE_CHALLENGE = "code_challenge";
public String CODE_CHALLENGE_METHOD = "code_challenge_method";
public String CODE_VERIFIER = "code_verifier";
}