check sector identifier URI's contents and match against redirect URIs, addresses #504

pull/516/head
Justin Richer 2013-09-13 14:22:24 -04:00
parent 1aa5fe25c6
commit 9b72c6b1f3
1 changed files with 102 additions and 0 deletions

View File

@ -18,11 +18,17 @@ package org.mitre.oauth2.service.impl;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.commons.codec.binary.Base64;
import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.repository.OAuth2ClientRepository;
import org.mitre.oauth2.repository.OAuth2TokenRepository;
@ -31,16 +37,27 @@ import org.mitre.openid.connect.model.WhitelistedSite;
import org.mitre.openid.connect.service.ApprovedSiteService;
import org.mitre.openid.connect.service.BlacklistedSiteService;
import org.mitre.openid.connect.service.WhitelistedSiteService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import com.google.common.base.Strings;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
@Service
public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {
private static Logger logger = LoggerFactory.getLogger(DefaultOAuth2ClientDetailsEntityService.class);
@Autowired
private OAuth2ClientRepository clientRepository;
@ -56,6 +73,12 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
@Autowired
private BlacklistedSiteService blacklistedSiteService;
// map of sector URI -> list of redirect URIs
private LoadingCache<String, List<String>> sectorRedirects = CacheBuilder.newBuilder()
.expireAfterAccess(1, TimeUnit.HOURS)
.maximumSize(100)
.build(new SectorIdentifierLoader());
@Override
public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
if (client.getId() != null) { // if it's not null, it's already been saved, this is an error
@ -85,6 +108,26 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
// timestamp this to right now
client.setCreatedAt(new Date());
// check the sector URI
if (!Strings.isNullOrEmpty(client.getSectorIdentifierUri())) {
try {
List<String> redirects = sectorRedirects.get(client.getSectorIdentifierUri());
if (client.getRegisteredRedirectUri() != null) {
for (String uri : client.getRegisteredRedirectUri()) {
if (!redirects.contains(uri)) {
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
}
}
}
} catch (ExecutionException e) {
throw new IllegalArgumentException("Unable to load sector identifier URI: " + client.getSectorIdentifierUri());
}
}
return clientRepository.saveClient(client);
}
@ -165,6 +208,24 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
newClient.getScope().remove("offline_access");
}
// check the sector URI
if (!Strings.isNullOrEmpty(newClient.getSectorIdentifierUri())) {
try {
List<String> redirects = sectorRedirects.get(newClient.getSectorIdentifierUri());
if (newClient.getRegisteredRedirectUri() != null) {
for (String uri : newClient.getRegisteredRedirectUri()) {
if (!redirects.contains(uri)) {
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
}
}
}
} catch (ExecutionException e) {
throw new IllegalArgumentException("Unable to load sector identifier URI: " + newClient.getSectorIdentifierUri());
}
}
return clientRepository.updateClient(oldClient.getId(), newClient);
}
throw new IllegalArgumentException("Neither old client or new client can be null!");
@ -196,4 +257,45 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
return client;
}
/**
* Utility class to load a sector identifier's set of authorized redirect URIs.
*
* @author jricher
*
*/
private class SectorIdentifierLoader extends CacheLoader<String, List<String>> {
private HttpClient httpClient = new DefaultHttpClient();
private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
private RestTemplate restTemplate = new RestTemplate(httpFactory);
private JsonParser parser = new JsonParser();
@Override
public List<String> load(String key) throws Exception {
if (!key.startsWith("https")) {
// TODO: this should optionally throw an error (#506)
logger.error("Sector identifier doesn't start with https, loading anyway...");
}
// key is the sector URI
String jsonString = restTemplate.getForObject(key, String.class);
JsonElement json = parser.parse(jsonString);
if (json.isJsonArray()) {
List<String> redirectUris = new ArrayList<String>();
for (JsonElement el : json.getAsJsonArray()) {
redirectUris.add(el.getAsString());
}
logger.info("Found " + redirectUris + " for sector " + key);
return redirectUris;
} else {
return null;
}
}
}
}