Added export/import support for BlacklistedSites, broke out WhitelistedSites into distinct import/export operations, and did some general cleaning up

pull/661/merge
arielak 2014-08-07 16:05:02 -04:00 committed by Justin Richer
parent 3e5240750c
commit 9d2ec0b845
2 changed files with 413 additions and 377 deletions

View File

@ -35,6 +35,8 @@ public interface MITREidDataService {
// member names // member names
public static final String REFRESHTOKENS = "refreshTokens"; public static final String REFRESHTOKENS = "refreshTokens";
public static final String ACCESSTOKENS = "accessTokens"; public static final String ACCESSTOKENS = "accessTokens";
public static final String WHITELISTEDSITES = "whitelistedSites";
public static final String BLACKLISTEDSITES = "blacklistedSites";
public static final String AUTHENTICATIONHOLDERS = "authenticationHolders"; public static final String AUTHENTICATIONHOLDERS = "authenticationHolders";
public static final String GRANTS = "grants"; public static final String GRANTS = "grants";
public static final String CLIENTS = "clients"; public static final String CLIENTS = "clients";

View File

@ -30,7 +30,6 @@ import java.io.ObjectOutputStream;
import java.io.Serializable; import java.io.Serializable;
import java.text.ParseException; import java.text.ParseException;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Date; import java.util.Date;
import java.util.HashMap; import java.util.HashMap;
@ -40,6 +39,7 @@ import java.util.Map;
import java.util.Map.Entry; import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.TimeZone; import java.util.TimeZone;
import java.util.logging.Level;
import org.mitre.jose.JWEAlgorithmEmbed; import org.mitre.jose.JWEAlgorithmEmbed;
import org.mitre.jose.JWEEncryptionMethodEmbed; import org.mitre.jose.JWEEncryptionMethodEmbed;
import org.mitre.jose.JWSAlgorithmEmbed; import org.mitre.jose.JWSAlgorithmEmbed;
@ -56,8 +56,10 @@ import org.mitre.oauth2.repository.OAuth2ClientRepository;
import org.mitre.oauth2.repository.OAuth2TokenRepository; import org.mitre.oauth2.repository.OAuth2TokenRepository;
import org.mitre.oauth2.repository.SystemScopeRepository; import org.mitre.oauth2.repository.SystemScopeRepository;
import org.mitre.openid.connect.model.ApprovedSite; import org.mitre.openid.connect.model.ApprovedSite;
import org.mitre.openid.connect.model.BlacklistedSite;
import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.model.WhitelistedSite;
import org.mitre.openid.connect.repository.ApprovedSiteRepository; import org.mitre.openid.connect.repository.ApprovedSiteRepository;
import org.mitre.openid.connect.repository.BlacklistedSiteRepository;
import org.mitre.openid.connect.repository.WhitelistedSiteRepository; import org.mitre.openid.connect.repository.WhitelistedSiteRepository;
import org.mitre.openid.connect.service.MITREidDataService; import org.mitre.openid.connect.service.MITREidDataService;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -87,6 +89,10 @@ public class MITREidDataService_1_0 implements MITREidDataService {
@Autowired @Autowired
private ApprovedSiteRepository approvedSiteRepository; private ApprovedSiteRepository approvedSiteRepository;
@Autowired @Autowired
private WhitelistedSiteRepository wlSiteRepository;
@Autowired
private BlacklistedSiteRepository blSiteRepository;
@Autowired
private AuthenticationHolderRepository authHolderRepository; private AuthenticationHolderRepository authHolderRepository;
@Autowired @Autowired
private OAuth2TokenRepository tokenRepository; private OAuth2TokenRepository tokenRepository;
@ -118,6 +124,16 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writeGrants(writer); writeGrants(writer);
writer.endArray(); writer.endArray();
writer.name(WHITELISTEDSITES);
writer.beginArray();
writeWhitelistedSites(writer);
writer.endArray();
writer.name(BLACKLISTEDSITES);
writer.beginArray();
writeBlacklistedSites(writer);
writer.endArray();
writer.name(AUTHENTICATIONHOLDERS); writer.name(AUTHENTICATIONHOLDERS);
writer.beginArray(); writer.beginArray();
writeAuthenticationHolders(writer); writeAuthenticationHolders(writer);
@ -151,25 +167,24 @@ public class MITREidDataService_1_0 implements MITREidDataService {
return sdf.format(date); return sdf.format(date);
} }
private static Date utcToDate(String s) throws ParseException { private static Date utcToDate(String s) {
if (s == null) { if (s == null) {
return null; return null;
} }
return sdf.parse(s); Date d = null;
try {
d = sdf.parse(s);
} catch(ParseException ex) {
logger.error("Unable to parse date string {}", s, ex);
}
return d;
} }
/** /**
* @param writer * @param writer
*/ */
private void writeRefreshTokens(JsonWriter writer) { private void writeRefreshTokens(JsonWriter writer) throws IOException {
Collection<OAuth2RefreshTokenEntity> tokens = new ArrayList<OAuth2RefreshTokenEntity>(); for (OAuth2RefreshTokenEntity token : tokenRepository.getAllRefreshTokens()) {
try {
tokens = tokenRepository.getAllRefreshTokens();
} catch (Exception ex) {
logger.error("Unable to read refresh tokens from data source", ex);
}
for (OAuth2RefreshTokenEntity token : tokens) {
try {
writer.beginObject(); writer.beginObject();
writer.name("id").value(token.getId()); writer.name("id").value(token.getId());
writer.name("expiration").value(toUTCString(token.getExpiration())); writer.name("expiration").value(toUTCString(token.getExpiration()));
@ -180,9 +195,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.name("value").value(token.getValue()); writer.name("value").value(token.getValue());
writer.endObject(); writer.endObject();
logger.debug("Wrote refresh token {}", token.getId()); logger.debug("Wrote refresh token {}", token.getId());
} catch (IOException ex) {
logger.error("Unable to write refresh token {}", token.getId(), ex);
}
} }
logger.info("Done writing refresh tokens"); logger.info("Done writing refresh tokens");
} }
@ -190,15 +202,8 @@ public class MITREidDataService_1_0 implements MITREidDataService {
/** /**
* @param writer * @param writer
*/ */
private void writeAccessTokens(JsonWriter writer) { private void writeAccessTokens(JsonWriter writer) throws IOException {
Collection<OAuth2AccessTokenEntity> tokens = new ArrayList<OAuth2AccessTokenEntity>(); for (OAuth2AccessTokenEntity token : tokenRepository.getAllAccessTokens()) {
try {
tokens = tokenRepository.getAllAccessTokens();
} catch (Exception ex) {
logger.error("Unable to read access tokens from data source", ex);
}
for (OAuth2AccessTokenEntity token : tokens) {
try {
writer.beginObject(); writer.beginObject();
writer.name("id").value(token.getId()); writer.name("id").value(token.getId());
writer.name("expiration").value(toUTCString(token.getExpiration())); writer.name("expiration").value(toUTCString(token.getExpiration()));
@ -211,18 +216,11 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.name("idTokenId") writer.name("idTokenId")
.value((token.getIdToken() != null) ? token.getIdToken().getId() : null); .value((token.getIdToken() != null) ? token.getIdToken().getId() : null);
writer.name("scope"); writer.name("scope");
writer.beginArray(); writeNullSafeArray(writer, token.getScope());
for (String s : token.getScope()) {
writer.value(s);
}
writer.endArray();
writer.name("type").value(token.getTokenType()); writer.name("type").value(token.getTokenType());
writer.name("value").value(token.getValue()); writer.name("value").value(token.getValue());
writer.endObject(); writer.endObject();
logger.debug("Wrote access token {}", token.getId()); logger.debug("Wrote access token {}", token.getId());
} catch (IOException ex) {
logger.error("Unable to write access token {}", token.getId(), ex);
}
} }
logger.info("Done writing access tokens"); logger.info("Done writing access tokens");
} }
@ -230,15 +228,8 @@ public class MITREidDataService_1_0 implements MITREidDataService {
/** /**
* @param writer * @param writer
*/ */
private void writeAuthenticationHolders(JsonWriter writer) { private void writeAuthenticationHolders(JsonWriter writer) throws IOException {
Collection<AuthenticationHolderEntity> holders = new ArrayList<AuthenticationHolderEntity>(); for (AuthenticationHolderEntity holder : authHolderRepository.getAll()) {
try {
holders = authHolderRepository.getAll();
} catch (Exception ex) {
logger.error("Unable to read authentication holders from data source", ex);
}
for (AuthenticationHolderEntity holder : holders) {
try {
writer.beginObject(); writer.beginObject();
writer.name("id").value(holder.getId()); writer.name("id").value(holder.getId());
writer.name("ownerId").value(holder.getOwnerId()); writer.name("ownerId").value(holder.getOwnerId());
@ -252,9 +243,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.endObject(); writer.endObject();
writer.endObject(); writer.endObject();
logger.debug("Wrote authentication holder {}", holder.getId()); logger.debug("Wrote authentication holder {}", holder.getId());
} catch (IOException ex) {
logger.error("Unable to write authentication holder {}", holder.getId(), ex);
}
} }
logger.info("Done writing authentication holders"); logger.info("Done writing authentication holders");
} }
@ -275,19 +263,10 @@ public class MITREidDataService_1_0 implements MITREidDataService {
} }
writer.endObject(); writer.endObject();
writer.name("clientId").value(authReq.getClientId()); writer.name("clientId").value(authReq.getClientId());
Set<String> scope = authReq.getScope();
writer.name("scope"); writer.name("scope");
writer.beginArray(); writeNullSafeArray(writer, authReq.getScope());
for (String s : scope) {
writer.value(s);
}
writer.endArray();
writer.name("resourceIds"); writer.name("resourceIds");
writer.beginArray(); writeNullSafeArray(writer, authReq.getResourceIds());
for (String s : authReq.getResourceIds()) {
writer.value(s);
}
writer.endArray();
writer.name("authorities"); writer.name("authorities");
writer.beginArray(); writer.beginArray();
for (GrantedAuthority authority : authReq.getAuthorities()) { for (GrantedAuthority authority : authReq.getAuthorities()) {
@ -299,40 +278,32 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.name("state").value(authReq.getState()); writer.name("state").value(authReq.getState());
writer.name("redirectUri").value(authReq.getRedirectUri()); writer.name("redirectUri").value(authReq.getRedirectUri());
writer.name("responseTypes"); writer.name("responseTypes");
writer.beginArray(); writeNullSafeArray(writer, authReq.getResponseTypes());
for (String s : authReq.getResponseTypes()) {
writer.value(s);
}
writer.endArray();
writer.endObject(); writer.endObject();
} }
private String base64UrlEncodeObject(Serializable obj) { private String base64UrlEncodeObject(Serializable obj) throws IOException {
String encoded = null;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos); ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(obj); oos.writeObject(obj);
encoded = BaseEncoding.base64Url().encode(baos.toByteArray()); String encoded = BaseEncoding.base64Url().encode(baos.toByteArray());
oos.close(); oos.close();
baos.close(); baos.close();
} catch (IOException ex) {
logger.error("Unable to encode object", ex);
}
return encoded; return encoded;
} }
private <T> T base64UrlDecodeObject(String encoded, Class<T> type) { private <T> T base64UrlDecodeObject(String encoded, Class<T> type) throws IOException {
T deserialized = null;
try {
byte[] decoded = BaseEncoding.base64Url().decode(encoded); byte[] decoded = BaseEncoding.base64Url().decode(encoded);
ByteArrayInputStream bais = new ByteArrayInputStream(decoded); ByteArrayInputStream bais = new ByteArrayInputStream(decoded);
ObjectInputStream ois = new ObjectInputStream(bais); ObjectInputStream ois = new ObjectInputStream(bais);
T deserialized = null;
try {
deserialized = type.cast(ois.readObject()); deserialized = type.cast(ois.readObject());
} catch (ClassNotFoundException ex) {
logger.error("Unable to decode object as type {}", type.getName(), ex);
} finally {
ois.close(); ois.close();
bais.close(); bais.close();
} catch (Exception ex) {
logger.error("Unable to decode object", ex);
} }
return deserialized; return deserialized;
} }
@ -340,9 +311,8 @@ public class MITREidDataService_1_0 implements MITREidDataService {
/** /**
* @param writer * @param writer
*/ */
private void writeGrants(JsonWriter writer) { private void writeGrants(JsonWriter writer) throws IOException {
for (ApprovedSite site : approvedSiteRepository.getAll()) { for (ApprovedSite site : approvedSiteRepository.getAll()) {
try {
writer.beginObject(); writer.beginObject();
writer.name("id").value(site.getId()); writer.name("id").value(site.getId());
writer.name("accessDate").value(toUTCString(site.getAccessDate())); writer.name("accessDate").value(toUTCString(site.getAccessDate()));
@ -351,31 +321,10 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.name("timeoutDate").value(toUTCString(site.getTimeoutDate())); writer.name("timeoutDate").value(toUTCString(site.getTimeoutDate()));
writer.name("userId").value(site.getUserId()); writer.name("userId").value(site.getUserId());
writer.name("allowedScopes"); writer.name("allowedScopes");
writer.beginArray(); writeNullSafeArray(writer, site.getAllowedScopes());
for (String s : site.getAllowedScopes()) { writer.name("whitelistedSiteId").value(site.getIsWhitelisted() ? site.getWhitelistedSite().getId() : null);
writer.value(s);
}
writer.endArray();
if (site.getIsWhitelisted()) {
WhitelistedSite wlSite = site.getWhitelistedSite();
writer.name("whitelistedSite");
writer.beginObject();
writer.name("id").value(wlSite.getId());
writer.name("clientId").value(wlSite.getClientId());
writer.name("creatorUserId").value(wlSite.getCreatorUserId());
writer.name("allowedScopes");
writer.beginArray();
for (String s : wlSite.getAllowedScopes()) {
writer.value(s);
}
writer.endArray();
writer.endObject();
}
writer.endObject(); writer.endObject();
logger.debug("Wrote grant {}", site.getId()); logger.debug("Wrote grant {}", site.getId());
} catch (IOException ex) {
logger.error("Unable to write grant {}", site.getId(), ex);
}
} }
logger.info("Done writing grants"); logger.info("Done writing grants");
} }
@ -383,19 +332,46 @@ public class MITREidDataService_1_0 implements MITREidDataService {
/** /**
* @param writer * @param writer
*/ */
private void writeClients(JsonWriter writer) { private void writeWhitelistedSites(JsonWriter writer) throws IOException {
for (WhitelistedSite wlSite : wlSiteRepository.getAll()) {
writer.beginObject();
writer.name("id").value(wlSite.getId());
writer.name("clientId").value(wlSite.getClientId());
writer.name("creatorUserId").value(wlSite.getCreatorUserId());
writer.name("allowedScopes");
writeNullSafeArray(writer, wlSite.getAllowedScopes());
writer.endObject();
logger.debug("Wrote whitelisted site {}", wlSite.getId());
}
logger.info("Done writing whitelisted sites");
}
/**
* @param writer
*/
private void writeBlacklistedSites(JsonWriter writer) throws IOException {
for (BlacklistedSite blSite : blSiteRepository.getAll()) {
writer.beginObject();
writer.name("id").value(blSite.getId());
writer.name("uri").value(blSite.getUri());
writer.endObject();
logger.debug("Wrote blacklisted site {}", blSite.getId());
}
logger.info("Done writing blacklisted sites");
}
/**
* @param writer
*/
private void writeClients(JsonWriter writer) throws IOException {
for (ClientDetailsEntity client : clientRepository.getAllClients()) { for (ClientDetailsEntity client : clientRepository.getAllClients()) {
try {
writer.beginObject(); writer.beginObject();
writer.name("clientId").value(client.getClientId()); writer.name("clientId").value(client.getClientId());
writer.name("resourceIds"); writer.name("resourceIds");
writeNullSafeArray(writer, client.getResourceIds()); writeNullSafeArray(writer, client.getResourceIds());
writer.name("secret").value(client.getClientSecret()); writer.name("secret").value(client.getClientSecret());
writer.name("scope"); writer.name("scope");
writeNullSafeArray(writer, client.getScope()); writeNullSafeArray(writer, client.getScope());
writer.name("authorities"); writer.name("authorities");
writer.beginArray(); writer.beginArray();
for (GrantedAuthority authority : client.getAuthorities()) { for (GrantedAuthority authority : client.getAuthorities()) {
@ -415,17 +391,9 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.name("tokenEndpointAuthMethod") writer.name("tokenEndpointAuthMethod")
.value((client.getTokenEndpointAuthMethod() != null) ? client.getTokenEndpointAuthMethod().getValue() : null); .value((client.getTokenEndpointAuthMethod() != null) ? client.getTokenEndpointAuthMethod().getValue() : null);
writer.name("grantTypes"); writer.name("grantTypes");
writer.beginArray(); writeNullSafeArray(writer, client.getGrantTypes());
for (String s : client.getGrantTypes()) {
writer.value(s);
}
writer.endArray();
writer.name("responseTypes"); writer.name("responseTypes");
writer.beginArray(); writeNullSafeArray(writer, client.getResponseTypes());
for (String s : client.getResponseTypes()) {
writer.value(s);
}
writer.endArray();
writer.name("policyUri").value(client.getPolicyUri()); writer.name("policyUri").value(client.getPolicyUri());
writer.name("jwksUri").value(client.getJwksUri()); writer.name("jwksUri").value(client.getJwksUri());
writer.name("applicationType") writer.name("applicationType")
@ -445,8 +413,7 @@ public class MITREidDataService_1_0 implements MITREidDataService {
Boolean requireAuthTime = null; Boolean requireAuthTime = null;
try { try {
requireAuthTime = client.getRequireAuthTime(); requireAuthTime = client.getRequireAuthTime();
} catch (NullPointerException e) { } catch (NullPointerException e) {}
}
if (requireAuthTime != null) { if (requireAuthTime != null) {
writer.name("requireAuthTime").value(requireAuthTime); writer.name("requireAuthTime").value(requireAuthTime);
} }
@ -462,9 +429,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
writer.name("dynamicallyRegistered").value(client.isDynamicallyRegistered()); writer.name("dynamicallyRegistered").value(client.isDynamicallyRegistered());
writer.endObject(); writer.endObject();
logger.debug("Wrote client {}", client.getId()); logger.debug("Wrote client {}", client.getId());
} catch (IOException ex) {
logger.error("Unable to write client {}", client.getId(), ex);
}
} }
logger.info("Done writing clients"); logger.info("Done writing clients");
} }
@ -525,6 +489,10 @@ public class MITREidDataService_1_0 implements MITREidDataService {
readClients(reader); readClients(reader);
} else if (name.equals(GRANTS)) { } else if (name.equals(GRANTS)) {
readGrants(reader); readGrants(reader);
} else if (name.equals(WHITELISTEDSITES)) {
readWhitelistedSites(reader);
} else if (name.equals(BLACKLISTEDSITES)) {
readBlacklistedSites(reader);
} else if (name.equals(AUTHENTICATIONHOLDERS)) { } else if (name.equals(AUTHENTICATIONHOLDERS)) {
readAuthenticationHolders(reader); readAuthenticationHolders(reader);
} else if (name.equals(ACCESSTOKENS)) { } else if (name.equals(ACCESSTOKENS)) {
@ -558,7 +526,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
private void readRefreshTokens(JsonReader reader) throws IOException { private void readRefreshTokens(JsonReader reader) throws IOException {
reader.beginArray(); reader.beginArray();
while (reader.hasNext()) { while (reader.hasNext()) {
try {
OAuth2RefreshTokenEntity token = new OAuth2RefreshTokenEntity(); OAuth2RefreshTokenEntity token = new OAuth2RefreshTokenEntity();
reader.beginObject(); reader.beginObject();
Long currentId = null; Long currentId = null;
@ -578,7 +545,12 @@ public class MITREidDataService_1_0 implements MITREidDataService {
Date date = utcToDate(reader.nextString()); Date date = utcToDate(reader.nextString());
token.setExpiration(date); token.setExpiration(date);
} else if (name.equals("value")) { } else if (name.equals("value")) {
token.setValue(reader.nextString()); String value = reader.nextString();
try {
token.setValue(value);
} catch (ParseException ex) {
logger.error("Unable to set refresh token value to {}", value, ex);
}
} else if (name.equals("clientId")) { } else if (name.equals("clientId")) {
clientId = reader.nextString(); clientId = reader.nextString();
} else if (name.equals("authenticationHolderId")) { } else if (name.equals("authenticationHolderId")) {
@ -600,9 +572,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
refreshTokenToAuthHolderRefs.put(currentId, authHolderId); refreshTokenToAuthHolderRefs.put(currentId, authHolderId);
refreshTokenOldToNewIdMap.put(currentId, newId); refreshTokenOldToNewIdMap.put(currentId, newId);
logger.debug("Read refresh token {}", currentId); logger.debug("Read refresh token {}", currentId);
} catch (ParseException ex) {
logger.error("Unable to read refresh token", ex);
}
} }
reader.endArray(); reader.endArray();
logger.info("Done reading refresh tokens"); logger.info("Done reading refresh tokens");
@ -621,7 +590,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
private void readAccessTokens(JsonReader reader) throws IOException { private void readAccessTokens(JsonReader reader) throws IOException {
reader.beginArray(); reader.beginArray();
while (reader.hasNext()) { while (reader.hasNext()) {
try {
OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity();
reader.beginObject(); reader.beginObject();
Long currentId = null; Long currentId = null;
@ -643,7 +611,12 @@ public class MITREidDataService_1_0 implements MITREidDataService {
Date date = utcToDate(reader.nextString()); Date date = utcToDate(reader.nextString());
token.setExpiration(date); token.setExpiration(date);
} else if (name.equals("value")) { } else if (name.equals("value")) {
token.setValue(reader.nextString()); String value = reader.nextString();
try {
token.setValue(value);
} catch (ParseException ex) {
logger.error("Unable to set refresh token value to {}", value, ex);
}
} else if (name.equals("clientId")) { } else if (name.equals("clientId")) {
clientId = reader.nextString(); clientId = reader.nextString();
} else if (name.equals("authenticationHolderId")) { } else if (name.equals("authenticationHolderId")) {
@ -680,9 +653,6 @@ public class MITREidDataService_1_0 implements MITREidDataService {
} }
accessTokenOldToNewIdMap.put(currentId, newId); accessTokenOldToNewIdMap.put(currentId, newId);
logger.debug("Read access token {}", currentId); logger.debug("Read access token {}", currentId);
} catch (ParseException ex) {
logger.error("Unable to read access token", ex);
}
} }
reader.endArray(); reader.endArray();
logger.info("Done reading access tokens"); logger.info("Done reading access tokens");
@ -825,8 +795,8 @@ public class MITREidDataService_1_0 implements MITREidDataService {
return dar; return dar;
} }
@Autowired Map<Long, Long> grantOldToNewIdMap = new HashMap<Long, Long>();
private WhitelistedSiteRepository wlSiteRepository; Map<Long, Long> grantToWhitelistedSiteRefs = new HashMap<Long, Long>();
/** /**
* @param reader * @param reader
@ -835,9 +805,9 @@ public class MITREidDataService_1_0 implements MITREidDataService {
private void readGrants(JsonReader reader) throws IOException { private void readGrants(JsonReader reader) throws IOException {
reader.beginArray(); reader.beginArray();
while (reader.hasNext()) { while (reader.hasNext()) {
try {
ApprovedSite site = new ApprovedSite(); ApprovedSite site = new ApprovedSite();
Long currentId = null; Long currentId = null;
Long whitelistedSiteId = null;
reader.beginObject(); reader.beginObject();
while (reader.hasNext()) { while (reader.hasNext()) {
switch (reader.peek()) { switch (reader.peek()) {
@ -865,18 +835,47 @@ public class MITREidDataService_1_0 implements MITREidDataService {
} else if (name.equals("allowedScopes")) { } else if (name.equals("allowedScopes")) {
Set<String> allowedScopes = readSet(reader); Set<String> allowedScopes = readSet(reader);
site.setAllowedScopes(allowedScopes); site.setAllowedScopes(allowedScopes);
} else if (name.equals("whitelistedSite")) { } else if (name.equals("whitelistedSiteId")) {
whitelistedSiteId = reader.nextLong();
} else {
logger.debug("Found unexpected entry");
reader.skipValue();
}
break;
default:
logger.debug("Found unexpected entry");
reader.skipValue();
continue;
}
}
reader.endObject();
Long newId = approvedSiteRepository.save(site).getId();
grantOldToNewIdMap.put(currentId, newId);
if(whitelistedSiteId != null) {
grantToWhitelistedSiteRefs.put(currentId, whitelistedSiteId);
}
logger.debug("Read grant {}", currentId);
}
reader.endArray();
logger.info("Done reading grants");
}
Map<Long, Long> whitelistedSiteOldToNewIdMap = new HashMap<Long, Long>();
private void readWhitelistedSites(JsonReader reader) throws IOException {
reader.beginArray();
while (reader.hasNext()) {
WhitelistedSite wlSite = new WhitelistedSite(); WhitelistedSite wlSite = new WhitelistedSite();
Long currentId = null;
reader.beginObject(); reader.beginObject();
while (reader.hasNext()) { while (reader.hasNext()) {
switch (reader.peek()) { switch (reader.peek()) {
case END_OBJECT: case END_OBJECT:
continue; continue;
case NAME: case NAME:
String wlName = reader.nextName(); String name = reader.nextName();
if (wlName.equals("id")) { if (name.equals("id")) {
//not needed currentId = reader.nextLong();
reader.skipValue();
} else if (name.equals("clientId")) { } else if (name.equals("clientId")) {
wlSite.setClientId(reader.nextString()); wlSite.setClientId(reader.nextString());
} else if (name.equals("creatorUserId")) { } else if (name.equals("creatorUserId")) {
@ -896,8 +895,28 @@ public class MITREidDataService_1_0 implements MITREidDataService {
} }
} }
reader.endObject(); reader.endObject();
wlSite = wlSiteRepository.save(wlSite); Long newId = wlSiteRepository.save(wlSite).getId();
site.setWhitelistedSite(wlSite); whitelistedSiteOldToNewIdMap.put(currentId, newId);
}
reader.endArray();
logger.info("Done reading whitelisted sites");
}
private void readBlacklistedSites(JsonReader reader) throws IOException {
reader.beginArray();
while (reader.hasNext()) {
BlacklistedSite blSite = new BlacklistedSite();
reader.beginObject();
while (reader.hasNext()) {
switch (reader.peek()) {
case END_OBJECT:
continue;
case NAME:
String name = reader.nextName();
if (name.equals("id")) {
reader.skipValue();
} else if (name.equals("uri")) {
blSite.setUri(reader.nextString());
} else { } else {
logger.debug("Found unexpected entry"); logger.debug("Found unexpected entry");
reader.skipValue(); reader.skipValue();
@ -910,14 +929,10 @@ public class MITREidDataService_1_0 implements MITREidDataService {
} }
} }
reader.endObject(); reader.endObject();
approvedSiteRepository.save(site).getId(); blSiteRepository.save(blSite);
logger.debug("Read grant {}", currentId);
} catch (ParseException ex) {
logger.error("Unable to read grant", ex);
}
} }
reader.endArray(); reader.endArray();
logger.info("Done reading grants"); logger.info("Done reading blacklisted sites");
} }
/** /**
@ -1147,6 +1162,7 @@ public class MITREidDataService_1_0 implements MITREidDataService {
refreshToken.setClient(client); refreshToken.setClient(client);
tokenRepository.saveRefreshToken(refreshToken); tokenRepository.saveRefreshToken(refreshToken);
} }
refreshTokenToClientRefs.clear();
for(Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { for(Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) {
Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId); Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId);
Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId);
@ -1156,6 +1172,7 @@ public class MITREidDataService_1_0 implements MITREidDataService {
refreshToken.setAuthenticationHolder(authHolder); refreshToken.setAuthenticationHolder(authHolder);
tokenRepository.saveRefreshToken(refreshToken); tokenRepository.saveRefreshToken(refreshToken);
} }
refreshTokenToAuthHolderRefs.clear();
for(Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { for(Long oldAccessTokenId : accessTokenToClientRefs.keySet()) {
String clientRef = accessTokenToClientRefs.get(oldAccessTokenId); String clientRef = accessTokenToClientRefs.get(oldAccessTokenId);
ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef);
@ -1164,6 +1181,7 @@ public class MITREidDataService_1_0 implements MITREidDataService {
accessToken.setClient(client); accessToken.setClient(client);
tokenRepository.saveAccessToken(accessToken); tokenRepository.saveAccessToken(accessToken);
} }
accessTokenToClientRefs.clear();
for(Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { for(Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) {
Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId); Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId);
Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId);
@ -1173,6 +1191,7 @@ public class MITREidDataService_1_0 implements MITREidDataService {
accessToken.setAuthenticationHolder(authHolder); accessToken.setAuthenticationHolder(authHolder);
tokenRepository.saveAccessToken(accessToken); tokenRepository.saveAccessToken(accessToken);
} }
accessTokenToAuthHolderRefs.clear();
for(Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { for(Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) {
Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId); Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId);
Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId);
@ -1182,6 +1201,8 @@ public class MITREidDataService_1_0 implements MITREidDataService {
accessToken.setRefreshToken(refreshToken); accessToken.setRefreshToken(refreshToken);
tokenRepository.saveAccessToken(accessToken); tokenRepository.saveAccessToken(accessToken);
} }
accessTokenToRefreshTokenRefs.clear();
refreshTokenOldToNewIdMap.clear();
for(Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { for(Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) {
Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId); Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId);
Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId); Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId);
@ -1191,5 +1212,18 @@ public class MITREidDataService_1_0 implements MITREidDataService {
accessToken.setIdToken(idToken); accessToken.setIdToken(idToken);
tokenRepository.saveAccessToken(accessToken); tokenRepository.saveAccessToken(accessToken);
} }
accessTokenToIdTokenRefs.clear();
accessTokenOldToNewIdMap.clear();
for(Long oldGrantId : grantToWhitelistedSiteRefs.keySet()) {
Long oldWhitelistedSiteId = grantToWhitelistedSiteRefs.get(oldGrantId);
Long newWhitelistedSiteId = whitelistedSiteOldToNewIdMap.get(oldWhitelistedSiteId);
WhitelistedSite wlSite = wlSiteRepository.getById(newWhitelistedSiteId);
Long newGrantId = grantOldToNewIdMap.get(oldGrantId);
ApprovedSite approvedSite = approvedSiteRepository.getById(newGrantId);
approvedSite.setWhitelistedSite(wlSite);
approvedSiteRepository.save(approvedSite);
}
grantOldToNewIdMap.clear();
grantToWhitelistedSiteRefs.clear();
} }
} }