diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java index af0938989..ae1be4cda 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_0.java @@ -22,21 +22,32 @@ import com.google.common.io.BaseEncoding; import com.google.gson.stream.JsonReader; import com.google.gson.stream.JsonToken; import com.google.gson.stream.JsonWriter; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.text.ParseException; import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Collection; import java.util.Date; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.TimeZone; +import org.mitre.jose.JWEAlgorithmEmbed; +import org.mitre.jose.JWEEncryptionMethodEmbed; +import org.mitre.jose.JWSAlgorithmEmbed; import org.mitre.oauth2.model.AuthenticationHolderEntity; import org.mitre.oauth2.model.ClientDetailsEntity; +import org.mitre.oauth2.model.ClientDetailsEntity.AppType; +import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod; +import org.mitre.oauth2.model.ClientDetailsEntity.SubjectType; import org.mitre.oauth2.model.OAuth2AccessTokenEntity; import org.mitre.oauth2.model.OAuth2RefreshTokenEntity; import org.mitre.oauth2.model.SystemScope; @@ -47,12 +58,16 @@ import org.mitre.oauth2.repository.SystemScopeRepository; import org.mitre.openid.connect.model.ApprovedSite; import org.mitre.openid.connect.model.WhitelistedSite; import org.mitre.openid.connect.repository.ApprovedSiteRepository; +import org.mitre.openid.connect.repository.WhitelistedSiteRepository; import org.mitre.openid.connect.service.MITREidDataService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.GrantedAuthorityImpl; import org.springframework.security.oauth2.provider.AuthorizationRequest; +import org.springframework.security.oauth2.provider.DefaultAuthorizationRequest; import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.stereotype.Service; @@ -65,7 +80,7 @@ import org.springframework.stereotype.Service; */ @Service public class MITREidDataService_1_0 implements MITREidDataService { - + private final static Logger logger = LoggerFactory.getLogger(MITREidDataService_1_0.class); @Autowired private OAuth2ClientRepository clientRepository; @@ -147,7 +162,13 @@ public class MITREidDataService_1_0 implements MITREidDataService { * @param writer */ private void writeRefreshTokens(JsonWriter writer) { - for (OAuth2RefreshTokenEntity token : tokenRepository.getAllRefreshTokens()) { + Collection tokens = new ArrayList(); + try { + tokens = tokenRepository.getAllRefreshTokens(); + } catch (Exception ex) { + logger.error("Unable to read refresh tokens from data source", ex); + } + for (OAuth2RefreshTokenEntity token : tokens) { try { writer.beginObject(); writer.name("id").value(token.getId()); @@ -170,7 +191,13 @@ public class MITREidDataService_1_0 implements MITREidDataService { * @param writer */ private void writeAccessTokens(JsonWriter writer) { - for (OAuth2AccessTokenEntity token : tokenRepository.getAllAccessTokens()) { + Collection tokens = new ArrayList(); + try { + tokens = tokenRepository.getAllAccessTokens(); + } catch (Exception ex) { + logger.error("Unable to read access tokens from data source", ex); + } + for (OAuth2AccessTokenEntity token : tokens) { try { writer.beginObject(); writer.name("id").value(token.getId()); @@ -204,7 +231,13 @@ public class MITREidDataService_1_0 implements MITREidDataService { * @param writer */ private void writeAuthenticationHolders(JsonWriter writer) { - for (AuthenticationHolderEntity holder : authHolderRepository.getAll()) { + Collection holders = new ArrayList(); + try { + holders = authHolderRepository.getAll(); + } catch (Exception ex) { + logger.error("Unable to read authentication holders from data source", ex); + } + for (AuthenticationHolderEntity holder : holders) { try { writer.beginObject(); writer.name("id").value(holder.getId()); @@ -274,14 +307,36 @@ public class MITREidDataService_1_0 implements MITREidDataService { writer.endObject(); } - private String base64UrlEncodeObject(Serializable obj) throws IOException { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(obj); - oos.close(); - return BaseEncoding.base64Url().encode(baos.toByteArray()); + private String base64UrlEncodeObject(Serializable obj) { + String encoded = null; + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(obj); + encoded = BaseEncoding.base64Url().encode(baos.toByteArray()); + oos.close(); + baos.close(); + } catch (IOException ex) { + logger.error("Unable to encode object", ex); + } + return encoded; } - + + private T base64UrlDecodeObject(String encoded, Class type) { + T deserialized = null; + try { + byte[] decoded = BaseEncoding.base64Url().decode(encoded); + ByteArrayInputStream bais = new ByteArrayInputStream(decoded); + ObjectInputStream ois = new ObjectInputStream(bais); + deserialized = type.cast(ois.readObject()); + ois.close(); + bais.close(); + } catch (Exception ex) { + logger.error("Unable to decode object", ex); + } + return deserialized; + } + /** * @param writer */ @@ -486,10 +541,12 @@ public class MITREidDataService_1_0 implements MITREidDataService { break; case END_OBJECT: // the object ended, we're done here + fixObjectReferences(); return; } } } + private Map refreshTokenToClientRefs = new HashMap(); private Map refreshTokenToAuthHolderRefs = new HashMap(); private Map refreshTokenOldToNewIdMap = new HashMap(); @@ -534,7 +591,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { refreshTokenToClientRefs.put(currentId, clientId); refreshTokenToAuthHolderRefs.put(currentId, authHolderId); refreshTokenOldToNewIdMap.put(currentId, newId); - logger.debug("Read refresh token {}", token.getId()); + logger.debug("Read refresh token {}", currentId); } catch (ParseException ex) { logger.error("Unable to read refresh token", ex); } @@ -542,6 +599,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { reader.endArray(); logger.info("Done reading refresh tokens"); } + private Map accessTokenToClientRefs = new HashMap(); private Map accessTokenToAuthHolderRefs = new HashMap(); private Map accessTokenToRefreshTokenRefs = new HashMap(); @@ -593,12 +651,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { idTokenId = reader.nextLong(); } } else if (name.equals("scope")) { - reader.beginArray(); - Set scope = new HashSet(); - while (reader.hasNext()) { - scope.add(reader.nextString()); - } - reader.endArray(); + Set scope = readSet(reader); token.setScope(scope); } else if (name.equals("type")) { token.setTokenType(reader.nextString()); @@ -614,7 +667,7 @@ public class MITREidDataService_1_0 implements MITREidDataService { accessTokenToRefreshTokenRefs.put(currentId, refreshTokenId); accessTokenToIdTokenRefs.put(currentId, idTokenId); accessTokenOldToNewIdMap.put(currentId, newId); - logger.debug("Read access token {}", token.getId()); + logger.debug("Read access token {}", currentId); } catch (ParseException ex) { logger.error("Unable to read access token", ex); } @@ -623,22 +676,190 @@ public class MITREidDataService_1_0 implements MITREidDataService { logger.info("Done reading access tokens"); } + private Map authHolderOldToNewIdMap = new HashMap(); + /** * @param reader * @throws IOException */ private void readAuthenticationHolders(JsonReader reader) throws IOException { - // TODO Auto-generated method stub - reader.skipValue(); + reader.beginArray(); + while (reader.hasNext()) { + AuthenticationHolderEntity ahe = new AuthenticationHolderEntity(); + reader.beginObject(); + Long currentId = null; + while (reader.hasNext()) { + String name = reader.nextName(); + if(name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("ownerId")) { + //not needed + reader.skipValue(); + } else if (name.equals("authentication")) { + AuthorizationRequest clientAuthorization = null; + Authentication userAuthentication = null; + reader.beginObject(); + while(reader.hasNext()) { + if (name.equals("clientAuthorization")) { + clientAuthorization = readAuthorizationRequest(reader); + } else if (name.equals("userAuthentication")) { + userAuthentication = base64UrlDecodeObject(reader.nextString(), Authentication.class); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + } + reader.endObject(); + OAuth2Authentication auth = new OAuth2Authentication(clientAuthorization, userAuthentication); + ahe.setAuthentication(auth); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + } + reader.endObject(); + Long newId = authHolderRepository.save(ahe).getId(); + authHolderOldToNewIdMap.put(currentId, newId); + logger.debug("Read authentication holder {}", currentId); + } + reader.endArray(); + logger.info("Done reading authentication holders"); } + //used by readAuthenticationHolders + private AuthorizationRequest readAuthorizationRequest(JsonReader reader) throws IOException { + Set scope = new LinkedHashSet(); + Set resourceIds = new HashSet(); + boolean approved = false; + Collection authorities = new HashSet(); + Map authorizationParameters = new HashMap(); + Map approvalParameters = new HashMap(); + String redirectUri = null; + String clientId = null; + reader.beginObject(); + while (reader.hasNext()) { + String name = reader.nextName(); + if (name.equals("authorizationParameters")) { + authorizationParameters = readMap(reader); + } else if (name.equals("approvalParameters")) { + approvalParameters = readMap(reader); + } else if (name.equals("clientId")) { + clientId = reader.nextString(); + } else if (name.equals("scope")) { + scope = readSet(reader); + } else if (name.equals("resourceIds")) { + resourceIds = readSet(reader); + } else if (name.equals("authorities")) { + Set authorityStrs = readSet(reader); + authorities = new HashSet(); + for (String s : authorityStrs) { + GrantedAuthority ga = new GrantedAuthorityImpl(s); + authorities.add(ga); + } + } else if (name.equals("approved")) { + approved = reader.nextBoolean(); + } else if (name.equals("denied")) { + if(approved == false) { + approved = !reader.nextBoolean(); + } + } else if (name.equals("redirectUri")) { + redirectUri = reader.nextString(); + } else { + reader.skipValue(); + } + } + reader.endObject(); + DefaultAuthorizationRequest dar = new DefaultAuthorizationRequest(authorizationParameters, approvalParameters, clientId, scope); + dar.setAuthorities(authorities); + dar.setResourceIds(resourceIds); + dar.setApproved(approved); + dar.setRedirectUri(redirectUri); + return dar; + } + + @Autowired + private WhitelistedSiteRepository wlSiteRepository; + /** * @param reader * @throws IOException */ private void readGrants(JsonReader reader) throws IOException { - // TODO Auto-generated method stub - reader.skipValue(); + reader.beginArray(); + while (reader.hasNext()) { + try { + ApprovedSite site = new ApprovedSite(); + Long currentId = null; + reader.beginObject(); + while (reader.hasNext()) { + String name = reader.nextName(); + if (name.equals("id")) { + currentId = reader.nextLong(); + } else if (name.equals("accessDate")) { + if (reader.peek() == JsonToken.NULL) { + reader.nextNull(); + } else { + Date date = utcToDate(reader.nextString()); + site.setAccessDate(date); + } + } else if (name.equals("clientId")) { + site.setClientId(reader.nextString()); + } else if (name.equals("creationDate")) { + if (reader.peek() == JsonToken.NULL) { + reader.nextNull(); + } else { + Date date = utcToDate(reader.nextString()); + site.setCreationDate(date); + } + } else if (name.equals("timeoutDate")) { + if (reader.peek() == JsonToken.NULL) { + reader.nextNull(); + } else { + Date date = utcToDate(reader.nextString()); + site.setTimeoutDate(date); + } + } else if (name.equals("userId")) { + site.setUserId(reader.nextString()); + } else if (name.equals("allowedScopes")) { + Set allowedScopes = readSet(reader); + site.setAllowedScopes(allowedScopes); + } else if (name.equals("whitelistedSite")) { + WhitelistedSite wlSite = new WhitelistedSite(); + reader.beginObject(); + while(reader.hasNext()) { + String wlName = reader.nextName(); + if (wlName.equals("id")) { + //not needed + reader.skipValue(); + } else if (name.equals("clientId")) { + wlSite.setClientId(reader.nextString()); + } else if (name.equals("creatorUserId")) { + wlSite.setCreatorUserId(reader.nextString()); + } else if (name.equals("allowedScopes")) { + Set allowedScopes = readSet(reader); + wlSite.setAllowedScopes(allowedScopes); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + } + reader.endObject(); + wlSite = wlSiteRepository.save(wlSite); + site.setWhitelistedSite(wlSite); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + } + reader.endObject(); + approvedSiteRepository.save(site).getId(); + logger.debug("Read grant {}", currentId); + } catch (ParseException ex) { + logger.error("Unable to read grant", ex); + } + } + reader.endArray(); + logger.info("Done reading grants"); } /** @@ -646,8 +867,111 @@ public class MITREidDataService_1_0 implements MITREidDataService { * @throws IOException */ private void readClients(JsonReader reader) throws IOException { - // TODO Auto-generated method stub - reader.skipValue(); + reader.beginArray(); + while (reader.hasNext()) { + ClientDetailsEntity client = new ClientDetailsEntity(); + reader.beginObject(); + while (reader.hasNext()) { + String name = reader.nextName(); + if (name.equals("clientId")) { + client.setClientId(reader.nextString()); + } else if (name.equals("resourceIds")) { + Set resourceIds = readSet(reader); + client.setResourceIds(resourceIds); + } else if (name.equals("secret")) { + client.setClientSecret(reader.nextString()); + } else if (name.equals("scope")) { + Set scope = readSet(reader); + client.setScope(scope); + } else if (name.equals("authorities")) { + Set authorityStrs = readSet(reader); + Set authorities = new HashSet(); + for (String s : authorityStrs) { + GrantedAuthority ga = new GrantedAuthorityImpl(s); + authorities.add(ga); + } + client.setAuthorities(authorities); + } else if (name.equals("accessTokenValiditySeconds")) { + client.setAccessTokenValiditySeconds(reader.nextInt()); + } else if (name.equals("refreshTokenValiditySeconds")) { + client.setRefreshTokenValiditySeconds(reader.nextInt()); + } else if (name.equals("redirectUris")) { + Set redirectUris = readSet(reader); + client.setRedirectUris(redirectUris); + } else if (name.equals("name")) { + client.setClientName(reader.nextString()); + } else if (name.equals("uri")) { + client.setClientUri(reader.nextString()); + } else if (name.equals("logoUri")) { + client.setLogoUri(reader.nextString()); + } else if (name.equals("contacts")) { + Set contacts = readSet(reader); + client.setContacts(contacts); + } else if (name.equals("tosUri")) { + client.setTosUri(reader.nextString()); + } else if (name.equals("tokenEndpointAuthMethod")) { + AuthMethod am = AuthMethod.getByValue(reader.nextString()); + client.setTokenEndpointAuthMethod(am); + } else if (name.equals("grantTypes")) { + Set grantTypes = readSet(reader); + client.setGrantTypes(grantTypes); + } else if (name.equals("responseTypes")) { + Set responseTypes = readSet(reader); + client.setGrantTypes(responseTypes); + } else if (name.equals("policyUri")) { + client.setPolicyUri(reader.nextString()); + } else if (name.equals("applicationType")) { + AppType appType = AppType.getByValue(reader.nextString()); + client.setApplicationType(appType); + } else if (name.equals("sectorIdentifierUri")) { + client.setSectorIdentifierUri(reader.nextString()); + } else if (name.equals("subjectType")) { + SubjectType st = SubjectType.getByValue(reader.nextString()); + client.setSubjectType(st); + } else if (name.equals("requestObjectSigningAlg")) { + JWSAlgorithmEmbed alg = JWSAlgorithmEmbed.getForAlgorithmName(reader.nextString()); + client.setRequestObjectSigningAlgEmbed(alg); + } else if (name.equals("userInfoEncryptedResponseAlg")) { + JWEAlgorithmEmbed alg = JWEAlgorithmEmbed.getForAlgorithmName(reader.nextString()); + client.setUserInfoEncryptedResponseAlgEmbed(alg); + } else if (name.equals("userInfoEncryptedResponseEnc")) { + JWEEncryptionMethodEmbed alg = JWEEncryptionMethodEmbed.getForAlgorithmName(reader.nextString()); + client.setUserInfoEncryptedResponseEncEmbed(alg); + } else if (name.equals("userInfoSignedResponseAlg")) { + JWSAlgorithmEmbed alg = JWSAlgorithmEmbed.getForAlgorithmName(reader.nextString()); + client.setUserInfoSignedResponseAlgEmbed(alg); + } else if (name.equals("defaultMaxAge")) { + client.setDefaultMaxAge(reader.nextInt()); + } else if (name.equals("requireAuthTime")) { + client.setRequireAuthTime(reader.nextBoolean()); + } else if (name.equals("defaultACRValues")) { + Set defaultACRvalues = readSet(reader); + client.setDefaultACRvalues(defaultACRvalues); + } else if (name.equals("initiateLoginUri")) { + client.setInitiateLoginUri(reader.nextString()); + } else if (name.equals("postLogoutRedirectUri")) { + client.setPostLogoutRedirectUri(reader.nextString()); + } else if (name.equals("requestUris")) { + Set requestUris = readSet(reader); + client.setRequestUris(requestUris); + } else if (name.equals("description")) { + client.setClientDescription(reader.nextString()); + } else if (name.equals("allowIntrospection")) { + client.setAllowIntrospection(reader.nextBoolean()); + } else if(name.equals("reuseRefreshToken")) { + client.setReuseRefreshToken(reader.nextBoolean()); + } else if(name.equals("dynamicallyRegistered")) { + client.setDynamicallyRegistered(reader.nextBoolean()); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + } + reader.endObject(); + clientRepository.saveClient(client); + } + reader.endArray(); + logger.info("Done reading clients"); } /** @@ -688,11 +1012,112 @@ public class MITREidDataService_1_0 implements MITREidDataService { continue; } } - reader.endObject(); - + reader.endObject(); sysScopeRepository.save(scope); } reader.endArray(); - logger.info("Done reading system scopes."); + logger.info("Done reading system scopes"); + } + + private Set readSet(JsonReader reader) throws IOException { + Set arraySet = null; + reader.beginArray(); + switch (reader.peek()) { + case STRING: + arraySet = new HashSet(); + while (reader.hasNext()) { + arraySet.add(reader.nextString()); + } + break; + case NUMBER: + arraySet = new HashSet(); + while (reader.hasNext()) { + arraySet.add(reader.nextLong()); + } + break; + default: + arraySet = new HashSet(); + break; + } + reader.endArray(); + return arraySet; + } + + private Map readMap(JsonReader reader) throws IOException { + Map map = new HashMap(); + reader.beginObject(); + while(reader.hasNext()) { + String name = reader.nextName(); + Object value = null; + switch(reader.peek()) { + case STRING: + value = reader.nextString(); + break; + case BOOLEAN: + value = reader.nextBoolean(); + break; + case NUMBER: + value = reader.nextLong(); + break; + } + map.put(name, value); + } + reader.endObject(); + return map; + } + + private void fixObjectReferences() { + for(Long oldRefreshTokenId : refreshTokenToClientRefs.keySet()) { + String clientRef = refreshTokenToClientRefs.get(oldRefreshTokenId); + ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); + Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); + OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); + refreshToken.setClient(client); + tokenRepository.saveRefreshToken(refreshToken); + } + for(Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { + Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId); + Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); + AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); + Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); + OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); + refreshToken.setAuthenticationHolder(authHolder); + tokenRepository.saveRefreshToken(refreshToken); + } + for(Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { + String clientRef = accessTokenToClientRefs.get(oldAccessTokenId); + ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setClient(client); + tokenRepository.saveAccessToken(accessToken); + } + for(Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { + Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId); + Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); + AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setAuthenticationHolder(authHolder); + tokenRepository.saveAccessToken(accessToken); + } + for(Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { + Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId); + Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); + OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setRefreshToken(refreshToken); + tokenRepository.saveAccessToken(accessToken); + } + for(Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { + Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId); + Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId); + OAuth2AccessTokenEntity idToken = tokenRepository.getAccessTokenById(newIdTokenId); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setIdToken(idToken); + tokenRepository.saveAccessToken(accessToken); + } } }