diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java index 9aeff131e..1c7f315bc 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java @@ -195,8 +195,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi token.setRefreshToken(refreshToken); } - tokenEnhancer.enhance(token, authentication); - + token = (OAuth2AccessTokenEntity) tokenEnhancer.enhance(token, authentication); tokenRepository.saveAccessToken(token); //Add approved site reference, if any diff --git a/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java index 7f1ffd01c..0d0c35beb 100644 --- a/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/test/java/org/mitre/oauth2/service/impl/TestDefaultOAuth2ProviderTokenService.java @@ -43,6 +43,8 @@ import org.mockito.InjectMocks; import org.mockito.Matchers; import org.mockito.Mock; import org.mockito.Mockito; +import org.mockito.stubbing.Answer; +import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.oauth2.common.exceptions.InvalidClientException; @@ -137,6 +139,17 @@ public class TestDefaultOAuth2ProviderTokenService { Mockito.when(authenticationHolderRepository.save(Matchers.any(AuthenticationHolderEntity.class))).thenReturn(storedAuthHolder); Mockito.when(scopeService.removeRestrictedScopes(Matchers.anySet())).then(AdditionalAnswers.returnsFirstArg()); + + Mockito.when(tokenEnhancer.enhance(Matchers.any(OAuth2AccessTokenEntity.class), Matchers.any(OAuth2Authentication.class))) + .thenAnswer(new Answer(){ + @Override + public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + return (OAuth2AccessTokenEntity) args[0]; + } + }); + + } /** @@ -185,8 +198,8 @@ public class TestDefaultOAuth2ProviderTokenService { Mockito.verify(clientDetailsService).loadClientByClientId(Matchers.anyString()); Mockito.verify(authenticationHolderRepository).save(Matchers.any(AuthenticationHolderEntity.class)); - Mockito.verify(tokenEnhancer).enhance(token, authentication); - Mockito.verify(tokenRepository).saveAccessToken(token); + Mockito.verify(tokenEnhancer).enhance(Matchers.any(OAuth2AccessTokenEntity.class), Mockito.eq(authentication)); + Mockito.verify(tokenRepository).saveAccessToken(Matchers.any(OAuth2AccessTokenEntity.class)); Mockito.verify(tokenRepository, Mockito.never()).saveRefreshToken(Matchers.any(OAuth2RefreshTokenEntity.class)); assertThat(token.getRefreshToken(), is(nullValue()));