Fix scope checking in refresh token flow

pull/1611/head
Andrea Ceccanti 2020-01-15 16:33:16 +01:00
parent caa687f979
commit 2c48a4625c
6 changed files with 479 additions and 449 deletions

View File

@ -22,7 +22,7 @@
<parent> <parent>
<artifactId>openid-connect-parent</artifactId> <artifactId>openid-connect-parent</artifactId>
<groupId>org.mitre</groupId> <groupId>org.mitre</groupId>
<version>1.3.5.cnaf.v20191003</version> <version>1.3.5.cnaf.20200115</version>
<relativePath>..</relativePath> <relativePath>..</relativePath>
</parent> </parent>
<artifactId>openid-connect-client</artifactId> <artifactId>openid-connect-client</artifactId>

View File

@ -22,7 +22,7 @@
<parent> <parent>
<artifactId>openid-connect-parent</artifactId> <artifactId>openid-connect-parent</artifactId>
<groupId>org.mitre</groupId> <groupId>org.mitre</groupId>
<version>1.3.5.cnaf.v20191003</version> <version>1.3.5.cnaf.20200115</version>
<relativePath>..</relativePath> <relativePath>..</relativePath>
</parent> </parent>
<artifactId>openid-connect-common</artifactId> <artifactId>openid-connect-common</artifactId>

View File

@ -23,7 +23,7 @@
<parent> <parent>
<groupId>org.mitre</groupId> <groupId>org.mitre</groupId>
<artifactId>openid-connect-parent</artifactId> <artifactId>openid-connect-parent</artifactId>
<version>1.3.5.cnaf.v20191003</version> <version>1.3.5.cnaf.20200115</version>
<relativePath>..</relativePath> <relativePath>..</relativePath>
</parent> </parent>
<build> <build>

View File

