diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java index b42a8aaf3..82687aff2 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_0.java @@ -1,8 +1,6 @@ package org.mitre.openid.connect.service.impl; import com.google.common.collect.ImmutableSet; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; import com.google.gson.stream.JsonReader; import java.io.IOException; import java.io.StringReader; @@ -15,7 +13,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.assertThat; import org.junit.Before; @@ -41,6 +38,10 @@ import org.mitre.openid.connect.util.DateUtil; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.InjectMocks; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.isA; +import static org.mockito.Matchers.isNull; import org.mockito.Mock; import org.mockito.Mockito; import static org.mockito.Mockito.*; @@ -48,7 +49,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.runners.MockitoJUnitRunner; import org.mockito.stubbing.Answer; import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.provider.AuthorizationRequest; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Request; @@ -102,7 +103,6 @@ public class TestMITREidDataService_1_0 { } } - @Test public void testImportRefreshTokens() throws IOException, ParseException { String expiration1 = "2014-09-10T22:49:44.090+0000"; @@ -160,7 +160,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 343L; @Override public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; @@ -188,7 +188,7 @@ public class TestMITREidDataService_1_0 { } }); when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 1L; + Long id = 678L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); @@ -208,12 +208,10 @@ public class TestMITREidDataService_1_0 { assertThat(savedRefreshTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); assertThat(savedRefreshTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); - assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getId(), equalTo(token1.getAuthenticationHolder().getId())); assertThat(savedRefreshTokens.get(0).getValue(), equalTo(token1.getValue())); assertThat(savedRefreshTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); assertThat(savedRefreshTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); - assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getId(), equalTo(token2.getAuthenticationHolder().getId())); assertThat(savedRefreshTokens.get(1).getValue(), equalTo(token2.getValue())); } @@ -294,7 +292,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(tokenRepository.saveAccessToken(isA(OAuth2AccessTokenEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 343L; @Override public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2AccessTokenEntity _token = (OAuth2AccessTokenEntity) invocation.getArguments()[0]; @@ -322,7 +320,7 @@ public class TestMITREidDataService_1_0 { } }); when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 1L; + Long id = 234L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); @@ -342,15 +340,15 @@ public class TestMITREidDataService_1_0 { assertThat(savedAccessTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); assertThat(savedAccessTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); - assertThat(savedAccessTokens.get(0).getAuthenticationHolder().getId(), equalTo(token1.getAuthenticationHolder().getId())); assertThat(savedAccessTokens.get(0).getValue(), equalTo(token1.getValue())); assertThat(savedAccessTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); assertThat(savedAccessTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); - assertThat(savedAccessTokens.get(1).getAuthenticationHolder().getId(), equalTo(token2.getAuthenticationHolder().getId())); assertThat(savedAccessTokens.get(1).getValue(), equalTo(token2.getValue())); } + + //several new client fields added in 1.1, perhaps additional tests for these should be added @Test public void testImportClients() throws IOException { ClientDetailsEntity client1 = new ClientDetailsEntity(); @@ -486,7 +484,6 @@ public class TestMITREidDataService_1_0 { WhitelistedSite site3 = new WhitelistedSite(); site3.setId(3L); site3.setClientId("baz"); - //site3.setAllowedScopes(null); String configJson = "{" + "\"" + MITREidDataService.CLIENTS + "\": [], " + @@ -511,7 +508,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(wlSiteRepository.save(isA(WhitelistedSite.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 345L; @Override public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { WhitelistedSite _site = (WhitelistedSite) invocation.getArguments()[0]; @@ -550,6 +547,9 @@ public class TestMITREidDataService_1_0 { WhitelistedSite mockWlSite1 = mock(WhitelistedSite.class); when(mockWlSite1.getId()).thenReturn(1L); + OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); + when(mockToken1.getId()).thenReturn(1L); + ApprovedSite site1 = new ApprovedSite(); site1.setId(1L); site1.setClientId("foo"); @@ -558,6 +558,7 @@ public class TestMITREidDataService_1_0 { site1.setUserId("user1"); site1.setWhitelistedSite(mockWlSite1); site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); + site1.setApprovedAccessTokens(ImmutableSet.of(mockToken1)); Date creationDate2 = DateUtil.utcToDate("2014-09-11T18:49:44.090+0000"); Date accessDate2 = DateUtil.utcToDate("2014-09-11T20:49:44.090+0000"); @@ -583,7 +584,8 @@ public class TestMITREidDataService_1_0 { "\"" + MITREidDataService.GRANTS + "\": [" + "{\"id\":1,\"clientId\":\"foo\",\"creationDate\":\"2014-09-10T22:49:44.090+0000\",\"accessDate\":\"2014-09-10T23:49:44.090+0000\"," - + "\"userId\":\"user1\",\"whitelistedSiteId\":null,\"allowedScopes\":[\"openid\",\"phone\"], \"whitelistedSiteId\":1}," + + + "\"userId\":\"user1\",\"whitelistedSiteId\":null,\"allowedScopes\":[\"openid\",\"phone\"], \"whitelistedSiteId\":1," + + "\"approvedAccessTokens\":[1]}," + "{\"id\":2,\"clientId\":\"bar\",\"creationDate\":\"2014-09-11T18:49:44.090+0000\",\"accessDate\":\"2014-09-11T20:49:44.090+0000\"," + "\"timeoutDate\":\"2014-10-01T20:49:44.090+0000\",\"userId\":\"user2\"," + "\"allowedScopes\":[\"openid\",\"offline_access\",\"email\",\"profile\"]}" + @@ -597,7 +599,7 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(approvedSiteRepository.save(isA(ApprovedSite.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 343L; @Override public ApprovedSite answer(InvocationOnMock invocation) throws Throwable { ApprovedSite _site = (ApprovedSite) invocation.getArguments()[0]; @@ -616,7 +618,7 @@ public class TestMITREidDataService_1_0 { } }); when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 2L; + Long id = 244L; @Override public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { WhitelistedSite _site = mock(WhitelistedSite.class); @@ -624,10 +626,19 @@ public class TestMITREidDataService_1_0 { return _site; } }); - + when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer() { + Long id = 221L; + @Override + public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); + when(_token.getId()).thenReturn(id++); + return _token; + } + }); + dataService.importData(reader); - //2 for sites, 1 more for updating whitelistedSite ref on #2 - verify(approvedSiteRepository, times(3)).save(capturedApprovedSites.capture()); + //2 for sites, 1 for updating access token ref on #1, 1 more for updating whitelistedSite ref on #2 + verify(approvedSiteRepository, times(4)).save(capturedApprovedSites.capture()); List savedSites = new ArrayList(fakeDb.values()); @@ -639,6 +650,7 @@ public class TestMITREidDataService_1_0 { assertThat(savedSites.get(0).getAllowedScopes(), equalTo(site1.getAllowedScopes())); assertThat(savedSites.get(0).getIsWhitelisted(), equalTo(site1.getIsWhitelisted())); assertThat(savedSites.get(0).getTimeoutDate(), equalTo(site1.getTimeoutDate())); + assertThat(savedSites.get(0).getApprovedAccessTokens().size(), equalTo(site1.getApprovedAccessTokens().size())); assertThat(savedSites.get(1).getClientId(), equalTo(site2.getClientId())); assertThat(savedSites.get(1).getAccessDate(), equalTo(site2.getAccessDate())); @@ -646,17 +658,26 @@ public class TestMITREidDataService_1_0 { assertThat(savedSites.get(1).getAllowedScopes(), equalTo(site2.getAllowedScopes())); assertThat(savedSites.get(1).getTimeoutDate(), equalTo(site2.getTimeoutDate())); assertThat(savedSites.get(1).getIsWhitelisted(), equalTo(site2.getIsWhitelisted())); + assertThat(savedSites.get(1).getApprovedAccessTokens().size(), equalTo(site2.getApprovedAccessTokens().size())); } @Test public void testImportAuthenticationHolders() throws IOException { - OAuth2Authentication auth1 = mock(OAuth2Authentication.class, withSettings().serializable()); + OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), + true, new HashSet(), new HashSet(), "http://foo.com", + new HashSet(), null); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); holder1.setId(1L); holder1.setAuthentication(auth1); - OAuth2Authentication auth2 = mock(OAuth2Authentication.class, withSettings().serializable()); + OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), + true, new HashSet(), new HashSet(), "http://bar.com", + new HashSet(), null); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); holder2.setId(2L); @@ -672,9 +693,10 @@ public class TestMITREidDataService_1_0 { "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + - "{\"id\":1,\"authentication\":{\"clientAuthorization\":{},\"userAuthentication\":null}}," + - "{\"id\":2,\"authentication\":{\"clientAuthorization\":{},\"userAuthentication\":null}}" + - + "{\"id\":1,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client1\",\"redirectUri\":\"http://foo.com\"}," + + "\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client2\",\"redirectUri\":\"http://bar.com\"}," + + "\"userAuthentication\":null}}" + " ]" + "}"; @@ -684,26 +706,26 @@ public class TestMITREidDataService_1_0 { final Map fakeDb = new HashMap(); when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 356L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { - AuthenticationHolderEntity _site = (AuthenticationHolderEntity) invocation.getArguments()[0]; - if(_site.getId() == null) { - _site.setId(id++); + AuthenticationHolderEntity _holder = (AuthenticationHolderEntity) invocation.getArguments()[0]; + if(_holder.getId() == null) { + _holder.setId(id++); } - fakeDb.put(_site.getId(), _site); - return _site; + fakeDb.put(_holder.getId(), _holder); + return _holder; } - }); + }); dataService.importData(reader); verify(authHolderRepository, times(2)).save(capturedAuthHolders.capture()); List savedAuthHolders = capturedAuthHolders.getAllValues(); - assertThat(savedAuthHolders.size(), is(2)); - assertThat(savedAuthHolders.get(0).getAuthentication().getDetails(), equalTo(holder1.getAuthentication().getDetails())); - assertThat(savedAuthHolders.get(1).getAuthentication().getDetails(), equalTo(holder2.getAuthentication().getDetails())); + assertThat(savedAuthHolders.size(), is(2)); + assertThat(savedAuthHolders.get(0).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder1.getAuthentication().getOAuth2Request().getClientId())); + assertThat(savedAuthHolders.get(1).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder2.getAuthentication().getOAuth2Request().getClientId())); } @Test @@ -778,4 +800,138 @@ public class TestMITREidDataService_1_0 { assertThat(savedScopes.get(2).isAllowDynReg(), equalTo(scope3.isAllowDynReg())); } + + @Test + public void testFixRefreshTokenAuthHolderReferencesOnImport() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = DateUtil.utcToDate(expiration1); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), + true, new HashSet(), new HashSet(), "http://foo.com", + new HashSet(), null); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); + + AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); + holder1.setId(1L); + holder1.setAuthentication(auth1); + + OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setValue("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ."); + token1.setAuthenticationHolder(holder1); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = DateUtil.utcToDate(expiration2); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), + true, new HashSet(), new HashSet(), "http://bar.com", + new HashSet(), null); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); + + AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); + holder2.setId(2L); + holder2.setAuthentication(auth2); + + OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setValue("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ."); + token2.setAuthenticationHolder(holder2); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + + + "{\"id\":1,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client1\",\"redirectUri\":\"http://foo.com\"}," + + "\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client2\",\"redirectUri\":\"http://bar.com\"}," + + "\"userAuthentication\":null}}" + + " ]," + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [" + + + "{\"id\":1,\"clientId\":\"mocked_client_1\",\"expiration\":\"2014-09-10T22:49:44.090+0000\"," + + "\"authenticationHolderId\":1,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.\"}," + + "{\"id\":2,\"clientId\":\"mocked_client_2\",\"expiration\":\"2015-01-07T18:31:50.079+0000\"," + + "\"authenticationHolderId\":2,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.\"}" + + + " ]" + + "}"; + System.err.println(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + final Map fakeRefreshTokenTable = new HashMap(); + final Map fakeAuthHolderTable = new HashMap(); + when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { + Long id = 343L; + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; + if(_token.getId() == null) { + _token.setId(id++); + } + fakeRefreshTokenTable.put(_token.getId(), _token); + return _token; + } + }); + when(tokenRepository.getRefreshTokenById(anyLong())).thenAnswer(new Answer() { + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeRefreshTokenTable.get(_id); + } + }); + when(clientRepository.getClientByClientId(anyString())).thenAnswer(new Answer() { + @Override + public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { + String _clientId = (String) invocation.getArguments()[0]; + ClientDetailsEntity _client = mock(ClientDetailsEntity.class); + when(_client.getClientId()).thenReturn(_clientId); + return _client; + } + }); + when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { + Long id = 356L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _holder = (AuthenticationHolderEntity) invocation.getArguments()[0]; + if(_holder.getId() == null) { + _holder.setId(id++); + } + fakeAuthHolderTable.put(_holder.getId(), _holder); + return _holder; + } + }); + when(authHolderRepository.getById(anyLong())).thenAnswer(new Answer() { + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeAuthHolderTable.get(_id); + } + }); + dataService.importData(reader); + + List savedRefreshTokens = new ArrayList(fakeRefreshTokenTable.values()); //capturedRefreshTokens.getAllValues(); + Collections.sort(savedRefreshTokens, new refreshTokenIdComparator()); + + assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId(), + equalTo(token1.getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId())); + assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId(), + equalTo(token2.getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId())); + } } \ No newline at end of file diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_1.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_1.java index 8ab802abc..e69891ee7 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_1.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_1.java @@ -284,7 +284,7 @@ public class TestMITREidDataService_1_1 { final Map fakeDb = new HashMap(); when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 332L; @Override public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; @@ -312,7 +312,7 @@ public class TestMITREidDataService_1_1 { } }); when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 1L; + Long id = 131L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); @@ -332,12 +332,10 @@ public class TestMITREidDataService_1_1 { assertThat(savedRefreshTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); assertThat(savedRefreshTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); - assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getId(), equalTo(token1.getAuthenticationHolder().getId())); assertThat(savedRefreshTokens.get(0).getValue(), equalTo(token1.getValue())); assertThat(savedRefreshTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); assertThat(savedRefreshTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); - assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getId(), equalTo(token2.getAuthenticationHolder().getId())); assertThat(savedRefreshTokens.get(1).getValue(), equalTo(token2.getValue())); } @@ -555,7 +553,7 @@ public class TestMITREidDataService_1_1 { final Map fakeDb = new HashMap(); when(tokenRepository.saveAccessToken(isA(OAuth2AccessTokenEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 324L; @Override public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2AccessTokenEntity _token = (OAuth2AccessTokenEntity) invocation.getArguments()[0]; @@ -583,7 +581,7 @@ public class TestMITREidDataService_1_1 { } }); when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 1L; + Long id = 133L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); @@ -603,12 +601,10 @@ public class TestMITREidDataService_1_1 { assertThat(savedAccessTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); assertThat(savedAccessTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); - assertThat(savedAccessTokens.get(0).getAuthenticationHolder().getId(), equalTo(token1.getAuthenticationHolder().getId())); assertThat(savedAccessTokens.get(0).getValue(), equalTo(token1.getValue())); assertThat(savedAccessTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); assertThat(savedAccessTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); - assertThat(savedAccessTokens.get(1).getAuthenticationHolder().getId(), equalTo(token2.getAuthenticationHolder().getId())); assertThat(savedAccessTokens.get(1).getValue(), equalTo(token2.getValue())); } @@ -1063,7 +1059,7 @@ public class TestMITREidDataService_1_1 { final Map fakeDb = new HashMap(); when(wlSiteRepository.save(isA(WhitelistedSite.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 333L; @Override public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { WhitelistedSite _site = (WhitelistedSite) invocation.getArguments()[0]; @@ -1288,7 +1284,7 @@ public class TestMITREidDataService_1_1 { final Map fakeDb = new HashMap(); when(approvedSiteRepository.save(isA(ApprovedSite.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 364L; @Override public ApprovedSite answer(InvocationOnMock invocation) throws Throwable { ApprovedSite _site = (ApprovedSite) invocation.getArguments()[0]; @@ -1307,7 +1303,7 @@ public class TestMITREidDataService_1_1 { } }); when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 2L; + Long id = 432L; @Override public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { WhitelistedSite _site = mock(WhitelistedSite.class); @@ -1316,7 +1312,7 @@ public class TestMITREidDataService_1_1 { } }); when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer() { - Long id = 2L; + Long id = 245L; @Override public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); @@ -1347,7 +1343,7 @@ public class TestMITREidDataService_1_1 { assertThat(savedSites.get(1).getAllowedScopes(), equalTo(site2.getAllowedScopes())); assertThat(savedSites.get(1).getTimeoutDate(), equalTo(site2.getTimeoutDate())); assertThat(savedSites.get(1).getIsWhitelisted(), equalTo(site2.getIsWhitelisted())); - assertThat(savedSites.get(1).getApprovedAccessTokens(), equalTo(site2.getApprovedAccessTokens())); //both should be null or empty + assertThat(savedSites.get(1).getApprovedAccessTokens().size(), equalTo(site2.getApprovedAccessTokens().size())); } @Test @@ -1481,9 +1477,10 @@ public class TestMITREidDataService_1_1 { "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + - "{\"id\":1,\"authentication\":{\"clientAuthorization\":{},\"userAuthentication\":null}}," + - "{\"id\":2,\"authentication\":{\"clientAuthorization\":{},\"userAuthentication\":null}}" + - + "{\"id\":1,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client1\",\"redirectUri\":\"http://foo.com\"}," + + "\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client2\",\"redirectUri\":\"http://bar.com\"}," + + "\"userAuthentication\":null}}" + " ]" + "}"; @@ -1493,7 +1490,7 @@ public class TestMITREidDataService_1_1 { final Map fakeDb = new HashMap(); when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { - Long id = 3L; + Long id = 243L; @Override public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { AuthenticationHolderEntity _site = (AuthenticationHolderEntity) invocation.getArguments()[0]; @@ -1511,8 +1508,8 @@ public class TestMITREidDataService_1_1 { List savedAuthHolders = capturedAuthHolders.getAllValues(); assertThat(savedAuthHolders.size(), is(2)); - assertThat(savedAuthHolders.get(0).getAuthentication().getName(), equalTo(holder1.getAuthentication().getName())); - assertThat(savedAuthHolders.get(1).getAuthentication().getName(), equalTo(holder2.getAuthentication().getName())); + assertThat(savedAuthHolders.get(0).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder1.getAuthentication().getOAuth2Request().getClientId())); + assertThat(savedAuthHolders.get(1).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder2.getAuthentication().getOAuth2Request().getClientId())); } @Test @@ -1698,6 +1695,140 @@ public class TestMITREidDataService_1_1 { } + @Test + public void testFixRefreshTokenAuthHolderReferencesOnImport() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = DateUtil.utcToDate(expiration1); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), + true, new HashSet(), new HashSet(), "http://foo.com", + new HashSet(), null); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); + + AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); + holder1.setId(1L); + holder1.setAuthentication(auth1); + + OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setValue("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ."); + token1.setAuthenticationHolder(holder1); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = DateUtil.utcToDate(expiration2); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), + true, new HashSet(), new HashSet(), "http://bar.com", + new HashSet(), null); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); + + AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); + holder2.setId(2L); + holder2.setAuthentication(auth2); + + OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setValue("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ."); + token2.setAuthenticationHolder(holder2); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + + + "{\"id\":1,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client1\",\"redirectUri\":\"http://foo.com\"}," + + "\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"clientAuthorization\":{\"clientId\":\"client2\",\"redirectUri\":\"http://bar.com\"}," + + "\"userAuthentication\":null}}" + + " ]," + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [" + + + "{\"id\":1,\"clientId\":\"mocked_client_1\",\"expiration\":\"2014-09-10T22:49:44.090+0000\"," + + "\"authenticationHolderId\":1,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.\"}," + + "{\"id\":2,\"clientId\":\"mocked_client_2\",\"expiration\":\"2015-01-07T18:31:50.079+0000\"," + + "\"authenticationHolderId\":2,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.\"}" + + + " ]" + + "}"; + System.err.println(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + final Map fakeRefreshTokenTable = new HashMap(); + final Map fakeAuthHolderTable = new HashMap(); + when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { + Long id = 343L; + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; + if(_token.getId() == null) { + _token.setId(id++); + } + fakeRefreshTokenTable.put(_token.getId(), _token); + return _token; + } + }); + when(tokenRepository.getRefreshTokenById(anyLong())).thenAnswer(new Answer() { + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeRefreshTokenTable.get(_id); + } + }); + when(clientRepository.getClientByClientId(anyString())).thenAnswer(new Answer() { + @Override + public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { + String _clientId = (String) invocation.getArguments()[0]; + ClientDetailsEntity _client = mock(ClientDetailsEntity.class); + when(_client.getClientId()).thenReturn(_clientId); + return _client; + } + }); + when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { + Long id = 356L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _holder = (AuthenticationHolderEntity) invocation.getArguments()[0]; + if(_holder.getId() == null) { + _holder.setId(id++); + } + fakeAuthHolderTable.put(_holder.getId(), _holder); + return _holder; + } + }); + when(authHolderRepository.getById(anyLong())).thenAnswer(new Answer() { + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeAuthHolderTable.get(_id); + } + }); + dataService.importData(reader); + + List savedRefreshTokens = new ArrayList(fakeRefreshTokenTable.values()); //capturedRefreshTokens.getAllValues(); + Collections.sort(savedRefreshTokens, new refreshTokenIdComparator()); + + assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId(), + equalTo(token1.getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId())); + assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId(), + equalTo(token2.getAuthenticationHolder().getAuthentication().getOAuth2Request().getClientId())); + } + private Set jsonArrayToStringSet(JsonArray a) { Set s = new HashSet(); for (JsonElement jsonElement : a) {