From af60bf7e7f5fc98848b395809c08e28bb855d71d Mon Sep 17 00:00:00 2001 From: arielak Date: Fri, 10 Oct 2014 18:13:24 -0400 Subject: [PATCH] Added tests for ensuring the references between a refresh token and its authentication holder are preserved over import. Minor cleanup of other tests. --- .../impl/TestMITREidDataService_1_0.java | 194 ++++++++++++++---- 1 file changed, 157 insertions(+), 37 deletions(-) diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java index e2f92fd96..af233adfd 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java @@ -23,7 +23,6 @@ import java.util.Set; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import org.junit.Before; @@ -57,6 +56,7 @@ import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.provider.AuthorizationRequest; +import org.springframework.security.oauth2.provider.DefaultAuthorizationRequest; import org.springframework.security.oauth2.provider.OAuth2Authentication; @RunWith(MockitoJUnitRunner.class) @@ -224,7 +224,6 @@ public class TestMITREidDataService_1_0 { } } - @Test public void testImportRefreshTokens() throws IOException, ParseException { String expiration1 = "2014-09-10T22:49:44.090+0000"; @@ -282,7 +281,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 323L; @Override public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; @@ -310,7 +309,7 @@ public class TestMITREidDataService_1_0 { } }); when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 1L; + Long id = 142L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); @@ -330,12 +329,10 @@ public class TestMITREidDataService_1_0 { assertThat(savedRefreshTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); assertThat(savedRefreshTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); - assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getId(), equalTo(token1.getAuthenticationHolder().getId())); assertThat(savedRefreshTokens.get(0).getValue(), equalTo(token1.getValue())); assertThat(savedRefreshTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); assertThat(savedRefreshTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); - assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getId(), equalTo(token2.getAuthenticationHolder().getId())); assertThat(savedRefreshTokens.get(1).getValue(), equalTo(token2.getValue())); } @@ -553,7 +550,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(tokenRepository.saveAccessToken(isA(OAuth2AccessTokenEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 335L; @Override public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2AccessTokenEntity _token = (OAuth2AccessTokenEntity) invocation.getArguments()[0]; @@ -581,7 +578,7 @@ public class TestMITREidDataService_1_0 { } }); when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 1L; + Long id = 135L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); @@ -601,12 +598,10 @@ public class TestMITREidDataService_1_0 { assertThat(savedAccessTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); assertThat(savedAccessTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); - assertThat(savedAccessTokens.get(0).getAuthenticationHolder().getId(), equalTo(token1.getAuthenticationHolder().getId())); assertThat(savedAccessTokens.get(0).getValue(), equalTo(token1.getValue())); assertThat(savedAccessTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); assertThat(savedAccessTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); - assertThat(savedAccessTokens.get(1).getAuthenticationHolder().getId(), equalTo(token2.getAuthenticationHolder().getId())); assertThat(savedAccessTokens.get(1).getValue(), equalTo(token2.getValue())); } @@ -1061,7 +1056,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(wlSiteRepository.save(isA(WhitelistedSite.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 332L; @Override public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { WhitelistedSite _site = (WhitelistedSite) invocation.getArguments()[0]; @@ -1267,7 +1262,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(approvedSiteRepository.save(isA(ApprovedSite.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 343L; @Override public ApprovedSite answer(InvocationOnMock invocation) throws Throwable { ApprovedSite _site = (ApprovedSite) invocation.getArguments()[0]; @@ -1286,7 +1281,7 @@ public class TestMITREidDataService_1_0 { } }); when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 2L; + Long id = 235L; @Override public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { WhitelistedSite _site = mock(WhitelistedSite.class); @@ -1320,19 +1315,17 @@ public class TestMITREidDataService_1_0 { @Test public void testExportAuthenticationHolders() throws IOException { - AuthorizationRequest mockRequest1 = mock(AuthorizationRequest.class); - when(mockRequest1.getAuthorizationParameters()).thenReturn(new HashMap()); - Authentication mockAuth1 = null; - OAuth2Authentication auth1 = new OAuth2Authentication(mockRequest1, mockAuth1); + AuthorizationRequest req1 = new DefaultAuthorizationRequest(new HashMap(), new HashMap(), "client1", new ArrayList()); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); holder1.setId(1L); holder1.setAuthentication(auth1); - - AuthorizationRequest mockRequest2 = mock(AuthorizationRequest.class); - when(mockRequest2.getAuthorizationParameters()).thenReturn(new HashMap()); - Authentication mockAuth2 = null; - OAuth2Authentication auth2 = new OAuth2Authentication(mockRequest2, mockAuth2); + + AuthorizationRequest req2 = new DefaultAuthorizationRequest(new HashMap(), new HashMap(), "client2", new ArrayList()); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); holder2.setId(2L); @@ -1417,19 +1410,17 @@ public class TestMITREidDataService_1_0 { @Test public void testImportAuthenticationHolders() throws IOException { - AuthorizationRequest mockRequest1 = mock(AuthorizationRequest.class); - when(mockRequest1.getAuthorizationParameters()).thenReturn(new HashMap()); - Authentication mockAuth1 = null; - OAuth2Authentication auth1 = new OAuth2Authentication(mockRequest1, mockAuth1); + AuthorizationRequest req1 = new DefaultAuthorizationRequest(new HashMap(), new HashMap(), "client1", new ArrayList()); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); holder1.setId(1L); holder1.setAuthentication(auth1); - - AuthorizationRequest mockRequest2 = mock(AuthorizationRequest.class); - when(mockRequest2.getAuthorizationParameters()).thenReturn(new HashMap()); - Authentication mockAuth2 = null; - OAuth2Authentication auth2 = new OAuth2Authentication(mockRequest2, mockAuth2); + + AuthorizationRequest req2 = new DefaultAuthorizationRequest(new HashMap(), new HashMap(), "client2", new ArrayList()); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); holder2.setId(2L); @@ -1445,8 +1436,8 @@ public class TestMITREidDataService_1_0 { "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + - "{\"id\":1,\"authentication\":{\"clientAuthorization\":{},\"userAuthentication\":\"\"}}," + - "{\"id\":2,\"authentication\":{\"clientAuthorization\":{},\"userAuthentication\":\"\"}}" + + "{\"id\":1,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client1\"},\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client2\"},\"userAuthentication\":null}}" + " ]" + "}"; @@ -1457,7 +1448,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 323L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _site = (AuthenticationHolderEntity) invocation.getArguments()[0]; @@ -1475,8 +1466,8 @@ public class TestMITREidDataService_1_0 { List savedAuthHolders = capturedAuthHolders.getAllValues(); assertThat(savedAuthHolders.size(), is(2)); - assertThat(savedAuthHolders.get(0).getAuthentication().getName(), equalTo(holder1.getAuthentication().getName())); - assertThat(savedAuthHolders.get(1).getAuthentication().getName(), equalTo(holder2.getAuthentication().getName())); + assertThat(savedAuthHolders.get(0).getAuthentication().getAuthorizationRequest().getClientId(), equalTo(holder1.getAuthentication().getAuthorizationRequest().getClientId())); + assertThat(savedAuthHolders.get(1).getAuthentication().getAuthorizationRequest().getClientId(), equalTo(holder2.getAuthentication().getAuthorizationRequest().getClientId())); } @Test @@ -1661,7 +1652,136 @@ public class TestMITREidDataService_1_0 { assertThat(savedScopes.get(2).isAllowDynReg(), equalTo(scope3.isAllowDynReg())); } - + + @Test + public void testFixRefreshTokenReferencesOnImport() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = DateUtil.utcToDate(expiration1); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + AuthorizationRequest req1 = new DefaultAuthorizationRequest(new HashMap(), new HashMap(), "client1", new ArrayList()); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); + + AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); + holder1.setId(1L); + holder1.setAuthentication(auth1); + + OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setValue("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ."); + token1.setAuthenticationHolder(holder1); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = DateUtil.utcToDate(expiration2); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + AuthorizationRequest req2 = new DefaultAuthorizationRequest(new HashMap(), new HashMap(), "client2", new ArrayList()); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); + + AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); + holder2.setId(2L); + holder2.setAuthentication(auth2); + + OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setValue("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ."); + token2.setAuthenticationHolder(holder2); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + + + "{\"id\":1,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client1\"},\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client2\"},\"userAuthentication\":null}}" + + + " ]," + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [" + + + "{\"id\":1,\"clientId\":\"mocked_client_1\",\"expiration\":\"2014-09-10T22:49:44.090+0000\"," + + "\"authenticationHolderId\":1,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.\"}," + + "{\"id\":2,\"clientId\":\"mocked_client_2\",\"expiration\":\"2015-01-07T18:31:50.079+0000\"," + + "\"authenticationHolderId\":2,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.\"}" + + + " ]" + + "}"; + System.err.println(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + final Map fakeRefreshTokenTable = new HashMap(); + final Map fakeAuthHolderTable = new HashMap(); + when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { + Long id = 343L; + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; + if(_token.getId() == null) { + _token.setId(id++); + } + fakeRefreshTokenTable.put(_token.getId(), _token); + return _token; + } + }); + when(tokenRepository.getRefreshTokenById(anyLong())).thenAnswer(new Answer() { + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeRefreshTokenTable.get(_id); + } + }); + when(clientRepository.getClientByClientId(anyString())).thenAnswer(new Answer() { + @Override + public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { + String _clientId = (String) invocation.getArguments()[0]; + ClientDetailsEntity _client = mock(ClientDetailsEntity.class); + when(_client.getClientId()).thenReturn(_clientId); + return _client; + } + }); + when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { + Long id = 356L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _holder = (AuthenticationHolderEntity) invocation.getArguments()[0]; + if(_holder.getId() == null) { + _holder.setId(id++); + } + fakeAuthHolderTable.put(_holder.getId(), _holder); + return _holder; + } + }); + when(authHolderRepository.getById(anyLong())).thenAnswer(new Answer() { + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeAuthHolderTable.get(_id); + } + }); + dataService.importData(reader); + + List savedRefreshTokens = new ArrayList(fakeRefreshTokenTable.values()); //capturedRefreshTokens.getAllValues(); + Collections.sort(savedRefreshTokens, new refreshTokenIdComparator()); + + assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getAuthentication().getAuthorizationRequest().getClientId(), + equalTo(token1.getAuthenticationHolder().getAuthentication().getAuthorizationRequest().getClientId())); + assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getAuthentication().getAuthorizationRequest().getClientId(), + equalTo(token2.getAuthenticationHolder().getAuthentication().getAuthorizationRequest().getClientId())); + } + private Set jsonArrayToStringSet(JsonArray a) { Set s = new HashSet(); for (JsonElement jsonElement : a) {