check sector identifier URI's contents and match against redirect URIs, addresses #504
parent
1aa5fe25c6
commit
9b72c6b1f3
|
@ -18,11 +18,17 @@ package org.mitre.oauth2.service.impl;
|
||||||
|
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
import java.security.SecureRandom;
|
import java.security.SecureRandom;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
import java.util.List;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
import org.apache.commons.codec.binary.Base64;
|
import org.apache.commons.codec.binary.Base64;
|
||||||
|
import org.apache.http.client.HttpClient;
|
||||||
|
import org.apache.http.impl.client.DefaultHttpClient;
|
||||||
import org.mitre.oauth2.model.ClientDetailsEntity;
|
import org.mitre.oauth2.model.ClientDetailsEntity;
|
||||||
import org.mitre.oauth2.repository.OAuth2ClientRepository;
|
import org.mitre.oauth2.repository.OAuth2ClientRepository;
|
||||||
import org.mitre.oauth2.repository.OAuth2TokenRepository;
|
import org.mitre.oauth2.repository.OAuth2TokenRepository;
|
||||||
|
@ -31,16 +37,27 @@ import org.mitre.openid.connect.model.WhitelistedSite;
|
||||||
import org.mitre.openid.connect.service.ApprovedSiteService;
|
import org.mitre.openid.connect.service.ApprovedSiteService;
|
||||||
import org.mitre.openid.connect.service.BlacklistedSiteService;
|
import org.mitre.openid.connect.service.BlacklistedSiteService;
|
||||||
import org.mitre.openid.connect.service.WhitelistedSiteService;
|
import org.mitre.openid.connect.service.WhitelistedSiteService;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
|
||||||
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
|
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
|
||||||
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
|
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
import com.google.common.base.Strings;
|
import com.google.common.base.Strings;
|
||||||
|
import com.google.common.cache.CacheBuilder;
|
||||||
|
import com.google.common.cache.CacheLoader;
|
||||||
|
import com.google.common.cache.LoadingCache;
|
||||||
|
import com.google.gson.JsonElement;
|
||||||
|
import com.google.gson.JsonParser;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {
|
public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEntityService {
|
||||||
|
|
||||||
|
private static Logger logger = LoggerFactory.getLogger(DefaultOAuth2ClientDetailsEntityService.class);
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private OAuth2ClientRepository clientRepository;
|
private OAuth2ClientRepository clientRepository;
|
||||||
|
|
||||||
|
@ -56,6 +73,12 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
||||||
@Autowired
|
@Autowired
|
||||||
private BlacklistedSiteService blacklistedSiteService;
|
private BlacklistedSiteService blacklistedSiteService;
|
||||||
|
|
||||||
|
// map of sector URI -> list of redirect URIs
|
||||||
|
private LoadingCache<String, List<String>> sectorRedirects = CacheBuilder.newBuilder()
|
||||||
|
.expireAfterAccess(1, TimeUnit.HOURS)
|
||||||
|
.maximumSize(100)
|
||||||
|
.build(new SectorIdentifierLoader());
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
|
public ClientDetailsEntity saveNewClient(ClientDetailsEntity client) {
|
||||||
if (client.getId() != null) { // if it's not null, it's already been saved, this is an error
|
if (client.getId() != null) { // if it's not null, it's already been saved, this is an error
|
||||||
|
@ -85,6 +108,26 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
||||||
|
|
||||||
// timestamp this to right now
|
// timestamp this to right now
|
||||||
client.setCreatedAt(new Date());
|
client.setCreatedAt(new Date());
|
||||||
|
|
||||||
|
|
||||||
|
// check the sector URI
|
||||||
|
if (!Strings.isNullOrEmpty(client.getSectorIdentifierUri())) {
|
||||||
|
try {
|
||||||
|
List<String> redirects = sectorRedirects.get(client.getSectorIdentifierUri());
|
||||||
|
|
||||||
|
if (client.getRegisteredRedirectUri() != null) {
|
||||||
|
for (String uri : client.getRegisteredRedirectUri()) {
|
||||||
|
if (!redirects.contains(uri)) {
|
||||||
|
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
throw new IllegalArgumentException("Unable to load sector identifier URI: " + client.getSectorIdentifierUri());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
return clientRepository.saveClient(client);
|
return clientRepository.saveClient(client);
|
||||||
}
|
}
|
||||||
|
@ -165,6 +208,24 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
||||||
newClient.getScope().remove("offline_access");
|
newClient.getScope().remove("offline_access");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check the sector URI
|
||||||
|
if (!Strings.isNullOrEmpty(newClient.getSectorIdentifierUri())) {
|
||||||
|
try {
|
||||||
|
List<String> redirects = sectorRedirects.get(newClient.getSectorIdentifierUri());
|
||||||
|
|
||||||
|
if (newClient.getRegisteredRedirectUri() != null) {
|
||||||
|
for (String uri : newClient.getRegisteredRedirectUri()) {
|
||||||
|
if (!redirects.contains(uri)) {
|
||||||
|
throw new IllegalArgumentException("Requested Redirect URI " + uri + " is not listed at sector identifier " + redirects);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
throw new IllegalArgumentException("Unable to load sector identifier URI: " + newClient.getSectorIdentifierUri());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return clientRepository.updateClient(oldClient.getId(), newClient);
|
return clientRepository.updateClient(oldClient.getId(), newClient);
|
||||||
}
|
}
|
||||||
throw new IllegalArgumentException("Neither old client or new client can be null!");
|
throw new IllegalArgumentException("Neither old client or new client can be null!");
|
||||||
|
@ -196,4 +257,45 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
|
||||||
return client;
|
return client;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility class to load a sector identifier's set of authorized redirect URIs.
|
||||||
|
*
|
||||||
|
* @author jricher
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private class SectorIdentifierLoader extends CacheLoader<String, List<String>> {
|
||||||
|
private HttpClient httpClient = new DefaultHttpClient();
|
||||||
|
private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
|
||||||
|
private RestTemplate restTemplate = new RestTemplate(httpFactory);
|
||||||
|
private JsonParser parser = new JsonParser();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<String> load(String key) throws Exception {
|
||||||
|
|
||||||
|
if (!key.startsWith("https")) {
|
||||||
|
// TODO: this should optionally throw an error (#506)
|
||||||
|
logger.error("Sector identifier doesn't start with https, loading anyway...");
|
||||||
|
}
|
||||||
|
|
||||||
|
// key is the sector URI
|
||||||
|
String jsonString = restTemplate.getForObject(key, String.class);
|
||||||
|
JsonElement json = parser.parse(jsonString);
|
||||||
|
|
||||||
|
if (json.isJsonArray()) {
|
||||||
|
List<String> redirectUris = new ArrayList<String>();
|
||||||
|
for (JsonElement el : json.getAsJsonArray()) {
|
||||||
|
redirectUris.add(el.getAsString());
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info("Found " + redirectUris + " for sector " + key);
|
||||||
|
|
||||||
|
return redirectUris;
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue