diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java index 113bf3c37..ee66603f7 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java @@ -19,6 +19,13 @@ */ package org.mitre.oauth2.service.impl; +import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE; +import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_CHALLENGE_METHOD; +import static org.mitre.openid.connect.request.ConnectRequestParameters.CODE_VERIFIER; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.util.Collection; import java.util.Date; import java.util.HashSet; @@ -30,6 +37,7 @@ import org.mitre.oauth2.model.AuthenticationHolderEntity; import org.mitre.oauth2.model.ClientDetailsEntity; import org.mitre.oauth2.model.OAuth2AccessTokenEntity; import org.mitre.oauth2.model.OAuth2RefreshTokenEntity; +import org.mitre.oauth2.model.PKCEAlgorithm; import org.mitre.oauth2.model.SystemScope; import org.mitre.oauth2.repository.AuthenticationHolderRepository; import org.mitre.oauth2.repository.OAuth2TokenRepository; @@ -44,9 +52,9 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.core.AuthenticationException; import org.springframework.security.oauth2.common.exceptions.InvalidClientException; +import org.springframework.security.oauth2.common.exceptions.InvalidRequestException; import org.springframework.security.oauth2.common.exceptions.InvalidScopeException; import org.springframework.security.oauth2.common.exceptions.InvalidTokenException; -import org.springframework.security.oauth2.provider.ClientAlreadyExistsException; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.TokenRequest; @@ -54,6 +62,7 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer; import org.springframework.stereotype.Service; import com.google.common.collect.Sets; +import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.PlainJWT; @@ -169,14 +178,43 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentication) throws AuthenticationException, InvalidClientException { if (authentication != null && authentication.getOAuth2Request() != null) { // look up our client - OAuth2Request clientAuth = authentication.getOAuth2Request(); + OAuth2Request request = authentication.getOAuth2Request(); - ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientAuth.getClientId()); + ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); if (client == null) { - throw new InvalidClientException("Client not found: " + clientAuth.getClientId()); + throw new InvalidClientException("Client not found: " + request.getClientId()); } + + // handle the PKCE code challenge if present + if (request.getExtensions().containsKey(CODE_CHALLENGE)) { + String challenge = (String) request.getExtensions().get(CODE_CHALLENGE); + PKCEAlgorithm alg = PKCEAlgorithm.parse((String) request.getExtensions().get(CODE_CHALLENGE_METHOD)); + + String verifier = request.getRequestParameters().get(CODE_VERIFIER); + + if (alg.equals(PKCEAlgorithm.plain)) { + // do a direct string comparison + if (!challenge.equals(verifier)) { + throw new InvalidRequestException("Code challenge and verifier do not match"); + } + } else if (alg.equals(PKCEAlgorithm.S256)) { + // hash the verifier + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + String hash = Base64URL.encode(digest.digest(verifier.getBytes(StandardCharsets.US_ASCII))).toString(); + if (!challenge.equals(hash)) { + throw new InvalidRequestException("Code challenge and verifier do not match"); + } + } catch (NoSuchAlgorithmException e) { + logger.error("Unknown algorithm for PKCE digest", e); + } + } + + } + + OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();//accessTokenFactory.createNewAccessToken(); // attach the client @@ -185,7 +223,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi // inherit the scope from the auth, but make a new set so it is //not unmodifiable. Unmodifiables don't play nicely with Eclipselink, which //wants to use the clone operation. - Set scopes = scopeService.fromStrings(clientAuth.getScope()); + Set scopes = scopeService.fromStrings(request.getScope()); // remove any of the special system scopes scopes = scopeService.removeReservedScopes(scopes); diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java index e5c97c356..7d697e078 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectOAuth2RequestFactory.java @@ -17,7 +17,7 @@ package org.mitre.openid.connect.request; -import static org.mitre.openid.connect.request.ConnectRequestParameters.AUD; +import static org.mitre.openid.connect.request.ConnectRequestParameters.*; import static org.mitre.openid.connect.request.ConnectRequestParameters.CLAIMS; import static org.mitre.openid.connect.request.ConnectRequestParameters.CLIENT_ID; import static org.mitre.openid.connect.request.ConnectRequestParameters.DISPLAY; @@ -41,6 +41,7 @@ import org.mitre.jwt.encryption.service.JWTEncryptionAndDecryptionService; import org.mitre.jwt.signer.service.JWTSigningAndValidationService; import org.mitre.jwt.signer.service.impl.ClientKeyCacheService; import org.mitre.oauth2.model.ClientDetailsEntity; +import org.mitre.oauth2.model.PKCEAlgorithm; import org.mitre.oauth2.service.ClientDetailsEntityService; import org.mitre.oauth2.service.SystemScopeService; import org.slf4j.Logger; @@ -138,6 +139,16 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { request.getExtensions().put(AUD, inputParams.get(AUD)); } + if (inputParams.containsKey(CODE_CHALLENGE)) { + request.getExtensions().put(CODE_CHALLENGE, inputParams.get(CODE_CHALLENGE)); + if (inputParams.containsKey(CODE_CHALLENGE_METHOD)) { + request.getExtensions().put(CODE_CHALLENGE_METHOD, inputParams.get(CODE_CHALLENGE_METHOD)); + } else { + // if the client doesn't specify a code challenge transformation method, it's "plain" + request.getExtensions().put(CODE_CHALLENGE_METHOD, PKCEAlgorithm.plain.getName()); + } + + } if (inputParams.containsKey(REQUEST)) { request.getExtensions().put(REQUEST, inputParams.get(REQUEST)); diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectRequestParameters.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectRequestParameters.java index f0858423a..3b71edb63 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectRequestParameters.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/request/ConnectRequestParameters.java @@ -46,5 +46,10 @@ public interface ConnectRequestParameters { // audience public String AUD = "aud"; - + + // PKCE + public String CODE_CHALLENGE = "code_challenge"; + public String CODE_CHALLENGE_METHOD = "code_challenge_method"; + public String CODE_VERIFIER = "code_verifier"; + }