diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java index d518282e0..bac354398 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java @@ -16,12 +16,10 @@ ******************************************************************************/ package org.mitre.openid.connect; -import java.io.Serializable; import java.security.NoSuchAlgorithmException; import java.security.spec.InvalidKeySpecException; import java.text.ParseException; import java.util.Collections; -import java.util.HashMap; import java.util.Map; import java.util.Set; @@ -45,7 +43,9 @@ import org.springframework.stereotype.Component; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Maps; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; import com.nimbusds.jose.Algorithm; import com.nimbusds.jose.JWEObject.State; import com.nimbusds.jose.JWSAlgorithm; @@ -76,6 +76,8 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { @Autowired private JwtEncryptionAndDecryptionService encryptionService; + private JsonParser parser = new JsonParser(); + /** * Constructor with arguments * @@ -97,43 +99,43 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { @Override public AuthorizationRequest createAuthorizationRequest(Map inputParams) { - Map parameters = processRequestObject(inputParams); - String clientId = parameters.get("client_id"); - ClientDetails client = null; - - if (clientId != null) { - client = clientDetailsService.loadClientByClientId(clientId); - } - - AuthorizationRequest request = new AuthorizationRequest(parameters, Collections. emptyMap(), - parameters.get(OAuth2Utils.CLIENT_ID), - OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.SCOPE)), null, - null, false, parameters.get(OAuth2Utils.STATE), - parameters.get(OAuth2Utils.REDIRECT_URI), - OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.RESPONSE_TYPE))); - - Set scopes = OAuth2Utils.parseParameterList(parameters.get("scope")); - if ((scopes == null || scopes.isEmpty()) && client != null) { - Set clientScopes = client.getScope(); - scopes = clientScopes; - } - - request.setScope(scopes); + AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections. emptyMap(), + inputParams.get(OAuth2Utils.CLIENT_ID), + OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null, + null, false, inputParams.get(OAuth2Utils.STATE), + inputParams.get(OAuth2Utils.REDIRECT_URI), + OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE))); //Add extension parameters to the 'extensions' map - Map extensions = Maps.newHashMap(); - if (parameters.containsKey("prompt")) { - extensions.put("prompt", parameters.get("prompt")); + + if (inputParams.containsKey("prompt")) { + request.getExtensions().put("prompt", inputParams.get("prompt")); } - if (parameters.containsKey("request")) { - extensions.put("request", parameters.get("request")); - } - if (parameters.containsKey("nonce")) { - extensions.put("nonce", parameters.get("nonce")); + if (inputParams.containsKey("nonce")) { + request.getExtensions().put("nonce", inputParams.get("nonce")); } - request.setExtensions(extensions); + if (inputParams.containsKey("claims")) { + JsonObject claimsRequest = parseClaimRequest(inputParams.get("claims")); + if (claimsRequest != null) { + request.getExtensions().put("claims", claimsRequest.toString()); + } + } + + if (inputParams.containsKey("request")) { + request.getExtensions().put("request", inputParams.get("request")); + processRequestObject(inputParams.get("request"), request); + } + + + if ((request.getScope() == null || request.getScope().isEmpty())) { + if (request.getClientId() != null) { + ClientDetails client = clientDetailsService.loadClientByClientId(request.getClientId()); + Set clientScopes = client.getScope(); + request.setScope(clientScopes); + } + } return request; } @@ -142,49 +144,28 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { * @param inputParams * @return */ - private Map processRequestObject(Map inputParams) { - - String jwtString = inputParams.get("request"); - - // if there's no request object, bail early - if (Strings.isNullOrEmpty(jwtString)) { - return inputParams; - } - - // start by copying over what's already in there - Map parameters = new HashMap(inputParams); + private void processRequestObject(String jwtString, AuthorizationRequest request) { // parse the request object try { JWT jwt = JWTParser.parse(jwtString); - /* - if (jwt instanceof EncryptedJWT) { - // TODO: it's an encrypted JWT, decrypt it and use it - } else { - // it's not encrypted... - } - */ - - - - // TODO: check parameter consistency, move keys to constants if (jwt instanceof SignedJWT) { // it's a signed JWT, check the signature SignedJWT signedJwt = (SignedJWT)jwt; - - String clientId = inputParams.get("client_id"); - if (clientId == null) { - clientId = signedJwt.getJWTClaimsSet().getStringClaim("client_id"); + + // need to check clientId first so that we can load the client to check other fields + if (request.getClientId() == null) { + request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim("client_id")); } - ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId); + ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); if (client == null) { - throw new InvalidClientException("Client not found: " + clientId); + throw new InvalidClientException("Client not found: " + request.getClientId()); } @@ -239,15 +220,15 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { } else if (jwt instanceof PlainJWT) { PlainJWT plainJwt = (PlainJWT)jwt; - String clientId = inputParams.get("client_id"); - if (clientId == null) { - clientId = plainJwt.getJWTClaimsSet().getStringClaim("client_id"); + // need to check clientId first so that we can load the client to check other fields + if (request.getClientId() == null) { + request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim("client_id")); } - ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId); + ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); if (client == null) { - throw new InvalidClientException("Client not found: " + clientId); + throw new InvalidClientException("Client not found: " + request.getClientId()); } if (client.getRequestObjectSigningAlg() == null) { @@ -270,13 +251,28 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { throw new InvalidClientException("Unable to decrypt the request object"); } + // need to check clientId first so that we can load the client to check other fields + if (request.getClientId() == null) { + request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim("client_id")); + } + + ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); + + if (client == null) { + throw new InvalidClientException("Client not found: " + request.getClientId()); + } + + } + /* + * Claims precedence order logic: + * * if (in Claims): * if (in params): * if (equal): - * all set + * OK * else (not equal): * error * else (not in params): @@ -285,64 +281,102 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory { * we don't care */ + // now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims + ReadOnlyJWTClaimsSet claims = jwt.getJWTClaimsSet(); - String clientId = claims.getStringClaim("client_id"); - if (clientId != null) { - parameters.put("client_id", clientId); - } - - String responseTypes = claims.getStringClaim("response_type"); - if (responseTypes != null) { - parameters.put("response_type", responseTypes); + Set responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim("response_type")); + if (responseTypes != null && !responseTypes.isEmpty()) { + if (request.getResponseTypes() == null || request.getResponseTypes().isEmpty()) { + // if it's null or empty, we fill in the value with what we were passed + request.setResponseTypes(responseTypes); + } else if (!request.getResponseTypes().equals(responseTypes)) { + // FIXME: throw an error + } } - if (claims.getStringClaim("redirect_uri") != null) { - if (inputParams.containsKey("redirect_uri") == false) { - parameters.put("redirect_uri", claims.getStringClaim("redirect_uri")); + String redirectUri = claims.getStringClaim("redirect_uri"); + if (redirectUri != null) { + if (request.getRedirectUri() == null) { + request.setRedirectUri(redirectUri); + } else if (!request.getRedirectUri().equals(redirectUri)) { + // FIXME: throw an error } } String state = claims.getStringClaim("state"); if(state != null) { - if (inputParams.containsKey("state") == false) { - parameters.put("state", state); + if (request.getState() == null) { + request.setState(state); + } else if (!request.getState().equals(state)) { + // FIXME: throw an error } } String nonce = claims.getStringClaim("nonce"); if(nonce != null) { - if (inputParams.containsKey("nonce") == false) { - parameters.put("nonce", nonce); + if (request.getExtensions().get("nonce") == null) { + request.getExtensions().put("nonce", nonce); + } else if (!request.getExtensions().get("nonce").equals(nonce)) { + // FIXME: throw an error } } String display = claims.getStringClaim("display"); if (display != null) { - if (inputParams.containsKey("display") == false) { - parameters.put("display", display); + if (request.getExtensions().get("display") == null) { + request.getExtensions().put("display", display); + } else if (!request.getExtensions().get("display").equals(display)) { + // FIXME: throw an error } } String prompt = claims.getStringClaim("prompt"); if (prompt != null) { - if (inputParams.containsKey("prompt") == false) { - parameters.put("prompt", prompt); + if (request.getExtensions().get("prompt") == null) { + request.getExtensions().put("prompt", prompt); + } else if (!request.getExtensions().get("prompt").equals(prompt)) { + // FIXME: throw an error } } - - String scope = claims.getStringClaim("scope"); - if (scope != null) { - if (inputParams.containsKey("scope") == false) { - parameters.put("scope", scope); + + Set scope = OAuth2Utils.parseParameterList(claims.getStringClaim("scope")); + if (scope != null && !scope.isEmpty()) { + if (request.getScope() == null || request.getScope().isEmpty()) { + request.setScope(scope); + } else if (!request.getScope().equals(scope)) { + // FIXME: throw an error } } + + JsonObject claimRequest = parseClaimRequest(claims.getStringClaim("claims")); + if (claimRequest != null) { + if (request.getExtensions().get("claims") == null) { + // we save the string because the object might not serialize + request.getExtensions().put("claims", claimRequest.toString()); + } else if (parseClaimRequest(request.getExtensions().get("claims").toString()).equals(claimRequest)) { + // FIXME: throw an error + } + } + } catch (ParseException e) { logger.error("ParseException while parsing RequestObject:", e); } - return parameters; } + /** + * @param claimRequestString + * @return + */ + private JsonObject parseClaimRequest(String claimRequestString) { + JsonElement el = parser .parse(claimRequestString); + if (el != null && el.isJsonObject()) { + return el.getAsJsonObject(); + } else { + return null; + } + } + /** * Create a symmetric signing and validation service for the given client *