From b396610f35dae52656631f9bcdaeada89d734ede Mon Sep 17 00:00:00 2001
From: Justin Richer <jricher@mitre.org>
Date: Wed, 18 Sep 2013 17:13:34 -0400
Subject: [PATCH] refactor processing of request object

---
 .../connect/ConnectOAuth2RequestFactory.java  | 220 ++++++++++--------
 1 file changed, 127 insertions(+), 93 deletions(-)

diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java
index d518282e0..bac354398 100644
--- a/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java
+++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/ConnectOAuth2RequestFactory.java
@@ -16,12 +16,10 @@
  ******************************************************************************/
 package org.mitre.openid.connect;
 
-import java.io.Serializable;
 import java.security.NoSuchAlgorithmException;
 import java.security.spec.InvalidKeySpecException;
 import java.text.ParseException;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 
@@ -45,7 +43,9 @@ import org.springframework.stereotype.Component;
 
 import com.google.common.base.Strings;
 import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.Maps;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParser;
 import com.nimbusds.jose.Algorithm;
 import com.nimbusds.jose.JWEObject.State;
 import com.nimbusds.jose.JWSAlgorithm;
@@ -76,6 +76,8 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
 	@Autowired
 	private JwtEncryptionAndDecryptionService encryptionService;
 
+	private JsonParser parser = new JsonParser();
+
 	/**
 	 * Constructor with arguments
 	 * 
@@ -97,43 +99,43 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
 	@Override
 	public AuthorizationRequest createAuthorizationRequest(Map<String, String> inputParams) {
 
-		Map<String, String> parameters = processRequestObject(inputParams);
 
-		String clientId = parameters.get("client_id");
-		ClientDetails client = null;
-
-		if (clientId != null) {
-			client = clientDetailsService.loadClientByClientId(clientId);
-		}
-
-		AuthorizationRequest request = new AuthorizationRequest(parameters, Collections.<String, String> emptyMap(),
-				parameters.get(OAuth2Utils.CLIENT_ID),
-				OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.SCOPE)), null,
-				null, false, parameters.get(OAuth2Utils.STATE),
-				parameters.get(OAuth2Utils.REDIRECT_URI),
-				OAuth2Utils.parseParameterList(parameters.get(OAuth2Utils.RESPONSE_TYPE)));
-
-		Set<String> scopes = OAuth2Utils.parseParameterList(parameters.get("scope"));
-		if ((scopes == null || scopes.isEmpty()) && client != null) {
-			Set<String> clientScopes = client.getScope();
-			scopes = clientScopes;
-		}
-
-		request.setScope(scopes);
+		AuthorizationRequest request = new AuthorizationRequest(inputParams, Collections.<String, String> emptyMap(),
+				inputParams.get(OAuth2Utils.CLIENT_ID),
+				OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.SCOPE)), null,
+				null, false, inputParams.get(OAuth2Utils.STATE),
+				inputParams.get(OAuth2Utils.REDIRECT_URI),
+				OAuth2Utils.parseParameterList(inputParams.get(OAuth2Utils.RESPONSE_TYPE)));
 		
 		//Add extension parameters to the 'extensions' map
-		Map<String, Serializable> extensions = Maps.newHashMap();
-		if (parameters.containsKey("prompt")) {
-			extensions.put("prompt", parameters.get("prompt"));
+		
+		if (inputParams.containsKey("prompt")) {
+			request.getExtensions().put("prompt", inputParams.get("prompt"));
 		}
-		if (parameters.containsKey("request")) {
-			extensions.put("request", parameters.get("request"));
-		}
-		if (parameters.containsKey("nonce")) {
-			extensions.put("nonce", parameters.get("nonce"));
+		if (inputParams.containsKey("nonce")) {
+			request.getExtensions().put("nonce", inputParams.get("nonce"));
 		}
 		
-		request.setExtensions(extensions);
+		if (inputParams.containsKey("claims")) {
+			JsonObject claimsRequest = parseClaimRequest(inputParams.get("claims"));
+			if (claimsRequest != null) {
+				request.getExtensions().put("claims", claimsRequest.toString());
+			}
+		}
+		
+		if (inputParams.containsKey("request")) {
+			request.getExtensions().put("request", inputParams.get("request"));
+			processRequestObject(inputParams.get("request"), request);
+		}
+		
+
+		if ((request.getScope() == null || request.getScope().isEmpty())) {
+			if (request.getClientId() != null) {
+				ClientDetails client = clientDetailsService.loadClientByClientId(request.getClientId());
+				Set<String> clientScopes = client.getScope();
+				request.setScope(clientScopes);
+			}
+		}
 
 		return request;
 	}
@@ -142,49 +144,28 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
 	 * @param inputParams
 	 * @return
 	 */