@ -29,7 +29,6 @@ import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.Collection; import java.util.Collection;
import java.util.Date; import java.util.Date;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.UUID; import java.util.UUID;
@ -66,6 +65,7 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.PlainJWT;
@ -331,33 +331,52 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();
// get the stored scopes from the authentication holder's authorization request; these are the scopes associated with the refresh token Set<String> reservedScopes = scopeService.toStrings(scopeService.getReserved());
Set<String> refreshScopesRequested = new HashSet<>(refreshToken.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope());
Set<SystemScope> refreshScopes = scopeService.fromStrings(refreshScopesRequested);
// remove any of the special system scopes
refreshScopes = scopeService.removeReservedScopes(refreshScopes);
Set<String> scopeRequested = authRequest.getScope() == null ? new HashSet<String>() : new HashSet<>(authRequest.getScope()); // Scopes linked to the refresh token, i.e. authorized by the user
Set<SystemScope> scope = scopeService.fromStrings(scopeRequested); Set<String> authorizedScopes = Sets.newHashSet(refreshToken.getAuthenticationHolder().getAuthentication().getOAuth2Request().getScope());
authorizedScopes.removeAll(reservedScopes);
// remove any of the special system scopes // Scopes requested in this refresh token flow
scope = scopeService.removeReservedScopes(scope); Set<String> requestedScopes = Sets.newHashSet();
if (authRequest.getScope() != null) {
if (scope != null && !scope.isEmpty()) { requestedScopes.addAll(authRequest.getScope());
// ensure a proper subset of scopes
if (refreshScopes != null && refreshScopes.containsAll(scope)) {
// set the scope of the new access token if requested
token.setScope(scopeService.toStrings(scope));
} else {
String errorMsg = "Up-scoping is not allowed.";
logger.error(errorMsg);
throw new InvalidScopeException(errorMsg);
}
} else {
// otherwise inherit the scope of the refresh token (if it's there -- this can return a null scope set)
token.setScope(scopeService.toStrings(refreshScopes));
} }
requestedScopes.removeAll(reservedScopes);
if (!requestedScopes.isEmpty()) {
// Check for upscoping
if (scopeService.scopesMatch(authorizedScopes, requestedScopes)) {
token.setScope(requestedScopes);
} else {
String errorMsg = "Up-scoping is not allowed.";
logger.error(errorMsg);
throw new InvalidScopeException(errorMsg);
}
} else {
// Preserve scopes linked to the original refresh token
token.setScope(authorizedScopes);
}
// if (scope != null && !scope.isEmpty()) {
// // ensure a proper subset of scopes
// // FIXME: ugly and inefficient translation to/from strings for no added value, just to work around
// // a terribly designed API
// if (refreshScopes != null && scopeService.scopesMatch(scopeService.toStrings(refreshScopes), scopeService.toStrings(scope))) {
// // set the scope of the new access token if requested
// token.setScope(scopeService.toStrings(scope));
// } else {
// String errorMsg = "Up-scoping is not allowed.";
// logger.error(errorMsg);
// throw new InvalidScopeException(errorMsg);
// }
// } else {
// // otherwise inherit the scope of the refresh token (if it's there -- this can return a null scope set)
// token.setScope(scopeService.toStrings(refreshScopes));
// }
token.setClient(client); token.setClient(client);
if (client.getAccessTokenValiditySeconds() != null) { if (client.getAccessTokenValiditySeconds() != null) {

View File

@ -17,6 +17,27 @@
*******************************************************************************/ *******************************************************************************/
package org.mitre.oauth2.service.impl; package org.mitre.oauth2.service.impl;
import static com.google.common.collect.Sets.newHashSet;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anySet;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.Date; import java.util.Date;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
@ -49,27 +70,6 @@ import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.TokenRequest; import org.springframework.security.oauth2.provider.TokenRequest;
import org.springframework.security.oauth2.provider.token.TokenEnhancer; import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import static com.google.common.collect.Sets.newHashSet;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anySet;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/** /**
* @author wkim * @author wkim
* *
@ -77,470 +77,481 @@ import static org.junit.Assert.fail;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class TestDefaultOAuth2ProviderTokenService { public class TestDefaultOAuth2ProviderTokenService {
// Grace period for time-sensitive tests. // Grace period for time-sensitive tests.
private static final long DELTA = 100L; private static final long DELTA = 100L;
// Test Fixture: // Test Fixture:
private OAuth2Authentication authentication; private OAuth2Authentication authentication;
private ClientDetailsEntity client; private ClientDetailsEntity client;
private ClientDetailsEntity badClient; private ClientDetailsEntity badClient;
private String clientId = "test_client"; private String clientId = "test_client";
private String badClientId = "bad_client"; private String badClientId = "bad_client";
private Set<String> scope = newHashSet("openid", "profile", "email", "offline_access"); private Set<String> scope = newHashSet("openid", "profile", "email", "offline_access");
private OAuth2RefreshTokenEntity refreshToken; private OAuth2RefreshTokenEntity refreshToken;
private OAuth2AccessTokenEntity accessToken; private OAuth2AccessTokenEntity accessToken;
private String refreshTokenValue = "refresh_token_value"; private String refreshTokenValue = "refresh_token_value";
private String userName = "6a50ac11786d402a9591d3e592ac770f"; private String userName = "6a50ac11786d402a9591d3e592ac770f";
private TokenRequest tokenRequest; private TokenRequest tokenRequest;
// for use when refreshing access tokens // for use when refreshing access tokens
private OAuth2Request storedAuthRequest; private OAuth2Request storedAuthRequest;
private OAuth2Authentication storedAuthentication; private OAuth2Authentication storedAuthentication;
private AuthenticationHolderEntity storedAuthHolder; private AuthenticationHolderEntity storedAuthHolder;
private Set<String> storedScope; private Set<String> storedScope;
@Mock @Mock
private OAuth2TokenRepository tokenRepository; private OAuth2TokenRepository tokenRepository;
@Mock @Mock
private AuthenticationHolderRepository authenticationHolderRepository; private AuthenticationHolderRepository authenticationHolderRepository;
@Mock @Mock
private ClientDetailsEntityService clientDetailsService; private ClientDetailsEntityService clientDetailsService;
@Mock @Mock
private TokenEnhancer tokenEnhancer; private TokenEnhancer tokenEnhancer;
@Mock @Mock
private SystemScopeService scopeService; private SystemScopeService scopeService;
@InjectMocks @InjectMocks
private DefaultOAuth2ProviderTokenService service; private DefaultOAuth2ProviderTokenService service;
/** /**
* Set up a mock authentication and mock client to work with. * Set up a mock authentication and mock client to work with.
*/ */
@Before @Before
public void prepare() { public void prepare() {
reset(tokenRepository, authenticationHolderRepository, clientDetailsService, tokenEnhancer); reset(tokenRepository, authenticationHolderRepository, clientDetailsService, tokenEnhancer);
authentication = Mockito.mock(OAuth2Authentication.class); authentication = Mockito.mock(OAuth2Authentication.class);
OAuth2Request clientAuth = new OAuth2Request(null, clientId, null, true, scope, null, null, null, null); OAuth2Request clientAuth =
when(authentication.getOAuth2Request()).thenReturn(clientAuth); new OAuth2Request(null, clientId, null, true, scope, null, null, null, null);
when(authentication.getOAuth2Request()).thenReturn(clientAuth);
client = Mockito.mock(ClientDetailsEntity.class);
when(client.getClientId()).thenReturn(clientId); client = Mockito.mock(ClientDetailsEntity.class);
when(clientDetailsService.loadClientByClientId(clientId)).thenReturn(client); when(client.getClientId()).thenReturn(clientId);
when(client.isReuseRefreshToken()).thenReturn(true); when(clientDetailsService.loadClientByClientId(clientId)).thenReturn(client);
when(client.isReuseRefreshToken()).thenReturn(true);
// by default in tests, allow refresh tokens
when(client.isAllowRefresh()).thenReturn(true); // by default in tests, allow refresh tokens
when(client.isAllowRefresh()).thenReturn(true);
// by default, clear access tokens on refresh
when(client.isClearAccessTokensOnRefresh()).thenReturn(true); // by default, clear access tokens on refresh
when(client.isClearAccessTokensOnRefresh()).thenReturn(true);
badClient = Mockito.mock(ClientDetailsEntity.class);
when(badClient.getClientId()).thenReturn(badClientId); badClient = Mockito.mock(ClientDetailsEntity.class);
when(clientDetailsService.loadClientByClientId(badClientId)).thenReturn(badClient); when(badClient.getClientId()).thenReturn(badClientId);
when(clientDetailsService.loadClientByClientId(badClientId)).thenReturn(badClient);
refreshToken = Mockito.mock(OAuth2RefreshTokenEntity.class);
when(tokenRepository.getRefreshTokenByValue(refreshTokenValue)).thenReturn(refreshToken); refreshToken = Mockito.mock(OAuth2RefreshTokenEntity.class);
when(refreshToken.getClient()).thenReturn(client); when(tokenRepository.getRefreshTokenByValue(refreshTokenValue)).thenReturn(refreshToken);
when(refreshToken.isExpired()).thenReturn(false); when(refreshToken.getClient()).thenReturn(client);
when(refreshToken.isExpired()).thenReturn(false);
accessToken = Mockito.mock(OAuth2AccessTokenEntity.class);
accessToken = Mockito.mock(OAuth2AccessTokenEntity.class);
tokenRequest = new TokenRequest(null, clientId, null, null);
tokenRequest = new TokenRequest(null, clientId, null, null);
storedAuthentication = authentication;
storedAuthRequest = clientAuth; storedAuthentication = authentication;
storedAuthHolder = mock(AuthenticationHolderEntity.class); storedAuthRequest = clientAuth;
storedScope = newHashSet(scope); storedAuthHolder = mock(AuthenticationHolderEntity.class);
storedScope = newHashSet(scope);
when(refreshToken.getAuthenticationHolder()).thenReturn(storedAuthHolder);
when(storedAuthHolder.getAuthentication()).thenReturn(storedAuthentication); when(refreshToken.getAuthenticationHolder()).thenReturn(storedAuthHolder);
when(storedAuthentication.getOAuth2Request()).thenReturn(storedAuthRequest); when(storedAuthHolder.getAuthentication()).thenReturn(storedAuthentication);
when(storedAuthentication.getOAuth2Request()).thenReturn(storedAuthRequest);
when(authenticationHolderRepository.save(any(AuthenticationHolderEntity.class))).thenReturn(storedAuthHolder);
when(authenticationHolderRepository.save(any(AuthenticationHolderEntity.class)))
when(scopeService.fromStrings(anySet())).thenAnswer(new Answer<Set<SystemScope>>() { .thenReturn(storedAuthHolder);
@Override
public Set<SystemScope> answer(InvocationOnMock invocation) throws Throwable { when(scopeService.fromStrings(anySet())).thenAnswer(new Answer<Set<SystemScope>>() {
Object[] args = invocation.getArguments(); @Override
Set<String> input = (Set<String>) args[0]; public Set<SystemScope> answer(InvocationOnMock invocation) throws Throwable {
Set<SystemScope> output = new HashSet<>(); Object[] args = invocation.getArguments();
for (String scope : input) { Set<String> input = (Set<String>) args[0];
output.add(new SystemScope(scope)); Set<SystemScope> output = new HashSet<>();
} for (String scope : input) {
return output; output.add(new SystemScope(scope));
} }
}); return output;
}
when(scopeService.toStrings(anySet())).thenAnswer(new Answer<Set<String>>() { });
@Override
public Set<String> answer(InvocationOnMock invocation) throws Throwable { when(scopeService.toStrings(anySet())).thenAnswer(new Answer<Set<String>>() {
Object[] args = invocation.getArguments(); @Override
Set<SystemScope> input = (Set<SystemScope>) args[0]; public Set<String> answer(InvocationOnMock invocation) throws Throwable {
Set<String> output = new HashSet<>(); Object[] args = invocation.getArguments();
for (SystemScope scope : input) { Set<SystemScope> input = (Set<SystemScope>) args[0];
output.add(scope.getValue()); Set<String> output = new HashSet<>();
} for (SystemScope scope : input) {
return output; output.add(scope.getValue());
} }
}); return output;
}
// we're not testing restricted or reserved scopes here, just pass through });
when(scopeService.removeReservedScopes(anySet())).then(returnsFirstArg());
when(scopeService.removeRestrictedAndReservedScopes(anySet())).then(returnsFirstArg()); when(scopeService.scopesMatch(anySet(), anySet())).thenAnswer(new Answer<Boolean>() {
@Override
when(tokenEnhancer.enhance(any(OAuth2AccessTokenEntity.class), any(OAuth2Authentication.class))) public Boolean answer(InvocationOnMock invocation) throws Throwable {
.thenAnswer(new Answer<OAuth2AccessTokenEntity>(){ Object[] args = invocation.getArguments();
@Override Set<String> expected = (Set<String>) args[0];
public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { Set<String> actual = (Set<String>) args[1];
Object[] args = invocation.getArguments(); return expected.containsAll(actual);
return (OAuth2AccessTokenEntity) args[0]; }
} });
});
// we're not testing restricted or reserved scopes here, just pass through
when(tokenRepository.saveAccessToken(any(OAuth2AccessTokenEntity.class))) when(scopeService.removeReservedScopes(anySet())).then(returnsFirstArg());
.thenAnswer(new Answer<OAuth2AccessTokenEntity>() { when(scopeService.removeRestrictedAndReservedScopes(anySet())).then(returnsFirstArg());
@Override
public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { when(tokenEnhancer.enhance(any(OAuth2AccessTokenEntity.class), any(OAuth2Authentication.class)))
Object[] args = invocation.getArguments(); .thenAnswer(new Answer<OAuth2AccessTokenEntity>() {
return (OAuth2AccessTokenEntity) args[0]; @Override
} public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
}); return (OAuth2AccessTokenEntity) args[0];
}
when(tokenRepository.saveRefreshToken(any(OAuth2RefreshTokenEntity.class))) });
.thenAnswer(new Answer<OAuth2RefreshTokenEntity>() {
@Override when(tokenRepository.saveAccessToken(any(OAuth2AccessTokenEntity.class)))
public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { .thenAnswer(new Answer<OAuth2AccessTokenEntity>() {
Object[] args = invocation.getArguments(); @Override
return (OAuth2RefreshTokenEntity) args[0]; public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable {
} Object[] args = invocation.getArguments();
}); return (OAuth2AccessTokenEntity) args[0];
}
}
});
/**
* Tests exception handling for null authentication or null authorization. when(tokenRepository.saveRefreshToken(any(OAuth2RefreshTokenEntity.class)))
*/ .thenAnswer(new Answer<OAuth2RefreshTokenEntity>() {
@Test @Override
public void createAccessToken_nullAuth() { public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable {
when(authentication.getOAuth2Request()).thenReturn(null); Object[] args = invocation.getArguments();
return (OAuth2RefreshTokenEntity) args[0];
try { }
service.createAccessToken(null); });
fail("Authentication parameter is null. Excpected a AuthenticationCredentialsNotFoundException.");
} catch (AuthenticationCredentialsNotFoundException e) { }
assertThat(e, is(notNullValue()));
} /**
* Tests exception handling for null authentication or null authorization.
try { */
service.createAccessToken(authentication); @Test
fail("AuthorizationRequest is null. Excpected a AuthenticationCredentialsNotFoundException."); public void createAccessToken_nullAuth() {
} catch (AuthenticationCredentialsNotFoundException e) { when(authentication.getOAuth2Request()).thenReturn(null);
assertThat(e, is(notNullValue()));
} try {
} service.createAccessToken(null);
fail(
/** "Authentication parameter is null. Excpected a AuthenticationCredentialsNotFoundException.");
* Tests exception handling for clients not found. } catch (AuthenticationCredentialsNotFoundException e) {
*/ assertThat(e, is(notNullValue()));
@Test(expected = InvalidClientException.class) }
public void createAccessToken_nullClient() {
when(clientDetailsService.loadClientByClientId(anyString())).thenReturn(null); try {
service.createAccessToken(authentication);
service.createAccessToken(authentication); fail("AuthorizationRequest is null. Excpected a AuthenticationCredentialsNotFoundException.");
} } catch (AuthenticationCredentialsNotFoundException e) {
assertThat(e, is(notNullValue()));
/** }
* Tests the creation of access tokens for clients that are not allowed to have refresh tokens. }
*/
@Test /**
public void createAccessToken_noRefresh() { * Tests exception handling for clients not found.
when(client.isAllowRefresh()).thenReturn(false); */
@Test(expected = InvalidClientException.class)
OAuth2AccessTokenEntity token = service.createAccessToken(authentication); public void createAccessToken_nullClient() {
when(clientDetailsService.loadClientByClientId(anyString())).thenReturn(null);
verify(clientDetailsService).loadClientByClientId(anyString());
verify(authenticationHolderRepository).save(any(AuthenticationHolderEntity.class)); service.createAccessToken(authentication);
verify(tokenEnhancer).enhance(any(OAuth2AccessTokenEntity.class), Matchers.eq(authentication)); }
verify(tokenRepository).saveAccessToken(any(OAuth2AccessTokenEntity.class));
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); /**
* Tests the creation of access tokens for clients that are not allowed to have refresh tokens.
verify(tokenRepository, Mockito.never()).saveRefreshToken(any(OAuth2RefreshTokenEntity.class)); */
@Test
assertThat(token.getRefreshToken(), is(nullValue())); public void createAccessToken_noRefresh() {
} when(client.isAllowRefresh()).thenReturn(false);
/** OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
* Tests the creation of access tokens for clients that are allowed to have refresh tokens.
*/ verify(clientDetailsService).loadClientByClientId(anyString());
@Test verify(authenticationHolderRepository).save(any(AuthenticationHolderEntity.class));
public void createAccessToken_yesRefresh() { verify(tokenEnhancer).enhance(any(OAuth2AccessTokenEntity.class), Matchers.eq(authentication));
OAuth2Request clientAuth = new OAuth2Request(null, clientId, null, true, newHashSet(SystemScopeService.OFFLINE_ACCESS), null, null, null, null); verify(tokenRepository).saveAccessToken(any(OAuth2AccessTokenEntity.class));
when(authentication.getOAuth2Request()).thenReturn(clientAuth); verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
when(client.isAllowRefresh()).thenReturn(true);
verify(tokenRepository, Mockito.never()).saveRefreshToken(any(OAuth2RefreshTokenEntity.class));
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
assertThat(token.getRefreshToken(), is(nullValue()));
// Note: a refactor may be appropriate to only save refresh tokens once to the repository during creation. }
verify(tokenRepository, atLeastOnce()).saveRefreshToken(any(OAuth2RefreshTokenEntity.class));
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); /**
* Tests the creation of access tokens for clients that are allowed to have refresh tokens.
assertThat(token.getRefreshToken(), is(notNullValue())); */
} @Test
public void createAccessToken_yesRefresh() {
/** OAuth2Request clientAuth = new OAuth2Request(null, clientId, null, true,
* Checks to see that the expiration date of new tokens is being set accurately to within some delta for time skew. newHashSet(SystemScopeService.OFFLINE_ACCESS), null, null, null, null);
*/ when(authentication.getOAuth2Request()).thenReturn(clientAuth);
@Test when(client.isAllowRefresh()).thenReturn(true);
public void createAccessToken_expiration() {
Integer accessTokenValiditySeconds = 3600; OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
Integer refreshTokenValiditySeconds = 600;
// Note: a refactor may be appropriate to only save refresh tokens once to the repository during
when(client.getAccessTokenValiditySeconds()).thenReturn(accessTokenValiditySeconds); // creation.
when(client.getRefreshTokenValiditySeconds()).thenReturn(refreshTokenValiditySeconds); verify(tokenRepository, atLeastOnce()).saveRefreshToken(any(OAuth2RefreshTokenEntity.class));
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
assertThat(token.getRefreshToken(), is(notNullValue()));
}
/**
* Checks to see that the expiration date of new tokens is being set accurately to within some
* delta for time skew.
*/
@Test
public void createAccessToken_expiration() {
Integer accessTokenValiditySeconds = 3600;
Integer refreshTokenValiditySeconds = 600;
when(client.getAccessTokenValiditySeconds()).thenReturn(accessTokenValiditySeconds);
when(client.getRefreshTokenValiditySeconds()).thenReturn(refreshTokenValiditySeconds);
long start = System.currentTimeMillis();
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
long end = System.currentTimeMillis();
// Accounting for some delta for time skew on either side.
Date lowerBoundAccessTokens = new Date(start + (accessTokenValiditySeconds * 1000L) - DELTA);
Date upperBoundAccessTokens = new Date(end + (accessTokenValiditySeconds * 1000L) + DELTA);
Date lowerBoundRefreshTokens = new Date(start + (refreshTokenValiditySeconds * 1000L) - DELTA);
Date upperBoundRefreshTokens = new Date(end + (refreshTokenValiditySeconds * 1000L) + DELTA);
long start = System.currentTimeMillis(); verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
long end = System.currentTimeMillis();
// Accounting for some delta for time skew on either side. assertTrue(token.getExpiration().after(lowerBoundAccessTokens)
Date lowerBoundAccessTokens = new Date(start + (accessTokenValiditySeconds * 1000L) - DELTA); && token.getExpiration().before(upperBoundAccessTokens));
Date upperBoundAccessTokens = new Date(end + (accessTokenValiditySeconds * 1000L) + DELTA); assertTrue(token.getRefreshToken().getExpiration().after(lowerBoundRefreshTokens)
Date lowerBoundRefreshTokens = new Date(start + (refreshTokenValiditySeconds * 1000L) - DELTA); && token.getRefreshToken().getExpiration().before(upperBoundRefreshTokens));
Date upperBoundRefreshTokens = new Date(end + (refreshTokenValiditySeconds * 1000L) + DELTA); }
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); @Test
public void createAccessToken_checkClient() {
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens)); verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
assertTrue(token.getRefreshToken().getExpiration().after(lowerBoundRefreshTokens) && token.getRefreshToken().getExpiration().before(upperBoundRefreshTokens));
}
@Test assertThat(token.getClient().getClientId(), equalTo(clientId));
public void createAccessToken_checkClient() { }
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); @Test
public void createAccessToken_checkScopes() {
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
assertThat(token.getClient().getClientId(), equalTo(clientId)); verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
}
@Test assertThat(token.getScope(), equalTo(scope));
public void createAccessToken_checkScopes() { }
OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); @Test
public void createAccessToken_checkAttachedAuthentication() {
AuthenticationHolderEntity authHolder = mock(AuthenticationHolderEntity.class);
when(authHolder.getAuthentication()).thenReturn(authentication);
assertThat(token.getScope(), equalTo(scope)); when(authenticationHolderRepository.save(any(AuthenticationHolderEntity.class)))
} .thenReturn(authHolder);
@Test OAuth2AccessTokenEntity token = service.createAccessToken(authentication);
public void createAccessToken_checkAttachedAuthentication() {
AuthenticationHolderEntity authHolder = mock(AuthenticationHolderEntity.class);
when(authHolder.getAuthentication()).thenReturn(authentication);
when(authenticationHolderRepository.save(any(AuthenticationHolderEntity.class))).thenReturn(authHolder); assertThat(token.getAuthenticationHolder().getAuthentication(), equalTo(authentication));
verify(authenticationHolderRepository).save(any(AuthenticationHolderEntity.class));
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
}
OAuth2AccessTokenEntity token = service.createAccessToken(authentication); @Test(expected = InvalidTokenException.class)
public void refreshAccessToken_noRefreshToken() {
when(tokenRepository.getRefreshTokenByValue(anyString())).thenReturn(null);
assertThat(token.getAuthenticationHolder().getAuthentication(), equalTo(authentication)); service.refreshAccessToken(refreshTokenValue, tokenRequest);
verify(authenticationHolderRepository).save(any(AuthenticationHolderEntity.class)); }
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
}
@Test(expected = InvalidTokenException.class) @Test(expected = InvalidClientException.class)
public void refreshAccessToken_noRefreshToken() { public void refreshAccessToken_notAllowRefresh() {
when(tokenRepository.getRefreshTokenByValue(anyString())).thenReturn(null); when(client.isAllowRefresh()).thenReturn(false);
service.refreshAccessToken(refreshTokenValue, tokenRequest); service.refreshAccessToken(refreshTokenValue, tokenRequest);
} }
@Test(expected = InvalidClientException.class) @Test(expected = InvalidClientException.class)
public void refreshAccessToken_notAllowRefresh() { public void refreshAccessToken_clientMismatch() {
when(client.isAllowRefresh()).thenReturn(false); tokenRequest = new TokenRequest(null, badClientId, null, null);
service.refreshAccessToken(refreshTokenValue, tokenRequest); service.refreshAccessToken(refreshTokenValue, tokenRequest);
} }
@Test(expected = InvalidClientException.class) @Test(expected = InvalidTokenException.class)
public void refreshAccessToken_clientMismatch() { public void refreshAccessToken_expired() {
tokenRequest = new TokenRequest(null, badClientId, null, null); when(refreshToken.isExpired()).thenReturn(true);
service.refreshAccessToken(refreshTokenValue, tokenRequest); service.refreshAccessToken(refreshTokenValue, tokenRequest);
} }
@Test(expected = InvalidTokenException.class) @Test
public void refreshAccessToken_expired() { public void refreshAccessToken_verifyAcessToken() {
when(refreshToken.isExpired()).thenReturn(true); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
service.refreshAccessToken(refreshTokenValue, tokenRequest); verify(tokenRepository).clearAccessTokensForRefreshToken(refreshToken);
}
@Test assertThat(token.getClient(), equalTo(client));
public void refreshAccessToken_verifyAcessToken() { assertThat(token.getRefreshToken(), equalTo(refreshToken));
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); assertThat(token.getAuthenticationHolder(), equalTo(storedAuthHolder));
verify(tokenRepository).clearAccessTokensForRefreshToken(refreshToken); verify(tokenEnhancer).enhance(token, storedAuthentication);
verify(tokenRepository).saveAccessToken(token);
assertThat(token.getClient(), equalTo(client)); }
assertThat(token.getRefreshToken(), equalTo(refreshToken));
assertThat(token.getAuthenticationHolder(), equalTo(storedAuthHolder));
verify(tokenEnhancer).enhance(token, storedAuthentication); @Test
verify(tokenRepository).saveAccessToken(token); public void refreshAccessToken_rotateRefreshToken() {
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); when(client.isReuseRefreshToken()).thenReturn(false);
} OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
@Test verify(tokenRepository).clearAccessTokensForRefreshToken(refreshToken);
public void refreshAccessToken_rotateRefreshToken() {
when(client.isReuseRefreshToken()).thenReturn(false);
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); assertThat(token.getClient(), equalTo(client));
assertThat(token.getRefreshToken(), not(equalTo(refreshToken)));
assertThat(token.getAuthenticationHolder(), equalTo(storedAuthHolder));
verify(tokenRepository).clearAccessTokensForRefreshToken(refreshToken); verify(tokenEnhancer).enhance(token, storedAuthentication);
verify(tokenRepository).saveAccessToken(token);
verify(tokenRepository).removeRefreshToken(refreshToken);
assertThat(token.getClient(), equalTo(client)); }
assertThat(token.getRefreshToken(), not(equalTo(refreshToken)));
assertThat(token.getAuthenticationHolder(), equalTo(storedAuthHolder));
verify(tokenEnhancer).enhance(token, storedAuthentication); @Test
verify(tokenRepository).saveAccessToken(token); public void refreshAccessToken_keepAccessTokens() {
verify(tokenRepository).removeRefreshToken(refreshToken); when(client.isClearAccessTokensOnRefresh()).thenReturn(false);
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet());
} OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
@Test verify(tokenRepository, never()).clearAccessTokensForRefreshToken(refreshToken);
public void refreshAccessToken_keepAccessTokens() {
when(client.isClearAccessTokensOnRefresh()).thenReturn(false);
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); assertThat(token.getClient(), equalTo(client));
assertThat(token.getRefreshToken(), equalTo(refreshToken));
assertThat(token.getAuthenticationHolder(), equalTo(storedAuthHolder));
verify(tokenRepository, never()).clearAccessTokensForRefreshToken(refreshToken); verify(tokenEnhancer).enhance(token, storedAuthentication);
verify(tokenRepository).saveAccessToken(token);
assertThat(token.getClient(), equalTo(client)); }
assertThat(token.getRefreshToken(), equalTo(refreshToken));
assertThat(token.getAuthenticationHolder(), equalTo(storedAuthHolder));
verify(tokenEnhancer).enhance(token, storedAuthentication); @Test
verify(tokenRepository).saveAccessToken(token); public void refreshAccessToken_requestingSameScope() {
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
}
@Test assertThat(token.getScope(), equalTo(storedScope));
public void refreshAccessToken_requestingSameScope() { }
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); @Test
public void refreshAccessToken_requestingLessScope() {
Set<String> lessScope = newHashSet("openid", "profile");
assertThat(token.getScope(), equalTo(storedScope)); tokenRequest.setScope(lessScope);
}
@Test OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
public void refreshAccessToken_requestingLessScope() {
Set<String> lessScope = newHashSet("openid", "profile");
tokenRequest.setScope(lessScope); assertThat(token.getScope(), equalTo(lessScope));
}
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); @Test(expected = InvalidScopeException.class)
public void refreshAccessToken_requestingMoreScope() {
Set<String> moreScope = newHashSet(storedScope);
moreScope.add("address");
moreScope.add("phone");
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); tokenRequest.setScope(moreScope);
assertThat(token.getScope(), equalTo(lessScope)); service.refreshAccessToken(refreshTokenValue, tokenRequest);
} }
@Test(expected = InvalidScopeException.class) /**
public void refreshAccessToken_requestingMoreScope() { * Tests the case where only some of the valid scope values are being requested along with other
Set<String> moreScope = newHashSet(storedScope); * extra unauthorized scope values.
moreScope.add("address"); */
moreScope.add("phone"); @Test(expected = InvalidScopeException.class)
public void refreshAccessToken_requestingMixedScope() {
Set<String> mixedScope = newHashSet("openid", "profile", "address", "phone"); // no email or
// offline_access
tokenRequest.setScope(moreScope); tokenRequest.setScope(mixedScope);
service.refreshAccessToken(refreshTokenValue, tokenRequest); service.refreshAccessToken(refreshTokenValue, tokenRequest);
} }
/** @Test
* Tests the case where only some of the valid scope values are being requested along with public void refreshAccessToken_requestingEmptyScope() {
* other extra unauthorized scope values. Set<String> emptyScope = newHashSet();
*/
@Test(expected = InvalidScopeException.class)
public void refreshAccessToken_requestingMixedScope() {
Set<String> mixedScope = newHashSet("openid", "profile", "address", "phone"); // no email or offline_access
tokenRequest.setScope(mixedScope); tokenRequest.setScope(emptyScope);
service.refreshAccessToken(refreshTokenValue, tokenRequest); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
}
@Test assertThat(token.getScope(), equalTo(storedScope));
public void refreshAccessToken_requestingEmptyScope() { }
Set<String> emptyScope = newHashSet();
tokenRequest.setScope(emptyScope); @Test
public void refreshAccessToken_requestingNullScope() {
tokenRequest.setScope(null);
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); assertThat(token.getScope(), equalTo(storedScope));
assertThat(token.getScope(), equalTo(storedScope)); }
}
@Test /**
public void refreshAccessToken_requestingNullScope() { * Checks to see that the expiration date of refreshed tokens is being set accurately to within
tokenRequest.setScope(null); * some delta for time skew.
*/
@Test
public void refreshAccessToken_expiration() {
Integer accessTokenValiditySeconds = 3600;
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); when(client.getAccessTokenValiditySeconds()).thenReturn(accessTokenValiditySeconds);
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); long start = System.currentTimeMillis();
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);
long end = System.currentTimeMillis();
assertThat(token.getScope(), equalTo(storedScope)); // Accounting for some delta for time skew on either side.
Date lowerBoundAccessTokens = new Date(start + (accessTokenValiditySeconds * 1000L) - DELTA);
Date upperBoundAccessTokens = new Date(end + (accessTokenValiditySeconds * 1000L) + DELTA);
}
/** assertTrue(token.getExpiration().after(lowerBoundAccessTokens)
* Checks to see that the expiration date of refreshed tokens is being set accurately to within some delta for time skew. && token.getExpiration().before(upperBoundAccessTokens));
*/ }
@Test
public void refreshAccessToken_expiration() {
Integer accessTokenValiditySeconds = 3600;
when(client.getAccessTokenValiditySeconds()).thenReturn(accessTokenValiditySeconds); @Test
public void getAllAccessTokensForUser() {
when(tokenRepository.getAccessTokensByUserName(userName)).thenReturn(newHashSet(accessToken));
long start = System.currentTimeMillis(); Set<OAuth2AccessTokenEntity> tokens = service.getAllAccessTokensForUser(userName);
OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); assertEquals(1, tokens.size());
long end = System.currentTimeMillis(); assertTrue(tokens.contains(accessToken));
}
// Accounting for some delta for time skew on either side. @Test
Date lowerBoundAccessTokens = new Date(start + (accessTokenValiditySeconds * 1000L) - DELTA); public void getAllRefreshTokensForUser() {
Date upperBoundAccessTokens = new Date(end + (accessTokenValiditySeconds * 1000L) + DELTA); when(tokenRepository.getRefreshTokensByUserName(userName)).thenReturn(newHashSet(refreshToken));
verify(scopeService, atLeastOnce()).removeReservedScopes(anySet()); Set<OAuth2RefreshTokenEntity> tokens = service.getAllRefreshTokensForUser(userName);
assertEquals(1, tokens.size());
assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens)); assertTrue(tokens.contains(refreshToken));
} }
@Test
public void getAllAccessTokensForUser(){
when(tokenRepository.getAccessTokensByUserName(userName)).thenReturn(newHashSet(accessToken));
Set<OAuth2AccessTokenEntity> tokens = service.getAllAccessTokensForUser(userName);
assertEquals(1, tokens.size());
assertTrue(tokens.contains(accessToken));
}
@Test
public void getAllRefreshTokensForUser(){
when(tokenRepository.getRefreshTokensByUserName(userName)).thenReturn(newHashSet(refreshToken));
Set<OAuth2RefreshTokenEntity> tokens = service.getAllRefreshTokensForUser(userName);
assertEquals(1, tokens.size());
assertTrue(tokens.contains(refreshToken));
}
} }

View File

@ -20,7 +20,7 @@
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>org.mitre</groupId> <groupId>org.mitre</groupId>
<artifactId>openid-connect-parent</artifactId> <artifactId>openid-connect-parent</artifactId>
<version>1.3.5.cnaf.v20191003</version> <version>1.3.5.cnaf.20200115</version>
<name>MITREid Connect</name> <name>MITREid Connect</name>
<packaging>pom</packaging> <packaging>pom</packaging>
<parent> <parent>