From a8377513a68608cc1235c5ab57c0b109ebbf23f0 Mon Sep 17 00:00:00 2001 From: arielak Date: Thu, 4 Sep 2014 16:20:21 -0400 Subject: [PATCH] Fixed reading/writing of approved access tokens --- .../service/impl/MITREidDataService_1_0.java | 227 +++++----- .../service/impl/MITREidDataService_1_1.java | 389 ++++++++---------- .../service/impl/MITREidDataService_1_X.java | 13 + 3 files changed, 318 insertions(+), 311 deletions(-) diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java index d3e021536..f1034dbbb 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java @@ -70,7 +70,7 @@ import org.springframework.stereotype.Service; */ @Service public class MITREidDataService_1_0 extends MITREidDataService_1_X { - + private final static Logger logger = LoggerFactory.getLogger(MITREidDataService_1_0.class); @Autowired private OAuth2ClientRepository clientRepository; @@ -89,11 +89,12 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { /* (non-Javadoc) * @see org.mitre.openid.connect.service.MITREidDataService#export(com.google.gson.stream.JsonWriter) */ + @Override public void exportData(JsonWriter writer) throws IOException { throw new UnsupportedOperationException("Not supported."); } - + /* (non-Javadoc) * @see org.mitre.openid.connect.service.MITREidDataService#importData(com.google.gson.stream.JsonReader) */ @@ -140,7 +141,6 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { } fixObjectReferences(); } - private Map refreshTokenToClientRefs = new HashMap(); private Map refreshTokenToAuthHolderRefs = new HashMap(); private Map refreshTokenOldToNewIdMap = new HashMap(); @@ -206,7 +206,6 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading refresh tokens"); } - private Map accessTokenToClientRefs = new HashMap(); private Map accessTokenToAuthHolderRefs = new HashMap(); private Map accessTokenToRefreshTokenRefs = new HashMap(); @@ -224,12 +223,12 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { private void readAccessTokens(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); - reader.beginObject(); - Long currentId = null; - String clientId = null; - Long authHolderId = null; - Long refreshTokenId = null; + OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); + reader.beginObject(); + Long currentId = null; + String clientId = null; + Long authHolderId = null; + Long refreshTokenId = null; Long idTokenId = null; while (reader.hasNext()) { switch (reader.peek()) { @@ -291,9 +290,8 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading access tokens"); } - private Map authHolderOldToNewIdMap = new HashMap(); - + /** * @param reader * @throws IOException @@ -425,10 +423,10 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { reader.endObject(); return new OAuth2Request(authorizationParameters, clientId, authorities, approved, scope, resourceIds, redirectUri, responseTypes, null); } - Map grantOldToNewIdMap = new HashMap(); Map grantToWhitelistedSiteRefs = new HashMap(); - + Map> grantToAccessTokensRefs = new HashMap>(); + /** * @param reader * @throws IOException @@ -436,63 +434,68 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { private void readGrants(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - ApprovedSite site = new ApprovedSite(); - Long currentId = null; - Long whitelistedSiteId = null; - reader.beginObject(); - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String name = reader.nextName(); - if (reader.peek() == JsonToken.NULL) { - reader.skipValue(); - } else if (name.equals("id")) { - currentId = reader.nextLong(); - } else if (name.equals("accessDate")) { - Date date = utcToDate(reader.nextString()); - site.setAccessDate(date); - } else if (name.equals("clientId")) { - site.setClientId(reader.nextString()); - } else if (name.equals("creationDate")) { - Date date = utcToDate(reader.nextString()); - site.setCreationDate(date); - } else if (name.equals("timeoutDate")) { - Date date = utcToDate(reader.nextString()); - site.setTimeoutDate(date); - } else if (name.equals("userId")) { - site.setUserId(reader.nextString()); - } else if (name.equals("allowedScopes")) { - Set allowedScopes = readSet(reader); - site.setAllowedScopes(allowedScopes); - } else if (name.equals("whitelistedSiteId")) { - whitelistedSiteId = reader.nextLong(); - } else { - logger.debug("Found unexpected entry"); - reader.skipValue(); - } - break; - default: + ApprovedSite site = new ApprovedSite(); + Long currentId = null; + Long whitelistedSiteId = null; + Set tokenIds = null; + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("accessDate")) { + Date date = utcToDate(reader.nextString()); + site.setAccessDate(date); + } else if (name.equals("clientId")) { + site.setClientId(reader.nextString()); + } else if (name.equals("creationDate")) { + Date date = utcToDate(reader.nextString()); + site.setCreationDate(date); + } else if (name.equals("timeoutDate")) { + Date date = utcToDate(reader.nextString()); + site.setTimeoutDate(date); + } else if (name.equals("userId")) { + site.setUserId(reader.nextString()); + } else if (name.equals("allowedScopes")) { + Set allowedScopes = readSet(reader); + site.setAllowedScopes(allowedScopes); + } else if (name.equals("whitelistedSiteId")) { + whitelistedSiteId = reader.nextLong(); + } else if (name.equals("approvedAccessTokens")) { + tokenIds = readSet(reader); + } else { logger.debug("Found unexpected entry"); reader.skipValue(); - continue; - } + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; } - reader.endObject(); - Long newId = approvedSiteRepository.save(site).getId(); - grantOldToNewIdMap.put(currentId, newId); - if(whitelistedSiteId != null) { - grantToWhitelistedSiteRefs.put(currentId, whitelistedSiteId); - } - logger.debug("Read grant {}", currentId); + } + reader.endObject(); + Long newId = approvedSiteRepository.save(site).getId(); + grantOldToNewIdMap.put(currentId, newId); + if (whitelistedSiteId != null) { + grantToWhitelistedSiteRefs.put(currentId, whitelistedSiteId); + } + if (tokenIds != null) { + grantToAccessTokensRefs.put(currentId, tokenIds); + } + logger.debug("Read grant {}", currentId); } reader.endArray(); logger.info("Done reading grants"); } - Map whitelistedSiteOldToNewIdMap = new HashMap(); - + /** * @param reader * @throws IOException @@ -536,7 +539,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading whitelisted sites"); } - + /** * @param reader * @throws IOException @@ -573,6 +576,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading blacklisted sites"); } + /** * @param reader * @throws IOException @@ -698,22 +702,23 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { } /** - * Read the list of system scopes from the reader and insert them - * into the scope repository. + * Read the list of system scopes from the reader and insert them into the + * scope repository. + * * @param reader * @throws IOException */ private void readSystemScopes(JsonReader reader) throws IOException { - reader.beginArray(); - while (reader.hasNext()) { - SystemScope scope = new SystemScope(); - reader.beginObject(); - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String name = reader.nextName(); + reader.beginArray(); + while (reader.hasNext()) { + SystemScope scope = new SystemScope(); + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); if (reader.peek() == JsonToken.NULL) { reader.skipValue(); } else if (name.equals("value")) { @@ -723,29 +728,29 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { } else if (name.equals("allowDynReg")) { scope.setAllowDynReg(reader.nextBoolean()); } else if (name.equals("defaultScope")) { - scope.setDefaultScope(reader.nextBoolean()); - } else if (name.equals("icon")) { - scope.setIcon(reader.nextString()); - } else { - logger.debug("found unexpected entry"); - reader.skipValue(); - } - break; - default: - logger.debug("Found unexpected entry"); - reader.skipValue(); - continue; - } - } - reader.endObject(); - sysScopeRepository.save(scope); - } - reader.endArray(); - logger.info("Done reading system scopes"); + scope.setDefaultScope(reader.nextBoolean()); + } else if (name.equals("icon")) { + scope.setIcon(reader.nextString()); + } else { + logger.debug("found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + sysScopeRepository.save(scope); + } + reader.endArray(); + logger.info("Done reading system scopes"); } - + private void fixObjectReferences() { - for(Long oldRefreshTokenId : refreshTokenToClientRefs.keySet()) { + for (Long oldRefreshTokenId : refreshTokenToClientRefs.keySet()) { String clientRef = refreshTokenToClientRefs.get(oldRefreshTokenId); ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); @@ -754,7 +759,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { tokenRepository.saveRefreshToken(refreshToken); } refreshTokenToClientRefs.clear(); - for(Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { + for (Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); @@ -764,7 +769,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { tokenRepository.saveRefreshToken(refreshToken); } refreshTokenToAuthHolderRefs.clear(); - for(Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { String clientRef = accessTokenToClientRefs.get(oldAccessTokenId); ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); @@ -773,7 +778,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { tokenRepository.saveAccessToken(accessToken); } accessTokenToClientRefs.clear(); - for(Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); @@ -783,7 +788,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { tokenRepository.saveAccessToken(accessToken); } accessTokenToAuthHolderRefs.clear(); - for(Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId); Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); @@ -794,7 +799,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { } accessTokenToRefreshTokenRefs.clear(); refreshTokenOldToNewIdMap.clear(); - for(Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId); Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId); OAuth2AccessTokenEntity idToken = tokenRepository.getAccessTokenById(newIdTokenId); @@ -804,8 +809,7 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { tokenRepository.saveAccessToken(accessToken); } accessTokenToIdTokenRefs.clear(); - accessTokenOldToNewIdMap.clear(); - for(Long oldGrantId : grantToWhitelistedSiteRefs.keySet()) { + for (Long oldGrantId : grantToWhitelistedSiteRefs.keySet()) { Long oldWhitelistedSiteId = grantToWhitelistedSiteRefs.get(oldGrantId); Long newWhitelistedSiteId = whitelistedSiteOldToNewIdMap.get(oldWhitelistedSiteId); WhitelistedSite wlSite = wlSiteRepository.getById(newWhitelistedSiteId); @@ -814,7 +818,20 @@ public class MITREidDataService_1_0 extends MITREidDataService_1_X { approvedSite.setWhitelistedSite(wlSite); approvedSiteRepository.save(approvedSite); } - grantOldToNewIdMap.clear(); grantToWhitelistedSiteRefs.clear(); + for (Long oldGrantId : grantToAccessTokensRefs.keySet()) { + Set oldAccessTokenIds = grantToAccessTokensRefs.get(oldGrantId); + Set tokens = new HashSet(); + for(Long oldTokenId : oldAccessTokenIds) { + Long newTokenId = accessTokenOldToNewIdMap.get(oldTokenId); + tokens.add(tokenRepository.getAccessTokenById(newTokenId)); + } + Long newGrantId = grantOldToNewIdMap.get(oldGrantId); + ApprovedSite site = approvedSiteRepository.getById(newGrantId); + site.setApprovedAccessTokens(tokens); + approvedSiteRepository.save(site); + } + accessTokenOldToNewIdMap.clear(); + grantOldToNewIdMap.clear(); } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_1.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_1.java index b24ed97ed..08aebb0e1 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_1.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_1.java @@ -24,7 +24,6 @@ import com.google.gson.stream.JsonWriter; import java.io.IOException; import java.io.Serializable; import java.text.ParseException; -import java.util.ArrayList; import java.util.Collection; import java.util.Date; import java.util.HashMap; @@ -73,7 +72,7 @@ import org.springframework.stereotype.Service; */ @Service public class MITREidDataService_1_1 extends MITREidDataService_1_X { - + private final static Logger logger = LoggerFactory.getLogger(MITREidDataService_1_1.class); @Autowired private OAuth2ClientRepository clientRepository; @@ -99,7 +98,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { // version tag at the root writer.name(MITREID_CONNECT_1_1); - + writer.beginObject(); // clients list @@ -117,7 +116,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { writer.beginArray(); writeWhitelistedSites(writer); writer.endArray(); - + writer.name(BLACKLISTEDSITES); writer.beginArray(); writeBlacklistedSites(writer); @@ -149,28 +148,18 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { /** * @param writer */ - private void writeRefreshTokens(JsonWriter writer) { - Collection tokens = new ArrayList(); - try { - tokens = tokenRepository.getAllRefreshTokens(); - } catch (Exception ex) { - logger.error("Unable to read refresh tokens from data source", ex); - } - for (OAuth2RefreshTokenEntity token : tokens) { - try { - writer.beginObject(); - writer.name("id").value(token.getId()); - writer.name("expiration").value(toUTCString(token.getExpiration())); - writer.name("clientId") - .value((token.getClient() != null) ? token.getClient().getClientId() : null); - writer.name("authenticationHolderId") - .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); - writer.name("value").value(token.getValue()); - writer.endObject(); - logger.debug("Wrote refresh token {}", token.getId()); - } catch (IOException ex) { - logger.error("Unable to write refresh token {}", token.getId(), ex); - } + private void writeRefreshTokens(JsonWriter writer) throws IOException { + for (OAuth2RefreshTokenEntity token : tokenRepository.getAllRefreshTokens()) { + writer.beginObject(); + writer.name("id").value(token.getId()); + writer.name("expiration").value(toUTCString(token.getExpiration())); + writer.name("clientId") + .value((token.getClient() != null) ? token.getClient().getClientId() : null); + writer.name("authenticationHolderId") + .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); + writer.name("value").value(token.getValue()); + writer.endObject(); + logger.debug("Wrote refresh token {}", token.getId()); } logger.info("Done writing refresh tokens"); } @@ -178,39 +167,29 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { /** * @param writer */ - private void writeAccessTokens(JsonWriter writer) { - Collection tokens = new ArrayList(); - try { - tokens = tokenRepository.getAllAccessTokens(); - } catch (Exception ex) { - logger.error("Unable to read access tokens from data source", ex); - } - for (OAuth2AccessTokenEntity token : tokens) { - try { - writer.beginObject(); - writer.name("id").value(token.getId()); - writer.name("expiration").value(toUTCString(token.getExpiration())); - writer.name("clientId") - .value((token.getClient() != null) ? token.getClient().getClientId() : null); - writer.name("authenticationHolderId") - .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); - writer.name("refreshTokenId") - .value((token.getRefreshToken() != null) ? token.getRefreshToken().getId() : null); - writer.name("idTokenId") - .value((token.getIdToken() != null) ? token.getIdToken().getId() : null); - writer.name("scope"); - writer.beginArray(); - for (String s : token.getScope()) { - writer.value(s); - } - writer.endArray(); - writer.name("type").value(token.getTokenType()); - writer.name("value").value(token.getValue()); - writer.endObject(); - logger.debug("Wrote access token {}", token.getId()); - } catch (IOException ex) { - logger.error("Unable to write access token {}", token.getId(), ex); + private void writeAccessTokens(JsonWriter writer) throws IOException { + for (OAuth2AccessTokenEntity token : tokenRepository.getAllAccessTokens()) { + writer.beginObject(); + writer.name("id").value(token.getId()); + writer.name("expiration").value(toUTCString(token.getExpiration())); + writer.name("clientId") + .value((token.getClient() != null) ? token.getClient().getClientId() : null); + writer.name("authenticationHolderId") + .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); + writer.name("refreshTokenId") + .value((token.getRefreshToken() != null) ? token.getRefreshToken().getId() : null); + writer.name("idTokenId") + .value((token.getIdToken() != null) ? token.getIdToken().getId() : null); + writer.name("scope"); + writer.beginArray(); + for (String s : token.getScope()) { + writer.value(s); } + writer.endArray(); + writer.name("type").value(token.getTokenType()); + writer.name("value").value(token.getValue()); + writer.endObject(); + logger.debug("Wrote access token {}", token.getId()); } logger.info("Done writing access tokens"); } @@ -218,31 +197,21 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { /** * @param writer */ - private void writeAuthenticationHolders(JsonWriter writer) { - Collection holders = new ArrayList(); - try { - holders = authHolderRepository.getAll(); - } catch (Exception ex) { - logger.error("Unable to read authentication holders from data source", ex); - } - for (AuthenticationHolderEntity holder : holders) { - try { - writer.beginObject(); - writer.name("id").value(holder.getId()); - writer.name("ownerId").value(holder.getOwnerId()); - writer.name("authentication"); - writer.beginObject(); - OAuth2Authentication oa2Auth = holder.getAuthentication(); - writer.name("clientAuthorization"); - writeAuthorizationRequest(oa2Auth.getOAuth2Request(), writer); - String userAuthentication = base64UrlEncodeObject(oa2Auth.getUserAuthentication()); - writer.name("userAuthentication").value(userAuthentication); - writer.endObject(); - writer.endObject(); - logger.debug("Wrote authentication holder {}", holder.getId()); - } catch (IOException ex) { - logger.error("Unable to write authentication holder {}", holder.getId(), ex); - } + private void writeAuthenticationHolders(JsonWriter writer) throws IOException { + for (AuthenticationHolderEntity holder : authHolderRepository.getAll()) { + writer.beginObject(); + writer.name("id").value(holder.getId()); + writer.name("ownerId").value(holder.getOwnerId()); + writer.name("authentication"); + writer.beginObject(); + OAuth2Authentication oa2Auth = holder.getAuthentication(); + writer.name("clientAuthorization"); + writeAuthorizationRequest(oa2Auth.getOAuth2Request(), writer); + String userAuthentication = base64UrlEncodeObject(oa2Auth.getUserAuthentication()); + writer.name("userAuthentication").value(userAuthentication); + writer.endObject(); + writer.endObject(); + logger.debug("Wrote authentication holder {}", holder.getId()); } logger.info("Done writing authentication holders"); } @@ -266,7 +235,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { writer.endArray(); writer.name("resourceIds"); writer.beginArray(); - if(authReq.getResourceIds() != null) { + if (authReq.getResourceIds() != null) { for (String s : authReq.getResourceIds()) { writer.value(s); } @@ -294,7 +263,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { writer.endObject(); writer.endObject(); } - + /** * @param writer */ @@ -310,6 +279,13 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { writer.name("allowedScopes"); writeNullSafeArray(writer, site.getAllowedScopes()); writer.name("whitelistedSiteId").value(site.getIsWhitelisted() ? site.getWhitelistedSite().getId() : null); + Set tokens = site.getApprovedAccessTokens(); + writer.name("approvedAccessTokens"); + writer.beginArray(); + for (OAuth2AccessTokenEntity token : tokens) { + writer.value(token.getId()); + } + writer.endArray(); writer.endObject(); logger.debug("Wrote grant {}", site.getId()); } @@ -332,7 +308,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { } logger.info("Done writing whitelisted sites"); } - + /** * @param writer */ @@ -346,7 +322,6 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { } logger.info("Done writing blacklisted sites"); } - /** * @param writer @@ -423,7 +398,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { writer.name("intitateLoginUri").value(client.getInitiateLoginUri()); writer.name("postLogoutRedirectUri").value(client.getPostLogoutRedirectUri()); writer.name("requestUris"); - writeNullSafeArray(writer, client.getRequestUris()); + writeNullSafeArray(writer, client.getRequestUris()); writer.name("description").value(client.getClientDescription()); writer.name("allowIntrospection").value(client.isAllowIntrospection()); writer.name("reuseRefreshToken").value(client.isReuseRefreshToken()); @@ -437,19 +412,6 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { logger.info("Done writing clients"); } - private void writeNullSafeArray(JsonWriter writer, Set items) - throws IOException { - if (items != null) { - writer.beginArray(); - for (String s : items) { - writer.value(s); - } - writer.endArray(); - } else { - writer.nullValue(); - } - } - /** * @param writer */ @@ -518,7 +480,6 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { } fixObjectReferences(); } - private Map refreshTokenToClientRefs = new HashMap(); private Map refreshTokenToAuthHolderRefs = new HashMap(); private Map refreshTokenOldToNewIdMap = new HashMap(); @@ -584,7 +545,6 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading refresh tokens"); } - private Map accessTokenToClientRefs = new HashMap(); private Map accessTokenToAuthHolderRefs = new HashMap(); private Map accessTokenToRefreshTokenRefs = new HashMap(); @@ -602,12 +562,12 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { private void readAccessTokens(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); - reader.beginObject(); - Long currentId = null; - String clientId = null; - Long authHolderId = null; - Long refreshTokenId = null; + OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); + reader.beginObject(); + Long currentId = null; + String clientId = null; + Long authHolderId = null; + Long refreshTokenId = null; Long idTokenId = null; while (reader.hasNext()) { switch (reader.peek()) { @@ -669,9 +629,8 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading access tokens"); } - private Map authHolderOldToNewIdMap = new HashMap(); - + /** * @param reader * @throws IOException @@ -791,9 +750,9 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { responseTypes = readSet(reader); } else if (name.equals("extensions")) { Map extEnc = readMap(reader); - for(Entry entry : extEnc.entrySet()) { + for (Entry entry : extEnc.entrySet()) { Serializable decoded = base64UrlDecodeObject(entry.getValue(), Serializable.class); - if(decoded != null) { + if (decoded != null) { extensions.put(entry.getKey(), decoded); } } @@ -810,10 +769,9 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { reader.endObject(); return new OAuth2Request(requestParameters, clientId, authorities, approved, scope, resourceIds, redirectUri, responseTypes, extensions); } - Map grantOldToNewIdMap = new HashMap(); Map grantToWhitelistedSiteRefs = new HashMap(); - + Map> grantToAccessTokensRefs = new HashMap>(); /** * @param reader * @throws IOException @@ -821,63 +779,68 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { private void readGrants(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - ApprovedSite site = new ApprovedSite(); - Long currentId = null; - Long whitelistedSiteId = null; - reader.beginObject(); - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String name = reader.nextName(); - if (reader.peek() == JsonToken.NULL) { - reader.skipValue(); - } else if (name.equals("id")) { - currentId = reader.nextLong(); - } else if (name.equals("accessDate")) { - Date date = utcToDate(reader.nextString()); - site.setAccessDate(date); - } else if (name.equals("clientId")) { - site.setClientId(reader.nextString()); - } else if (name.equals("creationDate")) { - Date date = utcToDate(reader.nextString()); - site.setCreationDate(date); - } else if (name.equals("timeoutDate")) { - Date date = utcToDate(reader.nextString()); - site.setTimeoutDate(date); - } else if (name.equals("userId")) { - site.setUserId(reader.nextString()); - } else if (name.equals("allowedScopes")) { - Set allowedScopes = readSet(reader); - site.setAllowedScopes(allowedScopes); - } else if (name.equals("whitelistedSiteId")) { - whitelistedSiteId = reader.nextLong(); - } else { - logger.debug("Found unexpected entry"); - reader.skipValue(); - } - break; - default: + ApprovedSite site = new ApprovedSite(); + Long currentId = null; + Long whitelistedSiteId = null; + Set tokenIds = null; + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("accessDate")) { + Date date = utcToDate(reader.nextString()); + site.setAccessDate(date); + } else if (name.equals("clientId")) { + site.setClientId(reader.nextString()); + } else if (name.equals("creationDate")) { + Date date = utcToDate(reader.nextString()); + site.setCreationDate(date); + } else if (name.equals("timeoutDate")) { + Date date = utcToDate(reader.nextString()); + site.setTimeoutDate(date); + } else if (name.equals("userId")) { + site.setUserId(reader.nextString()); + } else if (name.equals("allowedScopes")) { + Set allowedScopes = readSet(reader); + site.setAllowedScopes(allowedScopes); + } else if (name.equals("whitelistedSiteId")) { + whitelistedSiteId = reader.nextLong(); + } else if (name.equals("approvedAccessTokens")) { + tokenIds = readSet(reader); + } else { logger.debug("Found unexpected entry"); reader.skipValue(); - continue; - } + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; } - reader.endObject(); - Long newId = approvedSiteRepository.save(site).getId(); - grantOldToNewIdMap.put(currentId, newId); - if(whitelistedSiteId != null) { - grantToWhitelistedSiteRefs.put(currentId, whitelistedSiteId); - } - logger.debug("Read grant {}", currentId); + } + reader.endObject(); + Long newId = approvedSiteRepository.save(site).getId(); + grantOldToNewIdMap.put(currentId, newId); + if (whitelistedSiteId != null) { + grantToWhitelistedSiteRefs.put(currentId, whitelistedSiteId); + } + if (tokenIds != null) { + grantToAccessTokensRefs.put(currentId, tokenIds); + } + logger.debug("Read grant {}", currentId); } reader.endArray(); logger.info("Done reading grants"); } - Map whitelistedSiteOldToNewIdMap = new HashMap(); - + /** * @param reader * @throws IOException @@ -921,7 +884,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading whitelisted sites"); } - + /** * @param reader * @throws IOException @@ -958,6 +921,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { reader.endArray(); logger.info("Done reading blacklisted sites"); } + /** * @param reader * @throws IOException @@ -1083,22 +1047,23 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { } /** - * Read the list of system scopes from the reader and insert them - * into the scope repository. + * Read the list of system scopes from the reader and insert them into the + * scope repository. + * * @param reader * @throws IOException */ private void readSystemScopes(JsonReader reader) throws IOException { - reader.beginArray(); - while (reader.hasNext()) { - SystemScope scope = new SystemScope(); - reader.beginObject(); - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String name = reader.nextName(); + reader.beginArray(); + while (reader.hasNext()) { + SystemScope scope = new SystemScope(); + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); if (reader.peek() == JsonToken.NULL) { reader.skipValue(); } else if (name.equals("value")) { @@ -1108,29 +1073,29 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { } else if (name.equals("allowDynReg")) { scope.setAllowDynReg(reader.nextBoolean()); } else if (name.equals("defaultScope")) { - scope.setDefaultScope(reader.nextBoolean()); - } else if (name.equals("icon")) { - scope.setIcon(reader.nextString()); - } else { - logger.debug("found unexpected entry"); - reader.skipValue(); - } - break; - default: - logger.debug("Found unexpected entry"); - reader.skipValue(); - continue; - } - } - reader.endObject(); - sysScopeRepository.save(scope); - } - reader.endArray(); - logger.info("Done reading system scopes"); + scope.setDefaultScope(reader.nextBoolean()); + } else if (name.equals("icon")) { + scope.setIcon(reader.nextString()); + } else { + logger.debug("found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + sysScopeRepository.save(scope); + } + reader.endArray(); + logger.info("Done reading system scopes"); } - + private void fixObjectReferences() { - for(Long oldRefreshTokenId : refreshTokenToClientRefs.keySet()) { + for (Long oldRefreshTokenId : refreshTokenToClientRefs.keySet()) { String clientRef = refreshTokenToClientRefs.get(oldRefreshTokenId); ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); @@ -1139,7 +1104,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { tokenRepository.saveRefreshToken(refreshToken); } refreshTokenToClientRefs.clear(); - for(Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { + for (Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); @@ -1149,7 +1114,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { tokenRepository.saveRefreshToken(refreshToken); } refreshTokenToAuthHolderRefs.clear(); - for(Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { String clientRef = accessTokenToClientRefs.get(oldAccessTokenId); ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); @@ -1158,7 +1123,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { tokenRepository.saveAccessToken(accessToken); } accessTokenToClientRefs.clear(); - for(Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); @@ -1168,7 +1133,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { tokenRepository.saveAccessToken(accessToken); } accessTokenToAuthHolderRefs.clear(); - for(Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId); Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); @@ -1179,7 +1144,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { } accessTokenToRefreshTokenRefs.clear(); refreshTokenOldToNewIdMap.clear(); - for(Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { + for (Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId); Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId); OAuth2AccessTokenEntity idToken = tokenRepository.getAccessTokenById(newIdTokenId); @@ -1189,8 +1154,7 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { tokenRepository.saveAccessToken(accessToken); } accessTokenToIdTokenRefs.clear(); - accessTokenOldToNewIdMap.clear(); - for(Long oldGrantId : grantToWhitelistedSiteRefs.keySet()) { + for (Long oldGrantId : grantToWhitelistedSiteRefs.keySet()) { Long oldWhitelistedSiteId = grantToWhitelistedSiteRefs.get(oldGrantId); Long newWhitelistedSiteId = whitelistedSiteOldToNewIdMap.get(oldWhitelistedSiteId); WhitelistedSite wlSite = wlSiteRepository.getById(newWhitelistedSiteId); @@ -1199,7 +1163,20 @@ public class MITREidDataService_1_1 extends MITREidDataService_1_X { approvedSite.setWhitelistedSite(wlSite); approvedSiteRepository.save(approvedSite); } - grantOldToNewIdMap.clear(); grantToWhitelistedSiteRefs.clear(); + for (Long oldGrantId : grantToAccessTokensRefs.keySet()) { + Set oldAccessTokenIds = grantToAccessTokensRefs.get(oldGrantId); + Set tokens = new HashSet(); + for(Long oldTokenId : oldAccessTokenIds) { + Long newTokenId = accessTokenOldToNewIdMap.get(oldTokenId); + tokens.add(tokenRepository.getAccessTokenById(newTokenId)); + } + Long newGrantId = grantOldToNewIdMap.get(oldGrantId); + ApprovedSite site = approvedSiteRepository.getById(newGrantId); + site.setApprovedAccessTokens(tokens); + approvedSiteRepository.save(site); + } + accessTokenOldToNewIdMap.clear(); + grantOldToNewIdMap.clear(); } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_X.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_X.java index cc2daa66a..da5d5eaa9 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_X.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_X.java @@ -20,6 +20,7 @@ package org.mitre.openid.connect.service.impl; import com.google.common.io.BaseEncoding; import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -145,4 +146,16 @@ public abstract class MITREidDataService_1_X implements MITREidDataService { return map; } + protected void writeNullSafeArray(JsonWriter writer, Set items) + throws IOException { + if (items != null) { + writer.beginArray(); + for (String s : items) { + writer.value(s); + } + writer.endArray(); + } else { + writer.nullValue(); + } + } }