-	private Map<String, String> processRequestObject(Map<String, String> inputParams) {
-
-		String jwtString = inputParams.get("request");
-
-		// if there's no request object, bail early
-		if (Strings.isNullOrEmpty(jwtString)) {
-			return inputParams;
-		}
-
-		// start by copying over what's already in there
-		Map<String, String> parameters = new HashMap<String, String>(inputParams);
+	private void processRequestObject(String jwtString, AuthorizationRequest request) {
 
 		// parse the request object
 		try {
 			JWT jwt = JWTParser.parse(jwtString);
 
-			/*
-			if (jwt instanceof EncryptedJWT) {
-				// TODO: it's an encrypted JWT, decrypt it and use it
-			} else {
-				// it's not encrypted...
-			}
-			*/
-			
-			
-			
-			
 			// TODO: check parameter consistency, move keys to constants
 			
 			if (jwt instanceof SignedJWT) {
 				// it's a signed JWT, check the signature
 				
 				SignedJWT signedJwt = (SignedJWT)jwt;
-			
-				String clientId = inputParams.get("client_id");
-				if (clientId == null) {
-					clientId = signedJwt.getJWTClaimsSet().getStringClaim("client_id");
+
+				// need to check clientId first so that we can load the client to check other fields
+				if (request.getClientId() == null) {
+					request.setClientId(signedJwt.getJWTClaimsSet().getStringClaim("client_id"));
 				}
 				
-				ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId);
+				ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
 				
 				if (client == null) {
-					throw new InvalidClientException("Client not found: " + clientId);
+					throw new InvalidClientException("Client not found: " + request.getClientId());
 				}
 				
 				
@@ -239,15 +220,15 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
 			} else if (jwt instanceof PlainJWT) {
 				PlainJWT plainJwt = (PlainJWT)jwt;
 				
-				String clientId = inputParams.get("client_id");
-				if (clientId == null) {
-					clientId = plainJwt.getJWTClaimsSet().getStringClaim("client_id");
+				// need to check clientId first so that we can load the client to check other fields
+				if (request.getClientId() == null) {
+					request.setClientId(plainJwt.getJWTClaimsSet().getStringClaim("client_id"));
 				}
 				
-				ClientDetailsEntity client = clientDetailsService.loadClientByClientId(clientId);
+				ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
 				
 				if (client == null) {
-					throw new InvalidClientException("Client not found: " + clientId);
+					throw new InvalidClientException("Client not found: " + request.getClientId());
 				}
 				
 				if (client.getRequestObjectSigningAlg() == null) { 
@@ -270,13 +251,28 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
 					throw new InvalidClientException("Unable to decrypt the request object");
 				}
 				
+				// need to check clientId first so that we can load the client to check other fields
+				if (request.getClientId() == null) {
+					request.setClientId(encryptedJWT.getJWTClaimsSet().getStringClaim("client_id"));
+				}
+				
+				ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
+				
+				if (client == null) {
+					throw new InvalidClientException("Client not found: " + request.getClientId());
+				}
+				
+				
 			}
 			
+			
 			/*
+			 * Claims precedence order logic:
+			 * 
 			 * if (in Claims):
 			 * 		if (in params):
 			 * 			if (equal):
-			 * 				all set
+			 * 				OK
 			 * 			else (not equal):
 			 * 				error
 			 * 		else (not in params):
@@ -285,64 +281,102 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
 			 * 		we don't care
 			 */
 
+			// now that we've got the JWT, and it's been parsed, validated, and/or decrypted, we can process the claims
+			
 			ReadOnlyJWTClaimsSet claims = jwt.getJWTClaimsSet();
 			
-			String clientId = claims.getStringClaim("client_id");
-			if (clientId != null) {
-				parameters.put("client_id", clientId);
-			}
-			
-			String responseTypes = claims.getStringClaim("response_type");
-			if (responseTypes != null) {
-				parameters.put("response_type", responseTypes);
+			Set<String> responseTypes = OAuth2Utils.parseParameterList(claims.getStringClaim("response_type"));
+			if (responseTypes != null && !responseTypes.isEmpty()) {
+				if (request.getResponseTypes() == null || request.getResponseTypes().isEmpty()) {
+					// if it's null or empty, we fill in the value with what we were passed
+					request.setResponseTypes(responseTypes);
+				} else if (!request.getResponseTypes().equals(responseTypes)) {
+					// FIXME: throw an error					
+				}
 			}
 
-			if (claims.getStringClaim("redirect_uri") != null) {
-				if (inputParams.containsKey("redirect_uri") == false) {
-					parameters.put("redirect_uri", claims.getStringClaim("redirect_uri"));
+			String redirectUri = claims.getStringClaim("redirect_uri"); 
+			if (redirectUri != null) {
+				if (request.getRedirectUri() == null) {
+					request.setRedirectUri(redirectUri);
+				} else if (!request.getRedirectUri().equals(redirectUri)) {
+					// FIXME: throw an error
 				}
 			}
 
 			String state = claims.getStringClaim("state");
 			if(state != null) {
-				if (inputParams.containsKey("state") == false) {
-					parameters.put("state", state);
+				if (request.getState() == null) {
+					request.setState(state);
+				} else if (!request.getState().equals(state)) {
+					// FIXME: throw an error
 				}
 			}
 
 			String nonce = claims.getStringClaim("nonce");
 			if(nonce != null) {
-				if (inputParams.containsKey("nonce") == false) {
-					parameters.put("nonce", nonce);
+				if (request.getExtensions().get("nonce") == null) {
+					request.getExtensions().put("nonce", nonce);
+				} else if (!request.getExtensions().get("nonce").equals(nonce)) {
+					// FIXME: throw an error
 				}
 			}
 
 			String display = claims.getStringClaim("display");
 			if (display != null) {
-				if (inputParams.containsKey("display") == false) {
-					parameters.put("display", display);
+				if (request.getExtensions().get("display") == null) {
+					request.getExtensions().put("display", display);
+				} else if (!request.getExtensions().get("display").equals(display)) {
+					// FIXME: throw an error
 				}
 			}
 
 			String prompt = claims.getStringClaim("prompt");
 			if (prompt != null) {
-				if (inputParams.containsKey("prompt") == false) {
-					parameters.put("prompt", prompt);
+				if (request.getExtensions().get("prompt") == null) {
+					request.getExtensions().put("prompt", prompt);
+				} else if (!request.getExtensions().get("prompt").equals(prompt)) {
+					// FIXME: throw an error
 				}
 			}
-
-			String scope = claims.getStringClaim("scope");
-			if (scope != null) {
-				if (inputParams.containsKey("scope") == false) {
-					parameters.put("scope", scope);
+			
+			Set<String> scope = OAuth2Utils.parseParameterList(claims.getStringClaim("scope"));
+			if (scope != null && !scope.isEmpty()) {
+				if (request.getScope() == null || request.getScope().isEmpty()) {
+					request.setScope(scope);
+				} else if (!request.getScope().equals(scope)) {
+					// FIXME: throw an error
 				}
 			}
+			
+			JsonObject claimRequest = parseClaimRequest(claims.getStringClaim("claims"));
+			if (claimRequest != null) {
+				if (request.getExtensions().get("claims") == null) {
+					// we save the string because the object might not serialize
+					request.getExtensions().put("claims", claimRequest.toString());
+				} else if (parseClaimRequest(request.getExtensions().get("claims").toString()).equals(claimRequest)) {
+					// FIXME: throw an error
+				}
+			}
+			
 		} catch (ParseException e) {
 			logger.error("ParseException while parsing RequestObject:", e);
 		}
-		return parameters;
 	}
 
+	/**
+	 * @param claimRequestString
+	 * @return
+	 */
+    private JsonObject parseClaimRequest(String claimRequestString) {
+    	JsonElement el = parser .parse(claimRequestString);
+    	if (el != null && el.isJsonObject()) {
+    		return el.getAsJsonObject();
+    	} else {
+    		return null;
+    	}
+    }
+
 	/**
 	 * Create a symmetric signing and validation service for the given client
 	 *