refactor processing of request object
parent
47d304851d
commit
b396610f35
|
@ -16,12 +16,10 @@
|
|||
******************************************************************************/
|
||||
package org.mitre.openid.connect;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.security.NoSuchAlgorithmException;
|
||||
import java.security.spec.InvalidKeySpecException;
|
||||
import java.text.ParseException;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -45,7 +43,9 @@ import org.springframework.stereotype.Component;
|
|||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
import com.nimbusds.jose.Algorithm;
|
||||
import com.nimbusds.jose.JWEObject.State;
|
||||
import com.nimbusds.jose.JWSAlgorithm;
|
||||
|
@ -76,6 +76,8 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
@Autowired
|
||||
private JwtEncryptionAndDecryptionService encryptionService;
|
||||
|
||||
private JsonParser parser = new JsonParser();
|
||||
|
||||
/**
|
||||
* Constructor with arguments
|
||||
*
|
||||
|
@ -97,43 +99,43 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
@Override
|
||||
public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) {
|
||||
|
||||
Map<String, String> parameters = processRequestObject(inputParams);
|
||||
|
||||
String clientId = parameters.get("client_id");
|
||||
ClientDetails client = null;
|
||||
|
||||
if (clientId != null) {
|
||||
client = clientDetailsService.loadClientByClientId(clientId);
|
||||
}
|
||||
|
||||
AuthorizationRequest request = new AuthorizationRequest(parameters, Collections.<String, String> emptyMap(),
|
||||
parameters.get(OAuth2Utils.CLIENT_ID),
|
||||
OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.SCOPE)), null,
|
||||
null, false, parameters.get(OAuth2Utils.STATE),
|
||||
parameters.get(OAuth2Utils.REDIRECT_URI),
|
||||
OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.RESPONSE_TYPE)));
|
||||
|
||||
Set<String> scopes = OAuth2Utils.parseParameterList(parameters.get("scope"));
|
||||
if ((scopes == null || scopes.isEmpty()) && client != null) {
|
||||
Set<String> clientScopes = client.getScope();
|
||||
scopes = clientScopes;
|
||||
}
|
||||
|
||||
request.setScope(scopes);
|
||||
AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(),
|
||||
inputParams.get(OAuth2Utils.CLIENT_ID),
|
||||
OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null,
|
||||
null, false, inputParams.get(OAuth2Utils.STATE),
|
||||
inputParams.get(OAuth2Utils.REDIRECT_URI),
|
||||
OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE)));
|
||||
|
||||
//Add extension parameters to the 'extensions' map
|
||||
Map<String, Serializable> extensions = Maps.newHashMap();
|
||||
if (parameters.containsKey("prompt")) {
|
||||
extensions.put("prompt", parameters.get("prompt"));
|
||||
|
||||
if (inputParams.containsKey("prompt")) {
|
||||
request.getExtensions().put("prompt", inputParams.get("prompt"));
|
||||
}
|
||||
if (parameters.containsKey("request")) {
|
||||
extensions.put("request", parameters.get("request"));
|
||||
}
|
||||
if (parameters.containsKey("nonce")) {
|
||||
extensions.put("nonce", parameters.get("nonce"));
|
||||
if (inputParams.containsKey("nonce")) {
|
||||
request.getExtensions().put("nonce", inputParams.get("nonce"));
|
||||
}
|
||||
|
||||
request.setExtensions(extensions);
|
||||
if (inputParams.containsKey("claims")) {
|
||||
JsonObject claimsRequest = parseClaimRequest(inputParams.get("claims"));
|
||||
if (claimsRequest != null) {
|
||||
request.getExtensions().put("claims", claimsRequest.toString());
|
||||
}
|
||||
}
|
||||
|
||||
if (inputParams.containsKey("request")) {
|
||||
request.getExtensions().put("request", inputParams.get("request"));
|
||||
processRequestObject(inputParams.get("request"), request);
|
||||
}
|
||||
|
||||
|
||||
if ((request.getScope() == null || request.getScope().isEmpty())) {
|
||||
if (request.getClientId() != null) {
|
||||
ClientDetails client = clientDetailsService.loadClientByClientId(request.getClientId());
|
||||
Set<String> clientScopes = client.getScope();
|
||||
request.setScope(clientScopes);
|
||||
}
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
|
@ -142,49 +144,28 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
* @param inputParams
|
||||
* @return
|
||||
*/
|
||||
private Map<String, String> processRequestObject(Map<String, String> inputParams) {
|
||||
|
||||
String jwtString = inputParams.get("request");
|
||||
|
||||
// if there's no request object, bail early
|
||||
if (Strings.isNullOrEmpty(jwtString)) {
|
||||
return inputParams;
|
||||
}
|
||||
|
||||
// start by copying over what's already in there
|
||||
Map<String, String> parameters = new HashMap<String, String>(inputParams);
|
||||
private void processRequestObject(String jwtString, AuthorizationRequest request) {
|
||||
|
||||
// parse the request object
|
||||
try {
|
||||
JWT jwt = JWTParser.parse(jwtString);
|
||||
|
||||
/*
|
||||
if (jwt instanceof EncryptedJWT) {
|
||||
// TODO: it's an encrypted JWT, decrypt it and use it
|
||||
} else {
|
||||
// it's not encrypted...
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
|
||||
|
||||
// TODO: check parameter consistency, move keys to constants
|
||||
|
||||
if (jwt instanceof SignedJWT) {
|
||||
// it's a signed JWT, check the signature
|
||||
|
||||
SignedJWT signedJwt = (SignedJWT)jwt;
|
||||
|
||||
String clientId = inputParams.get("client_id");
|
||||
if (clientId == null) {
|
||||
clientId = signedJwt.getJWTClaimsSet().getStringClaim("client_id");
|
||||
|
||||
// need to check clientId first so that we can load the client to check other fields
|
||||
if (request.getClientId() == null) {
|
||||
request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim("client_id"));
|
||||
}
|
||||
|
||||
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId);
|
||||
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
|
||||
|
||||
if (client == null) {
|
||||
throw new InvalidClientException("Client not found: " + clientId);
|
||||
throw new InvalidClientException("Client not found: " + request.getClientId());
|
||||
}
|
||||
|
||||
|
||||
|
@ -239,15 +220,15 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
} else if (jwt instanceof PlainJWT) {
|
||||
PlainJWT plainJwt = (PlainJWT)jwt;
|
||||
|
||||
String clientId = inputParams.get("client_id");
|
||||
if (clientId == null) {
|
||||
clientId = plainJwt.getJWTClaimsSet().getStringClaim("client_id");
|
||||
// need to check clientId first so that we can load the client to check other fields
|
||||
if (request.getClientId() == null) {
|
||||
request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim("client_id"));
|
||||
}
|
||||
|
||||
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId);
|
||||
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
|
||||
|
||||
if (client == null) {
|
||||
throw new InvalidClientException("Client not found: " + clientId);
|
||||
throw new InvalidClientException("Client not found: " + request.getClientId());
|
||||
}
|
||||
|
||||
if (client.getRequestObjectSigningAlg() == null) {
|
||||
|
@ -270,13 +251,28 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
throw new InvalidClientException("Unable to decrypt the request object");
|
||||
}
|
||||
|
||||
// need to check clientId first so that we can load the client to check other fields
|
||||
if (request.getClientId() == null) {
|
||||
request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim("client_id"));
|
||||
}
|
||||
|
||||
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
|
||||
|
||||
if (client == null) {
|
||||
throw new InvalidClientException("Client not found: " + request.getClientId());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Claims precedence order logic:
|
||||
*
|
||||
* if (in Claims):
|
||||
* if (in params):
|
||||
* if (equal):
|
||||
* all set
|
||||
* OK
|
||||
* else (not equal):
|
||||
* error
|
||||
* else (not in params):
|
||||
|
@ -285,64 +281,102 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
|
|||
* we don't care
|
||||
*/
|
||||
|
||||
// now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims
|
||||
|
||||
ReadOnlyJWTClaimsSet claims = jwt.getJWTClaimsSet();
|
||||
|
||||
String clientId = claims.getStringClaim("client_id");
|
||||
if (clientId != null) {
|
||||
parameters.put("client_id", clientId);
|
||||
}
|
||||
|
||||
String responseTypes = claims.getStringClaim("response_type");
|
||||
if (responseTypes != null) {
|
||||
parameters.put("response_type", responseTypes);
|
||||
Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim("response_type"));
|
||||
if (responseTypes != null && !responseTypes.isEmpty()) {
|
||||
if (request.getResponseTypes() == null || request.getResponseTypes().isEmpty()) {
|
||||
// if it's null or empty, we fill in the value with what we were passed
|
||||
request.setResponseTypes(responseTypes);
|
||||
} else if (!request.getResponseTypes().equals(responseTypes)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
if (claims.getStringClaim("redirect_uri") != null) {
|
||||
if (inputParams.containsKey("redirect_uri") == false) {
|
||||
parameters.put("redirect_uri", claims.getStringClaim("redirect_uri"));
|
||||
String redirectUri = claims.getStringClaim("redirect_uri");
|
||||
if (redirectUri != null) {
|
||||
if (request.getRedirectUri() == null) {
|
||||
request.setRedirectUri(redirectUri);
|
||||
} else if (!request.getRedirectUri().equals(redirectUri)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
String state = claims.getStringClaim("state");
|
||||
if(state != null) {
|
||||
if (inputParams.containsKey("state") == false) {
|
||||
parameters.put("state", state);
|
||||
if (request.getState() == null) {
|
||||
request.setState(state);
|
||||
} else if (!request.getState().equals(state)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
String nonce = claims.getStringClaim("nonce");
|
||||
if(nonce != null) {
|
||||
if (inputParams.containsKey("nonce") == false) {
|
||||
parameters.put("nonce", nonce);
|
||||
if (request.getExtensions().get("nonce") == null) {
|
||||
request.getExtensions().put("nonce", nonce);
|
||||
} else if (!request.getExtensions().get("nonce").equals(nonce)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
String display = claims.getStringClaim("display");
|
||||
if (display != null) {
|
||||
if (inputParams.containsKey("display") == false) {
|
||||
parameters.put("display", display);
|
||||
if (request.getExtensions().get("display") == null) {
|
||||
request.getExtensions().put("display", display);
|
||||
} else if (!request.getExtensions().get("display").equals(display)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
String prompt = claims.getStringClaim("prompt");
|
||||
if (prompt != null) {
|
||||
if (inputParams.containsKey("prompt") == false) {
|
||||
parameters.put("prompt", prompt);
|
||||
if (request.getExtensions().get("prompt") == null) {
|
||||
request.getExtensions().put("prompt", prompt);
|
||||
} else if (!request.getExtensions().get("prompt").equals(prompt)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
String scope = claims.getStringClaim("scope");
|
||||
if (scope != null) {
|
||||
if (inputParams.containsKey("scope") == false) {
|
||||
parameters.put("scope", scope);
|
||||
|
||||
Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim("scope"));
|
||||
if (scope != null && !scope.isEmpty()) {
|
||||
if (request.getScope() == null || request.getScope().isEmpty()) {
|
||||
request.setScope(scope);
|
||||
} else if (!request.getScope().equals(scope)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
JsonObject claimRequest = parseClaimRequest(claims.getStringClaim("claims"));
|
||||
if (claimRequest != null) {
|
||||
if (request.getExtensions().get("claims") == null) {
|
||||
// we save the string because the object might not serialize
|
||||
request.getExtensions().put("claims", claimRequest.toString());
|
||||
} else if (parseClaimRequest(request.getExtensions().get("claims").toString()).equals(claimRequest)) {
|
||||
// FIXME: throw an error
|
||||
}
|
||||
}
|
||||
|
||||
} catch (ParseException e) {
|
||||
logger.error("ParseException while parsing RequestObject:", e);
|
||||
}
|
||||
return parameters;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param claimRequestString
|
||||
* @return
|
||||
*/
|
||||
private JsonObject parseClaimRequest(String claimRequestString) {
|
||||
JsonElement el = parser .parse(claimRequestString);
|
||||
if (el != null && el.isJsonObject()) {
|
||||
return el.getAsJsonObject();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a symmetric signing and validation service for the given client
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue