client refactoring, and javadocing

pull/59/head M-1
U-MITRE\mjwalsh 2012-03-26 14:18:54 -04:00
parent b8c953281e
commit c84c751991
1 changed files with 373 additions and 333 deletions

View File

@ -9,13 +9,13 @@ import java.security.KeyPair;
import java.security.KeyPairGenerator; import java.security.KeyPairGenerator;
import java.security.PrivateKey; import java.security.PrivateKey;
import java.security.PublicKey; import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.Signature; import java.security.Signature;
import java.util.Arrays; import java.util.Arrays;
import java.util.Enumeration; import java.util.Enumeration;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.http.Cookie; import javax.servlet.http.Cookie;
@ -23,8 +23,6 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import org.apache.commons.codec.binary.Base64; import org.apache.commons.codec.binary.Base64;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.HttpClient; import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient; import org.apache.http.impl.client.DefaultHttpClient;
import org.mitre.openid.connect.model.IdToken; import org.mitre.openid.connect.model.IdToken;
@ -33,6 +31,7 @@ import org.springframework.security.authentication.AuthenticationServiceExceptio
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpClientErrorException;
@ -93,21 +92,78 @@ public class OpenIdConnectAuthenticationFilter extends
private final static String NONCE_SIGNATURE_COOKIE_NAME = "nonce"; private final static String NONCE_SIGNATURE_COOKIE_NAME = "nonce";
private final static String FILTER_PROCESSES_URL = "/openid_connect_login"; private final static String FILTER_PROCESSES_URL = "/openid_connect_login";
/**
* Builds the redirect_uri that will be sent to the Authorization Endpoint.
* By default returns the URL of the current request minus zero or more
* fields of the URL's query string.
*
* @param request
* the current request which is being processed by this filter
* @param ignoreFields
* an array of field names to ignore.
* @return a URL built from the messaged parameters.
*/
public static String buildRedirectURI(HttpServletRequest request,
String[] ignoreFields) {
List<String> ignore = (ignoreFields != null) ? Arrays
.asList(ignoreFields) : null;
boolean isFirst = true;
StringBuffer sb = request.getRequestURL();
for (Enumeration<?> e = request.getParameterNames(); e
.hasMoreElements();) {
String name = (String) e.nextElement();
if ((ignore == null) || (!ignore.contains(name))) {
// Assume for simplicity that there is only one value
String value = request.getParameter(name);
if (value == null) {
continue;
}
if (isFirst) {
sb.append("?");
isFirst = false;
}
sb.append(name).append("=").append(value);
if (e.hasMoreElements()) {
sb.append("&");
}
}
}
return sb.toString();
}
/** /**
* Return the URL w/ GET parameters * Return the URL w/ GET parameters
* *
* @param baseURI * @param baseURI
* @param params * A String containing the protocol, server address, path, and
* @return * program as per "http://server/path/program"
* @param queryStringFields
* A map where each key is the field name and the associated
* key's value is the field value used to populate the URL's
* query string
* @return A String representing the URL in form of
* http://server/path/program?query_string from the messaged
* parameters.
*/ */
public static String buildURL(String baseURI, public static String buildURL(String baseURI,
Map<String, String> urlVariables) { Map<String, String> queryStringFields) {
StringBuilder URLBuilder = new StringBuilder(baseURI); StringBuilder URLBuilder = new StringBuilder(baseURI);
char appendChar = '?'; char appendChar = '?';
for (Map.Entry<String, String> param : urlVariables.entrySet()) { for (Map.Entry<String, String> param : queryStringFields.entrySet()) {
try { try {
URLBuilder.append(appendChar).append(param.getKey()) URLBuilder.append(appendChar).append(param.getKey())
.append('=') .append('=')
@ -124,7 +180,13 @@ public class OpenIdConnectAuthenticationFilter extends
/** /**
* Returns the signature text for the byte array of data * Returns the signature text for the byte array of data
* *
* @return * @param signer
* The algorithm to sign with
* @param privateKey
* The private key to sign with
* @param data
* The data to be signed
* @return The signature text
*/ */
public static String sign(Signature signer, PrivateKey privateKey, public static String sign(Signature signer, PrivateKey privateKey,
byte[] data) { byte[] data) {
@ -151,8 +213,10 @@ public class OpenIdConnectAuthenticationFilter extends
* Verifies the signature text against the data * Verifies the signature text against the data
* *
* @param data * @param data
* The data
* @param sigText * @param sigText
* @return * The signature text
* @return True if valid, false if not
*/ */
public static boolean verify(Signature signer, PublicKey publicKey, public static boolean verify(Signature signer, PublicKey publicKey,
String data, String sigText) { String data, String sigText) {
@ -197,7 +261,7 @@ public class OpenIdConnectAuthenticationFilter extends
private Signature signer; private Signature signer;
/** /**
* * OpenIdConnectAuthenticationFilter constructor
*/ */
protected OpenIdConnectAuthenticationFilter() { protected OpenIdConnectAuthenticationFilter() {
super(FILTER_PROCESSES_URL); super(FILTER_PROCESSES_URL);
@ -213,34 +277,21 @@ public class OpenIdConnectAuthenticationFilter extends
public void afterPropertiesSet() { public void afterPropertiesSet() {
super.afterPropertiesSet(); super.afterPropertiesSet();
if (errorRedirectURI == null) { Assert.notNull(errorRedirectURI,
throw new IllegalArgumentException(
"An Error Redirect URI must be supplied"); "An Error Redirect URI must be supplied");
}
if (authorizationEndpointURI == null) { Assert.notNull(authorizationEndpointURI,
throw new IllegalArgumentException(
"An Authorization Endpoint URI must be supplied"); "An Authorization Endpoint URI must be supplied");
}
if (tokenEndpointURI == null) { Assert.notNull(tokenEndpointURI,
throw new IllegalArgumentException(
"A Token ID Endpoint URI must be supplied"); "A Token ID Endpoint URI must be supplied");
}
if (checkIDEndpointURI == null) { Assert.notNull(checkIDEndpointURI,
throw new IllegalArgumentException(
"A Check ID Endpoint URI must be supplied"); "A Check ID Endpoint URI must be supplied");
}
if (clientId == null) { Assert.notNull(clientId, "A Client ID must be supplied");
throw new IllegalArgumentException("A Client ID must be supplied");
}
if (clientSecret == null) { Assert.notNull(clientSecret, "A Client Secret must be supplied");
throw new IllegalArgumentException(
"A Client Secret must be supplied");
}
KeyPairGenerator keyPairGenerator; KeyPairGenerator keyPairGenerator;
try { try {
@ -256,19 +307,11 @@ public class OpenIdConnectAuthenticationFilter extends
throw new IllegalStateException(generalSecurityException); throw new IllegalStateException(generalSecurityException);
} }
// prepend the spec necessary scope // prepend the spec necessary SCOPE
setScope((scope != null && !scope.isEmpty()) ? SCOPE + " " + scope setScope((scope != null && !scope.isEmpty()) ? SCOPE + " " + scope
: SCOPE); : SCOPE);
} }
/*
* (non-Javadoc)
*
* @see org.springframework.security.web.authentication.
* AbstractAuthenticationProcessingFilter
* #attemptAuthentication(javax.servlet.http.HttpServletRequest,
* javax.servlet.http.HttpServletResponse)
*/
/* /*
* (non-Javadoc) * (non-Javadoc)
* *
@ -282,52 +325,59 @@ public class OpenIdConnectAuthenticationFilter extends
HttpServletResponse response) throws AuthenticationException, HttpServletResponse response) throws AuthenticationException,
IOException, ServletException { IOException, ServletException {
final boolean debug = logger.isDebugEnabled();
if (request.getParameter("error") != null) { if (request.getParameter("error") != null) {
// Handle Authorization Endpoint error handleError(request, response);
String error = request.getParameter("error");
String errorDescription = request.getParameter("error_description");
String errorURI = request.getParameter("error_uri");
Map<String, String> requestParams = new HashMap<String, String>();
requestParams.put("error", error);
if (errorDescription != null) {
requestParams.put("error_description", errorDescription);
}
if (errorURI != null) {
requestParams.put("error_uri", errorURI);
}
response.sendRedirect(buildURL(errorRedirectURI, requestParams));
} else { } else {
// Determine if the Authorization Endpoint issued an // Determine if the Authorization Endpoint issued an
// authorization grant // authorization grant
if (request.getParameter("code") != null) {
return handleAuthorizationGrantResponse(request);
} else {
handleAuthorizationRequest(request, response);
}
}
return null;
}
/**
* Handles the authorization grant response
*
* @param request
* The request from which to extract parameters and perform the
* authentication
* @return The authenticated user token, or null if authentication is
* incomplete.
*/
private Authentication handleAuthorizationGrantResponse(
HttpServletRequest request) {
final boolean debug = logger.isDebugEnabled();
String authorizationGrant = request.getParameter("code"); String authorizationGrant = request.getParameter("code");
if (authorizationGrant != null) {
// Handle Token Endpoint interaction // Handle Token Endpoint interaction
HttpClient httpClient = new DefaultHttpClient(); HttpClient httpClient = new DefaultHttpClient();
httpClient.getParams().setParameter("http.socket.timeout", httpClient.getParams().setParameter("http.socket.timeout",
new Integer(httpSocketTimeout)); new Integer(httpSocketTimeout));
// //
// TODO: basic auth is untested (it wasn't working last I tested) // TODO: basic auth is untested (it wasn't working last I
// UsernamePasswordCredentials credentials = new UsernamePasswordCredentials( // tested)
// clientId, clientSecret); // UsernamePasswordCredentials credentials = new
// ((DefaultHttpClient) httpClient).getCredentialsProvider() // UsernamePasswordCredentials(
// .setCredentials(AuthScope.ANY, credentials); // clientId, clientSecret);
// ((DefaultHttpClient) httpClient).getCredentialsProvider()
// .setCredentials(AuthScope.ANY, credentials);
//
HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory( HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory(
httpClient); httpClient);
@ -337,14 +387,13 @@ public class OpenIdConnectAuthenticationFilter extends
MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>(); MultiValueMap<String, String> form = new LinkedMultiValueMap<String, String>();
form.add("grant_type", "authorization_code"); form.add("grant_type", "authorization_code");
form.add("code", authorizationGrant); form.add("code", authorizationGrant);
form.add("redirect_uri", form.add("redirect_uri", OpenIdConnectAuthenticationFilter
buildRedirectURI(request, new String[] { "code" })); .buildRedirectURI(request, new String[] { "code" }));
// pass clientId and clientSecret in post of request // pass clientId and clientSecret in post of request
form.add("client_id", clientId); form.add("client_id", clientId);
form.add("client_secret", clientSecret); form.add("client_secret", clientSecret);
if (debug) { if (debug) {
logger.debug("tokenEndpointURI = " + tokenEndpointURI); logger.debug("tokenEndpointURI = " + tokenEndpointURI);
logger.debug("form = " + form); logger.debug("form = " + form);
@ -353,8 +402,8 @@ public class OpenIdConnectAuthenticationFilter extends
String jsonString = null; String jsonString = null;
try { try {
jsonString = restTemplate.postForObject(tokenEndpointURI, jsonString = restTemplate.postForObject(tokenEndpointURI, form,
form, String.class); String.class);
} catch (HttpClientErrorException httpClientErrorException) { } catch (HttpClientErrorException httpClientErrorException) {
// Handle error // Handle error
@ -402,7 +451,8 @@ public class OpenIdConnectAuthenticationFilter extends
// e.printStackTrace(); // e.printStackTrace();
throw new AuthenticationServiceException( throw new AuthenticationServiceException(
"Problem parsing id_token return from Token endpoint: " + e); "Problem parsing id_token return from Token endpoint: "
+ e);
} }
} else { } else {
@ -422,31 +472,28 @@ public class OpenIdConnectAuthenticationFilter extends
httpClient.getParams().setParameter("http.socket.timeout", httpClient.getParams().setParameter("http.socket.timeout",
new Integer(httpSocketTimeout)); new Integer(httpSocketTimeout));
factory = new HttpComponentsClientHttpRequestFactory( factory = new HttpComponentsClientHttpRequestFactory(httpClient);
httpClient);
restTemplate = new RestTemplate(factory); restTemplate = new RestTemplate(factory);
form = new LinkedMultiValueMap<String, String>(); form = new LinkedMultiValueMap<String, String>();
form.add("access_token", form.add("access_token", jsonRoot.getAsJsonObject().get("id_token")
jsonRoot.getAsJsonObject().get("id_token")
.getAsString()); .getAsString());
jsonString = null; jsonString = null;
try { try {
jsonString = restTemplate.postForObject( jsonString = restTemplate.postForObject(checkIDEndpointURI,
checkIDEndpointURI, form, String.class); form, String.class);
} catch (HttpClientErrorException httpClientErrorException) { } catch (HttpClientErrorException httpClientErrorException) {
// Handle error // Handle error
logger.error("Check ID Endpoint error response: " logger.error("Check ID Endpoint error response: "
+ httpClientErrorException.getStatusText() + httpClientErrorException.getStatusText() + " : "
+ " : " + httpClientErrorException.getMessage()); + httpClientErrorException.getMessage());
throw new AuthenticationServiceException( throw new AuthenticationServiceException("Unable check token.");
"Unable check token.");
} }
jsonRoot = new JsonParser().parse(jsonString); jsonRoot = new JsonParser().parse(jsonString);
@ -478,8 +525,7 @@ public class OpenIdConnectAuthenticationFilter extends
logger.error("Possible replay attack detected! " logger.error("Possible replay attack detected! "
+ "The comparison of the nonce in the returned " + "The comparison of the nonce in the returned "
+ "ID Token to the signed session " + "ID Token to the signed session "
+ NONCE_SIGNATURE_COOKIE_NAME + NONCE_SIGNATURE_COOKIE_NAME + " failed.");
+ " failed.");
throw new AuthenticationServiceException( throw new AuthenticationServiceException(
"Possible replay attack detected! " "Possible replay attack detected! "
@ -493,7 +539,8 @@ public class OpenIdConnectAuthenticationFilter extends
logger.error(NONCE_SIGNATURE_COOKIE_NAME logger.error(NONCE_SIGNATURE_COOKIE_NAME
+ " was found, but was null or empty."); + " was found, but was null or empty.");
throw new AuthenticationServiceException(NONCE_SIGNATURE_COOKIE_NAME throw new AuthenticationServiceException(
NONCE_SIGNATURE_COOKIE_NAME
+ " was found, but was null or empty."); + " was found, but was null or empty.");
} }
@ -502,8 +549,8 @@ public class OpenIdConnectAuthenticationFilter extends
logger.error(NONCE_SIGNATURE_COOKIE_NAME logger.error(NONCE_SIGNATURE_COOKIE_NAME
+ " cookie was not found."); + " cookie was not found.");
throw new AuthenticationServiceException(NONCE_SIGNATURE_COOKIE_NAME throw new AuthenticationServiceException(
+ " cookie was not found."); NONCE_SIGNATURE_COOKIE_NAME + " cookie was not found.");
} }
// Create an Authentication object for the token, and // Create an Authentication object for the token, and
@ -512,16 +559,28 @@ public class OpenIdConnectAuthenticationFilter extends
OpenIdConnectAuthenticationToken token = new OpenIdConnectAuthenticationToken( OpenIdConnectAuthenticationToken token = new OpenIdConnectAuthenticationToken(
userId, idToken); userId, idToken);
Authentication authentication = this Authentication authentication = this.getAuthenticationManager()
.getAuthenticationManager().authenticate(token); .authenticate(token);
return authentication; return authentication;
} }
}
} else { /**
* Initiate an Authorization request
// Initiate an Authorization request *
* @param request
* The request from which to extract parameters and perform the
* authentication
* @param response
* The response, needed to set a cookie and do a redirect as part
* of a multi-stage authentication process
* @throws IOException
* If an input or output exception occurs
*/
private void handleAuthorizationRequest(HttpServletRequest request,
HttpServletResponse response) throws IOException {
Map<String, String> urlVariables = new HashMap<String, String>(); Map<String, String> urlVariables = new HashMap<String, String>();
@ -530,8 +589,8 @@ public class OpenIdConnectAuthenticationFilter extends
urlVariables.put("response_type", "code"); urlVariables.put("response_type", "code");
urlVariables.put("client_id", clientId); urlVariables.put("client_id", clientId);
urlVariables.put("scope", scope); urlVariables.put("scope", scope);
urlVariables.put("redirect_uri", urlVariables.put("redirect_uri", OpenIdConnectAuthenticationFilter
buildRedirectURI(request, null)); .buildRedirectURI(request, null));
// Create a string value used to associate a user agent session // Create a string value used to associate a user agent session
// with an ID Token to mitigate replay attacks. The value is // with an ID Token to mitigate replay attacks. The value is
@ -539,10 +598,10 @@ public class OpenIdConnectAuthenticationFilter extends
// store a random value as a signed session cookie, and pass the // store a random value as a signed session cookie, and pass the
// value in the nonce parameter. // value in the nonce parameter.
String nonce = new BigInteger(50, new Random()).toString(16); String nonce = new BigInteger(50, new SecureRandom()).toString(16);
Cookie nonceCookie = new Cookie(NONCE_SIGNATURE_COOKIE_NAME, Cookie nonceCookie = new Cookie(NONCE_SIGNATURE_COOKIE_NAME, sign(
sign(signer, privateKey, nonce.getBytes())); signer, privateKey, nonce.getBytes()));
response.addCookie(nonceCookie); response.addCookie(nonceCookie);
@ -552,61 +611,42 @@ public class OpenIdConnectAuthenticationFilter extends
// TODO: display, prompt, request, request_uri // TODO: display, prompt, request, request_uri
response.sendRedirect(buildURL(authorizationEndpointURI, response.sendRedirect(OpenIdConnectAuthenticationFilter.buildURL(
urlVariables)); authorizationEndpointURI, urlVariables));
}
}
return null;
} }
/** /**
* Builds the redirect_uri that will be sent to the Authorization Endpoint. * Handle Authorization Endpoint error
* By default returns the URL of the current request.
* *
* @param request * @param request
* the current request which is being processed by this filter * The request from which to extract parameters and handle the
* @param ingoreParameters * error
* an array of parameter names to ignore. * @param response
* @return * The response, needed to do a redirect to display the error
* @throws IOException
* If an input or output exception occurs
*/ */
private String buildRedirectURI(HttpServletRequest request, private void handleError(HttpServletRequest request,
String[] ingoreParameters) { HttpServletResponse response) throws IOException {
List<String> ignore = (ingoreParameters != null) ? Arrays String error = request.getParameter("error");
.asList(ingoreParameters) : null; String errorDescription = request.getParameter("error_description");
String errorURI = request.getParameter("error_uri");
boolean isFirst = true; Map<String, String> requestParams = new HashMap<String, String>();
StringBuffer sb = request.getRequestURL(); requestParams.put("error", error);
for (Enumeration<?> e = request.getParameterNames(); e if (errorDescription != null) {
.hasMoreElements();) { requestParams.put("error_description", errorDescription);
String name = (String) e.nextElement();
if ((ignore == null) || (!ignore.contains(name))) {
// Assume for simplicity that there is only one value
String value = request.getParameter(name);
if (value == null) {
continue;
} }
if (isFirst) { if (errorURI != null) {
sb.append("?"); requestParams.put("error_uri", errorURI);
isFirst = false;
} }
sb.append(name).append("=").append(value); response.sendRedirect(OpenIdConnectAuthenticationFilter.buildURL(
errorRedirectURI, requestParams));
if (e.hasMoreElements()) {
sb.append("&");
}
}
}
return sb.toString();
} }
public void setAuthorizationEndpointURI(String authorizationEndpointURI) { public void setAuthorizationEndpointURI(String authorizationEndpointURI) {