diff --git a/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java index c020daf3c..cef7ead69 100644 --- a/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java @@ -100,6 +100,8 @@ public class TestDefaultOAuth2ProviderTokenService { public void prepare() { Mockito.reset(tokenRepository, authenticationHolderRepository, clientDetailsService, tokenEnhancer); + + authentication = Mockito.mock(OAuth2Authentication.class); OAuth2Request clientAuth = new OAuth2Request(null, clientId, null, true, scope, null, null, null); Mockito.when(authentication.getOAuth2Request()).thenReturn(clientAuth); @@ -116,16 +118,18 @@ public class TestDefaultOAuth2ProviderTokenService { Mockito.when(refreshToken.getClient()).thenReturn(client); Mockito.when(refreshToken.isExpired()).thenReturn(false); - tokenRequest = Mockito.mock(TokenRequest.class); + tokenRequest = new TokenRequest(null, clientId, null, null); - storedAuthRequest = Mockito.mock(OAuth2Request.class); - storedAuthentication = Mockito.mock(OAuth2Authentication.class); + storedAuthentication = authentication; + storedAuthRequest = clientAuth; storedAuthHolder = Mockito.mock(AuthenticationHolderEntity.class); storedScope = Sets.newHashSet(scope); Mockito.when(refreshToken.getAuthenticationHolder()).thenReturn(storedAuthHolder); Mockito.when(storedAuthHolder.getAuthentication()).thenReturn(storedAuthentication); Mockito.when(storedAuthentication.getOAuth2Request()).thenReturn(storedAuthRequest); + + Mockito.when(authenticationHolderRepository.save(Matchers.any(AuthenticationHolderEntity.class))).thenReturn(storedAuthHolder); } /** @@ -187,8 +191,8 @@ public class TestDefaultOAuth2ProviderTokenService { @Test public void createAccessToken_yesRefresh() { - OAuth2Request clientAuth = authentication.getOAuth2Request(); - Mockito.when(clientAuth.getScope()).thenReturn(Sets.newHashSet("offline_access")); + OAuth2Request clientAuth = new OAuth2Request(null, clientId, null, true, Sets.newHashSet("offline_access"), null, null, null); + Mockito.when(authentication.getOAuth2Request()).thenReturn(clientAuth); Mockito.when(client.isAllowRefresh()).thenReturn(true); OAuth2AccessTokenEntity token = service.createAccessToken(authentication); @@ -307,7 +311,7 @@ public class TestDefaultOAuth2ProviderTokenService { Set lessScope = Sets.newHashSet("openid", "profile"); - Mockito.when(tokenRequest.getScope()).thenReturn(lessScope); + tokenRequest.setScope(lessScope); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); @@ -321,7 +325,7 @@ public class TestDefaultOAuth2ProviderTokenService { moreScope.add("address"); moreScope.add("phone"); - Mockito.when(tokenRequest.getScope()).thenReturn(moreScope); + tokenRequest.setScope(moreScope); service.refreshAccessToken(refreshTokenValue, tokenRequest); } @@ -335,7 +339,7 @@ public class TestDefaultOAuth2ProviderTokenService { Set mixedScope = Sets.newHashSet("openid", "profile", "address", "phone"); // no email or offline_access - Mockito.when(tokenRequest.getScope()).thenReturn(mixedScope); + tokenRequest.setScope(mixedScope); service.refreshAccessToken(refreshTokenValue, tokenRequest); } @@ -345,7 +349,7 @@ public class TestDefaultOAuth2ProviderTokenService { Set emptyScope = Sets.newHashSet(); - Mockito.when(tokenRequest.getScope()).thenReturn(emptyScope); + tokenRequest.setScope(emptyScope); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest); @@ -355,7 +359,7 @@ public class TestDefaultOAuth2ProviderTokenService { @Test public void refreshAccessToken_requestingNullScope() { - Mockito.when(tokenRequest.getScope()).thenReturn(null); + tokenRequest.setScope(null); OAuth2AccessTokenEntity token = service.refreshAccessToken(refreshTokenValue, tokenRequest);