diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java index b293d7a64..6d0c93837 100644 --- a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java +++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java @@ -16,10 +16,9 @@ ******************************************************************************/ package org.mitre.openid.connect.service; -import java.io.IOException; - import com.google.gson.stream.JsonReader; import com.google.gson.stream.JsonWriter; +import java.io.IOException; /** * @author jricher @@ -36,6 +35,8 @@ public interface MITREidDataService { // member names public static final String REFRESHTOKENS = "refreshTokens"; public static final String ACCESSTOKENS = "accessTokens"; + public static final String WHITELISTEDSITES = "whitelistedSites"; + public static final String BLACKLISTEDSITES = "blacklistedSites"; public static final String AUTHENTICATIONHOLDERS = "authenticationHolders"; public static final String GRANTS = "grants"; public static final String CLIENTS = "clients"; diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java index f75bce59c..1a1b66b33 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java @@ -50,8 +50,10 @@ import org.mitre.oauth2.repository.OAuth2ClientRepository; import org.mitre.oauth2.repository.OAuth2TokenRepository; import org.mitre.oauth2.repository.SystemScopeRepository; import org.mitre.openid.connect.model.ApprovedSite; +import org.mitre.openid.connect.model.BlacklistedSite; import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.repository.ApprovedSiteRepository; +import org.mitre.openid.connect.repository.BlacklistedSiteRepository; import org.mitre.openid.connect.repository.WhitelistedSiteRepository; import org.mitre.openid.connect.service.MITREidDataService; import org.slf4j.Logger; @@ -59,7 +61,7 @@ import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.GrantedAuthorityImpl; +import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.stereotype.Service; @@ -80,6 +82,10 @@ public class MITREidDataService_1_0 implements MITREidDataService { @Autowired private ApprovedSiteRepository approvedSiteRepository; @Autowired + private WhitelistedSiteRepository wlSiteRepository; + @Autowired + private BlacklistedSiteRepository blSiteRepository; + @Autowired private AuthenticationHolderRepository authHolderRepository; @Autowired private OAuth2TokenRepository tokenRepository; @@ -93,13 +99,20 @@ public class MITREidDataService_1_0 implements MITREidDataService { */ @Override public void exportData(JsonWriter writer) throws IOException { + throw new UnsupportedOperationException("Not supported."); } - private static Date utcToDate(String s) throws ParseException { + private static Date utcToDate(String s) { if (s == null) { return null; } - return sdf.parse(s); + Date d = null; + try { + d = sdf.parse(s); + } catch(ParseException ex) { + logger.error("Unable to parse date string {}", s, ex); + } + return d; } /* (non-Javadoc) @@ -123,6 +136,10 @@ public class MITREidDataService_1_0 implements MITREidDataService { readClients(reader); } else if (name.equals(GRANTS)) { readGrants(reader); + } else if (name.equals(WHITELISTEDSITES)) { + readWhitelistedSites(reader); + } else if (name.equals(BLACKLISTEDSITES)) { + readBlacklistedSites(reader); } else if (name.equals(AUTHENTICATIONHOLDERS)) { readAuthenticationHolders(reader); } else if (name.equals(ACCESSTOKENS)) { @@ -149,6 +166,10 @@ public class MITREidDataService_1_0 implements MITREidDataService { private Map refreshTokenToAuthHolderRefs = new HashMap(); private Map refreshTokenOldToNewIdMap = new HashMap(); + /** + * @param reader + * @throws IOException + */ /** * @param reader * @throws IOException @@ -156,51 +177,52 @@ public class MITREidDataService_1_0 implements MITREidDataService { private void readRefreshTokens(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - try { - OAuth2RefreshTokenEntity token = new OAuth2RefreshTokenEntity(); - reader.beginObject(); - Long currentId = null; - String clientId = null; - Long authHolderId = null; - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String name = reader.nextName(); - if (reader.peek() == JsonToken.NULL) { - reader.skipValue(); - } else if (name.equals("id")) { - currentId = reader.nextLong(); - } else if (name.equals("expiration")) { - Date date = utcToDate(reader.nextString()); - token.setExpiration(date); - } else if (name.equals("value")) { - token.setValue(reader.nextString()); - } else if (name.equals("clientId")) { - clientId = reader.nextString(); - } else if (name.equals("authenticationHolderId")) { - authHolderId = reader.nextLong(); - } else { - logger.debug("Found unexpected entry"); - reader.skipValue(); + OAuth2RefreshTokenEntity token = new OAuth2RefreshTokenEntity(); + reader.beginObject(); + Long currentId = null; + String clientId = null; + Long authHolderId = null; + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("expiration")) { + Date date = utcToDate(reader.nextString()); + token.setExpiration(date); + } else if (name.equals("value")) { + String value = reader.nextString(); + try { + token.setValue(value); + } catch (ParseException ex) { + logger.error("Unable to set refresh token value to {}", value, ex); } - break; - default: + } else if (name.equals("clientId")) { + clientId = reader.nextString(); + } else if (name.equals("authenticationHolderId")) { + authHolderId = reader.nextLong(); + } else { logger.debug("Found unexpected entry"); reader.skipValue(); - continue; - } + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; } - reader.endObject(); - Long newId = tokenRepository.saveRefreshToken(token).getId(); - refreshTokenToClientRefs.put(currentId, clientId); - refreshTokenToAuthHolderRefs.put(currentId, authHolderId); - refreshTokenOldToNewIdMap.put(currentId, newId); - logger.debug("Read refresh token {}", currentId); - } catch (ParseException ex) { - logger.error("Unable to read refresh token", ex); } + reader.endObject(); + Long newId = tokenRepository.saveRefreshToken(token).getId(); + refreshTokenToClientRefs.put(currentId, clientId); + refreshTokenToAuthHolderRefs.put(currentId, authHolderId); + refreshTokenOldToNewIdMap.put(currentId, newId); + logger.debug("Read refresh token {}", currentId); } reader.endArray(); logger.info("Done reading refresh tokens"); @@ -212,6 +234,10 @@ public class MITREidDataService_1_0 implements MITREidDataService { private Map accessTokenToIdTokenRefs = new HashMap(); private Map accessTokenOldToNewIdMap = new HashMap(); + /** + * @param reader + * @throws IOException + */ /** * @param reader * @throws IOException @@ -219,68 +245,69 @@ public class MITREidDataService_1_0 implements MITREidDataService { private void readAccessTokens(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - try { OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); reader.beginObject(); Long currentId = null; String clientId = null; Long authHolderId = null; Long refreshTokenId = null; - Long idTokenId = null; - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String name = reader.nextName(); - if (reader.peek() == JsonToken.NULL) { - reader.skipValue(); - } else if (name.equals("id")) { - currentId = reader.nextLong(); - } else if (name.equals("expiration")) { - Date date = utcToDate(reader.nextString()); - token.setExpiration(date); - } else if (name.equals("value")) { - token.setValue(reader.nextString()); - } else if (name.equals("clientId")) { - clientId = reader.nextString(); - } else if (name.equals("authenticationHolderId")) { - authHolderId = reader.nextLong(); - } else if (name.equals("refreshTokenId")) { - refreshTokenId = reader.nextLong(); - } else if (name.equals("idTokenId")) { - idTokenId = reader.nextLong(); - } else if (name.equals("scope")) { - Set scope = readSet(reader); - token.setScope(scope); - } else if (name.equals("type")) { - token.setTokenType(reader.nextString()); - } else { - logger.debug("Found unexpected entry"); - reader.skipValue(); + Long idTokenId = null; + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("expiration")) { + Date date = utcToDate(reader.nextString()); + token.setExpiration(date); + } else if (name.equals("value")) { + String value = reader.nextString(); + try { + token.setValue(value); + } catch (ParseException ex) { + logger.error("Unable to set refresh token value to {}", value, ex); } - break; - default: + } else if (name.equals("clientId")) { + clientId = reader.nextString(); + } else if (name.equals("authenticationHolderId")) { + authHolderId = reader.nextLong(); + } else if (name.equals("refreshTokenId")) { + refreshTokenId = reader.nextLong(); + } else if (name.equals("idTokenId")) { + idTokenId = reader.nextLong(); + } else if (name.equals("scope")) { + Set scope = readSet(reader); + token.setScope(scope); + } else if (name.equals("type")) { + token.setTokenType(reader.nextString()); + } else { logger.debug("Found unexpected entry"); reader.skipValue(); - continue; - } + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; } - reader.endObject(); - Long newId = tokenRepository.saveAccessToken(token).getId(); - accessTokenToClientRefs.put(currentId, clientId); - accessTokenToAuthHolderRefs.put(currentId, authHolderId); - if(refreshTokenId != null) { - accessTokenToRefreshTokenRefs.put(currentId, refreshTokenId); - } - if(idTokenId != null) { - accessTokenToIdTokenRefs.put(currentId, idTokenId); - } - accessTokenOldToNewIdMap.put(currentId, newId); - logger.debug("Read access token {}", currentId); - } catch (ParseException ex) { - logger.error("Unable to read access token", ex); } + reader.endObject(); + Long newId = tokenRepository.saveAccessToken(token).getId(); + accessTokenToClientRefs.put(currentId, clientId); + accessTokenToAuthHolderRefs.put(currentId, authHolderId); + if (refreshTokenId != null) { + accessTokenToRefreshTokenRefs.put(currentId, refreshTokenId); + } + if (idTokenId != null) { + accessTokenToIdTokenRefs.put(currentId, idTokenId); + } + accessTokenOldToNewIdMap.put(currentId, newId); + logger.debug("Read access token {}", currentId); } reader.endArray(); logger.info("Done reading access tokens"); @@ -410,7 +437,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { Set authorityStrs = readSet(reader); authorities = new HashSet(); for (String s : authorityStrs) { - GrantedAuthority ga = new GrantedAuthorityImpl(s); + GrantedAuthority ga = new SimpleGrantedAuthority(s); authorities.add(ga); } } else if (name.equals("approved")) { @@ -437,8 +464,8 @@ public class MITREidDataService_1_0 implements MITREidDataService { return new OAuth2Request(authorizationParameters, clientId, authorities, approved, scope, resourceIds, redirectUri, responseTypes, null); } - @Autowired - private WhitelistedSiteRepository wlSiteRepository; + Map grantOldToNewIdMap = new HashMap(); + Map grantToWhitelistedSiteRefs = new HashMap(); /** * @param reader @@ -447,9 +474,9 @@ public class MITREidDataService_1_0 implements MITREidDataService { private void readGrants(JsonReader reader) throws IOException { reader.beginArray(); while (reader.hasNext()) { - try { ApprovedSite site = new ApprovedSite(); Long currentId = null; + Long whitelistedSiteId = null; reader.beginObject(); while (reader.hasNext()) { switch (reader.peek()) { @@ -477,39 +504,8 @@ public class MITREidDataService_1_0 implements MITREidDataService { } else if (name.equals("allowedScopes")) { Set allowedScopes = readSet(reader); site.setAllowedScopes(allowedScopes); - } else if (name.equals("whitelistedSite")) { - WhitelistedSite wlSite = new WhitelistedSite(); - reader.beginObject(); - while (reader.hasNext()) { - switch (reader.peek()) { - case END_OBJECT: - continue; - case NAME: - String wlName = reader.nextName(); - if (wlName.equals("id")) { - //not needed - reader.skipValue(); - } else if (name.equals("clientId")) { - wlSite.setClientId(reader.nextString()); - } else if (name.equals("creatorUserId")) { - wlSite.setCreatorUserId(reader.nextString()); - } else if (name.equals("allowedScopes")) { - Set allowedScopes = readSet(reader); - wlSite.setAllowedScopes(allowedScopes); - } else { - logger.debug("Found unexpected entry"); - reader.skipValue(); - } - break; - default: - logger.debug("Found unexpected entry"); - reader.skipValue(); - continue; - } - } - reader.endObject(); - wlSite = wlSiteRepository.save(wlSite); - site.setWhitelistedSite(wlSite); + } else if (name.equals("whitelistedSiteId")) { + whitelistedSiteId = reader.nextLong(); } else { logger.debug("Found unexpected entry"); reader.skipValue(); @@ -522,16 +518,99 @@ public class MITREidDataService_1_0 implements MITREidDataService { } } reader.endObject(); - approvedSiteRepository.save(site).getId(); + Long newId = approvedSiteRepository.save(site).getId(); + grantOldToNewIdMap.put(currentId, newId); + if(whitelistedSiteId != null) { + grantToWhitelistedSiteRefs.put(currentId, whitelistedSiteId); + } logger.debug("Read grant {}", currentId); - } catch (ParseException ex) { - logger.error("Unable to read grant", ex); - } } reader.endArray(); logger.info("Done reading grants"); } - + + Map whitelistedSiteOldToNewIdMap = new HashMap(); + + /** + * @param reader + * @throws IOException + */ + private void readWhitelistedSites(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + WhitelistedSite wlSite = new WhitelistedSite(); + Long currentId = null; + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("clientId")) { + wlSite.setClientId(reader.nextString()); + } else if (name.equals("creatorUserId")) { + wlSite.setCreatorUserId(reader.nextString()); + } else if (name.equals("allowedScopes")) { + Set allowedScopes = readSet(reader); + wlSite.setAllowedScopes(allowedScopes); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + Long newId = wlSiteRepository.save(wlSite).getId(); + whitelistedSiteOldToNewIdMap.put(currentId, newId); + } + reader.endArray(); + logger.info("Done reading whitelisted sites"); + } + + /** + * @param reader + * @throws IOException + */ + private void readBlacklistedSites(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + BlacklistedSite blSite = new BlacklistedSite(); + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (name.equals("id")) { + reader.skipValue(); + } else if (name.equals("uri")) { + blSite.setUri(reader.nextString()); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + blSiteRepository.save(blSite); + } + reader.endArray(); + logger.info("Done reading blacklisted sites"); + } /** * @param reader * @throws IOException @@ -563,7 +642,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { Set authorityStrs = readSet(reader); Set authorities = new HashSet(); for (String s : authorityStrs) { - GrantedAuthority ga = new GrantedAuthorityImpl(s); + GrantedAuthority ga = new SimpleGrantedAuthority(s); authorities.add(ga); } client.setAuthorities(authorities); @@ -759,6 +838,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { refreshToken.setClient(client); tokenRepository.saveRefreshToken(refreshToken); } + refreshTokenToClientRefs.clear(); for(Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); @@ -768,6 +848,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { refreshToken.setAuthenticationHolder(authHolder); tokenRepository.saveRefreshToken(refreshToken); } + refreshTokenToAuthHolderRefs.clear(); for(Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { String clientRef = accessTokenToClientRefs.get(oldAccessTokenId); ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); @@ -776,6 +857,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { accessToken.setClient(client); tokenRepository.saveAccessToken(accessToken); } + accessTokenToClientRefs.clear(); for(Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); @@ -785,6 +867,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { accessToken.setAuthenticationHolder(authHolder); tokenRepository.saveAccessToken(accessToken); } + accessTokenToAuthHolderRefs.clear(); for(Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId); Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); @@ -794,6 +877,8 @@ public class MITREidDataService_1_0 implements MITREidDataService { accessToken.setRefreshToken(refreshToken); tokenRepository.saveAccessToken(accessToken); } + accessTokenToRefreshTokenRefs.clear(); + refreshTokenOldToNewIdMap.clear(); for(Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId); Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId); @@ -803,5 +888,18 @@ public class MITREidDataService_1_0 implements MITREidDataService { accessToken.setIdToken(idToken); tokenRepository.saveAccessToken(accessToken); } + accessTokenToIdTokenRefs.clear(); + accessTokenOldToNewIdMap.clear(); + for(Long oldGrantId : grantToWhitelistedSiteRefs.keySet()) { + Long oldWhitelistedSiteId = grantToWhitelistedSiteRefs.get(oldGrantId); + Long newWhitelistedSiteId = whitelistedSiteOldToNewIdMap.get(oldWhitelistedSiteId); + WhitelistedSite wlSite = wlSiteRepository.getById(newWhitelistedSiteId); + Long newGrantId = grantOldToNewIdMap.get(oldGrantId); + ApprovedSite approvedSite = approvedSiteRepository.getById(newGrantId); + approvedSite.setWhitelistedSite(wlSite); + approvedSiteRepository.save(approvedSite); + } + grantOldToNewIdMap.clear(); + grantToWhitelistedSiteRefs.clear(); } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java index 83fd9b2cc..9cc2b4de2 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java @@ -16,16 +16,19 @@ ******************************************************************************/ package org.mitre.openid.connect.web; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonToken; +import com.google.gson.stream.JsonWriter; import java.io.IOException; import java.io.Reader; import java.security.Principal; import java.text.SimpleDateFormat; import java.util.Date; - import javax.servlet.http.HttpServletResponse; - import org.mitre.openid.connect.config.ConfigurationPropertiesBean; import org.mitre.openid.connect.service.MITREidDataService; +import org.mitre.openid.connect.service.impl.MITREidDataService_1_0; +import org.mitre.openid.connect.service.impl.MITREidDataService_1_1; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -35,12 +38,6 @@ import org.springframework.ui.Model; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; -import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonToken; -import com.google.gson.stream.JsonWriter; -import org.mitre.openid.connect.service.impl.MITREidDataService_1_0; -import org.mitre.openid.connect.service.impl.MITREidDataService_1_1; - /** * API endpoint for importing and exporting the current state of a server. * Includes all tokens, grants, whitelists, blacklists, and clients. @@ -128,10 +125,8 @@ public class DataAPI { writer.close(); } catch (IOException e) { - // TODO Auto-generated catch block - e.printStackTrace(); + logger.error("Unable to export data", e); } - }