diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java index 262b9ee32..151e48694 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java @@ -66,7 +66,6 @@ import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import com.google.common.base.Strings; -import com.google.common.collect.Sets; import com.nimbusds.jose.util.Base64URL; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.PlainJWT; @@ -102,35 +101,14 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi @Autowired private ApprovedSiteService approvedSiteService; - @Override - public Set getAllAccessTokensForUser(String id) { - - Set all = tokenRepository.getAllAccessTokens(); - Set results = Sets.newLinkedHashSet(); - - for (OAuth2AccessTokenEntity token : all) { - if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) { - results.add(token); - } - } - - return results; + public Set getAllAccessTokensForUser(String sub) { + return tokenRepository.getAccessTokensBySub(sub); } - @Override - public Set getAllRefreshTokensForUser(String id) { - Set all = tokenRepository.getAllRefreshTokens(); - Set results = Sets.newLinkedHashSet(); - - for (OAuth2RefreshTokenEntity token : all) { - if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) { - results.add(token); - } - } - - return results; + public Set getAllRefreshTokensForUser(String sub) { + return tokenRepository.getRefreshTokensBySub(sub); } @Override diff --git a/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java index 412d3ac96..9099c3355 100644 --- a/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java @@ -21,6 +21,7 @@ import java.util.Date; import java.util.HashSet; import java.util.Set; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,6 +53,7 @@ import org.springframework.security.oauth2.provider.token.TokenEnhancer; import com.google.common.collect.Sets; +import static com.google.common.collect.Sets.newHashSet; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; @@ -60,7 +62,8 @@ import static org.hamcrest.CoreMatchers.nullValue; import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; - +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -83,7 +86,9 @@ public class TestDefaultOAuth2ProviderTokenService { private String badClientId = "bad_client"; private Set scope = Sets.newHashSet("openid", "profile", "email", "offline_access"); private OAuth2RefreshTokenEntity refreshToken; + private OAuth2AccessTokenEntity accessToken; private String refreshTokenValue = "refresh_token_value"; + private String userSub = "6a50ac11786d402a9591d3e592ac770f"; private TokenRequest tokenRequest; // for use when refreshing access tokens @@ -142,6 +147,8 @@ public class TestDefaultOAuth2ProviderTokenService { Mockito.when(tokenRepository.getRefreshTokenByValue(refreshTokenValue)).thenReturn(refreshToken); Mockito.when(refreshToken.getClient()).thenReturn(client); Mockito.when(refreshToken.isExpired()).thenReturn(false); + + accessToken = Mockito.mock(OAuth2AccessTokenEntity.class); tokenRequest = new TokenRequest(null, clientId, null, null); @@ -542,5 +549,22 @@ public class TestDefaultOAuth2ProviderTokenService { assertTrue(token.getExpiration().after(lowerBoundAccessTokens) && token.getExpiration().before(upperBoundAccessTokens)); } - + + @Test + public void getAllAccessTokensForUser(){ + Mockito.when(tokenRepository.getAccessTokensBySub(userSub)).thenReturn(newHashSet(accessToken)); + + Set tokens = service.getAllAccessTokensForUser(userSub); + assertEquals(1, tokens.size()); + assertTrue(tokens.contains(accessToken)); + } + + @Test + public void getAllRefreshTokensForUser(){ + Mockito.when(tokenRepository.getRefreshTokensBySub(userSub)).thenReturn(newHashSet(refreshToken)); + + Set tokens = service.getAllRefreshTokensForUser(userSub); + assertEquals(1, tokens.size()); + assertTrue(tokens.contains(refreshToken)); + } }