Cleaned up indentation, whitespace, and imports.

pull/604/head
Justin Richer 2014-05-27 13:02:49 -04:00
parent 8185171119
commit 525f3aa2a8
36 changed files with 245 additions and 259 deletions

View File

@ -97,7 +97,7 @@ public class IntrospectingTokenService implements ResourceServerTokenServices {
public void setIntrospectionConfigurationService(IntrospectionConfigurationService introspectionUrlProvider) { public void setIntrospectionConfigurationService(IntrospectionConfigurationService introspectionUrlProvider) {
this.introspectionConfigurationService = introspectionUrlProvider; this.introspectionConfigurationService = introspectionUrlProvider;
} }
/** /**
* @param introspectionAuthorityGranter the introspectionAuthorityGranter to set * @param introspectionAuthorityGranter the introspectionAuthorityGranter to set
*/ */
@ -111,7 +111,7 @@ public class IntrospectingTokenService implements ResourceServerTokenServices {
public IntrospectionAuthorityGranter getIntrospectionAuthorityGranter() { public IntrospectionAuthorityGranter getIntrospectionAuthorityGranter() {
return introspectionAuthorityGranter; return introspectionAuthorityGranter;
} }
// Check if there is a token and authentication in the cache // Check if there is a token and authentication in the cache
// and check if it is not expired. // and check if it is not expired.
private TokenCacheObject checkCache(String key) { private TokenCacheObject checkCache(String key) {

View File

@ -96,10 +96,10 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
@Autowired @Autowired
private JWKSetCacheService validationServices; private JWKSetCacheService validationServices;
@Autowired(required=false) @Autowired(required=false)
private SymmetricCacheService symmetricCacheService; private SymmetricCacheService symmetricCacheService;
@Autowired(required=false) @Autowired(required=false)
private JwtSigningAndValidationService authenticationSignerService; private JwtSigningAndValidationService authenticationSignerService;
@ -113,7 +113,7 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
// private helpers to handle target link URLs // private helpers to handle target link URLs
private TargetLinkURIAuthenticationSuccessHandler targetSuccessHandler = new TargetLinkURIAuthenticationSuccessHandler(); private TargetLinkURIAuthenticationSuccessHandler targetSuccessHandler = new TargetLinkURIAuthenticationSuccessHandler();
private TargetLinkURIChecker deepLinkFilter; private TargetLinkURIChecker deepLinkFilter;
protected int httpSocketTimeout = HTTP_SOCKET_TIMEOUT; protected int httpSocketTimeout = HTTP_SOCKET_TIMEOUT;
/** /**
@ -128,17 +128,17 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
@Override @Override
public void afterPropertiesSet() { public void afterPropertiesSet() {
super.afterPropertiesSet(); super.afterPropertiesSet();
// if our JOSE validators don't get wired in, drop defaults into place // if our JOSE validators don't get wired in, drop defaults into place
if (validationServices == null) { if (validationServices == null) {
validationServices = new JWKSetCacheService(); validationServices = new JWKSetCacheService();
} }
if (symmetricCacheService == null) { if (symmetricCacheService == null) {
symmetricCacheService = new SymmetricCacheService(); symmetricCacheService = new SymmetricCacheService();
} }
} }
/* /*
@ -206,7 +206,7 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
// there's a target URL in the response, we should save this so we can forward to it later // there's a target URL in the response, we should save this so we can forward to it later
session.setAttribute(TARGET_SESSION_VARIABLE, issResp.getTargetLinkUri()); session.setAttribute(TARGET_SESSION_VARIABLE, issResp.getTargetLinkUri());
} }
if (Strings.isNullOrEmpty(issuer)) { if (Strings.isNullOrEmpty(issuer)) {
logger.error("No issuer found: " + issuer); logger.error("No issuer found: " + issuer);
throw new AuthenticationServiceException("No issuer found: " + issuer); throw new AuthenticationServiceException("No issuer found: " + issuer);
@ -315,37 +315,37 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
return httpRequest; return httpRequest;
} }
}; };
} else { } else {
// we're not doing basic auth, figure out what other flavor we have // we're not doing basic auth, figure out what other flavor we have
restTemplate = new RestTemplate(factory); restTemplate = new RestTemplate(factory);
if (SECRET_JWT.equals(clientConfig.getTokenEndpointAuthMethod()) || PRIVATE_KEY.equals(clientConfig.getTokenEndpointAuthMethod())) { if (SECRET_JWT.equals(clientConfig.getTokenEndpointAuthMethod()) || PRIVATE_KEY.equals(clientConfig.getTokenEndpointAuthMethod())) {
// do a symmetric secret signed JWT for auth // do a symmetric secret signed JWT for auth
JwtSigningAndValidationService signer = null; JwtSigningAndValidationService signer = null;
JWSAlgorithm alg = clientConfig.getTokenEndpointAuthSigningAlg(); JWSAlgorithm alg = clientConfig.getTokenEndpointAuthSigningAlg();
if (SECRET_JWT.equals(clientConfig.getTokenEndpointAuthMethod()) && if (SECRET_JWT.equals(clientConfig.getTokenEndpointAuthMethod()) &&
(alg.equals(JWSAlgorithm.HS256) (alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384) || alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512))) { || alg.equals(JWSAlgorithm.HS512))) {
// generate one based on client secret // generate one based on client secret
signer = symmetricCacheService.getSymmetricValidtor(clientConfig.getClient()); signer = symmetricCacheService.getSymmetricValidtor(clientConfig.getClient());
} else if (PRIVATE_KEY.equals(clientConfig.getTokenEndpointAuthMethod())) { } else if (PRIVATE_KEY.equals(clientConfig.getTokenEndpointAuthMethod())) {
// needs to be wired in to the bean // needs to be wired in to the bean
signer = authenticationSignerService; signer = authenticationSignerService;
} }
if (signer == null) { if (signer == null) {
throw new AuthenticationServiceException("Couldn't find required signer service for use with private key auth."); throw new AuthenticationServiceException("Couldn't find required signer service for use with private key auth.");
} }
JWTClaimsSet claimsSet = new JWTClaimsSet(); JWTClaimsSet claimsSet = new JWTClaimsSet();
claimsSet.setIssuer(clientConfig.getClientId()); claimsSet.setIssuer(clientConfig.getClientId());
claimsSet.setSubject(clientConfig.getClientId()); claimsSet.setSubject(clientConfig.getClientId());
claimsSet.setAudience(Lists.newArrayList(serverConfig.getTokenEndpointUri())); claimsSet.setAudience(Lists.newArrayList(serverConfig.getTokenEndpointUri()));
@ -353,15 +353,15 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
// TODO: make this configurable // TODO: make this configurable
Date exp = new Date(System.currentTimeMillis() + (60 * 1000)); // auth good for 60 seconds Date exp = new Date(System.currentTimeMillis() + (60 * 1000)); // auth good for 60 seconds
claimsSet.setExpirationTime(exp); claimsSet.setExpirationTime(exp);
Date now = new Date(System.currentTimeMillis()); Date now = new Date(System.currentTimeMillis());
claimsSet.setIssueTime(now); claimsSet.setIssueTime(now);
claimsSet.setNotBeforeTime(now); claimsSet.setNotBeforeTime(now);
SignedJWT jwt = new SignedJWT(new JWSHeader(alg), claimsSet); SignedJWT jwt = new SignedJWT(new JWSHeader(alg), claimsSet);
signer.signJwt(jwt, alg); signer.signJwt(jwt, alg);
form.add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); form.add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
form.add("client_assertion", jwt.serialize()); form.add("client_assertion", jwt.serialize());
} else { } else {
@ -369,7 +369,7 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
form.add("client_id", clientConfig.getClientId()); form.add("client_id", clientConfig.getClientId());
form.add("client_secret", clientConfig.getClientSecret()); form.add("client_secret", clientConfig.getClientSecret());
} }
} }
logger.debug("tokenEndpointURI = " + serverConfig.getTokenEndpointUri()); logger.debug("tokenEndpointURI = " + serverConfig.getTokenEndpointUri());
@ -630,22 +630,22 @@ public class OIDCAuthenticationFilter extends AbstractAuthenticationProcessingFi
protected class TargetLinkURIAuthenticationSuccessHandler implements AuthenticationSuccessHandler { protected class TargetLinkURIAuthenticationSuccessHandler implements AuthenticationSuccessHandler {
private AuthenticationSuccessHandler passthrough; private AuthenticationSuccessHandler passthrough;
@Override @Override
public void onAuthenticationSuccess(HttpServletRequest request, public void onAuthenticationSuccess(HttpServletRequest request,
HttpServletResponse response, Authentication authentication) HttpServletResponse response, Authentication authentication)
throws IOException, ServletException { throws IOException, ServletException {
HttpSession session = request.getSession(); HttpSession session = request.getSession();
// check to see if we've got a target // check to see if we've got a target
String target = getStoredSessionString(session, TARGET_SESSION_VARIABLE); String target = getStoredSessionString(session, TARGET_SESSION_VARIABLE);
if (!Strings.isNullOrEmpty(target)) { if (!Strings.isNullOrEmpty(target)) {
session.removeAttribute(TARGET_SESSION_VARIABLE); session.removeAttribute(TARGET_SESSION_VARIABLE);
target = deepLinkFilter.filter(target); target = deepLinkFilter.filter(target);
response.sendRedirect(target); response.sendRedirect(target);
} else { } else {
// if the target was blank, use the default behavior here // if the target was blank, use the default behavior here

View File

@ -10,7 +10,7 @@ package org.mitre.openid.connect.client;
public class StaticPrefixTargetLinkURIChecker implements TargetLinkURIChecker { public class StaticPrefixTargetLinkURIChecker implements TargetLinkURIChecker {
private String prefix = ""; private String prefix = "";
@Override @Override
public String filter(String target) { public String filter(String target) {
if (target == null) { if (target == null) {

View File

@ -48,13 +48,6 @@ import com.google.gson.JsonElement;
import com.google.gson.JsonObject; import com.google.gson.JsonObject;
import com.google.gson.JsonParser; import com.google.gson.JsonParser;
import static org.mitre.discovery.util.JsonUtils.getAsBoolean;
import static org.mitre.discovery.util.JsonUtils.getAsEncryptionMethodList;
import static org.mitre.discovery.util.JsonUtils.getAsJweAlgorithmList;
import static org.mitre.discovery.util.JsonUtils.getAsJwsAlgorithmList;
import static org.mitre.discovery.util.JsonUtils.getAsString;
import static org.mitre.discovery.util.JsonUtils.getAsStringList;
/** /**
* *
* Dynamically fetches OpenID Connect server configurations based on the issuer. Caches the server configurations. * Dynamically fetches OpenID Connect server configurations based on the issuer. Caches the server configurations.

View File

@ -23,7 +23,6 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import org.apache.http.client.HttpClient; import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.DefaultHttpClient;
import org.apache.http.impl.client.SystemDefaultHttpClient; import org.apache.http.impl.client.SystemDefaultHttpClient;
import org.mitre.jose.keystore.JWKSetKeyStore; import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService; import org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService;
@ -130,10 +129,10 @@ public class JWKSetCacheService {
} }
/** /**
* @author jricher * @author jricher
* *
*/ */
private class JWKSetEncryptorFetcher extends CacheLoader<String, JwtEncryptionAndDecryptionService> { private class JWKSetEncryptorFetcher extends CacheLoader<String, JwtEncryptionAndDecryptionService> {
private HttpClient httpClient = new SystemDefaultHttpClient(); private HttpClient httpClient = new SystemDefaultHttpClient();
private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient); private HttpComponentsClientHttpRequestFactory httpFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
private RestTemplate restTemplate = new RestTemplate(httpFactory); private RestTemplate restTemplate = new RestTemplate(httpFactory);

View File

@ -31,11 +31,11 @@ import com.nimbusds.jose.util.Base64URL;
*/ */
@Service @Service
public class SymmetricCacheService { public class SymmetricCacheService {
private static Logger logger = LoggerFactory.getLogger(SymmetricCacheService.class); private static Logger logger = LoggerFactory.getLogger(SymmetricCacheService.class);
private LoadingCache<String, JwtSigningAndValidationService> validators; private LoadingCache<String, JwtSigningAndValidationService> validators;
public SymmetricCacheService() { public SymmetricCacheService() {
validators = CacheBuilder.newBuilder() validators = CacheBuilder.newBuilder()
@ -43,8 +43,8 @@ public class SymmetricCacheService {
.maximumSize(100) .maximumSize(100)
.build(new SymmetricValidatorBuilder()); .build(new SymmetricValidatorBuilder());
} }
/** /**
* Create a symmetric signing and validation service for the given client * Create a symmetric signing and validation service for the given client
* *
@ -62,7 +62,7 @@ public class SymmetricCacheService {
logger.error("Couldn't create symmetric validator for client " + client.getClientId() + " without a client secret"); logger.error("Couldn't create symmetric validator for client " + client.getClientId() + " without a client secret");
return null; return null;
} }
try { try {
return validators.get(client.getClientSecret()); return validators.get(client.getClientSecret());
} catch (UncheckedExecutionException ue) { } catch (UncheckedExecutionException ue) {
@ -72,28 +72,28 @@ public class SymmetricCacheService {
logger.error("Problem loading client validator", e); logger.error("Problem loading client validator", e);
return null; return null;
} }
} }
public class SymmetricValidatorBuilder extends CacheLoader<String, JwtSigningAndValidationService> { public class SymmetricValidatorBuilder extends CacheLoader<String, JwtSigningAndValidationService> {
@Override @Override
public JwtSigningAndValidationService load(String key) throws Exception { public JwtSigningAndValidationService load(String key) throws Exception {
try { try {
String id = "SYMMETRIC-KEY"; String id = "SYMMETRIC-KEY";
JWK jwk = new OctetSequenceKey(Base64URL.encode(key), Use.SIGNATURE, null, id, null, null, null); JWK jwk = new OctetSequenceKey(Base64URL.encode(key), Use.SIGNATURE, null, id, null, null, null);
Map<String, JWK> keys = ImmutableMap.of(id, jwk); Map<String, JWK> keys = ImmutableMap.of(id, jwk);
JwtSigningAndValidationService service = new DefaultJwtSigningAndValidationService(keys); JwtSigningAndValidationService service = new DefaultJwtSigningAndValidationService(keys);
return service; return service;
} catch (NoSuchAlgorithmException e) { } catch (NoSuchAlgorithmException e) {
logger.error("Couldn't create symmetric validator for client", e); logger.error("Couldn't create symmetric validator for client", e);
} catch (InvalidKeySpecException e) { } catch (InvalidKeySpecException e) {
logger.error("Couldn't create symmetric validator for client", e); logger.error("Couldn't create symmetric validator for client", e);
} }
throw new IllegalArgumentException("Couldn't create symmetric validator for client"); throw new IllegalArgumentException("Couldn't create symmetric validator for client");
} }

View File

@ -351,14 +351,14 @@ public class ClientDetailsEntity implements ClientDetails {
@Override @Override
@Transient @Transient
public boolean isSecretRequired() { public boolean isSecretRequired() {
if (getTokenEndpointAuthMethod() != null && if (getTokenEndpointAuthMethod() != null &&
getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) || getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC) ||
getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) { getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) {
return true; return true;
} else { } else {
return false; return false;
} }
} }
/** /**

View File

@ -32,7 +32,7 @@ public interface AuthenticationHolderRepository {
public void remove(AuthenticationHolderEntity a); public void remove(AuthenticationHolderEntity a);
public AuthenticationHolderEntity save(AuthenticationHolderEntity a); public AuthenticationHolderEntity save(AuthenticationHolderEntity a);
public List<AuthenticationHolderEntity> getOrphanedAuthenticationHolders(); public List<AuthenticationHolderEntity> getOrphanedAuthenticationHolders();

View File

@ -57,7 +57,7 @@ public interface OAuth2TokenRepository {
public Set<OAuth2AccessTokenEntity> getAllAccessTokens(); public Set<OAuth2AccessTokenEntity> getAllAccessTokens();
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens(); public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens();
public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens(); public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens();
public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(); public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens();

View File

@ -22,8 +22,6 @@ import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.OAuth2AccessTokenEntity; import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.OAuth2Request;
import com.nimbusds.jose.JWSAlgorithm;
/** /**
* Service to create specialty OpenID Connect tokens. * Service to create specialty OpenID Connect tokens.
* *
@ -62,5 +60,5 @@ public interface OIDCTokenService {
* @return * @return
*/ */
public OAuth2AccessTokenEntity createResourceAccessToken(ClientDetailsEntity client); public OAuth2AccessTokenEntity createResourceAccessToken(ClientDetailsEntity client);
} }

View File

@ -51,7 +51,7 @@ public interface StatsService {
* @return * @return
*/ */
public Integer getCountForClientId(Long id); public Integer getCountForClientId(Long id);
/** /**
* Trigger the stats to be recalculated upon next update. * Trigger the stats to be recalculated upon next update.
*/ */

View File

@ -113,8 +113,8 @@ public class DiscoveryEndpoint {
model.addAttribute("code", HttpStatus.NOT_FOUND); model.addAttribute("code", HttpStatus.NOT_FOUND);
return "httpCodeView"; return "httpCodeView";
} }
} else { } else {
logger.info("Unknown URI format: " + resource); logger.info("Unknown URI format: " + resource);
model.addAttribute("code", HttpStatus.NOT_FOUND); model.addAttribute("code", HttpStatus.NOT_FOUND);
@ -261,7 +261,7 @@ public class DiscoveryEndpoint {
Collection<JWSAlgorithm> serverSigningAlgs = signService.getAllSigningAlgsSupported(); Collection<JWSAlgorithm> serverSigningAlgs = signService.getAllSigningAlgsSupported();
Collection<JWSAlgorithm> clientSymmetricSigningAlgs = Lists.newArrayList(JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512); Collection<JWSAlgorithm> clientSymmetricSigningAlgs = Lists.newArrayList(JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512);
Collection<JWSAlgorithm> clientSymmetricAndAsymmetricSigningAlgs = Lists.newArrayList(JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512, JWSAlgorithm.RS256, JWSAlgorithm.RS384, JWSAlgorithm.RS512); Collection<JWSAlgorithm> clientSymmetricAndAsymmetricSigningAlgs = Lists.newArrayList(JWSAlgorithm.HS256, JWSAlgorithm.HS384, JWSAlgorithm.HS512, JWSAlgorithm.RS256, JWSAlgorithm.RS384, JWSAlgorithm.RS512);
Map<String, Object> m = new HashMap<String, Object>(); Map<String, Object> m = new HashMap<String, Object>();
m.put("issuer", config.getIssuer()); m.put("issuer", config.getIssuer());
m.put("authorization_endpoint", baseUrl + "authorize"); m.put("authorization_endpoint", baseUrl + "authorize");

View File

@ -33,8 +33,8 @@ import org.springframework.transaction.annotation.Transactional;
@Transactional @Transactional
public class JpaAuthenticationHolderRepository implements AuthenticationHolderRepository { public class JpaAuthenticationHolderRepository implements AuthenticationHolderRepository {
private static final int MAXEXPIREDRESULTS = 1000; private static final int MAXEXPIREDRESULTS = 1000;
@PersistenceContext @PersistenceContext
private EntityManager manager; private EntityManager manager;
@ -77,7 +77,7 @@ public class JpaAuthenticationHolderRepository implements AuthenticationHolderRe
public AuthenticationHolderEntity save(AuthenticationHolderEntity a) { public AuthenticationHolderEntity save(AuthenticationHolderEntity a) {
return JpaUtil.saveOrUpdate(a.getId(), manager, a); return JpaUtil.saveOrUpdate(a.getId(), manager, a);
} }
@Override @Override
@Transactional @Transactional
public List<AuthenticationHolderEntity> getOrphanedAuthenticationHolders() { public List<AuthenticationHolderEntity> getOrphanedAuthenticationHolders() {

View File

@ -181,7 +181,7 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository {
List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList(); List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
return JpaUtil.getSingleResult(accessTokens); return JpaUtil.getSingleResult(accessTokens);
} }
@Override @Override
public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() { public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() {
TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery("OAuth2AccessTokenEntity.getAllExpiredByDate", OAuth2AccessTokenEntity.class); TypedQuery<OAuth2AccessTokenEntity> query = manager.createNamedQuery("OAuth2AccessTokenEntity.getAllExpiredByDate", OAuth2AccessTokenEntity.class);

View File

@ -78,7 +78,7 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
@Autowired @Autowired
private SystemScopeService scopeService; private SystemScopeService scopeService;
@Autowired @Autowired
private StatsService statsService; private StatsService statsService;
@ -142,9 +142,9 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
client.setScope(scopeService.removeRestrictedScopes(client.getScope())); client.setScope(scopeService.removeRestrictedScopes(client.getScope()));
ClientDetailsEntity c = clientRepository.saveClient(client); ClientDetailsEntity c = clientRepository.saveClient(client);
statsService.resetCache(); statsService.resetCache();
return c; return c;
} }
@ -202,14 +202,14 @@ public class DefaultOAuth2ClientDetailsEntityService implements ClientDetailsEnt
clientRepository.deleteClient(client); clientRepository.deleteClient(client);
statsService.resetCache(); statsService.resetCache();
} }
/** /**
* Update the oldClient with information from the newClient. The * Update the oldClient with information from the newClient. The
* id from oldClient is retained. * id from oldClient is retained.
* *
* Checks to make sure the refresh grant type and * Checks to make sure the refresh grant type and
* the scopes are set appropriately. * the scopes are set appropriately.
* *
* Checks to make sure the redirect URIs aren't blacklisted. * Checks to make sure the redirect URIs aren't blacklisted.

View File

@ -51,8 +51,6 @@ import org.springframework.security.oauth2.provider.TokenRequest;
import org.springframework.security.oauth2.provider.token.TokenEnhancer; import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import com.google.common.base.Predicate;
import com.google.common.collect.Collections2;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.PlainJWT;
@ -404,8 +402,8 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
for (OAuth2AccessTokenEntity oAuth2AccessTokenEntity : accessTokens) { for (OAuth2AccessTokenEntity oAuth2AccessTokenEntity : accessTokens) {
try { try {
revokeAccessToken(oAuth2AccessTokenEntity); revokeAccessToken(oAuth2AccessTokenEntity);
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
//An ID token is deleted with its corresponding access token, but then the ID token is on the list of expired tokens as well and there is //An ID token is deleted with its corresponding access token, but then the ID token is on the list of expired tokens as well and there is
//nothing in place to distinguish it from any other. //nothing in place to distinguish it from any other.
//An attempt to delete an already deleted token returns an error, stopping the cleanup dead. We need it to keep going. //An attempt to delete an already deleted token returns an error, stopping the cleanup dead. We need it to keep going.
} }
@ -416,7 +414,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
for (OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity : refreshTokens) { for (OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity : refreshTokens) {
revokeRefreshToken(oAuth2RefreshTokenEntity); revokeRefreshToken(oAuth2RefreshTokenEntity);
} }
Collection<AuthenticationHolderEntity> authHolders = getOrphanedAuthenticationHolders(); Collection<AuthenticationHolderEntity> authHolders = getOrphanedAuthenticationHolders();
logger.info("Found " + authHolders.size() + " orphaned authentication holders"); logger.info("Found " + authHolders.size() + " orphaned authentication holders");
for(AuthenticationHolderEntity authHolder : authHolders) { for(AuthenticationHolderEntity authHolder : authHolders) {
@ -431,7 +429,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
private Collection<OAuth2RefreshTokenEntity> getExpiredRefreshTokens() { private Collection<OAuth2RefreshTokenEntity> getExpiredRefreshTokens() {
return Sets.newHashSet(tokenRepository.getAllExpiredRefreshTokens()); return Sets.newHashSet(tokenRepository.getAllExpiredRefreshTokens());
} }
private Collection<AuthenticationHolderEntity> getOrphanedAuthenticationHolders() { private Collection<AuthenticationHolderEntity> getOrphanedAuthenticationHolders() {
return Sets.newHashSet(authenticationHolderRepository.getOrphanedAuthenticationHolders()); return Sets.newHashSet(authenticationHolderRepository.getOrphanedAuthenticationHolders());
} }

View File

@ -58,22 +58,22 @@ public class TokenApiView extends AbstractView {
JsonObject o = new JsonObject(); JsonObject o = new JsonObject();
o.addProperty("value", src.getValue()); o.addProperty("value", src.getValue());
o.addProperty("id", src.getId()); o.addProperty("id", src.getId());
o.addProperty("idTokenId", src.getIdToken() != null ? src.getIdToken().getId() : null); o.addProperty("idTokenId", src.getIdToken() != null ? src.getIdToken().getId() : null);
o.addProperty("refreshTokenId", src.getRefreshToken() != null ? src.getRefreshToken().getId() : null); o.addProperty("refreshTokenId", src.getRefreshToken() != null ? src.getRefreshToken().getId() : null);
o.add("scopes", context.serialize(src.getScope())); o.add("scopes", context.serialize(src.getScope()));
o.addProperty("clientId", src.getClient().getClientId()); o.addProperty("clientId", src.getClient().getClientId());
o.addProperty("userId", src.getAuthenticationHolder().getAuthentication().getName()); o.addProperty("userId", src.getAuthenticationHolder().getAuthentication().getName());
o.add("expiration", context.serialize(src.getExpiration())); o.add("expiration", context.serialize(src.getExpiration()));
return o; return o;
} }
}) })
.registerTypeAdapter(OAuth2RefreshTokenEntity.class, new JsonSerializer<OAuth2RefreshTokenEntity>() { .registerTypeAdapter(OAuth2RefreshTokenEntity.class, new JsonSerializer<OAuth2RefreshTokenEntity>() {
@ -81,20 +81,20 @@ public class TokenApiView extends AbstractView {
public JsonElement serialize(OAuth2RefreshTokenEntity src, public JsonElement serialize(OAuth2RefreshTokenEntity src,
Type typeOfSrc, JsonSerializationContext context) { Type typeOfSrc, JsonSerializationContext context) {
JsonObject o = new JsonObject(); JsonObject o = new JsonObject();
o.addProperty("value", src.getValue()); o.addProperty("value", src.getValue());
o.addProperty("id", src.getId()); o.addProperty("id", src.getId());
o.add("scopes", context.serialize(src.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope())); o.add("scopes", context.serialize(src.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope()));
o.addProperty("clientId", src.getClient().getClientId()); o.addProperty("clientId", src.getClient().getClientId());
o.addProperty("userId", src.getAuthenticationHolder().getAuthentication().getName()); o.addProperty("userId", src.getAuthenticationHolder().getAuthentication().getName());
o.add("expiration", context.serialize(src.getExpiration())); o.add("expiration", context.serialize(src.getExpiration()));
return o; return o;
} }
}) })
.serializeNulls() .serializeNulls()
.setDateFormat("yyyy-MM-dd'T'HH:mm:ssZ") .setDateFormat("yyyy-MM-dd'T'HH:mm:ssZ")

View File

@ -69,13 +69,13 @@ public class OAuthConfirmationController {
@Autowired @Autowired
private ScopeClaimTranslationService scopeClaimTranslationService; private ScopeClaimTranslationService scopeClaimTranslationService;
@Autowired @Autowired
private UserInfoService userInfoService; private UserInfoService userInfoService;
@Autowired @Autowired
private StatsService statsService; private StatsService statsService;
private static Logger logger = LoggerFactory.getLogger(OAuthConfirmationController.class); private static Logger logger = LoggerFactory.getLogger(OAuthConfirmationController.class);
public OAuthConfirmationController() { public OAuthConfirmationController() {
@ -131,7 +131,7 @@ public class OAuthConfirmationController {
model.put("redirect_uri", redirect_uri); model.put("redirect_uri", redirect_uri);
// pre-process the scopes // pre-process the scopes
Set<SystemScope> scopes = scopeService.fromStrings(authRequest.getScope()); Set<SystemScope> scopes = scopeService.fromStrings(authRequest.getScope());
@ -157,7 +157,7 @@ public class OAuthConfirmationController {
for (SystemScope systemScope : sortedScopes) { for (SystemScope systemScope : sortedScopes) {
Map<String, String> claimValues = new HashMap<String, String>(); Map<String, String> claimValues = new HashMap<String, String>();
Set<String> claims = scopeClaimTranslationService.getClaimsForScope(systemScope.getValue()); Set<String> claims = scopeClaimTranslationService.getClaimsForScope(systemScope.getValue());
for (String claim : claims) { for (String claim : claims) {
if (userJson.has(claim) && userJson.get(claim).isJsonPrimitive()) { if (userJson.has(claim) && userJson.get(claim).isJsonPrimitive()) {
@ -165,23 +165,23 @@ public class OAuthConfirmationController {
claimValues.put(claim, userJson.get(claim).getAsString()); claimValues.put(claim, userJson.get(claim).getAsString());
} }
} }
claimsForScopes.put(systemScope.getValue(), claimValues); claimsForScopes.put(systemScope.getValue(), claimValues);
} }
model.put("claims", claimsForScopes); model.put("claims", claimsForScopes);
// client stats // client stats
Integer count = statsService.getCountForClientId(client.getId()); Integer count = statsService.getCountForClientId(client.getId());
model.put("count", count); model.put("count", count);
// contacts // contacts
if (client.getContacts() != null) { if (client.getContacts() != null) {
String contacts = Joiner.on(", ").join(client.getContacts()); String contacts = Joiner.on(", ").join(client.getContacts());
model.put("contacts", contacts); model.put("contacts", contacts);
} }
// if the client is over a week old and has more than one registration, don't give such a big warning // if the client is over a week old and has more than one registration, don't give such a big warning
// instead, tag as "Generally Recognized As Safe (gras) // instead, tag as "Generally Recognized As Safe (gras)
Date lastWeek = new Date(System.currentTimeMillis() + (60 * 60 * 24 * 7 * 1000)); Date lastWeek = new Date(System.currentTimeMillis() + (60 * 60 * 24 * 7 * 1000));
@ -191,10 +191,10 @@ public class OAuthConfirmationController {
} else { } else {
model.put("gras", false); model.put("gras", false);
} }
// inject a random value for CSRF purposes // inject a random value for CSRF purposes
model.put("csrf", authRequest.getExtensions().get("csrf")); model.put("csrf", authRequest.getExtensions().get("csrf"));
return "approve"; return "approve";
} }

View File

@ -76,7 +76,7 @@ public class TokenAPI {
return "tokenApiView"; return "tokenApiView";
} }
} }
@RequestMapping(value = "/access/{id}", method = RequestMethod.DELETE, produces = "application/json") @RequestMapping(value = "/access/{id}", method = RequestMethod.DELETE, produces = "application/json")
public String deleteAccessTokenById(@PathVariable("id") Long id, ModelMap m, Principal p) { public String deleteAccessTokenById(@PathVariable("id") Long id, ModelMap m, Principal p) {
@ -94,19 +94,19 @@ public class TokenAPI {
return "jsonErrorView"; return "jsonErrorView";
} else { } else {
tokenService.revokeAccessToken(token); tokenService.revokeAccessToken(token);
return "httpCodeView"; return "httpCodeView";
} }
} }
@RequestMapping(value = "/refresh", method = RequestMethod.GET, produces = "application/json") @RequestMapping(value = "/refresh", method = RequestMethod.GET, produces = "application/json")
public String getAllRefreshTokens(ModelMap m, Principal p) { public String getAllRefreshTokens(ModelMap m, Principal p) {
Set<OAuth2RefreshTokenEntity> allTokens = tokenService.getAllRefreshTokensForUser(p.getName()); Set<OAuth2RefreshTokenEntity> allTokens = tokenService.getAllRefreshTokensForUser(p.getName());
m.put("entity", allTokens); m.put("entity", allTokens);
return "tokenApiView"; return "tokenApiView";
} }
@RequestMapping(value = "/refresh/{id}", method = RequestMethod.GET, produces = "application/json") @RequestMapping(value = "/refresh/{id}", method = RequestMethod.GET, produces = "application/json")
@ -129,7 +129,7 @@ public class TokenAPI {
return "tokenApiView"; return "tokenApiView";
} }
} }
@RequestMapping(value = "/refresh/{id}", method = RequestMethod.DELETE, produces = "application/json") @RequestMapping(value = "/refresh/{id}", method = RequestMethod.DELETE, produces = "application/json")
public String deleteRefreshTokenById(@PathVariable("id") Long id, ModelMap m, Principal p) { public String deleteRefreshTokenById(@PathVariable("id") Long id, ModelMap m, Principal p) {
@ -147,9 +147,9 @@ public class TokenAPI {
return "jsonErrorView"; return "jsonErrorView";
} else { } else {
tokenService.revokeRefreshToken(token); tokenService.revokeRefreshToken(token);
return "httpCodeView"; return "httpCodeView";
} }
} }
} }

View File

@ -61,7 +61,7 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
@Autowired @Autowired
private JWKSetCacheService validators; private JWKSetCacheService validators;
@Autowired @Autowired
private SymmetricCacheService symmetricCacheService; private SymmetricCacheService symmetricCacheService;
@ -124,12 +124,12 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
if (request.getClientId() != null) { if (request.getClientId() != null) {
try { try {
ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId()); ClientDetailsEntity client = clientDetailsService.loadClientByClientId(request.getClientId());
if ((request.getScope() == null || request.getScope().isEmpty())) { if ((request.getScope() == null || request.getScope().isEmpty())) {
Set<String> clientScopes = client.getScope(); Set<String> clientScopes = client.getScope();
request.setScope(clientScopes); request.setScope(clientScopes);
} }
if (request.getExtensions().get("max_age") == null && client.getDefaultMaxAge() != null) { if (request.getExtensions().get("max_age") == null && client.getDefaultMaxAge() != null) {
request.getExtensions().put("max_age", client.getDefaultMaxAge().toString()); request.getExtensions().put("max_age", client.getDefaultMaxAge().toString());
} }
@ -138,12 +138,12 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
} }
} }
// add CSRF protection to the request on first parse // add CSRF protection to the request on first parse
String csrf = UUID.randomUUID().toString(); String csrf = UUID.randomUUID().toString();
request.getExtensions().put("csrf", csrf); request.getExtensions().put("csrf", csrf);
return request; return request;
} }
@ -180,7 +180,7 @@ public class ConnectOAuth2RequestFactory extends DefaultOAuth2RequestFactory {
JWSAlgorithm alg = signedJwt.getHeader().getAlgorithm(); JWSAlgorithm alg = signedJwt.getHeader().getAlgorithm();
if (client.getRequestObjectSigningAlg() == null || if (client.getRequestObjectSigningAlg() == null ||
!client.getRequestObjectSigningAlg().equals(alg)) { !client.getRequestObjectSigningAlg().equals(alg)) {
throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")"); throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
} }

View File

@ -32,6 +32,10 @@ import com.nimbusds.jwt.JWT;
*/ */
public class JwtBearerAssertionAuthenticationToken extends AbstractAuthenticationToken { public class JwtBearerAssertionAuthenticationToken extends AbstractAuthenticationToken {
/**
*
*/
private static final long serialVersionUID = -3138213539914074617L;
private String clientId; private String clientId;
private JWT jwt; private JWT jwt;

View File

@ -55,7 +55,7 @@ public class JwtBearerAuthenticationProvider implements AuthenticationProvider {
// map of verifiers, load keys for clients // map of verifiers, load keys for clients
@Autowired @Autowired
private JWKSetCacheService validators; private JWKSetCacheService validators;
// map of symmetric verifiers for client secrets // map of symmetric verifiers for client secrets
@Autowired @Autowired
private SymmetricCacheService symmetricCacheService; private SymmetricCacheService symmetricCacheService;
@ -92,15 +92,15 @@ public class JwtBearerAuthenticationProvider implements AuthenticationProvider {
JWSAlgorithm alg = jws.getHeader().getAlgorithm(); JWSAlgorithm alg = jws.getHeader().getAlgorithm();
if (client.getTokenEndpointAuthSigningAlg() != null && if (client.getTokenEndpointAuthSigningAlg() != null &&
!client.getTokenEndpointAuthSigningAlg().equals(alg)) { !client.getTokenEndpointAuthSigningAlg().equals(alg)) {
throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")"); throw new InvalidClientException("Client's registered request object signing algorithm (" + client.getRequestObjectSigningAlg() + ") does not match request object's actual algorithm (" + alg.getName() + ")");
} }
if (client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) && if (client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY) &&
(alg.equals(JWSAlgorithm.RS256) (alg.equals(JWSAlgorithm.RS256)
|| alg.equals(JWSAlgorithm.RS384) || alg.equals(JWSAlgorithm.RS384)
|| alg.equals(JWSAlgorithm.RS512))) { || alg.equals(JWSAlgorithm.RS512))) {
JwtSigningAndValidationService validator = validators.getValidator(client.getJwksUri()); JwtSigningAndValidationService validator = validators.getValidator(client.getJwksUri());
@ -113,24 +113,24 @@ public class JwtBearerAuthenticationProvider implements AuthenticationProvider {
} }
} else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) && } else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT) &&
(alg.equals(JWSAlgorithm.HS256) (alg.equals(JWSAlgorithm.HS256)
|| alg.equals(JWSAlgorithm.HS384) || alg.equals(JWSAlgorithm.HS384)
|| alg.equals(JWSAlgorithm.HS512))) { || alg.equals(JWSAlgorithm.HS512))) {
// it's HMAC, we need to make a validator based on the client secret // it's HMAC, we need to make a validator based on the client secret
JwtSigningAndValidationService validator = symmetricCacheService.getSymmetricValidtor(client); JwtSigningAndValidationService validator = symmetricCacheService.getSymmetricValidtor(client);
if (validator == null) { if (validator == null) {
throw new AuthenticationServiceException("Unable to create signature validator for client's secret: " + client.getClientSecret()); throw new AuthenticationServiceException("Unable to create signature validator for client's secret: " + client.getClientSecret());
} }
if (!validator.validateSignature(jws)) { if (!validator.validateSignature(jws)) {
throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication."); throw new AuthenticationServiceException("Signature did not validate for presented JWT authentication.");
} }
} }
} }
// check the issuer // check the issuer
if (jwtClaims.getIssuer() == null) { if (jwtClaims.getIssuer() == null) {
throw new AuthenticationServiceException("Assertion Token Issuer is null"); throw new AuthenticationServiceException("Assertion Token Issuer is null");

View File

@ -42,7 +42,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException; import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.provider.AuthorizationRequest; import org.springframework.security.oauth2.provider.AuthorizationRequest;
import org.springframework.security.oauth2.provider.OAuth2RequestFactory; import org.springframework.security.oauth2.provider.OAuth2RequestFactory;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -68,7 +67,7 @@ public class PromptFilter extends GenericFilterBean {
@Autowired @Autowired
private ClientDetailsEntityService clientService; private ClientDetailsEntityService clientService;
/** /**
* *
*/ */
@ -77,7 +76,7 @@ public class PromptFilter extends GenericFilterBean {
HttpServletRequest request = (HttpServletRequest) req; HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res; HttpServletResponse response = (HttpServletResponse) res;
// skip everything that's not an authorize URL // skip everything that's not an authorize URL
if (!request.getServletPath().startsWith("/authorize")) { if (!request.getServletPath().startsWith("/authorize")) {
chain.doFilter(req, res); chain.doFilter(req, res);
@ -88,7 +87,7 @@ public class PromptFilter extends GenericFilterBean {
AuthorizationRequest authRequest = authRequestFactory.createAuthorizationRequest(createRequestMap(request.getParameterMap())); AuthorizationRequest authRequest = authRequestFactory.createAuthorizationRequest(createRequestMap(request.getParameterMap()));
ClientDetailsEntity client = null; ClientDetailsEntity client = null;
try { try {
client = clientService.loadClientByClientId(authRequest.getClientId()); client = clientService.loadClientByClientId(authRequest.getClientId());
} catch (InvalidClientException e) { } catch (InvalidClientException e) {
@ -96,7 +95,7 @@ public class PromptFilter extends GenericFilterBean {
} catch (IllegalArgumentException e) { } catch (IllegalArgumentException e) {
// no need to worry about this here, it would be caught elsewhere // no need to worry about this here, it would be caught elsewhere
} }
if (authRequest.getExtensions().get("prompt") != null) { if (authRequest.getExtensions().get("prompt") != null) {
// we have a "prompt" parameter // we have a "prompt" parameter
String prompt = (String)authRequest.getExtensions().get("prompt"); String prompt = (String)authRequest.getExtensions().get("prompt");
@ -156,14 +155,14 @@ public class PromptFilter extends GenericFilterBean {
Integer max = (client != null ? client.getDefaultMaxAge() : null); Integer max = (client != null ? client.getDefaultMaxAge() : null);
String maxAge = (String) authRequest.getExtensions().get("max_age"); String maxAge = (String) authRequest.getExtensions().get("max_age");
if (maxAge != null) { if (maxAge != null) {
max = Integer.parseInt(maxAge); max = Integer.parseInt(maxAge);
} }
if (max != null) { if (max != null) {
HttpSession session = request.getSession(); HttpSession session = request.getSession();
Date authTime = (Date) session.getAttribute(AuthenticationTimeStamper.AUTH_TIMESTAMP); Date authTime = (Date) session.getAttribute(AuthenticationTimeStamper.AUTH_TIMESTAMP);
Date now = new Date(); Date now = new Date();
if (authTime != null) { if (authTime != null) {
long seconds = (now.getTime() - authTime.getTime()) / 1000; long seconds = (now.getTime() - authTime.getTime()) / 1000;

View File

@ -53,7 +53,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
@Autowired @Autowired
private OAuth2TokenRepository tokenRepository; private OAuth2TokenRepository tokenRepository;
@Autowired @Autowired
private StatsService statsService; private StatsService statsService;
@ -90,7 +90,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
} }
approvedSiteRepository.remove(approvedSite); approvedSiteRepository.remove(approvedSite);
statsService.resetCache(); statsService.resetCache();
} }
@ -164,7 +164,7 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
remove(expired); remove(expired);
} }
} }
} }
private Predicate<ApprovedSite> isExpired = new Predicate<ApprovedSite>() { private Predicate<ApprovedSite> isExpired = new Predicate<ApprovedSite>() {

View File

@ -78,7 +78,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService {
@Autowired @Autowired
private JWKSetCacheService encrypters; private JWKSetCacheService encrypters;
@Autowired @Autowired
private SymmetricCacheService symmetricCacheService; private SymmetricCacheService symmetricCacheService;
@ -99,7 +99,7 @@ public class DefaultOIDCTokenService implements OIDCTokenService {
if (request.getExtensions().containsKey("max_age") if (request.getExtensions().containsKey("max_age")
|| (request.getExtensions().containsKey("idtoken")) // TODO: parse the ID Token claims (#473) -- for now assume it could be in there || (request.getExtensions().containsKey("idtoken")) // TODO: parse the ID Token claims (#473) -- for now assume it could be in there
|| (client.getRequireAuthTime() != null && client.getRequireAuthTime())) { || (client.getRequireAuthTime() != null && client.getRequireAuthTime())) {
Date authTime = (Date) request.getExtensions().get(AuthenticationTimeStamper.AUTH_TIMESTAMP); Date authTime = (Date) request.getExtensions().get(AuthenticationTimeStamper.AUTH_TIMESTAMP);
if (authTime != null) { if (authTime != null) {
idClaims.setClaim("auth_time", authTime.getTime() / 1000); idClaims.setClaim("auth_time", authTime.getTime() / 1000);
@ -130,42 +130,42 @@ public class DefaultOIDCTokenService implements OIDCTokenService {
Base64URL at_hash = IdTokenHashUtils.getAccessTokenHash(signingAlg, accessToken); Base64URL at_hash = IdTokenHashUtils.getAccessTokenHash(signingAlg, accessToken);
idClaims.setClaim("at_hash", at_hash); idClaims.setClaim("at_hash", at_hash);
} }
if (client.getIdTokenEncryptedResponseAlg() != null && !client.getIdTokenEncryptedResponseAlg().equals(Algorithm.NONE) if (client.getIdTokenEncryptedResponseAlg() != null && !client.getIdTokenEncryptedResponseAlg().equals(Algorithm.NONE)
&& client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE) && client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE)
&& !Strings.isNullOrEmpty(client.getJwksUri())) { && !Strings.isNullOrEmpty(client.getJwksUri())) {
JwtEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri()); JwtEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri());
if (encrypter != null) { if (encrypter != null) {
EncryptedJWT idToken = new EncryptedJWT(new JWEHeader(client.getIdTokenEncryptedResponseAlg(), client.getIdTokenEncryptedResponseEnc()), idClaims); EncryptedJWT idToken = new EncryptedJWT(new JWEHeader(client.getIdTokenEncryptedResponseAlg(), client.getIdTokenEncryptedResponseEnc()), idClaims);
encrypter.encryptJwt(idToken); encrypter.encryptJwt(idToken);
idTokenEntity.setJwt(idToken); idTokenEntity.setJwt(idToken);
} else { } else {
logger.error("Couldn't find encrypter for client: " + client.getClientId()); logger.error("Couldn't find encrypter for client: " + client.getClientId());
} }
} else { } else {
SignedJWT idToken = new SignedJWT(new JWSHeader(signingAlg), idClaims); SignedJWT idToken = new SignedJWT(new JWSHeader(signingAlg), idClaims);
if (signingAlg.equals(JWSAlgorithm.HS256) if (signingAlg.equals(JWSAlgorithm.HS256)
|| signingAlg.equals(JWSAlgorithm.HS384) || signingAlg.equals(JWSAlgorithm.HS384)
|| signingAlg.equals(JWSAlgorithm.HS512)) { || signingAlg.equals(JWSAlgorithm.HS512)) {
JwtSigningAndValidationService signer = symmetricCacheService.getSymmetricValidtor(client); JwtSigningAndValidationService signer = symmetricCacheService.getSymmetricValidtor(client);
// sign it with the client's secret // sign it with the client's secret
signer.signJwt(idToken); signer.signJwt(idToken);
} else { } else {
// sign it with the server's key // sign it with the server's key
jwtService.signJwt(idToken); jwtService.signJwt(idToken);
} }
idTokenEntity.setJwt(idToken); idTokenEntity.setJwt(idToken);
} }

View File

@ -51,37 +51,37 @@ public class DefaultStatsService implements StatsService {
@Autowired @Autowired
private ClientDetailsEntityService clientService; private ClientDetailsEntityService clientService;
// stats cache // stats cache
private Supplier<Map<String, Integer>> summaryCache = createSummaryCache(); private Supplier<Map<String, Integer>> summaryCache = createSummaryCache();
private Supplier<Map<String, Integer>> createSummaryCache() { private Supplier<Map<String, Integer>> createSummaryCache() {
return Suppliers.memoizeWithExpiration(new Supplier<Map<String, Integer>>() { return Suppliers.memoizeWithExpiration(new Supplier<Map<String, Integer>>() {
@Override @Override
public Map<String, Integer> get() { public Map<String, Integer> get() {
return computeSummaryStats(); return computeSummaryStats();
} }
}, 10, TimeUnit.MINUTES); }, 10, TimeUnit.MINUTES);
} }
private Supplier<Map<Long, Integer>> byClientIdCache = createByClientIdCache(); private Supplier<Map<Long, Integer>> byClientIdCache = createByClientIdCache();
private Supplier<Map<Long, Integer>> createByClientIdCache() { private Supplier<Map<Long, Integer>> createByClientIdCache() {
return Suppliers.memoizeWithExpiration(new Supplier<Map<Long, Integer>>() { return Suppliers.memoizeWithExpiration(new Supplier<Map<Long, Integer>>() {
@Override @Override
public Map<Long, Integer> get() { public Map<Long, Integer> get() {
return computeByClientId(); return computeByClientId();
} }
}, 10, TimeUnit.MINUTES); }, 10, TimeUnit.MINUTES);
} }
@Override @Override
public Map<String, Integer> getSummaryStats() { public Map<String, Integer> getSummaryStats() {
return summaryCache.get(); return summaryCache.get();
} }
// do the actual computation // do the actual computation
private Map<String, Integer> computeSummaryStats() { private Map<String, Integer> computeSummaryStats() {
// get all approved sites // get all approved sites
@ -110,7 +110,7 @@ public class DefaultStatsService implements StatsService {
public Map<Long, Integer> getByClientId() { public Map<Long, Integer> getByClientId() {
return byClientIdCache.get(); return byClientIdCache.get();
} }
private Map<Long, Integer> computeByClientId() { private Map<Long, Integer> computeByClientId() {
// get all approved sites // get all approved sites
Collection<ApprovedSite> allSites = approvedSiteService.getAll(); Collection<ApprovedSite> allSites = approvedSiteService.getAll();
@ -162,5 +162,5 @@ public class DefaultStatsService implements StatsService {
summaryCache = createSummaryCache(); summaryCache = createSummaryCache();
byClientIdCache = createByClientIdCache(); byClientIdCache = createByClientIdCache();
} }
} }

View File

@ -33,7 +33,6 @@ import org.mitre.openid.connect.service.UserInfoService;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.OAuth2Request;
@ -41,7 +40,6 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader; import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
@ -72,7 +70,7 @@ public class ConnectTokenEnhancer implements TokenEnhancer {
@Autowired @Autowired
private JWKSetCacheService encryptors; private JWKSetCacheService encryptors;
@Autowired @Autowired
private SymmetricCacheService symmetricCacheService; private SymmetricCacheService symmetricCacheService;
@ -105,7 +103,7 @@ public class ConnectTokenEnhancer implements TokenEnhancer {
jwtService.signJwt(signed); jwtService.signJwt(signed);
token.setJwt(signed); token.setJwt(signed);
/** /**
* Authorization request scope MUST include "openid" in OIDC, but access token request * Authorization request scope MUST include "openid" in OIDC, but access token request
* may or may not include the scope parameter. As long as the AuthorizationRequest * may or may not include the scope parameter. As long as the AuthorizationRequest
@ -115,18 +113,18 @@ public class ConnectTokenEnhancer implements TokenEnhancer {
* Also, there must be a user authentication involved in the request for it to be considered * Also, there must be a user authentication involved in the request for it to be considered
* OIDC and not OAuth, so we check for that as well. * OIDC and not OAuth, so we check for that as well.
*/ */
if (originalAuthRequest.getScope().contains("openid") if (originalAuthRequest.getScope().contains("openid")
&& !authentication.isClientOnly()) { && !authentication.isClientOnly()) {
String username = authentication.getName(); String username = authentication.getName();
UserInfo userInfo = userInfoService.getByUsernameAndClientId(username, clientId); UserInfo userInfo = userInfoService.getByUsernameAndClientId(username, clientId);
if (userInfo != null) { if (userInfo != null) {
OAuth2AccessTokenEntity idTokenEntity = connectTokenService.createIdToken(client, OAuth2AccessTokenEntity idTokenEntity = connectTokenService.createIdToken(client,
originalAuthRequest, claims.getIssueTime(), originalAuthRequest, claims.getIssueTime(),
userInfo.getSub(), token); userInfo.getSub(), token);
// attach the id token to the parent access token // attach the id token to the parent access token
token.setIdToken(idTokenEntity); token.setIdToken(idTokenEntity);
} else { } else {

View File

@ -97,18 +97,18 @@ public class TofuUserApprovalHandler implements UserApprovalHandler {
} else { } else {
// if not, check to see if the user has approved it // if not, check to see if the user has approved it
if (Boolean.parseBoolean(authorizationRequest.getApprovalParameters().get("user_oauth_approval"))) { // TODO: make parameter name configurable? if (Boolean.parseBoolean(authorizationRequest.getApprovalParameters().get("user_oauth_approval"))) { // TODO: make parameter name configurable?
// check the value of the CSRF parameter // check the value of the CSRF parameter
if (authorizationRequest.getExtensions().get("csrf") != null) { if (authorizationRequest.getExtensions().get("csrf") != null) {
if (authorizationRequest.getExtensions().get("csrf").equals(authorizationRequest.getApprovalParameters().get("csrf"))) { if (authorizationRequest.getExtensions().get("csrf").equals(authorizationRequest.getApprovalParameters().get("csrf"))) {
// make sure the user is actually authenticated // make sure the user is actually authenticated
return userAuthentication.isAuthenticated(); return userAuthentication.isAuthenticated();
} }
} }
} }
// if the above doesn't pass, it's not yet approved // if the above doesn't pass, it's not yet approved
return false; return false;
} }

View File

@ -9,7 +9,6 @@ import java.io.Writer;
import java.text.ParseException; import java.text.ParseException;
import java.util.Date; import java.util.Date;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID; import java.util.UUID;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
@ -28,7 +27,6 @@ import org.springframework.stereotype.Component;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject; import com.google.gson.JsonObject;
import com.nimbusds.jose.Algorithm; import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JWEHeader; import com.nimbusds.jose.JWEHeader;
@ -46,16 +44,16 @@ import com.nimbusds.jwt.SignedJWT;
public class UserInfoJwtView extends UserInfoView { public class UserInfoJwtView extends UserInfoView {
private static Logger logger = LoggerFactory.getLogger(UserInfoJwtView.class); private static Logger logger = LoggerFactory.getLogger(UserInfoJwtView.class);
@Autowired @Autowired
private JwtSigningAndValidationService jwtService; private JwtSigningAndValidationService jwtService;
@Autowired @Autowired
private ConfigurationPropertiesBean config; private ConfigurationPropertiesBean config;
@Autowired @Autowired
private JWKSetCacheService encrypters; private JWKSetCacheService encrypters;
@Autowired @Autowired
private SymmetricCacheService symmetricCacheService; private SymmetricCacheService symmetricCacheService;
@ -65,40 +63,40 @@ public class UserInfoJwtView extends UserInfoView {
try { try {
ClientDetailsEntity client = (ClientDetailsEntity)model.get("client"); ClientDetailsEntity client = (ClientDetailsEntity)model.get("client");
// use the parser to import the user claims into the object // use the parser to import the user claims into the object
StringWriter writer = new StringWriter(); StringWriter writer = new StringWriter();
gson.toJson(json, writer); gson.toJson(json, writer);
JWTClaimsSet claims = JWTClaimsSet.parse(writer.toString()); JWTClaimsSet claims = JWTClaimsSet.parse(writer.toString());
claims.setAudience(Lists.newArrayList(client.getClientId())); claims.setAudience(Lists.newArrayList(client.getClientId()));
claims.setIssuer(config.getIssuer()); claims.setIssuer(config.getIssuer());
claims.setIssueTime(new Date()); claims.setIssueTime(new Date());
claims.setJWTID(UUID.randomUUID().toString()); // set a random NONCE in the middle of it claims.setJWTID(UUID.randomUUID().toString()); // set a random NONCE in the middle of it
if (client.getIdTokenEncryptedResponseAlg() != null && !client.getIdTokenEncryptedResponseAlg().equals(Algorithm.NONE) if (client.getIdTokenEncryptedResponseAlg() != null && !client.getIdTokenEncryptedResponseAlg().equals(Algorithm.NONE)
&& client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE) && client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE)
&& !Strings.isNullOrEmpty(client.getJwksUri())) { && !Strings.isNullOrEmpty(client.getJwksUri())) {
// encrypt it to the client's key // encrypt it to the client's key
JwtEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri()); JwtEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client.getJwksUri());
if (encrypter != null) { if (encrypter != null) {
EncryptedJWT encrypted = new EncryptedJWT(new JWEHeader(client.getIdTokenEncryptedResponseAlg(), client.getIdTokenEncryptedResponseEnc()), claims); EncryptedJWT encrypted = new EncryptedJWT(new JWEHeader(client.getIdTokenEncryptedResponseAlg(), client.getIdTokenEncryptedResponseEnc()), claims);
encrypter.encryptJwt(encrypted); encrypter.encryptJwt(encrypted);
Writer out = response.getWriter(); Writer out = response.getWriter();
out.write(encrypted.serialize()); out.write(encrypted.serialize());
} else { } else {
logger.error("Couldn't find encrypter for client: " + client.getClientId()); logger.error("Couldn't find encrypter for client: " + client.getClientId());
} }
@ -108,9 +106,9 @@ public class UserInfoJwtView extends UserInfoView {
if (client.getUserInfoSignedResponseAlg() != null) { if (client.getUserInfoSignedResponseAlg() != null) {
signingAlg = client.getUserInfoSignedResponseAlg(); signingAlg = client.getUserInfoSignedResponseAlg();
} }
SignedJWT signed = new SignedJWT(new JWSHeader(signingAlg), claims); SignedJWT signed = new SignedJWT(new JWSHeader(signingAlg), claims);
if (signingAlg.equals(JWSAlgorithm.HS256) if (signingAlg.equals(JWSAlgorithm.HS256)
|| signingAlg.equals(JWSAlgorithm.HS384) || signingAlg.equals(JWSAlgorithm.HS384)
|| signingAlg.equals(JWSAlgorithm.HS512)) { || signingAlg.equals(JWSAlgorithm.HS512)) {
@ -123,16 +121,16 @@ public class UserInfoJwtView extends UserInfoView {
// sign it with the server's key // sign it with the server's key
jwtService.signJwt(signed); jwtService.signJwt(signed);
} }
Writer out = response.getWriter(); Writer out = response.getWriter();
out.write(signed.serialize()); out.write(signed.serialize());
} }
} catch (IOException e) { } catch (IOException e) {
logger.error("IO Exception in UserInfoJwtView", e); logger.error("IO Exception in UserInfoJwtView", e);
} catch (ParseException e) { } catch (ParseException e) {
// TODO Auto-generated catch block // TODO Auto-generated catch block
e.printStackTrace(); e.printStackTrace();
} }
} }
} }

View File

@ -90,29 +90,29 @@ public class UserInfoView extends AbstractView {
response.setContentType("application/json"); response.setContentType("application/json");
JsonObject authorizedClaims = null; JsonObject authorizedClaims = null;
JsonObject requestedClaims = null; JsonObject requestedClaims = null;
if (model.get("authorizedClaims") != null) { if (model.get("authorizedClaims") != null) {
authorizedClaims = jsonParser.parse((String) model.get("authorizedClaims")).getAsJsonObject(); authorizedClaims = jsonParser.parse((String) model.get("authorizedClaims")).getAsJsonObject();
} }
if (model.get("requestedClaims") != null) { if (model.get("requestedClaims") != null) {
requestedClaims = jsonParser.parse((String) model.get("requestedClaims")).getAsJsonObject(); requestedClaims = jsonParser.parse((String) model.get("requestedClaims")).getAsJsonObject();
} }
JsonObject json = toJsonFromRequestObj(userInfo, scope, authorizedClaims, requestedClaims); JsonObject json = toJsonFromRequestObj(userInfo, scope, authorizedClaims, requestedClaims);
writeOut(json, model, request, response); writeOut(json, model, request, response);
} }
protected void writeOut(JsonObject json, Map<String, Object> model, HttpServletRequest request, HttpServletResponse response) { protected void writeOut(JsonObject json, Map<String, Object> model, HttpServletRequest request, HttpServletResponse response) {
try { try {
Writer out = response.getWriter(); Writer out = response.getWriter();
gson.toJson(json, out); gson.toJson(json, out);
} catch (IOException e) { } catch (IOException e) {
logger.error("IOException in UserInfoView.java: ", e); logger.error("IOException in UserInfoView.java: ", e);
} }
} }
/** /**

View File

@ -65,7 +65,7 @@ public class ClientDynamicRegistrationEndpoint {
@Autowired @Autowired
private SystemScopeService scopeService; private SystemScopeService scopeService;
@Autowired @Autowired
private BlacklistedSiteService blacklistService; private BlacklistedSiteService blacklistService;
@ -125,33 +125,33 @@ public class ClientDynamicRegistrationEndpoint {
newClient.setGrantTypes(Sets.newHashSet("authorization_code")); // allow authorization code grant type by default newClient.setGrantTypes(Sets.newHashSet("authorization_code")); // allow authorization code grant type by default
} }
} }
// check to make sure this client registered a redirect URI if using a redirect flow // check to make sure this client registered a redirect URI if using a redirect flow
if (newClient.getGrantTypes().contains("authorization_code") || newClient.getGrantTypes().contains("implicit")) { if (newClient.getGrantTypes().contains("authorization_code") || newClient.getGrantTypes().contains("implicit")) {
if (newClient.getRedirectUris() == null || newClient.getRedirectUris().isEmpty()) { if (newClient.getRedirectUris() == null || newClient.getRedirectUris().isEmpty()) {
// return an error // return an error
m.addAttribute("error", "invalid_client_uri"); m.addAttribute("error", "invalid_client_uri");
m.addAttribute("errorMessage", "Clients using a redirect-based grant type must register at least one redirect URI."); m.addAttribute("errorMessage", "Clients using a redirect-based grant type must register at least one redirect URI.");
m.addAttribute("code", HttpStatus.BAD_REQUEST); m.addAttribute("code", HttpStatus.BAD_REQUEST);
return "jsonErrorView"; return "jsonErrorView";
} }
for (String uri : newClient.getRedirectUris()) { for (String uri : newClient.getRedirectUris()) {
if (blacklistService.isBlacklisted(uri)) { if (blacklistService.isBlacklisted(uri)) {
// return an error // return an error
m.addAttribute("error", "invalid_client_uri"); m.addAttribute("error", "invalid_client_uri");
m.addAttribute("errorMessage", "Redirect URI is not allowed: " + uri); m.addAttribute("errorMessage", "Redirect URI is not allowed: " + uri);
m.addAttribute("code", HttpStatus.BAD_REQUEST); m.addAttribute("code", HttpStatus.BAD_REQUEST);
return "jsonErrorView"; return "jsonErrorView";
} }
} }
} }
// set default response types if needed // set default response types if needed
// TODO: these aren't checked by SECOAUTH // TODO: these aren't checked by SECOAUTH
// TODO: the consistency between the response_type and grant_type needs to be checked by the client service, most likely // TODO: the consistency between the response_type and grant_type needs to be checked by the client service, most likely
if (newClient.getResponseTypes() == null || newClient.getResponseTypes().isEmpty()) { if (newClient.getResponseTypes() == null || newClient.getResponseTypes().isEmpty()) {
newClient.setResponseTypes(Sets.newHashSet("code")); // default to allowing only the auth code flow newClient.setResponseTypes(Sets.newHashSet("code")); // default to allowing only the auth code flow
} }
@ -175,7 +175,7 @@ public class ClientDynamicRegistrationEndpoint {
// this client has been dynamically registered (obviously) // this client has been dynamically registered (obviously)
newClient.setDynamicallyRegistered(true); newClient.setDynamicallyRegistered(true);
// this client can't do token introspection // this client can't do token introspection
newClient.setAllowIntrospection(false); newClient.setAllowIntrospection(false);

View File

@ -65,7 +65,7 @@ public class ProtectedResourceRegistrationEndpoint {
@Autowired @Autowired
private SystemScopeService scopeService; private SystemScopeService scopeService;
@Autowired @Autowired
private BlacklistedSiteService blacklistService; private BlacklistedSiteService blacklistService;
@ -121,7 +121,7 @@ public class ProtectedResourceRegistrationEndpoint {
newClient.setGrantTypes(new HashSet<String>()); newClient.setGrantTypes(new HashSet<String>());
newClient.setResponseTypes(new HashSet<String>()); newClient.setResponseTypes(new HashSet<String>());
newClient.setRedirectUris(new HashSet<String>()); newClient.setRedirectUris(new HashSet<String>());
if (newClient.getTokenEndpointAuthMethod() == null) { if (newClient.getTokenEndpointAuthMethod() == null) {
newClient.setTokenEndpointAuthMethod(AuthMethod.SECRET_BASIC); newClient.setTokenEndpointAuthMethod(AuthMethod.SECRET_BASIC);
} }
@ -133,12 +133,12 @@ public class ProtectedResourceRegistrationEndpoint {
// we need to generate a secret // we need to generate a secret
newClient = clientService.generateClientSecret(newClient); newClient = clientService.generateClientSecret(newClient);
} }
// don't issue tokens to this client // don't issue tokens to this client
newClient.setAccessTokenValiditySeconds(0); newClient.setAccessTokenValiditySeconds(0);
newClient.setIdTokenValiditySeconds(0); newClient.setIdTokenValiditySeconds(0);
newClient.setRefreshTokenValiditySeconds(0); newClient.setRefreshTokenValiditySeconds(0);
// clear out unused fields // clear out unused fields
newClient.setDefaultACRvalues(new HashSet<String>()); newClient.setDefaultACRvalues(new HashSet<String>());
newClient.setDefaultMaxAge(null); newClient.setDefaultMaxAge(null);
@ -158,7 +158,7 @@ public class ProtectedResourceRegistrationEndpoint {
// this client has been dynamically registered (obviously) // this client has been dynamically registered (obviously)
newClient.setDynamicallyRegistered(true); newClient.setDynamicallyRegistered(true);
// this client has access to the introspection endpoint // this client has access to the introspection endpoint
newClient.setAllowIntrospection(true); newClient.setAllowIntrospection(true);

View File

@ -31,7 +31,6 @@ import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.ui.Model; import org.springframework.ui.Model;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod; import org.springframework.web.bind.annotation.RequestMethod;
@ -50,7 +49,7 @@ public class UserInfoEndpoint {
@Autowired @Autowired
private UserInfoService userInfoService; private UserInfoService userInfoService;
@Autowired @Autowired
private ClientDetailsEntityService clientService; private ClientDetailsEntityService clientService;
@ -61,7 +60,7 @@ public class UserInfoEndpoint {
*/ */
@PreAuthorize("hasRole('ROLE_USER') and #oauth2.hasScope('openid')") @PreAuthorize("hasRole('ROLE_USER') and #oauth2.hasScope('openid')")
@RequestMapping(value="/userinfo", method= {RequestMethod.GET, RequestMethod.POST}, produces = {"application/json", "application/jwt"}) @RequestMapping(value="/userinfo", method= {RequestMethod.GET, RequestMethod.POST}, produces = {"application/json", "application/jwt"})
public String getInfo(@RequestParam(value="claims", required=false) String claimsRequestJsonString, public String getInfo(@RequestParam(value="claims", required=false) String claimsRequestJsonString,
@RequestHeader(value="Accept") String acceptHeader, @RequestHeader(value="Accept") String acceptHeader,
OAuth2Authentication auth, Model model) { OAuth2Authentication auth, Model model) {
@ -93,9 +92,9 @@ public class UserInfoEndpoint {
// content negotiation // content negotiation
List<MediaType> mediaTypes = MediaType.parseMediaTypes(acceptHeader); List<MediaType> mediaTypes = MediaType.parseMediaTypes(acceptHeader);
MediaType.sortBySpecificityAndQuality(mediaTypes); MediaType.sortBySpecificityAndQuality(mediaTypes);
MediaType jose = new MediaType("application", "jwt"); MediaType jose = new MediaType("application", "jwt");
for (MediaType m : mediaTypes) { for (MediaType m : mediaTypes) {
if (!m.isWildcardType() && m.isCompatibleWith(jose)) { if (!m.isWildcardType() && m.isCompatibleWith(jose)) {
ClientDetailsEntity client = clientService.loadClientByClientId(auth.getOAuth2Request().getClientId()); ClientDetailsEntity client = clientService.loadClientByClientId(auth.getOAuth2Request().getClientId());
@ -104,8 +103,8 @@ public class UserInfoEndpoint {
return "userInfoJwtView"; return "userInfoJwtView";
} }
} }
return "userInfoView"; return "userInfoView";
} }

View File

@ -68,7 +68,7 @@ public class TestDefaultOAuth2ClientDetailsEntityService {
@Mock @Mock
private SystemScopeService scopeService; private SystemScopeService scopeService;
@Mock @Mock
private StatsService statsService; private StatsService statsService;

View File

@ -51,7 +51,7 @@ public class TestDefaultApprovedSiteService {
@Mock @Mock
private ApprovedSiteRepository repository; private ApprovedSiteRepository repository;
@Mock @Mock
private StatsService statsService; private StatsService statsService;

View File

@ -99,7 +99,7 @@ public class TestDefaultStatsService {
Mockito.when(ap5.getUserId()).thenReturn(userId2); Mockito.when(ap5.getUserId()).thenReturn(userId2);
Mockito.when(ap5.getClientId()).thenReturn(clientId1); Mockito.when(ap5.getClientId()).thenReturn(clientId1);
Mockito.when(ap6.getUserId()).thenReturn(userId1); Mockito.when(ap6.getUserId()).thenReturn(userId1);
Mockito.when(ap6.getClientId()).thenReturn(clientId4); Mockito.when(ap6.getClientId()).thenReturn(clientId4);
@ -170,10 +170,10 @@ public class TestDefaultStatsService {
assertThat(service.getCountForClientId(3L), is(1)); assertThat(service.getCountForClientId(3L), is(1));
assertThat(service.getCountForClientId(4L), is(0)); assertThat(service.getCountForClientId(4L), is(0));
} }
@Test @Test
public void cacheAndReset() { public void cacheAndReset() {
Map<String, Integer> stats = service.getSummaryStats(); Map<String, Integer> stats = service.getSummaryStats();
assertThat(stats.get("approvalCount"), is(4)); assertThat(stats.get("approvalCount"), is(4));
@ -181,22 +181,22 @@ public class TestDefaultStatsService {
assertThat(stats.get("clientCount"), is(3)); assertThat(stats.get("clientCount"), is(3));
Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4, ap5, ap6)); Mockito.when(approvedSiteService.getAll()).thenReturn(Sets.newHashSet(ap1, ap2, ap3, ap4, ap5, ap6));
Map<String, Integer> stats2 = service.getSummaryStats(); Map<String, Integer> stats2 = service.getSummaryStats();
// cache should remain the same due to memoized functions // cache should remain the same due to memoized functions
assertThat(stats2.get("approvalCount"), is(4)); assertThat(stats2.get("approvalCount"), is(4));
assertThat(stats2.get("userCount"), is(2)); assertThat(stats2.get("userCount"), is(2));
assertThat(stats2.get("clientCount"), is(3)); assertThat(stats2.get("clientCount"), is(3));
// reset the cache and make sure the count goes up // reset the cache and make sure the count goes up
service.resetCache(); service.resetCache();
Map<String, Integer> stats3 = service.getSummaryStats(); Map<String, Integer> stats3 = service.getSummaryStats();
assertThat(stats3.get("approvalCount"), is(6)); assertThat(stats3.get("approvalCount"), is(6));
assertThat(stats3.get("userCount"), is(2)); assertThat(stats3.get("userCount"), is(2));
assertThat(stats3.get("clientCount"), is(4)); assertThat(stats3.get("clientCount"), is(4));
} }
} }