check sector identifier URI's contents and match against redirect URIs, addresses #504
parent
1aa5fe25c6
commit
9b72c6b1f3
|
@ -18,11 +18,17 @@ package org.mitre.oauth2.service.impl;
|
|||
|
||||
import java.math.BigInteger;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import org.apache.commons.codec.binary.Base64;
|
||||
import org.apache.http.client.HttpClient;
|
||||
import org.apache.http.impl.client.DefaultHttpClient;
|
||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||
import org.mitre.oauth2.repository.OAuth2ClientRepository;
|
||||
import org.mitre.oauth2.repository.OAuth2TokenRepository;
|
||||
|
@ -31,16 +37,27 @@ import org.mitre.openid.connect.model.WhitelistedSite;
|
|||
import org.mitre.openid.connect.service.ApprovedSiteService;
|
||||
import org.mitre.openid.connect.service.BlacklistedSiteService;
|
||||
import org.mitre.openid.connect.service.WhitelistedSiteService;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
|
||||
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
|
||||
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.cache.CacheLoader;
|
||||
import com.google.common.cache.LoadingCache;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonParser;
|
||||
|
||||
@Service
|
||||
public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {
|
||||
|
||||
private static Logger logger = LoggerFactory.getLogger(DefaultOAuth2ClientDetailsEntityService.class);
|
||||
|
||||
@Autowired
|
||||
private OAuth2ClientRepository clientRepository;
|
||||
|
||||
|
@ -56,6 +73,12 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
|||
@Autowired
|
||||
private BlacklistedSiteService blacklistedSiteService;
|
||||
|
||||
// map of sector URI -> list of redirect URIs
|
||||
private LoadingCache<String, List<String>> sectorRedirects = CacheBuilder.newBuilder()
|
||||
.expireAfterAccess(1, TimeUnit.HOURS)
|
||||
.maximumSize(100)
|
||||
.build(new SectorIdentifierLoader());
|
||||
|
||||
@Override
|
||||
public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
|
||||
if (client.getId() != null) { // if it's not null, it's already been saved, this is an error
|
||||
|
@ -85,6 +108,26 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
|||
|
||||
// timestamp this to right now
|
||||
client.setCreatedAt(new Date());
|
||||
|
||||
|
||||
// check the sector URI
|
||||
if (!Strings.isNullOrEmpty(client.getSectorIdentifierUri())) {
|
||||
try {
|
||||
List<String> redirects = sectorRedirects.get(client.getSectorIdentifierUri());
|
||||
|
||||
if (client.getRegisteredRedirectUri() != null) {
|
||||
for (String uri : client.getRegisteredRedirectUri()) {
|
||||
if (!redirects.contains(uri)) {
|
||||
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (ExecutionException e) {
|
||||
throw new IllegalArgumentException("Unable to load sector identifier URI: " + client.getSectorIdentifierUri());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return clientRepository.saveClient(client);
|
||||
}
|
||||
|
@ -165,6 +208,24 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
|||
newClient.getScope().remove("offline_access");
|
||||
}
|
||||
|
||||
// check the sector URI
|
||||
if (!Strings.isNullOrEmpty(newClient.getSectorIdentifierUri())) {
|
||||
try {
|
||||
List<String> redirects = sectorRedirects.get(newClient.getSectorIdentifierUri());
|
||||
|
||||
if (newClient.getRegisteredRedirectUri() != null) {
|
||||
for (String uri : newClient.getRegisteredRedirectUri()) {
|
||||
if (!redirects.contains(uri)) {
|
||||
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (ExecutionException e) {
|
||||
throw new IllegalArgumentException("Unable to load sector identifier URI: " + newClient.getSectorIdentifierUri());
|
||||
}
|
||||
}
|
||||
|
||||
return clientRepository.updateClient(oldClient.getId(), newClient);
|
||||
}
|
||||
throw new IllegalArgumentException("Neither old client or new client can be null!");
|
||||
|
@ -196,4 +257,45 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
|||
return client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class to load a sector identifier's set of authorized redirect URIs.
|
||||
*
|
||||
* @author jricher
|
||||
*
|
||||
*/
|
||||
private class SectorIdentifierLoader extends CacheLoader<String, List<String>> {
|
||||
private HttpClient httpClient = new DefaultHttpClient();
|
||||
private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
|
||||
private RestTemplate restTemplate = new RestTemplate(httpFactory);
|
||||
private JsonParser parser = new JsonParser();
|
||||
|
||||
@Override
|
||||
public List<String> load(String key) throws Exception {
|
||||
|
||||
if (!key.startsWith("https")) {
|
||||
// TODO: this should optionally throw an error (#506)
|
||||
logger.error("Sector identifier doesn't start with https, loading anyway...");
|
||||
}
|
||||
|
||||
// key is the sector URI
|
||||
String jsonString = restTemplate.getForObject(key, String.class);
|
||||
JsonElement json = parser.parse(jsonString);
|
||||
|
||||
if (json.isJsonArray()) {
|
||||
List<String> redirectUris = new ArrayList<String>();
|
||||
for (JsonElement el : json.getAsJsonArray()) {
|
||||
redirectUris.add(el.getAsString());
|
||||
}
|
||||
|
||||
logger.info("Found " + redirectUris + " for sector " + key);
|
||||
|
||||
return redirectUris;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue