Browse Source

Use query by user sub to get all tokens for user

pull/1378/head
Sauli Ketola 7 years ago
parent
commit
3f277047e3
  1. 30
      openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java
  2. 28
      openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java

30
openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java

@ -66,7 +66,6 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
@ -102,35 +101,14 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
@Autowired
private ApprovedSiteService approvedSiteService;
@Override
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) {
Set<OAuth2AccessTokenEntity> all = tokenRepository.getAllAccessTokens();
Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();
for (OAuth2AccessTokenEntity token : all) {
if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
results.add(token);
}
}
return results;
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String sub) {
return tokenRepository.getAccessTokensBySub(sub);
}
@Override
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String id) {
Set<OAuth2RefreshTokenEntity> all = tokenRepository.getAllRefreshTokens();
Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();
for (OAuth2RefreshTokenEntity token : all) {
if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
results.add(token);
}
}
return results;
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String sub) {
return tokenRepository.getRefreshTokensBySub(sub);
}
@Override

28
openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java

@ -21,6 +21,7 @@ import java.util.Date;
import java.util.HashSet;
import java.util.Set;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -52,6 +53,7 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import com.google.common.collect.Sets;
import static com.google.common.collect.Sets.newHashSet;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
@ -60,7 +62,8 @@ import static org.hamcrest.CoreMatchers.nullValue;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.when;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@ -83,7 +86,9 @@ public class TestDefaultOAuth2ProviderTokenService {
private String badClientId = "bad_client";
private Set<String> scope = Sets.newHashSet("openid", "profile", "email", "offline_access");
private OAuth2RefreshTokenEntity refreshToken;
private OAuth2AccessTokenEntity accessToken;
private String refreshTokenValue = "refresh_token_value";
private String userSub = "6a50ac11786d402a9591d3e592ac770f";
private TokenRequest tokenRequest;
// for use when refreshing access tokens
@ -142,6 +147,8 @@ public class TestDefaultOAuth2ProviderTokenService {
Mockito.when(tokenRepository.getRefreshTokenByValue(refreshTokenValue)).thenReturn(refreshToken);
Mockito.when(refreshToken.getClient()).thenReturn(client);
Mockito.when(refreshToken.isExpired()).thenReturn(false);
accessToken = Mockito.mock(OAuth2AccessTokenEntity.class);
tokenRequest = new TokenRequest(null, clientId, null, null);
@ -542,5 +549,22 @@ public class TestDefaultOAuth2ProviderTokenService {
assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens));
}
@Test
public void getAllAccessTokensForUser(){
Mockito.when(tokenRepository.getAccessTokensBySub(userSub)).thenReturn(newHashSet(accessToken));
Set<OAuth2AccessTokenEntity> tokens = service.getAllAccessTokensForUser(userSub);
assertEquals(1, tokens.size());
assertTrue(tokens.contains(accessToken));
}
@Test
public void getAllRefreshTokensForUser(){
Mockito.when(tokenRepository.getRefreshTokensBySub(userSub)).thenReturn(newHashSet(refreshToken));
Set<OAuth2RefreshTokenEntity> tokens = service.getAllRefreshTokensForUser(userSub);
assertEquals(1, tokens.size());
assertTrue(tokens.contains(refreshToken));
}
}

Loading…
Cancel
Save