Use query by user sub to get all tokens for user

pull/1378/head
Sauli Ketola 7 years ago
parent 417a6b7c74
commit 3f277047e3

@ -66,7 +66,6 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.Sets;
import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT; import com.nimbusds.jwt.PlainJWT;
@ -102,35 +101,14 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
@Autowired @Autowired
private ApprovedSiteService approvedSiteService; private ApprovedSiteService approvedSiteService;
@Override @Override
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) { public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String sub) {
return tokenRepository.getAccessTokensBySub(sub);
Set<OAuth2AccessTokenEntity> all = tokenRepository.getAllAccessTokens();
Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();
for (OAuth2AccessTokenEntity token : all) {
if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
results.add(token);
}
}
return results;
} }
@Override @Override
public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String id) { public Set<OAuth2RefreshTokenEntity> getAllRefreshTokensForUser(String sub) {
Set<OAuth2RefreshTokenEntity> all = tokenRepository.getAllRefreshTokens(); return tokenRepository.getRefreshTokensBySub(sub);
Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();
for (OAuth2RefreshTokenEntity token : all) {
if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
results.add(token);
}
}
return results;
} }
@Override @Override

@ -21,6 +21,7 @@ import java.util.Date;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -52,6 +53,7 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import static com.google.common.collect.Sets.newHashSet;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.not;
@ -60,7 +62,8 @@ import static org.hamcrest.CoreMatchers.nullValue;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
@ -83,7 +86,9 @@ public class TestDefaultOAuth2ProviderTokenService {
private String badClientId = "bad_client"; private String badClientId = "bad_client";
private Set<String> scope = Sets.newHashSet("openid", "profile", "email", "offline_access"); private Set<String> scope = Sets.newHashSet("openid", "profile", "email", "offline_access");
private OAuth2RefreshTokenEntity refreshToken; private OAuth2RefreshTokenEntity refreshToken;
private OAuth2AccessTokenEntity accessToken;
private String refreshTokenValue = "refresh_token_value"; private String refreshTokenValue = "refresh_token_value";
private String userSub = "6a50ac11786d402a9591d3e592ac770f";
private TokenRequest tokenRequest; private TokenRequest tokenRequest;
// for use when refreshing access tokens // for use when refreshing access tokens
@ -143,6 +148,8 @@ public class TestDefaultOAuth2ProviderTokenService {
Mockito.when(refreshToken.getClient()).thenReturn(client); Mockito.when(refreshToken.getClient()).thenReturn(client);
Mockito.when(refreshToken.isExpired()).thenReturn(false); Mockito.when(refreshToken.isExpired()).thenReturn(false);
accessToken = Mockito.mock(OAuth2AccessTokenEntity.class);
tokenRequest = new TokenRequest(null, clientId, null, null); tokenRequest = new TokenRequest(null, clientId, null, null);
storedAuthentication = authentication; storedAuthentication = authentication;
@ -543,4 +550,21 @@ public class TestDefaultOAuth2ProviderTokenService {
assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens)); assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens));
} }
@Test
public void getAllAccessTokensForUser(){
Mockito.when(tokenRepository.getAccessTokensBySub(userSub)).thenReturn(newHashSet(accessToken));
Set<OAuth2AccessTokenEntity> tokens = service.getAllAccessTokensForUser(userSub);
assertEquals(1, tokens.size());
assertTrue(tokens.contains(accessToken));
}
@Test
public void getAllRefreshTokensForUser(){
Mockito.when(tokenRepository.getRefreshTokensBySub(userSub)).thenReturn(newHashSet(refreshToken));
Set<OAuth2RefreshTokenEntity> tokens = service.getAllRefreshTokensForUser(userSub);
assertEquals(1, tokens.size());
assertTrue(tokens.contains(refreshToken));
}
} }

Loading…
Cancel
Save