Use query by user sub to get all tokens for user
parent
417a6b7c74
commit
3f277047e3
|
@ -66,7 +66,6 @@ import org.springframework.stereotype.Service;
|
|||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.nimbusds.jose.util.Base64URL;
|
||||
import com.nimbusds.jwt.JWTClaimsSet;
|
||||
import com.nimbusds.jwt.PlainJWT;
|
||||
|
@ -102,35 +101,14 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
|
|||
@Autowired
|
||||
private ApprovedSiteService approvedSiteService;
|
||||
|
||||
|
||||
@Override
|
||||
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) {
|
||||
|
||||
Set<OAuth2AccessTokenEntity> all = tokenRepository.getAllAccessTokens();
|
||||
Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();
|
||||
|
||||
for (OAuth2AccessTokenEntity token : all) {
|
||||
if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
|
||||
results.add(token);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String sub) {
|
||||
return tokenRepository.getAccessTokensBySub(sub);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String id) {
|
||||
Set<OAuth2RefreshTokenEntity> all = tokenRepository.getAllRefreshTokens();
|
||||
Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();
|
||||
|
||||
for (OAuth2RefreshTokenEntity token : all) {
|
||||
if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
|
||||
results.add(token);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String sub) {
|
||||
return tokenRepository.getRefreshTokensBySub(sub);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.util.Date;
|
|||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -52,6 +53,7 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer;
|
|||
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import static com.google.common.collect.Sets.newHashSet;
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.CoreMatchers.is;
|
||||
import static org.hamcrest.CoreMatchers.not;
|
||||
|
@ -60,7 +62,8 @@ import static org.hamcrest.CoreMatchers.nullValue;
|
|||
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertThat;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assert.fail;
|
||||
|
@ -83,7 +86,9 @@ public class TestDefaultOAuth2ProviderTokenService {
|
|||
private String badClientId = "bad_client";
|
||||
private Set<String> scope = Sets.newHashSet("openid", "profile", "email", "offline_access");
|
||||
private OAuth2RefreshTokenEntity refreshToken;
|
||||
private OAuth2AccessTokenEntity accessToken;
|
||||
private String refreshTokenValue = "refresh_token_value";
|
||||
private String userSub = "6a50ac11786d402a9591d3e592ac770f";
|
||||
private TokenRequest tokenRequest;
|
||||
|
||||
// for use when refreshing access tokens
|
||||
|
@ -142,6 +147,8 @@ public class TestDefaultOAuth2ProviderTokenService {
|
|||
Mockito.when(tokenRepository.getRefreshTokenByValue(refreshTokenValue)).thenReturn(refreshToken);
|
||||
Mockito.when(refreshToken.getClient()).thenReturn(client);
|
||||
Mockito.when(refreshToken.isExpired()).thenReturn(false);
|
||||
|
||||
accessToken = Mockito.mock(OAuth2AccessTokenEntity.class);
|
||||
|
||||
tokenRequest = new TokenRequest(null, clientId, null, null);
|
||||
|
||||
|
@ -542,5 +549,22 @@ public class TestDefaultOAuth2ProviderTokenService {
|
|||
|
||||
assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void getAllAccessTokensForUser(){
|
||||
Mockito.when(tokenRepository.getAccessTokensBySub(userSub)).thenReturn(newHashSet(accessToken));
|
||||
|
||||
Set<OAuth2AccessTokenEntity> tokens = service.getAllAccessTokensForUser(userSub);
|
||||
assertEquals(1, tokens.size());
|
||||
assertTrue(tokens.contains(accessToken));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getAllRefreshTokensForUser(){
|
||||
Mockito.when(tokenRepository.getRefreshTokensBySub(userSub)).thenReturn(newHashSet(refreshToken));
|
||||
|
||||
Set<OAuth2RefreshTokenEntity> tokens = service.getAllRefreshTokensForUser(userSub);
|
||||
assertEquals(1, tokens.size());
|
||||
assertTrue(tokens.contains(refreshToken));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue