diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java index be82505f5..64f81091f 100644 --- a/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java +++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/service/MITREidDataService.java @@ -33,6 +33,7 @@ public interface MITREidDataService { public static final String MITREID_CONNECT_1_0 = "mitreid-connect-1.0"; public static final String MITREID_CONNECT_1_1 = "mitreid-connect-1.1"; public static final String MITREID_CONNECT_1_2 = "mitreid-connect-1.2"; + public static final String MITREID_CONNECT_1_3 = "mitreid-connect-1.3"; // member names public static final String REFRESHTOKENS = "refreshTokens"; diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_2.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_2.java index 4440e3d37..bc8658ed8 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_2.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_2.java @@ -170,367 +170,7 @@ public class MITREidDataService_1_2 extends MITREidDataServiceSupport implements @Override public void exportData(JsonWriter writer) throws IOException { - // version tag at the root - writer.name(MITREID_CONNECT_1_2); - - writer.beginObject(); - - // clients list - writer.name(CLIENTS); - writer.beginArray(); - writeClients(writer); - writer.endArray(); - - writer.name(GRANTS); - writer.beginArray(); - writeGrants(writer); - writer.endArray(); - - writer.name(WHITELISTEDSITES); - writer.beginArray(); - writeWhitelistedSites(writer); - writer.endArray(); - - writer.name(BLACKLISTEDSITES); - writer.beginArray(); - writeBlacklistedSites(writer); - writer.endArray(); - - writer.name(AUTHENTICATIONHOLDERS); - writer.beginArray(); - writeAuthenticationHolders(writer); - writer.endArray(); - - writer.name(ACCESSTOKENS); - writer.beginArray(); - writeAccessTokens(writer); - writer.endArray(); - - writer.name(REFRESHTOKENS); - writer.beginArray(); - writeRefreshTokens(writer); - writer.endArray(); - - writer.name(SYSTEMSCOPES); - writer.beginArray(); - writeSystemScopes(writer); - writer.endArray(); - - writer.endObject(); // end mitreid-connect-1.2 - } - - /** - * @param writer - */ - private void writeRefreshTokens(JsonWriter writer) throws IOException { - for (OAuth2RefreshTokenEntity token : tokenRepository.getAllRefreshTokens()) { - writer.beginObject(); - writer.name(ID).value(token.getId()); - writer.name(EXPIRATION).value(toUTCString(token.getExpiration())); - writer.name(CLIENT_ID) - .value((token.getClient() != null) ? token.getClient().getClientId() : null); - writer.name(AUTHENTICATION_HOLDER_ID) - .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); - writer.name(VALUE).value(token.getValue()); - writer.endObject(); - logger.debug("Wrote refresh token {}", token.getId()); - } - logger.info("Done writing refresh tokens"); - } - - /** - * @param writer - */ - private void writeAccessTokens(JsonWriter writer) throws IOException { - for (OAuth2AccessTokenEntity token : tokenRepository.getAllAccessTokens()) { - writer.beginObject(); - writer.name(ID).value(token.getId()); - writer.name(EXPIRATION).value(toUTCString(token.getExpiration())); - writer.name(CLIENT_ID) - .value((token.getClient() != null) ? token.getClient().getClientId() : null); - writer.name(AUTHENTICATION_HOLDER_ID) - .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); - writer.name(REFRESH_TOKEN_ID) - .value((token.getRefreshToken() != null) ? token.getRefreshToken().getId() : null); - writer.name(ID_TOKEN_ID) - .value((token.getIdToken() != null) ? token.getIdToken().getId() : null); - writer.name(SCOPE); - writer.beginArray(); - for (String s : token.getScope()) { - writer.value(s); - } - writer.endArray(); - writer.name(TYPE).value(token.getTokenType()); - writer.name(VALUE).value(token.getValue()); - writer.endObject(); - logger.debug("Wrote access token {}", token.getId()); - } - logger.info("Done writing access tokens"); - } - - /** - * @param writer - */ - private void writeAuthenticationHolders(JsonWriter writer) throws IOException { - for (AuthenticationHolderEntity holder : authHolderRepository.getAll()) { - writer.beginObject(); - writer.name(ID).value(holder.getId()); - - writer.name(REQUEST_PARAMETERS); - writer.beginObject(); - for (Entry entry : holder.getRequestParameters().entrySet()) { - writer.name(entry.getKey()).value(entry.getValue()); - } - writer.endObject(); - writer.name(CLIENT_ID).value(holder.getClientId()); - Set scope = holder.getScope(); - writer.name(SCOPE); - writer.beginArray(); - for (String s : scope) { - writer.value(s); - } - writer.endArray(); - writer.name(RESOURCE_IDS); - writer.beginArray(); - if (holder.getResourceIds() != null) { - for (String s : holder.getResourceIds()) { - writer.value(s); - } - } - writer.endArray(); - writer.name(AUTHORITIES); - writer.beginArray(); - for (GrantedAuthority authority : holder.getAuthorities()) { - writer.value(authority.getAuthority()); - } - writer.endArray(); - writer.name(APPROVED).value(holder.isApproved()); - writer.name(REDIRECT_URI).value(holder.getRedirectUri()); - writer.name(RESPONSE_TYPES); - writer.beginArray(); - for (String s : holder.getResponseTypes()) { - writer.value(s); - } - writer.endArray(); - writer.name(EXTENSIONS); - writer.beginObject(); - for (Entry entry : holder.getExtensions().entrySet()) { - // while the extension map itself is Serializable, we enforce storage of Strings - if (entry.getValue() instanceof String) { - writer.name(entry.getKey()).value((String) entry.getValue()); - } else { - logger.warn("Skipping non-string extension: " + entry); - } - } - writer.endObject(); - - writer.name(SAVED_USER_AUTHENTICATION); - if (holder.getUserAuth() != null) { - writer.beginObject(); - writer.name(NAME).value(holder.getUserAuth().getName()); - writer.name(SOURCE_CLASS).value(holder.getUserAuth().getSourceClass()); - writer.name(AUTHENTICATED).value(holder.getUserAuth().isAuthenticated()); - writer.name(AUTHORITIES); - writer.beginArray(); - for (GrantedAuthority authority : holder.getUserAuth().getAuthorities()) { - writer.value(authority.getAuthority()); - } - writer.endArray(); - - writer.endObject(); - } else { - writer.nullValue(); - } - - - writer.endObject(); - logger.debug("Wrote authentication holder {}", holder.getId()); - } - logger.info("Done writing authentication holders"); - } - - /** - * @param writer - */ - private void writeGrants(JsonWriter writer) throws IOException { - for (ApprovedSite site : approvedSiteRepository.getAll()) { - writer.beginObject(); - writer.name(ID).value(site.getId()); - writer.name(ACCESS_DATE).value(toUTCString(site.getAccessDate())); - writer.name(CLIENT_ID).value(site.getClientId()); - writer.name(CREATION_DATE).value(toUTCString(site.getCreationDate())); - writer.name(TIMEOUT_DATE).value(toUTCString(site.getTimeoutDate())); - writer.name(USER_ID).value(site.getUserId()); - writer.name(ALLOWED_SCOPES); - writeNullSafeArray(writer, site.getAllowedScopes()); - Set tokens = site.getApprovedAccessTokens(); - writer.name(APPROVED_ACCESS_TOKENS); - writer.beginArray(); - for (OAuth2AccessTokenEntity token : tokens) { - writer.value(token.getId()); - } - writer.endArray(); - writer.endObject(); - logger.debug("Wrote grant {}", site.getId()); - } - logger.info("Done writing grants"); - } - - /** - * @param writer - */ - private void writeWhitelistedSites(JsonWriter writer) throws IOException { - for (WhitelistedSite wlSite : wlSiteRepository.getAll()) { - writer.beginObject(); - writer.name(ID).value(wlSite.getId()); - writer.name(CLIENT_ID).value(wlSite.getClientId()); - writer.name(CREATOR_USER_ID).value(wlSite.getCreatorUserId()); - writer.name(ALLOWED_SCOPES); - writeNullSafeArray(writer, wlSite.getAllowedScopes()); - writer.endObject(); - logger.debug("Wrote whitelisted site {}", wlSite.getId()); - } - logger.info("Done writing whitelisted sites"); - } - - /** - * @param writer - */ - private void writeBlacklistedSites(JsonWriter writer) throws IOException { - for (BlacklistedSite blSite : blSiteRepository.getAll()) { - writer.beginObject(); - writer.name(ID).value(blSite.getId()); - writer.name(URI).value(blSite.getUri()); - writer.endObject(); - logger.debug("Wrote blacklisted site {}", blSite.getId()); - } - logger.info("Done writing blacklisted sites"); - } - - /** - * @param writer - */ - private void writeClients(JsonWriter writer) { - for (ClientDetailsEntity client : clientRepository.getAllClients()) { - try { - writer.beginObject(); - writer.name(CLIENT_ID).value(client.getClientId()); - writer.name(RESOURCE_IDS); - writeNullSafeArray(writer, client.getResourceIds()); - - writer.name(SECRET).value(client.getClientSecret()); - - writer.name(SCOPE); - writeNullSafeArray(writer, client.getScope()); - - writer.name(AUTHORITIES); - writer.beginArray(); - for (GrantedAuthority authority : client.getAuthorities()) { - writer.value(authority.getAuthority()); - } - writer.endArray(); - writer.name(ACCESS_TOKEN_VALIDITY_SECONDS).value(client.getAccessTokenValiditySeconds()); - writer.name(REFRESH_TOKEN_VALIDITY_SECONDS).value(client.getRefreshTokenValiditySeconds()); - writer.name(REDIRECT_URIS); - writeNullSafeArray(writer, client.getRedirectUris()); - writer.name(CLAIMS_REDIRECT_URIS); - writeNullSafeArray(writer, client.getClaimsRedirectUris()); - writer.name(NAME).value(client.getClientName()); - writer.name(URI).value(client.getClientUri()); - writer.name(LOGO_URI).value(client.getLogoUri()); - writer.name(CONTACTS); - writeNullSafeArray(writer, client.getContacts()); - writer.name(TOS_URI).value(client.getTosUri()); - writer.name(TOKEN_ENDPOINT_AUTH_METHOD) - .value((client.getTokenEndpointAuthMethod() != null) ? client.getTokenEndpointAuthMethod().getValue() : null); - writer.name(GRANT_TYPES); - writer.beginArray(); - for (String s : client.getGrantTypes()) { - writer.value(s); - } - writer.endArray(); - writer.name(RESPONSE_TYPES); - writer.beginArray(); - for (String s : client.getResponseTypes()) { - writer.value(s); - } - writer.endArray(); - writer.name(POLICY_URI).value(client.getPolicyUri()); - writer.name(JWKS_URI).value(client.getJwksUri()); - writer.name(JWKS).value((client.getJwks() != null) ? client.getJwks().toString() : null); - writer.name(APPLICATION_TYPE) - .value((client.getApplicationType() != null) ? client.getApplicationType().getValue() : null); - writer.name(SECTOR_IDENTIFIER_URI).value(client.getSectorIdentifierUri()); - writer.name(SUBJECT_TYPE) - .value((client.getSubjectType() != null) ? client.getSubjectType().getValue() : null); - writer.name(REQUEST_OBJECT_SIGNING_ALG) - .value((client.getRequestObjectSigningAlg() != null) ? client.getRequestObjectSigningAlg().getName() : null); - writer.name(ID_TOKEN_SIGNED_RESPONSE_ALG) - .value((client.getIdTokenSignedResponseAlg() != null) ? client.getIdTokenSignedResponseAlg().getName() : null); - writer.name(ID_TOKEN_ENCRYPTED_RESPONSE_ALG) - .value((client.getIdTokenEncryptedResponseAlg() != null) ? client.getIdTokenEncryptedResponseAlg().getName() : null); - writer.name(ID_TOKEN_ENCRYPTED_RESPONSE_ENC) - .value((client.getIdTokenEncryptedResponseEnc() != null) ? client.getIdTokenEncryptedResponseEnc().getName() : null); - writer.name(USER_INFO_SIGNED_RESPONSE_ALG) - .value((client.getUserInfoSignedResponseAlg() != null) ? client.getUserInfoSignedResponseAlg().getName() : null); - writer.name(USER_INFO_ENCRYPTED_RESPONSE_ALG) - .value((client.getUserInfoEncryptedResponseAlg() != null) ? client.getUserInfoEncryptedResponseAlg().getName() : null); - writer.name(USER_INFO_ENCRYPTED_RESPONSE_ENC) - .value((client.getUserInfoEncryptedResponseEnc() != null) ? client.getUserInfoEncryptedResponseEnc().getName() : null); - writer.name(TOKEN_ENDPOINT_AUTH_SIGNING_ALG) - .value((client.getTokenEndpointAuthSigningAlg() != null) ? client.getTokenEndpointAuthSigningAlg().getName() : null); - writer.name(DEFAULT_MAX_AGE).value(client.getDefaultMaxAge()); - Boolean requireAuthTime = null; - try { - requireAuthTime = client.getRequireAuthTime(); - } catch (NullPointerException e) { - } - if (requireAuthTime != null) { - writer.name(REQUIRE_AUTH_TIME).value(requireAuthTime); - } - writer.name(DEFAULT_ACR_VALUES); - writeNullSafeArray(writer, client.getDefaultACRvalues()); - writer.name(INTITATE_LOGIN_URI).value(client.getInitiateLoginUri()); - writer.name(POST_LOGOUT_REDIRECT_URI); - writeNullSafeArray(writer, client.getPostLogoutRedirectUris()); - writer.name(REQUEST_URIS); - writeNullSafeArray(writer, client.getRequestUris()); - writer.name(DESCRIPTION).value(client.getClientDescription()); - writer.name(ALLOW_INTROSPECTION).value(client.isAllowIntrospection()); - writer.name(REUSE_REFRESH_TOKEN).value(client.isReuseRefreshToken()); - writer.name(CLEAR_ACCESS_TOKENS_ON_REFRESH).value(client.isClearAccessTokensOnRefresh()); - writer.name(DYNAMICALLY_REGISTERED).value(client.isDynamicallyRegistered()); - writer.endObject(); - logger.debug("Wrote client {}", client.getId()); - } catch (IOException ex) { - logger.error("Unable to write client {}", client.getId(), ex); - } - } - logger.info("Done writing clients"); - } - - /** - * @param writer - */ - private void writeSystemScopes(JsonWriter writer) { - for (SystemScope sysScope : sysScopeRepository.getAll()) { - try { - writer.beginObject(); - writer.name(ID).value(sysScope.getId()); - writer.name(DESCRIPTION).value(sysScope.getDescription()); - writer.name(ICON).value(sysScope.getIcon()); - writer.name(VALUE).value(sysScope.getValue()); - writer.name(RESTRICTED).value(sysScope.isRestricted()); - writer.name(STRUCTURED).value(sysScope.isStructured()); - writer.name(STRUCTURED_PARAMETER).value(sysScope.getStructuredParamDescription()); - writer.name(DEFAULT_SCOPE).value(sysScope.isDefaultScope()); - writer.endObject(); - logger.debug("Wrote system scope {}", sysScope.getId()); - } catch (IOException ex) { - logger.error("Unable to write system scope {}", sysScope.getId(), ex); - } - } - logger.info("Done writing system scopes"); + throw new UnsupportedOperationException("Can not export 1.2 format from this version."); } /* (non-Javadoc) diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_3.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_3.java new file mode 100644 index 000000000..1bb9476d1 --- /dev/null +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/MITREidDataService_1_3.java @@ -0,0 +1,1292 @@ +/******************************************************************************* + * Copyright 2016 The MITRE Corporation + * and the MIT Internet Trust Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +package org.mitre.openid.connect.service.impl; + +import java.io.IOException; +import java.io.Serializable; +import java.text.ParseException; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.mitre.oauth2.model.AuthenticationHolderEntity; +import org.mitre.oauth2.model.ClientDetailsEntity; +import org.mitre.oauth2.model.ClientDetailsEntity.AppType; +import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod; +import org.mitre.oauth2.model.ClientDetailsEntity.SubjectType; +import org.mitre.oauth2.model.OAuth2AccessTokenEntity; +import org.mitre.oauth2.model.OAuth2RefreshTokenEntity; +import org.mitre.oauth2.model.PKCEAlgorithm; +import org.mitre.oauth2.model.SavedUserAuthentication; +import org.mitre.oauth2.model.SystemScope; +import org.mitre.oauth2.repository.AuthenticationHolderRepository; +import org.mitre.oauth2.repository.OAuth2ClientRepository; +import org.mitre.oauth2.repository.OAuth2TokenRepository; +import org.mitre.oauth2.repository.SystemScopeRepository; +import org.mitre.openid.connect.model.ApprovedSite; +import org.mitre.openid.connect.model.BlacklistedSite; +import org.mitre.openid.connect.model.WhitelistedSite; +import org.mitre.openid.connect.repository.ApprovedSiteRepository; +import org.mitre.openid.connect.repository.BlacklistedSiteRepository; +import org.mitre.openid.connect.repository.WhitelistedSiteRepository; +import org.mitre.openid.connect.service.MITREidDataService; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.stereotype.Service; + +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonToken; +import com.google.gson.stream.JsonWriter; +import com.nimbusds.jose.EncryptionMethod; +import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.JWKSet; +import com.nimbusds.jwt.JWTParser; + +import static org.mitre.util.JsonUtils.readMap; +import static org.mitre.util.JsonUtils.readSet; +import static org.mitre.util.JsonUtils.writeNullSafeArray; + +/** + * + * Data service to import and export MITREid 1.2 configuration. + * + * @author jricher + * @author arielak + */ +@Service +@SuppressWarnings(value = {"unchecked"}) +public class MITREidDataService_1_3 extends MITREidDataServiceSupport implements MITREidDataService { + + private static final String DEFAULT_SCOPE = "defaultScope"; + private static final String STRUCTURED_PARAMETER = "structuredParameter"; + private static final String STRUCTURED = "structured"; + private static final String RESTRICTED = "restricted"; + private static final String ICON = "icon"; + private static final String DYNAMICALLY_REGISTERED = "dynamicallyRegistered"; + private static final String CLEAR_ACCESS_TOKENS_ON_REFRESH = "clearAccessTokensOnRefresh"; + private static final String REUSE_REFRESH_TOKEN = "reuseRefreshToken"; + private static final String ALLOW_INTROSPECTION = "allowIntrospection"; + private static final String DESCRIPTION = "description"; + private static final String REQUEST_URIS = "requestUris"; + private static final String POST_LOGOUT_REDIRECT_URI = "postLogoutRedirectUri"; + private static final String INTITATE_LOGIN_URI = "intitateLoginUri"; + private static final String DEFAULT_ACR_VALUES = "defaultACRValues"; + private static final String REQUIRE_AUTH_TIME = "requireAuthTime"; + private static final String DEFAULT_MAX_AGE = "defaultMaxAge"; + private static final String TOKEN_ENDPOINT_AUTH_SIGNING_ALG = "tokenEndpointAuthSigningAlg"; + private static final String USER_INFO_ENCRYPTED_RESPONSE_ENC = "userInfoEncryptedResponseEnc"; + private static final String USER_INFO_ENCRYPTED_RESPONSE_ALG = "userInfoEncryptedResponseAlg"; + private static final String USER_INFO_SIGNED_RESPONSE_ALG = "userInfoSignedResponseAlg"; + private static final String ID_TOKEN_ENCRYPTED_RESPONSE_ENC = "idTokenEncryptedResponseEnc"; + private static final String ID_TOKEN_ENCRYPTED_RESPONSE_ALG = "idTokenEncryptedResponseAlg"; + private static final String ID_TOKEN_SIGNED_RESPONSE_ALG = "idTokenSignedResponseAlg"; + private static final String REQUEST_OBJECT_SIGNING_ALG = "requestObjectSigningAlg"; + private static final String SUBJECT_TYPE = "subjectType"; + private static final String SECTOR_IDENTIFIER_URI = "sectorIdentifierUri"; + private static final String APPLICATION_TYPE = "applicationType"; + private static final String JWKS = "jwks"; + private static final String JWKS_URI = "jwksUri"; + private static final String POLICY_URI = "policyUri"; + private static final String GRANT_TYPES = "grantTypes"; + private static final String TOKEN_ENDPOINT_AUTH_METHOD = "tokenEndpointAuthMethod"; + private static final String TOS_URI = "tosUri"; + private static final String CONTACTS = "contacts"; + private static final String LOGO_URI = "logoUri"; + private static final String REDIRECT_URIS = "redirectUris"; + private static final String REFRESH_TOKEN_VALIDITY_SECONDS = "refreshTokenValiditySeconds"; + private static final String ACCESS_TOKEN_VALIDITY_SECONDS = "accessTokenValiditySeconds"; + private static final String SECRET = "secret"; + private static final String URI = "uri"; + private static final String CREATOR_USER_ID = "creatorUserId"; + private static final String APPROVED_ACCESS_TOKENS = "approvedAccessTokens"; + private static final String ALLOWED_SCOPES = "allowedScopes"; + private static final String USER_ID = "userId"; + private static final String TIMEOUT_DATE = "timeoutDate"; + private static final String CREATION_DATE = "creationDate"; + private static final String ACCESS_DATE = "accessDate"; + private static final String AUTHENTICATED = "authenticated"; + private static final String SOURCE_CLASS = "sourceClass"; + private static final String NAME = "name"; + private static final String SAVED_USER_AUTHENTICATION = "savedUserAuthentication"; + private static final String EXTENSIONS = "extensions"; + private static final String RESPONSE_TYPES = "responseTypes"; + private static final String REDIRECT_URI = "redirectUri"; + private static final String APPROVED = "approved"; + private static final String AUTHORITIES = "authorities"; + private static final String RESOURCE_IDS = "resourceIds"; + private static final String REQUEST_PARAMETERS = "requestParameters"; + private static final String TYPE = "type"; + private static final String SCOPE = "scope"; + private static final String ID_TOKEN_ID = "idTokenId"; + private static final String REFRESH_TOKEN_ID = "refreshTokenId"; + private static final String VALUE = "value"; + private static final String AUTHENTICATION_HOLDER_ID = "authenticationHolderId"; + private static final String CLIENT_ID = "clientId"; + private static final String EXPIRATION = "expiration"; + private static final String CLAIMS_REDIRECT_URIS = "claimsRedirectUris"; + private static final String ID = "id"; + private static final String CODE_CHALLENGE_METHOD = "codeChallengeMethod"; + private static final String SOFTWARE_STATEMENT = "softwareStatement"; + + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(MITREidDataService_1_3.class); + @Autowired + private OAuth2ClientRepository clientRepository; + @Autowired + private ApprovedSiteRepository approvedSiteRepository; + @Autowired + private WhitelistedSiteRepository wlSiteRepository; + @Autowired + private BlacklistedSiteRepository blSiteRepository; + @Autowired + private AuthenticationHolderRepository authHolderRepository; + @Autowired + private OAuth2TokenRepository tokenRepository; + @Autowired + private SystemScopeRepository sysScopeRepository; + + /* (non-Javadoc) + * @see org.mitre.openid.connect.service.MITREidDataService#export(com.google.gson.stream.JsonWriter) + */ + @Override + public void exportData(JsonWriter writer) throws IOException { + + // version tag at the root + writer.name(MITREID_CONNECT_1_3); + + writer.beginObject(); + + // clients list + writer.name(CLIENTS); + writer.beginArray(); + writeClients(writer); + writer.endArray(); + + writer.name(GRANTS); + writer.beginArray(); + writeGrants(writer); + writer.endArray(); + + writer.name(WHITELISTEDSITES); + writer.beginArray(); + writeWhitelistedSites(writer); + writer.endArray(); + + writer.name(BLACKLISTEDSITES); + writer.beginArray(); + writeBlacklistedSites(writer); + writer.endArray(); + + writer.name(AUTHENTICATIONHOLDERS); + writer.beginArray(); + writeAuthenticationHolders(writer); + writer.endArray(); + + writer.name(ACCESSTOKENS); + writer.beginArray(); + writeAccessTokens(writer); + writer.endArray(); + + writer.name(REFRESHTOKENS); + writer.beginArray(); + writeRefreshTokens(writer); + writer.endArray(); + + writer.name(SYSTEMSCOPES); + writer.beginArray(); + writeSystemScopes(writer); + writer.endArray(); + + writer.endObject(); // end mitreid-connect-1.2 + } + + /** + * @param writer + */ + private void writeRefreshTokens(JsonWriter writer) throws IOException { + for (OAuth2RefreshTokenEntity token : tokenRepository.getAllRefreshTokens()) { + writer.beginObject(); + writer.name(ID).value(token.getId()); + writer.name(EXPIRATION).value(toUTCString(token.getExpiration())); + writer.name(CLIENT_ID) + .value((token.getClient() != null) ? token.getClient().getClientId() : null); + writer.name(AUTHENTICATION_HOLDER_ID) + .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); + writer.name(VALUE).value(token.getValue()); + writer.endObject(); + logger.debug("Wrote refresh token {}", token.getId()); + } + logger.info("Done writing refresh tokens"); + } + + /** + * @param writer + */ + private void writeAccessTokens(JsonWriter writer) throws IOException { + for (OAuth2AccessTokenEntity token : tokenRepository.getAllAccessTokens()) { + writer.beginObject(); + writer.name(ID).value(token.getId()); + writer.name(EXPIRATION).value(toUTCString(token.getExpiration())); + writer.name(CLIENT_ID) + .value((token.getClient() != null) ? token.getClient().getClientId() : null); + writer.name(AUTHENTICATION_HOLDER_ID) + .value((token.getAuthenticationHolder() != null) ? token.getAuthenticationHolder().getId() : null); + writer.name(REFRESH_TOKEN_ID) + .value((token.getRefreshToken() != null) ? token.getRefreshToken().getId() : null); + writer.name(ID_TOKEN_ID) + .value((token.getIdToken() != null) ? token.getIdToken().getId() : null); + writer.name(SCOPE); + writer.beginArray(); + for (String s : token.getScope()) { + writer.value(s); + } + writer.endArray(); + writer.name(TYPE).value(token.getTokenType()); + writer.name(VALUE).value(token.getValue()); + writer.endObject(); + logger.debug("Wrote access token {}", token.getId()); + } + logger.info("Done writing access tokens"); + } + + /** + * @param writer + */ + private void writeAuthenticationHolders(JsonWriter writer) throws IOException { + for (AuthenticationHolderEntity holder : authHolderRepository.getAll()) { + writer.beginObject(); + writer.name(ID).value(holder.getId()); + + writer.name(REQUEST_PARAMETERS); + writer.beginObject(); + for (Entry entry : holder.getRequestParameters().entrySet()) { + writer.name(entry.getKey()).value(entry.getValue()); + } + writer.endObject(); + writer.name(CLIENT_ID).value(holder.getClientId()); + Set scope = holder.getScope(); + writer.name(SCOPE); + writer.beginArray(); + for (String s : scope) { + writer.value(s); + } + writer.endArray(); + writer.name(RESOURCE_IDS); + writer.beginArray(); + if (holder.getResourceIds() != null) { + for (String s : holder.getResourceIds()) { + writer.value(s); + } + } + writer.endArray(); + writer.name(AUTHORITIES); + writer.beginArray(); + for (GrantedAuthority authority : holder.getAuthorities()) { + writer.value(authority.getAuthority()); + } + writer.endArray(); + writer.name(APPROVED).value(holder.isApproved()); + writer.name(REDIRECT_URI).value(holder.getRedirectUri()); + writer.name(RESPONSE_TYPES); + writer.beginArray(); + for (String s : holder.getResponseTypes()) { + writer.value(s); + } + writer.endArray(); + writer.name(EXTENSIONS); + writer.beginObject(); + for (Entry entry : holder.getExtensions().entrySet()) { + // while the extension map itself is Serializable, we enforce storage of Strings + if (entry.getValue() instanceof String) { + writer.name(entry.getKey()).value((String) entry.getValue()); + } else { + logger.warn("Skipping non-string extension: " + entry); + } + } + writer.endObject(); + + writer.name(SAVED_USER_AUTHENTICATION); + if (holder.getUserAuth() != null) { + writer.beginObject(); + writer.name(NAME).value(holder.getUserAuth().getName()); + writer.name(SOURCE_CLASS).value(holder.getUserAuth().getSourceClass()); + writer.name(AUTHENTICATED).value(holder.getUserAuth().isAuthenticated()); + writer.name(AUTHORITIES); + writer.beginArray(); + for (GrantedAuthority authority : holder.getUserAuth().getAuthorities()) { + writer.value(authority.getAuthority()); + } + writer.endArray(); + + writer.endObject(); + } else { + writer.nullValue(); + } + + + writer.endObject(); + logger.debug("Wrote authentication holder {}", holder.getId()); + } + logger.info("Done writing authentication holders"); + } + + /** + * @param writer + */ + private void writeGrants(JsonWriter writer) throws IOException { + for (ApprovedSite site : approvedSiteRepository.getAll()) { + writer.beginObject(); + writer.name(ID).value(site.getId()); + writer.name(ACCESS_DATE).value(toUTCString(site.getAccessDate())); + writer.name(CLIENT_ID).value(site.getClientId()); + writer.name(CREATION_DATE).value(toUTCString(site.getCreationDate())); + writer.name(TIMEOUT_DATE).value(toUTCString(site.getTimeoutDate())); + writer.name(USER_ID).value(site.getUserId()); + writer.name(ALLOWED_SCOPES); + writeNullSafeArray(writer, site.getAllowedScopes()); + Set tokens = site.getApprovedAccessTokens(); + writer.name(APPROVED_ACCESS_TOKENS); + writer.beginArray(); + for (OAuth2AccessTokenEntity token : tokens) { + writer.value(token.getId()); + } + writer.endArray(); + writer.endObject(); + logger.debug("Wrote grant {}", site.getId()); + } + logger.info("Done writing grants"); + } + + /** + * @param writer + */ + private void writeWhitelistedSites(JsonWriter writer) throws IOException { + for (WhitelistedSite wlSite : wlSiteRepository.getAll()) { + writer.beginObject(); + writer.name(ID).value(wlSite.getId()); + writer.name(CLIENT_ID).value(wlSite.getClientId()); + writer.name(CREATOR_USER_ID).value(wlSite.getCreatorUserId()); + writer.name(ALLOWED_SCOPES); + writeNullSafeArray(writer, wlSite.getAllowedScopes()); + writer.endObject(); + logger.debug("Wrote whitelisted site {}", wlSite.getId()); + } + logger.info("Done writing whitelisted sites"); + } + + /** + * @param writer + */ + private void writeBlacklistedSites(JsonWriter writer) throws IOException { + for (BlacklistedSite blSite : blSiteRepository.getAll()) { + writer.beginObject(); + writer.name(ID).value(blSite.getId()); + writer.name(URI).value(blSite.getUri()); + writer.endObject(); + logger.debug("Wrote blacklisted site {}", blSite.getId()); + } + logger.info("Done writing blacklisted sites"); + } + + /** + * @param writer + */ + private void writeClients(JsonWriter writer) { + for (ClientDetailsEntity client : clientRepository.getAllClients()) { + try { + writer.beginObject(); + writer.name(CLIENT_ID).value(client.getClientId()); + writer.name(RESOURCE_IDS); + writeNullSafeArray(writer, client.getResourceIds()); + + writer.name(SECRET).value(client.getClientSecret()); + + writer.name(SCOPE); + writeNullSafeArray(writer, client.getScope()); + + writer.name(AUTHORITIES); + writer.beginArray(); + for (GrantedAuthority authority : client.getAuthorities()) { + writer.value(authority.getAuthority()); + } + writer.endArray(); + writer.name(ACCESS_TOKEN_VALIDITY_SECONDS).value(client.getAccessTokenValiditySeconds()); + writer.name(REFRESH_TOKEN_VALIDITY_SECONDS).value(client.getRefreshTokenValiditySeconds()); + writer.name(REDIRECT_URIS); + writeNullSafeArray(writer, client.getRedirectUris()); + writer.name(CLAIMS_REDIRECT_URIS); + writeNullSafeArray(writer, client.getClaimsRedirectUris()); + writer.name(NAME).value(client.getClientName()); + writer.name(URI).value(client.getClientUri()); + writer.name(LOGO_URI).value(client.getLogoUri()); + writer.name(CONTACTS); + writeNullSafeArray(writer, client.getContacts()); + writer.name(TOS_URI).value(client.getTosUri()); + writer.name(TOKEN_ENDPOINT_AUTH_METHOD) + .value((client.getTokenEndpointAuthMethod() != null) ? client.getTokenEndpointAuthMethod().getValue() : null); + writer.name(GRANT_TYPES); + writer.beginArray(); + for (String s : client.getGrantTypes()) { + writer.value(s); + } + writer.endArray(); + writer.name(RESPONSE_TYPES); + writer.beginArray(); + for (String s : client.getResponseTypes()) { + writer.value(s); + } + writer.endArray(); + writer.name(POLICY_URI).value(client.getPolicyUri()); + writer.name(JWKS_URI).value(client.getJwksUri()); + writer.name(JWKS).value((client.getJwks() != null) ? client.getJwks().toString() : null); + writer.name(APPLICATION_TYPE) + .value((client.getApplicationType() != null) ? client.getApplicationType().getValue() : null); + writer.name(SECTOR_IDENTIFIER_URI).value(client.getSectorIdentifierUri()); + writer.name(SUBJECT_TYPE) + .value((client.getSubjectType() != null) ? client.getSubjectType().getValue() : null); + writer.name(REQUEST_OBJECT_SIGNING_ALG) + .value((client.getRequestObjectSigningAlg() != null) ? client.getRequestObjectSigningAlg().getName() : null); + writer.name(ID_TOKEN_SIGNED_RESPONSE_ALG) + .value((client.getIdTokenSignedResponseAlg() != null) ? client.getIdTokenSignedResponseAlg().getName() : null); + writer.name(ID_TOKEN_ENCRYPTED_RESPONSE_ALG) + .value((client.getIdTokenEncryptedResponseAlg() != null) ? client.getIdTokenEncryptedResponseAlg().getName() : null); + writer.name(ID_TOKEN_ENCRYPTED_RESPONSE_ENC) + .value((client.getIdTokenEncryptedResponseEnc() != null) ? client.getIdTokenEncryptedResponseEnc().getName() : null); + writer.name(USER_INFO_SIGNED_RESPONSE_ALG) + .value((client.getUserInfoSignedResponseAlg() != null) ? client.getUserInfoSignedResponseAlg().getName() : null); + writer.name(USER_INFO_ENCRYPTED_RESPONSE_ALG) + .value((client.getUserInfoEncryptedResponseAlg() != null) ? client.getUserInfoEncryptedResponseAlg().getName() : null); + writer.name(USER_INFO_ENCRYPTED_RESPONSE_ENC) + .value((client.getUserInfoEncryptedResponseEnc() != null) ? client.getUserInfoEncryptedResponseEnc().getName() : null); + writer.name(TOKEN_ENDPOINT_AUTH_SIGNING_ALG) + .value((client.getTokenEndpointAuthSigningAlg() != null) ? client.getTokenEndpointAuthSigningAlg().getName() : null); + writer.name(DEFAULT_MAX_AGE).value(client.getDefaultMaxAge()); + Boolean requireAuthTime = null; + try { + requireAuthTime = client.getRequireAuthTime(); + } catch (NullPointerException e) { + } + if (requireAuthTime != null) { + writer.name(REQUIRE_AUTH_TIME).value(requireAuthTime); + } + writer.name(DEFAULT_ACR_VALUES); + writeNullSafeArray(writer, client.getDefaultACRvalues()); + writer.name(INTITATE_LOGIN_URI).value(client.getInitiateLoginUri()); + writer.name(POST_LOGOUT_REDIRECT_URI); + writeNullSafeArray(writer, client.getPostLogoutRedirectUris()); + writer.name(REQUEST_URIS); + writeNullSafeArray(writer, client.getRequestUris()); + writer.name(DESCRIPTION).value(client.getClientDescription()); + writer.name(ALLOW_INTROSPECTION).value(client.isAllowIntrospection()); + writer.name(REUSE_REFRESH_TOKEN).value(client.isReuseRefreshToken()); + writer.name(CLEAR_ACCESS_TOKENS_ON_REFRESH).value(client.isClearAccessTokensOnRefresh()); + writer.name(DYNAMICALLY_REGISTERED).value(client.isDynamicallyRegistered()); + writer.name(CODE_CHALLENGE_METHOD).value(client.getCodeChallengeMethod() != null ? client.getCodeChallengeMethod().getName() : null); + writer.name(SOFTWARE_STATEMENT).value(client.getSoftwareStatement() != null ? client.getSoftwareStatement().serialize() : null); + writer.endObject(); + logger.debug("Wrote client {}", client.getId()); + } catch (IOException ex) { + logger.error("Unable to write client {}", client.getId(), ex); + } + } + logger.info("Done writing clients"); + } + + /** + * @param writer + */ + private void writeSystemScopes(JsonWriter writer) { + for (SystemScope sysScope : sysScopeRepository.getAll()) { + try { + writer.beginObject(); + writer.name(ID).value(sysScope.getId()); + writer.name(DESCRIPTION).value(sysScope.getDescription()); + writer.name(ICON).value(sysScope.getIcon()); + writer.name(VALUE).value(sysScope.getValue()); + writer.name(RESTRICTED).value(sysScope.isRestricted()); + writer.name(STRUCTURED).value(sysScope.isStructured()); + writer.name(STRUCTURED_PARAMETER).value(sysScope.getStructuredParamDescription()); + writer.name(DEFAULT_SCOPE).value(sysScope.isDefaultScope()); + writer.endObject(); + logger.debug("Wrote system scope {}", sysScope.getId()); + } catch (IOException ex) { + logger.error("Unable to write system scope {}", sysScope.getId(), ex); + } + } + logger.info("Done writing system scopes"); + } + + /* (non-Javadoc) + * @see org.mitre.openid.connect.service.MITREidDataService#importData(com.google.gson.stream.JsonReader) + */ + @Override + public void importData(JsonReader reader) throws IOException { + + logger.info("Reading configuration for 1.2"); + + // this *HAS* to start as an object + reader.beginObject(); + + while (reader.hasNext()) { + JsonToken tok = reader.peek(); + switch (tok) { + case NAME: + String name = reader.nextName(); + // find out which member it is + if (name.equals(CLIENTS)) { + readClients(reader); + } else if (name.equals(GRANTS)) { + readGrants(reader); + } else if (name.equals(WHITELISTEDSITES)) { + readWhitelistedSites(reader); + } else if (name.equals(BLACKLISTEDSITES)) { + readBlacklistedSites(reader); + } else if (name.equals(AUTHENTICATIONHOLDERS)) { + readAuthenticationHolders(reader); + } else if (name.equals(ACCESSTOKENS)) { + readAccessTokens(reader); + } else if (name.equals(REFRESHTOKENS)) { + readRefreshTokens(reader); + } else if (name.equals(SYSTEMSCOPES)) { + readSystemScopes(reader); + } else { + // unknown token, skip it + reader.skipValue(); + } + break; + case END_OBJECT: + // the object ended, we're done here + reader.endObject(); + continue; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + fixObjectReferences(); + } + private Map refreshTokenToClientRefs = new HashMap(); + private Map refreshTokenToAuthHolderRefs = new HashMap(); + private Map refreshTokenOldToNewIdMap = new HashMap(); + + /** + * @param reader + * @throws IOException + */ + /** + * @param reader + * @throws IOException + */ + private void readRefreshTokens(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + OAuth2RefreshTokenEntity token = new OAuth2RefreshTokenEntity(); + reader.beginObject(); + Long currentId = null; + String clientId = null; + Long authHolderId = null; + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(ID)) { + currentId = reader.nextLong(); + } else if (name.equals(EXPIRATION)) { + Date date = utcToDate(reader.nextString()); + token.setExpiration(date); + } else if (name.equals(VALUE)) { + String value = reader.nextString(); + try { + token.setJwt(JWTParser.parse(value)); + } catch (ParseException ex) { + logger.error("Unable to set refresh token value to {}", value, ex); + } + } else if (name.equals(CLIENT_ID)) { + clientId = reader.nextString(); + } else if (name.equals(AUTHENTICATION_HOLDER_ID)) { + authHolderId = reader.nextLong(); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + Long newId = tokenRepository.saveRefreshToken(token).getId(); + refreshTokenToClientRefs.put(currentId, clientId); + refreshTokenToAuthHolderRefs.put(currentId, authHolderId); + refreshTokenOldToNewIdMap.put(currentId, newId); + logger.debug("Read refresh token {}", currentId); + } + reader.endArray(); + logger.info("Done reading refresh tokens"); + } + private Map accessTokenToClientRefs = new HashMap(); + private Map accessTokenToAuthHolderRefs = new HashMap(); + private Map accessTokenToRefreshTokenRefs = new HashMap(); + private Map accessTokenToIdTokenRefs = new HashMap(); + private Map accessTokenOldToNewIdMap = new HashMap(); + + /** + * @param reader + * @throws IOException + */ + /** + * @param reader + * @throws IOException + */ + private void readAccessTokens(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + OAuth2AccessTokenEntity token = new OAuth2AccessTokenEntity(); + reader.beginObject(); + Long currentId = null; + String clientId = null; + Long authHolderId = null; + Long refreshTokenId = null; + Long idTokenId = null; + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(ID)) { + currentId = reader.nextLong(); + } else if (name.equals(EXPIRATION)) { + Date date = utcToDate(reader.nextString()); + token.setExpiration(date); + } else if (name.equals(VALUE)) { + String value = reader.nextString(); + try { + // all tokens are JWTs + token.setJwt(JWTParser.parse(value)); + } catch (ParseException ex) { + logger.error("Unable to set refresh token value to {}", value, ex); + } + } else if (name.equals(CLIENT_ID)) { + clientId = reader.nextString(); + } else if (name.equals(AUTHENTICATION_HOLDER_ID)) { + authHolderId = reader.nextLong(); + } else if (name.equals(REFRESH_TOKEN_ID)) { + refreshTokenId = reader.nextLong(); + } else if (name.equals(ID_TOKEN_ID)) { + idTokenId = reader.nextLong(); + } else if (name.equals(SCOPE)) { + Set scope = readSet(reader); + token.setScope(scope); + } else if (name.equals(TYPE)) { + token.setTokenType(reader.nextString()); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + Long newId = tokenRepository.saveAccessToken(token).getId(); + accessTokenToClientRefs.put(currentId, clientId); + accessTokenToAuthHolderRefs.put(currentId, authHolderId); + if (refreshTokenId != null) { + accessTokenToRefreshTokenRefs.put(currentId, refreshTokenId); + } + if (idTokenId != null) { + accessTokenToIdTokenRefs.put(currentId, idTokenId); + } + accessTokenOldToNewIdMap.put(currentId, newId); + logger.debug("Read access token {}", currentId); + } + reader.endArray(); + logger.info("Done reading access tokens"); + } + private Map authHolderOldToNewIdMap = new HashMap(); + + /** + * @param reader + * @throws IOException + */ + private void readAuthenticationHolders(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + AuthenticationHolderEntity ahe = new AuthenticationHolderEntity(); + reader.beginObject(); + Long currentId = null; + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(ID)) { + currentId = reader.nextLong(); + } else if (name.equals(REQUEST_PARAMETERS)) { + ahe.setRequestParameters(readMap(reader)); + } else if (name.equals(CLIENT_ID)) { + ahe.setClientId(reader.nextString()); + } else if (name.equals(SCOPE)) { + ahe.setScope(readSet(reader)); + } else if (name.equals(RESOURCE_IDS)) { + ahe.setResourceIds(readSet(reader)); + } else if (name.equals(AUTHORITIES)) { + Set authorityStrs = readSet(reader); + Set authorities = new HashSet(); + for (String s : authorityStrs) { + GrantedAuthority ga = new SimpleGrantedAuthority(s); + authorities.add(ga); + } + ahe.setAuthorities(authorities); + } else if (name.equals(APPROVED)) { + ahe.setApproved(reader.nextBoolean()); + } else if (name.equals(REDIRECT_URI)) { + ahe.setRedirectUri(reader.nextString()); + } else if (name.equals(RESPONSE_TYPES)) { + ahe.setResponseTypes(readSet(reader)); + } else if (name.equals(EXTENSIONS)) { + ahe.setExtensions(readMap(reader)); + } else if (name.equals(SAVED_USER_AUTHENTICATION)) { + ahe.setUserAuth(readSavedUserAuthentication(reader)); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + Long newId = authHolderRepository.save(ahe).getId(); + authHolderOldToNewIdMap.put(currentId, newId); + logger.debug("Read authentication holder {}", currentId); + } + reader.endArray(); + logger.info("Done reading authentication holders"); + } + + /** + * @param reader + * @return + * @throws IOException + */ + private SavedUserAuthentication readSavedUserAuthentication(JsonReader reader) throws IOException { + SavedUserAuthentication savedUserAuth = new SavedUserAuthentication(); + reader.beginObject(); + + while (reader.hasNext()) { + switch(reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(NAME)) { + savedUserAuth.setName(reader.nextString()); + } else if (name.equals(SOURCE_CLASS)) { + savedUserAuth.setSourceClass(reader.nextString()); + } else if (name.equals(AUTHENTICATED)) { + savedUserAuth.setAuthenticated(reader.nextBoolean()); + } else if (name.equals(AUTHORITIES)) { + Set authorityStrs = readSet(reader); + Set authorities = new HashSet(); + for (String s : authorityStrs) { + GrantedAuthority ga = new SimpleGrantedAuthority(s); + authorities.add(ga); + } + savedUserAuth.setAuthorities(authorities); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + + reader.endObject(); + return savedUserAuth; + } + + Map grantOldToNewIdMap = new HashMap<>(); + Map> grantToAccessTokensRefs = new HashMap<>(); + + /** + * @param reader + * @throws IOException + */ + private void readGrants(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + ApprovedSite site = new ApprovedSite(); + Long currentId = null; + Set tokenIds = null; + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(ID)) { + currentId = reader.nextLong(); + } else if (name.equals(ACCESS_DATE)) { + Date date = utcToDate(reader.nextString()); + site.setAccessDate(date); + } else if (name.equals(CLIENT_ID)) { + site.setClientId(reader.nextString()); + } else if (name.equals(CREATION_DATE)) { + Date date = utcToDate(reader.nextString()); + site.setCreationDate(date); + } else if (name.equals(TIMEOUT_DATE)) { + Date date = utcToDate(reader.nextString()); + site.setTimeoutDate(date); + } else if (name.equals(USER_ID)) { + site.setUserId(reader.nextString()); + } else if (name.equals(ALLOWED_SCOPES)) { + Set allowedScopes = readSet(reader); + site.setAllowedScopes(allowedScopes); + } else if (name.equals(APPROVED_ACCESS_TOKENS)) { + tokenIds = readSet(reader); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + Long newId = approvedSiteRepository.save(site).getId(); + grantOldToNewIdMap.put(currentId, newId); + if (tokenIds != null) { + grantToAccessTokensRefs.put(currentId, tokenIds); + } + logger.debug("Read grant {}", currentId); + } + reader.endArray(); + logger.info("Done reading grants"); + } + Map whitelistedSiteOldToNewIdMap = new HashMap(); + + /** + * @param reader + * @throws IOException + */ + private void readWhitelistedSites(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + WhitelistedSite wlSite = new WhitelistedSite(); + Long currentId = null; + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (name.equals(ID)) { + currentId = reader.nextLong(); + } else if (name.equals(CLIENT_ID)) { + wlSite.setClientId(reader.nextString()); + } else if (name.equals(CREATOR_USER_ID)) { + wlSite.setCreatorUserId(reader.nextString()); + } else if (name.equals(ALLOWED_SCOPES)) { + Set allowedScopes = readSet(reader); + wlSite.setAllowedScopes(allowedScopes); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + Long newId = wlSiteRepository.save(wlSite).getId(); + whitelistedSiteOldToNewIdMap.put(currentId, newId); + } + reader.endArray(); + logger.info("Done reading whitelisted sites"); + } + + /** + * @param reader + * @throws IOException + */ + private void readBlacklistedSites(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + BlacklistedSite blSite = new BlacklistedSite(); + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (name.equals(ID)) { + reader.skipValue(); + } else if (name.equals(URI)) { + blSite.setUri(reader.nextString()); + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + blSiteRepository.save(blSite); + } + reader.endArray(); + logger.info("Done reading blacklisted sites"); + } + + /** + * @param reader + * @throws IOException + */ + private void readClients(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + ClientDetailsEntity client = new ClientDetailsEntity(); + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(CLIENT_ID)) { + client.setClientId(reader.nextString()); + } else if (name.equals(RESOURCE_IDS)) { + Set resourceIds = readSet(reader); + client.setResourceIds(resourceIds); + } else if (name.equals(SECRET)) { + client.setClientSecret(reader.nextString()); + } else if (name.equals(SCOPE)) { + Set scope = readSet(reader); + client.setScope(scope); + } else if (name.equals(AUTHORITIES)) { + Set authorityStrs = readSet(reader); + Set authorities = new HashSet(); + for (String s : authorityStrs) { + GrantedAuthority ga = new SimpleGrantedAuthority(s); + authorities.add(ga); + } + client.setAuthorities(authorities); + } else if (name.equals(ACCESS_TOKEN_VALIDITY_SECONDS)) { + client.setAccessTokenValiditySeconds(reader.nextInt()); + } else if (name.equals(REFRESH_TOKEN_VALIDITY_SECONDS)) { + client.setRefreshTokenValiditySeconds(reader.nextInt()); + } else if (name.equals(REDIRECT_URIS)) { + Set redirectUris = readSet(reader); + client.setRedirectUris(redirectUris); + } else if (name.equals(CLAIMS_REDIRECT_URIS)) { + Set claimsRedirectUris = readSet(reader); + client.setClaimsRedirectUris(claimsRedirectUris); + } else if (name.equals(NAME)) { + client.setClientName(reader.nextString()); + } else if (name.equals(URI)) { + client.setClientUri(reader.nextString()); + } else if (name.equals(LOGO_URI)) { + client.setLogoUri(reader.nextString()); + } else if (name.equals(CONTACTS)) { + Set contacts = readSet(reader); + client.setContacts(contacts); + } else if (name.equals(TOS_URI)) { + client.setTosUri(reader.nextString()); + } else if (name.equals(TOKEN_ENDPOINT_AUTH_METHOD)) { + AuthMethod am = AuthMethod.getByValue(reader.nextString()); + client.setTokenEndpointAuthMethod(am); + } else if (name.equals(GRANT_TYPES)) { + Set grantTypes = readSet(reader); + client.setGrantTypes(grantTypes); + } else if (name.equals(RESPONSE_TYPES)) { + Set responseTypes = readSet(reader); + client.setResponseTypes(responseTypes); + } else if (name.equals(POLICY_URI)) { + client.setPolicyUri(reader.nextString()); + } else if (name.equals(APPLICATION_TYPE)) { + AppType appType = AppType.getByValue(reader.nextString()); + client.setApplicationType(appType); + } else if (name.equals(SECTOR_IDENTIFIER_URI)) { + client.setSectorIdentifierUri(reader.nextString()); + } else if (name.equals(SUBJECT_TYPE)) { + SubjectType st = SubjectType.getByValue(reader.nextString()); + client.setSubjectType(st); + } else if (name.equals(JWKS_URI)) { + client.setJwksUri(reader.nextString()); + } else if (name.equals(JWKS)) { + try { + client.setJwks(JWKSet.parse(reader.nextString())); + } catch (ParseException e) { + logger.error("Couldn't parse JWK Set", e); + } + } else if (name.equals(REQUEST_OBJECT_SIGNING_ALG)) { + JWSAlgorithm alg = JWSAlgorithm.parse(reader.nextString()); + client.setRequestObjectSigningAlg(alg); + } else if (name.equals(USER_INFO_ENCRYPTED_RESPONSE_ALG)) { + JWEAlgorithm alg = JWEAlgorithm.parse(reader.nextString()); + client.setUserInfoEncryptedResponseAlg(alg); + } else if (name.equals(USER_INFO_ENCRYPTED_RESPONSE_ENC)) { + EncryptionMethod alg = EncryptionMethod.parse(reader.nextString()); + client.setUserInfoEncryptedResponseEnc(alg); + } else if (name.equals(USER_INFO_SIGNED_RESPONSE_ALG)) { + JWSAlgorithm alg = JWSAlgorithm.parse(reader.nextString()); + client.setUserInfoSignedResponseAlg(alg); + } else if (name.equals(ID_TOKEN_SIGNED_RESPONSE_ALG)) { + JWSAlgorithm alg = JWSAlgorithm.parse(reader.nextString()); + client.setIdTokenSignedResponseAlg(alg); + } else if (name.equals(ID_TOKEN_ENCRYPTED_RESPONSE_ALG)) { + JWEAlgorithm alg = JWEAlgorithm.parse(reader.nextString()); + client.setIdTokenEncryptedResponseAlg(alg); + } else if (name.equals(ID_TOKEN_ENCRYPTED_RESPONSE_ENC)) { + EncryptionMethod alg = EncryptionMethod.parse(reader.nextString()); + client.setIdTokenEncryptedResponseEnc(alg); + } else if (name.equals(TOKEN_ENDPOINT_AUTH_SIGNING_ALG)) { + JWSAlgorithm alg = JWSAlgorithm.parse(reader.nextString()); + client.setTokenEndpointAuthSigningAlg(alg); + } else if (name.equals(DEFAULT_MAX_AGE)) { + client.setDefaultMaxAge(reader.nextInt()); + } else if (name.equals(REQUIRE_AUTH_TIME)) { + client.setRequireAuthTime(reader.nextBoolean()); + } else if (name.equals(DEFAULT_ACR_VALUES)) { + Set defaultACRvalues = readSet(reader); + client.setDefaultACRvalues(defaultACRvalues); + } else if (name.equals("initiateLoginUri")) { + client.setInitiateLoginUri(reader.nextString()); + } else if (name.equals(POST_LOGOUT_REDIRECT_URI)) { + Set postLogoutUris = readSet(reader); + client.setPostLogoutRedirectUris(postLogoutUris); + } else if (name.equals(REQUEST_URIS)) { + Set requestUris = readSet(reader); + client.setRequestUris(requestUris); + } else if (name.equals(DESCRIPTION)) { + client.setClientDescription(reader.nextString()); + } else if (name.equals(ALLOW_INTROSPECTION)) { + client.setAllowIntrospection(reader.nextBoolean()); + } else if (name.equals(REUSE_REFRESH_TOKEN)) { + client.setReuseRefreshToken(reader.nextBoolean()); + } else if (name.equals(CLEAR_ACCESS_TOKENS_ON_REFRESH)) { + client.setClearAccessTokensOnRefresh(reader.nextBoolean()); + } else if (name.equals(DYNAMICALLY_REGISTERED)) { + client.setDynamicallyRegistered(reader.nextBoolean()); + } else if (name.equals(CODE_CHALLENGE_METHOD)) { + client.setCodeChallengeMethod(PKCEAlgorithm.parse(reader.nextString())); + } else if (name.equals(SOFTWARE_STATEMENT)) { + try { + client.setSoftwareStatement(JWTParser.parse(reader.nextString())); + } catch (ParseException e) { + logger.error("Couldn't parse software statement", e); + } + } else { + logger.debug("Found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + clientRepository.saveClient(client); + } + reader.endArray(); + logger.info("Done reading clients"); + } + + /** + * Read the list of system scopes from the reader and insert them into the + * scope repository. + * + * @param reader + * @throws IOException + */ + private void readSystemScopes(JsonReader reader) throws IOException { + reader.beginArray(); + while (reader.hasNext()) { + SystemScope scope = new SystemScope(); + reader.beginObject(); + while (reader.hasNext()) { + switch (reader.peek()) { + case END_OBJECT: + continue; + case NAME: + String name = reader.nextName(); + if (reader.peek() == JsonToken.NULL) { + reader.skipValue(); + } else if (name.equals(VALUE)) { + scope.setValue(reader.nextString()); + } else if (name.equals(DESCRIPTION)) { + scope.setDescription(reader.nextString()); + } else if (name.equals(RESTRICTED)) { + scope.setRestricted(reader.nextBoolean()); + } else if (name.equals(DEFAULT_SCOPE)) { + scope.setDefaultScope(reader.nextBoolean()); + } else if (name.equals(ICON)) { + scope.setIcon(reader.nextString()); + } else if (name.equals(STRUCTURED)) { + scope.setStructured(reader.nextBoolean()); + } else if (name.equals(STRUCTURED_PARAMETER)) { + scope.setStructuredParamDescription(reader.nextString()); + } else { + logger.debug("found unexpected entry"); + reader.skipValue(); + } + break; + default: + logger.debug("Found unexpected entry"); + reader.skipValue(); + continue; + } + } + reader.endObject(); + sysScopeRepository.save(scope); + } + reader.endArray(); + logger.info("Done reading system scopes"); + } + + private void fixObjectReferences() { + logger.info("Fixing object references..."); + for (Long oldRefreshTokenId : refreshTokenToClientRefs.keySet()) { + String clientRef = refreshTokenToClientRefs.get(oldRefreshTokenId); + ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); + Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); + OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); + refreshToken.setClient(client); + tokenRepository.saveRefreshToken(refreshToken); + } + refreshTokenToClientRefs.clear(); + for (Long oldRefreshTokenId : refreshTokenToAuthHolderRefs.keySet()) { + Long oldAuthHolderId = refreshTokenToAuthHolderRefs.get(oldRefreshTokenId); + Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); + AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); + Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); + OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); + refreshToken.setAuthenticationHolder(authHolder); + tokenRepository.saveRefreshToken(refreshToken); + } + refreshTokenToAuthHolderRefs.clear(); + for (Long oldAccessTokenId : accessTokenToClientRefs.keySet()) { + String clientRef = accessTokenToClientRefs.get(oldAccessTokenId); + ClientDetailsEntity client = clientRepository.getClientByClientId(clientRef); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setClient(client); + tokenRepository.saveAccessToken(accessToken); + } + accessTokenToClientRefs.clear(); + for (Long oldAccessTokenId : accessTokenToAuthHolderRefs.keySet()) { + Long oldAuthHolderId = accessTokenToAuthHolderRefs.get(oldAccessTokenId); + Long newAuthHolderId = authHolderOldToNewIdMap.get(oldAuthHolderId); + AuthenticationHolderEntity authHolder = authHolderRepository.getById(newAuthHolderId); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setAuthenticationHolder(authHolder); + tokenRepository.saveAccessToken(accessToken); + } + accessTokenToAuthHolderRefs.clear(); + for (Long oldAccessTokenId : accessTokenToRefreshTokenRefs.keySet()) { + Long oldRefreshTokenId = accessTokenToRefreshTokenRefs.get(oldAccessTokenId); + Long newRefreshTokenId = refreshTokenOldToNewIdMap.get(oldRefreshTokenId); + OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenById(newRefreshTokenId); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setRefreshToken(refreshToken); + tokenRepository.saveAccessToken(accessToken); + } + accessTokenToRefreshTokenRefs.clear(); + refreshTokenOldToNewIdMap.clear(); + for (Long oldAccessTokenId : accessTokenToIdTokenRefs.keySet()) { + Long oldIdTokenId = accessTokenToIdTokenRefs.get(oldAccessTokenId); + Long newIdTokenId = accessTokenOldToNewIdMap.get(oldIdTokenId); + OAuth2AccessTokenEntity idToken = tokenRepository.getAccessTokenById(newIdTokenId); + Long newAccessTokenId = accessTokenOldToNewIdMap.get(oldAccessTokenId); + OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenById(newAccessTokenId); + accessToken.setIdToken(idToken); + tokenRepository.saveAccessToken(accessToken); + } + accessTokenToIdTokenRefs.clear(); + for (Long oldGrantId : grantToAccessTokensRefs.keySet()) { + Set oldAccessTokenIds = grantToAccessTokensRefs.get(oldGrantId); + Set tokens = new HashSet(); + for(Long oldTokenId : oldAccessTokenIds) { + Long newTokenId = accessTokenOldToNewIdMap.get(oldTokenId); + tokens.add(tokenRepository.getAccessTokenById(newTokenId)); + } + Long newGrantId = grantOldToNewIdMap.get(oldGrantId); + ApprovedSite site = approvedSiteRepository.getById(newGrantId); + site.setApprovedAccessTokens(tokens); + approvedSiteRepository.save(site); + } + accessTokenOldToNewIdMap.clear(); + grantOldToNewIdMap.clear(); + logger.info("Done fixing object references."); + } + +} diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java index 509ccc60c..15a78c04a 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/web/DataAPI.java @@ -29,6 +29,7 @@ import org.mitre.openid.connect.service.MITREidDataService; import org.mitre.openid.connect.service.impl.MITREidDataService_1_0; import org.mitre.openid.connect.service.impl.MITREidDataService_1_1; import org.mitre.openid.connect.service.impl.MITREidDataService_1_2; +import org.mitre.openid.connect.service.impl.MITREidDataService_1_3; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -75,6 +76,9 @@ public class DataAPI { @Autowired private MITREidDataService_1_2 dataService_1_2; + + @Autowired + private MITREidDataService_1_3 dataService_1_3; @RequestMapping(method = RequestMethod.POST, consumes = MediaType.APPLICATION_JSON_VALUE) public String importData(Reader in, Model m) throws IOException { @@ -94,6 +98,8 @@ public class DataAPI { dataService_1_1.importData(reader); } else if (name.equals(MITREidDataService.MITREID_CONNECT_1_2)) { dataService_1_2.importData(reader); + } else if (name.equals(MITREidDataService.MITREID_CONNECT_1_3)) { + dataService_1_3.importData(reader); } else { // consume the next bit silently for now logger.debug("Skipping value for " + name); // TODO: write these out? @@ -134,7 +140,7 @@ public class DataAPI { writer.value(prin.getName()); // delegate to the service to do the actual export - dataService_1_2.exportData(writer); + dataService_1_3.exportData(writer); writer.endObject(); // end root writer.close(); diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_2.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_2.java index e9b612f01..7073671cc 100644 --- a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_2.java +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_2.java @@ -149,121 +149,6 @@ public class TestMITREidDataService_1_2 { Mockito.reset(clientRepository, approvedSiteRepository, authHolderRepository, tokenRepository, sysScopeRepository, wlSiteRepository, blSiteRepository); } - @Test - public void testExportRefreshTokens() throws IOException, ParseException { - String expiration1 = "2014-09-10T22:49:44.090+0000"; - Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); - - ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); - when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); - - AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); - when(mockedAuthHolder1.getId()).thenReturn(1L); - - OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); - token1.setId(1L); - token1.setClient(mockedClient1); - token1.setExpiration(expirationDate1); - token1.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.")); - token1.setAuthenticationHolder(mockedAuthHolder1); - - String expiration2 = "2015-01-07T18:31:50.079+0000"; - Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); - - ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); - when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); - - AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); - when(mockedAuthHolder2.getId()).thenReturn(2L); - - OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); - token2.setId(2L); - token2.setClient(mockedClient2); - token2.setExpiration(expirationDate2); - token2.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.")); - token2.setAuthenticationHolder(mockedAuthHolder2); - - Set allRefreshTokens = ImmutableSet.of(token1, token2); - - Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(allRefreshTokens); - Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - - // check our refresh token list (this test) - JsonArray refreshTokens = config.get(MITREidDataService.REFRESHTOKENS).getAsJsonArray(); - - assertThat(refreshTokens.size(), is(2)); - // check for both of our refresh tokens in turn - Set checked = new HashSet<>(); - for (JsonElement e : refreshTokens) { - assertThat(e.isJsonObject(), is(true)); - JsonObject token = e.getAsJsonObject(); - - OAuth2RefreshTokenEntity compare = null; - if (token.get("id").getAsLong() == token1.getId()) { - compare = token1; - } else if (token.get("id").getAsLong() == token2.getId()) { - compare = token2; - } - - if (compare == null) { - fail("Could not find matching id: " + token.get("id").getAsString()); - } else { - assertThat(token.get("id").getAsLong(), equalTo(compare.getId())); - assertThat(token.get("clientId").getAsString(), equalTo(compare.getClient().getClientId())); - assertThat(token.get("expiration").getAsString(), equalTo(formatter.print(compare.getExpiration(), Locale.ENGLISH))); - assertThat(token.get("value").getAsString(), equalTo(compare.getValue())); - assertThat(token.get("authenticationHolderId").getAsLong(), equalTo(compare.getAuthenticationHolder().getId())); - checked.add(compare); - } - } - // make sure all of our refresh tokens were found - assertThat(checked.containsAll(allRefreshTokens), is(true)); - } - private class refreshTokenIdComparator implements Comparator { @Override public int compare(OAuth2RefreshTokenEntity entity1, OAuth2RefreshTokenEntity entity2) { @@ -384,143 +269,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedRefreshTokens.get(1).getValue(), equalTo(token2.getValue())); } - @Test - public void testExportAccessTokens() throws IOException, ParseException { - String expiration1 = "2014-09-10T22:49:44.090+0000"; - Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); - - ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); - when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); - - AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); - when(mockedAuthHolder1.getId()).thenReturn(1L); - - OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); - token1.setId(1L); - token1.setClient(mockedClient1); - token1.setExpiration(expirationDate1); - token1.setJwt(JWTParser.parse("eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3ODk5NjgsInN1YiI6IjkwMzQyLkFTREZKV0ZBIiwiYXRfaGFzaCI6InptTmt1QmNRSmNYQktNaVpFODZqY0EiLCJhdWQiOlsiY2xpZW50Il0sImlzcyI6Imh0dHA6XC9cL2xvY2FsaG9zdDo4MDgwXC9vcGVuaWQtY29ubmVjdC1zZXJ2ZXItd2ViYXBwXC8iLCJpYXQiOjE0MTI3ODkzNjh9.xkEJ9IMXpH7qybWXomfq9WOOlpGYnrvGPgey9UQ4GLzbQx7JC0XgJK83PmrmBZosvFPCmota7FzI_BtwoZLgAZfFiH6w3WIlxuogoH-TxmYbxEpTHoTsszZppkq9mNgOlArV4jrR9y3TPo4MovsH71dDhS_ck-CvAlJunHlqhs0")); - token1.setAuthenticationHolder(mockedAuthHolder1); - token1.setScope(ImmutableSet.of("id-token")); - token1.setTokenType("Bearer"); - - String expiration2 = "2015-01-07T18:31:50.079+0000"; - Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); - - ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); - when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); - - AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); - when(mockedAuthHolder2.getId()).thenReturn(2L); - - OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); - when(mockRefreshToken2.getId()).thenReturn(1L); - - OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); - token2.setId(2L); - token2.setClient(mockedClient2); - token2.setExpiration(expirationDate2); - token2.setJwt(JWTParser.parse("eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3OTI5NjgsImF1ZCI6WyJjbGllbnQiXSwiaXNzIjoiaHR0cDpcL1wvbG9jYWxob3N0OjgwODBcL29wZW5pZC1jb25uZWN0LXNlcnZlci13ZWJhcHBcLyIsImp0aSI6IjBmZGE5ZmRiLTYyYzItNGIzZS05OTdiLWU0M2VhMDUwMzNiOSIsImlhdCI6MTQxMjc4OTM2OH0.xgaVpRLYE5MzbgXfE0tZt823tjAm6Oh3_kdR1P2I9jRLR6gnTlBQFlYi3Y_0pWNnZSerbAE8Tn6SJHZ9k-curVG0-ByKichV7CNvgsE5X_2wpEaUzejvKf8eZ-BammRY-ie6yxSkAarcUGMvGGOLbkFcz5CtrBpZhfd75J49BIQ")); - token2.setAuthenticationHolder(mockedAuthHolder2); - token2.setIdToken(token1); - token2.setRefreshToken(mockRefreshToken2); - token2.setScope(ImmutableSet.of("openid", "offline_access", "email", "profile")); - token2.setTokenType("Bearer"); - - Set allAccessTokens = ImmutableSet.of(token1, token2); - - Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(allAccessTokens); - Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - - // check our access token list (this test) - JsonArray accessTokens = config.get(MITREidDataService.ACCESSTOKENS).getAsJsonArray(); - - assertThat(accessTokens.size(), is(2)); - // check for both of our access tokens in turn - Set checked = new HashSet<>(); - for (JsonElement e : accessTokens) { - assertTrue(e.isJsonObject()); - JsonObject token = e.getAsJsonObject(); - - OAuth2AccessTokenEntity compare = null; - if (token.get("id").getAsLong() == token1.getId().longValue()) { - compare = token1; - } else if (token.get("id").getAsLong() == token2.getId().longValue()) { - compare = token2; - } - - if (compare == null) { - fail("Could not find matching id: " + token.get("id").getAsString()); - } else { - assertThat(token.get("id").getAsLong(), equalTo(compare.getId())); - assertThat(token.get("clientId").getAsString(), equalTo(compare.getClient().getClientId())); - assertThat(token.get("expiration").getAsString(), equalTo(formatter.print(compare.getExpiration(), Locale.ENGLISH))); - assertThat(token.get("value").getAsString(), equalTo(compare.getValue())); - assertThat(token.get("type").getAsString(), equalTo(compare.getTokenType())); - assertThat(token.get("authenticationHolderId").getAsLong(), equalTo(compare.getAuthenticationHolder().getId())); - assertTrue(token.get("scope").isJsonArray()); - assertThat(jsonArrayToStringSet(token.getAsJsonArray("scope")), equalTo(compare.getScope())); - if(token.get("idTokenId").isJsonNull()) { - assertNull(compare.getIdToken()); - } else { - assertThat(token.get("idTokenId").getAsLong(), equalTo(compare.getIdToken().getId())); - } - if(token.get("refreshTokenId").isJsonNull()) { - assertNull(compare.getIdToken()); - } else { - assertThat(token.get("refreshTokenId").getAsLong(), equalTo(compare.getRefreshToken().getId())); - } - checked.add(compare); - } - } - // make sure all of our access tokens were found - assertThat(checked.containsAll(allAccessTokens), is(true)); - } - private class accessTokenIdComparator implements Comparator { @Override public int compare(OAuth2AccessTokenEntity entity1, OAuth2AccessTokenEntity entity2) { @@ -653,111 +401,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedAccessTokens.get(1).getValue(), equalTo(token2.getValue())); } - @Test - public void testExportClients() throws IOException { - ClientDetailsEntity client1 = new ClientDetailsEntity(); - client1.setId(1L); - client1.setAccessTokenValiditySeconds(3600); - client1.setClientId("client1"); - client1.setClientSecret("clientsecret1"); - client1.setRedirectUris(ImmutableSet.of("http://foo.com/")); - client1.setScope(ImmutableSet.of("foo", "bar", "baz", "dolphin")); - client1.setGrantTypes(ImmutableSet.of("implicit", "authorization_code", "urn:ietf:params:oauth:grant_type:redelegate", "refresh_token")); - client1.setAllowIntrospection(true); - - ClientDetailsEntity client2 = new ClientDetailsEntity(); - client2.setId(2L); - client2.setAccessTokenValiditySeconds(3600); - client2.setClientId("client2"); - client2.setClientSecret("clientsecret2"); - client2.setRedirectUris(ImmutableSet.of("http://bar.baz.com/")); - client2.setScope(ImmutableSet.of("foo", "dolphin", "electric-wombat")); - client2.setGrantTypes(ImmutableSet.of("client_credentials", "urn:ietf:params:oauth:grant_type:redelegate")); - client2.setAllowIntrospection(false); - - Set allClients = ImmutableSet.of(client1, client2); - - Mockito.when(clientRepository.getAllClients()).thenReturn(allClients); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - - // check our client list (this test) - JsonArray clients = config.get(MITREidDataService.CLIENTS).getAsJsonArray(); - - assertThat(clients.size(), is(2)); - // check for both of our clients in turn - Set checked = new HashSet<>(); - for (JsonElement e : clients) { - assertThat(e.isJsonObject(), is(true)); - JsonObject client = e.getAsJsonObject(); - - ClientDetailsEntity compare = null; - if (client.get("clientId").getAsString().equals(client1.getClientId())) { - compare = client1; - } else if (client.get("clientId").getAsString().equals(client2.getClientId())) { - compare = client2; - } - - if (compare == null) { - fail("Could not find matching clientId: " + client.get("clientId").getAsString()); - } else { - assertThat(client.get("clientId").getAsString(), equalTo(compare.getClientId())); - assertThat(client.get("secret").getAsString(), equalTo(compare.getClientSecret())); - assertThat(client.get("accessTokenValiditySeconds").getAsInt(), equalTo(compare.getAccessTokenValiditySeconds())); - assertThat(client.get("allowIntrospection").getAsBoolean(), equalTo(compare.isAllowIntrospection())); - assertThat(jsonArrayToStringSet(client.get("redirectUris").getAsJsonArray()), equalTo(compare.getRedirectUris())); - assertThat(jsonArrayToStringSet(client.get("scope").getAsJsonArray()), equalTo(compare.getScope())); - assertThat(jsonArrayToStringSet(client.get("grantTypes").getAsJsonArray()), equalTo(compare.getGrantTypes())); - checked.add(compare); - } - } - // make sure all of our clients were found - assertThat(checked.containsAll(allClients), is(true)); - } - @Test public void testImportClients() throws IOException { ClientDetailsEntity client1 = new ClientDetailsEntity(); @@ -832,99 +475,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedClients.get(1).isAllowIntrospection(), equalTo(client2.isAllowIntrospection())); } - @Test - public void testExportBlacklistedSites() throws IOException { - BlacklistedSite site1 = new BlacklistedSite(); - site1.setId(1L); - site1.setUri("http://foo.com"); - - BlacklistedSite site2 = new BlacklistedSite(); - site2.setId(2L); - site2.setUri("http://bar.com"); - - BlacklistedSite site3 = new BlacklistedSite(); - site3.setId(3L); - site3.setUri("http://baz.com"); - - Set allBlacklistedSites = ImmutableSet.of(site1, site2, site3); - - Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(blSiteRepository.getAll()).thenReturn(allBlacklistedSites); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - // check our scope list (this test) - JsonArray sites = config.get(MITREidDataService.BLACKLISTEDSITES).getAsJsonArray(); - - assertThat(sites.size(), is(3)); - // check for both of our sites in turn - Set checked = new HashSet<>(); - for (JsonElement e : sites) { - assertThat(e.isJsonObject(), is(true)); - JsonObject site = e.getAsJsonObject(); - - BlacklistedSite compare = null; - if (site.get("id").getAsLong() == site1.getId().longValue()) { - compare = site1; - } else if (site.get("id").getAsLong() == site2.getId().longValue()) { - compare = site2; - } else if (site.get("id").getAsLong() == site3.getId().longValue()) { - compare = site3; - } - - if (compare == null) { - fail("Could not find matching blacklisted site id: " + site.get("id").getAsString()); - } else { - assertThat(site.get("uri").getAsString(), equalTo(compare.getUri())); - checked.add(compare); - } - } - // make sure all of our clients were found - assertThat(checked.containsAll(allBlacklistedSites), is(true)); - - } - @Test public void testImportBlacklistedSites() throws IOException { BlacklistedSite site1 = new BlacklistedSite(); @@ -973,99 +523,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedSites.get(2).getUri(), equalTo(site3.getUri())); } - @Test - public void testExportWhitelistedSites() throws IOException { - WhitelistedSite site1 = new WhitelistedSite(); - site1.setId(1L); - site1.setClientId("foo"); - - WhitelistedSite site2 = new WhitelistedSite(); - site2.setId(2L); - site2.setClientId("bar"); - - WhitelistedSite site3 = new WhitelistedSite(); - site3.setId(3L); - site3.setClientId("baz"); - - Set allWhitelistedSites = ImmutableSet.of(site1, site2, site3); - - Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(allWhitelistedSites); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - // check our scope list (this test) - JsonArray sites = config.get(MITREidDataService.WHITELISTEDSITES).getAsJsonArray(); - - assertThat(sites.size(), is(3)); - // check for both of our sites in turn - Set checked = new HashSet<>(); - for (JsonElement e : sites) { - assertThat(e.isJsonObject(), is(true)); - JsonObject site = e.getAsJsonObject(); - - WhitelistedSite compare = null; - if (site.get("id").getAsLong() == site1.getId().longValue()) { - compare = site1; - } else if (site.get("id").getAsLong() == site2.getId().longValue()) { - compare = site2; - } else if (site.get("id").getAsLong() == site3.getId().longValue()) { - compare = site3; - } - - if (compare == null) { - fail("Could not find matching whitelisted site id: " + site.get("id").getAsString()); - } else { - assertThat(site.get("clientId").getAsString(), equalTo(compare.getClientId())); - checked.add(compare); - } - } - // make sure all of our clients were found - assertThat(checked.containsAll(allWhitelistedSites), is(true)); - - } - @Test public void testImportWhitelistedSites() throws IOException { WhitelistedSite site1 = new WhitelistedSite(); @@ -1135,131 +592,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedSites.get(2).getClientId(), equalTo(site3.getClientId())); } - @Test - public void testExportGrants() throws IOException, ParseException { - Date creationDate1 = formatter.parse("2014-09-10T22:49:44.090+0000", Locale.ENGLISH); - Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); - - OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); - when(mockToken1.getId()).thenReturn(1L); - - ApprovedSite site1 = new ApprovedSite(); - site1.setId(1L); - site1.setClientId("foo"); - site1.setCreationDate(creationDate1); - site1.setAccessDate(accessDate1); - site1.setUserId("user1"); - site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); - site1.setApprovedAccessTokens(ImmutableSet.of(mockToken1)); - - Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); - Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); - Date timeoutDate2 = formatter.parse("2014-10-01T20:49:44.090+0000", Locale.ENGLISH); - - ApprovedSite site2 = new ApprovedSite(); - site2.setId(2L); - site2.setClientId("bar"); - site2.setCreationDate(creationDate2); - site2.setAccessDate(accessDate2); - site2.setUserId("user2"); - site2.setAllowedScopes(ImmutableSet.of("openid", "offline_access", "email", "profile")); - site2.setTimeoutDate(timeoutDate2); - - Set allApprovedSites = ImmutableSet.of(site1, site2); - - Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(allApprovedSites); - Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - // check our scope list (this test) - JsonArray sites = config.get(MITREidDataService.GRANTS).getAsJsonArray(); - - assertThat(sites.size(), is(2)); - // check for both of our sites in turn - Set checked = new HashSet<>(); - for (JsonElement e : sites) { - assertThat(e.isJsonObject(), is(true)); - JsonObject site = e.getAsJsonObject(); - - ApprovedSite compare = null; - if (site.get("id").getAsLong() == site1.getId().longValue()) { - compare = site1; - } else if (site.get("id").getAsLong() == site2.getId().longValue()) { - compare = site2; - } - - if (compare == null) { - fail("Could not find matching whitelisted site id: " + site.get("id").getAsString()); - } else { - assertThat(site.get("clientId").getAsString(), equalTo(compare.getClientId())); - assertThat(site.get("creationDate").getAsString(), equalTo(formatter.print(compare.getCreationDate(), Locale.ENGLISH))); - assertThat(site.get("accessDate").getAsString(), equalTo(formatter.print(compare.getAccessDate(), Locale.ENGLISH))); - if(site.get("timeoutDate").isJsonNull()) { - assertNull(compare.getTimeoutDate()); - } else { - assertThat(site.get("timeoutDate").getAsString(), equalTo(formatter.print(compare.getTimeoutDate(), Locale.ENGLISH))); - } - assertThat(site.get("userId").getAsString(), equalTo(compare.getUserId())); - assertThat(jsonArrayToStringSet(site.getAsJsonArray("allowedScopes")), equalTo(compare.getAllowedScopes())); - if (site.get("approvedAccessTokens").isJsonNull() || site.getAsJsonArray("approvedAccessTokens") == null) { - assertTrue(compare.getApprovedAccessTokens() == null || compare.getApprovedAccessTokens().isEmpty()); - } else { - assertNotNull(compare.getApprovedAccessTokens()); - Set tokenIds = new HashSet<>(); - for(OAuth2AccessTokenEntity entity : compare.getApprovedAccessTokens()) { - tokenIds.add(entity.getId().toString()); - } - assertThat(jsonArrayToStringSet(site.getAsJsonArray("approvedAccessTokens")), equalTo(tokenIds)); - } - checked.add(compare); - } - } - // make sure all of our clients were found - assertThat(checked.containsAll(allApprovedSites), is(true)); - } - @Test public void testImportGrants() throws IOException, ParseException { Date creationDate1 = formatter.parse("2014-09-10T22:49:44.090+0000", Locale.ENGLISH); @@ -1376,113 +708,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedSites.get(1).getApprovedAccessTokens().size(), equalTo(site2.getApprovedAccessTokens().size())); } - @Test - public void testExportAuthenticationHolders() throws IOException { - OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), - true, new HashSet(), new HashSet(), "http://foo.com", - new HashSet(), null); - Authentication mockAuth1 = new UsernamePasswordAuthenticationToken("user1", "pass1", AuthorityUtils.commaSeparatedStringToAuthorityList("ROLE_USER")); - OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); - - AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); - holder1.setId(1L); - holder1.setAuthentication(auth1); - - OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), - true, new HashSet(), new HashSet(), "http://bar.com", - new HashSet(), null); - OAuth2Authentication auth2 = new OAuth2Authentication(req2, null); - - AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); - holder2.setId(2L); - holder2.setAuthentication(auth2); - - List allAuthHolders = ImmutableList.of(holder1, holder2); - - when(clientRepository.getAllClients()).thenReturn(new HashSet()); - when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - when(blSiteRepository.getAll()).thenReturn(new HashSet()); - when(authHolderRepository.getAll()).thenReturn(allAuthHolders); - when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - when(sysScopeRepository.getAll()).thenReturn(new HashSet()); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - - // check our holder list (this test) - JsonArray holders = config.get(MITREidDataService.AUTHENTICATIONHOLDERS).getAsJsonArray(); - - assertThat(holders.size(), is(2)); - // check for both of our clients in turn - Set checked = new HashSet<>(); - for (JsonElement e : holders) { - assertThat(e.isJsonObject(), is(true)); - JsonObject holder = e.getAsJsonObject(); - - AuthenticationHolderEntity compare = null; - if (holder.get("id").getAsLong() == holder1.getId()) { - compare = holder1; - } else if (holder.get("id").getAsLong() == holder2.getId()) { - compare = holder2; - } - - if (compare == null) { - fail("Could not find matching authentication holder id: " + holder.get("id").getAsString()); - } else { - assertTrue(holder.get("clientId").getAsString().equals(compare.getClientId())); - assertTrue(holder.get("approved").getAsBoolean() == compare.isApproved()); - assertTrue(holder.get("redirectUri").getAsString().equals(compare.getRedirectUri())); - if (compare.getUserAuth() != null) { - assertTrue(holder.get("savedUserAuthentication").isJsonObject()); - JsonObject savedAuth = holder.get("savedUserAuthentication").getAsJsonObject(); - assertTrue(savedAuth.get("name").getAsString().equals(compare.getUserAuth().getName())); - assertTrue(savedAuth.get("authenticated").getAsBoolean() == compare.getUserAuth().isAuthenticated()); - assertTrue(savedAuth.get("sourceClass").getAsString().equals(compare.getUserAuth().getSourceClass())); - } - checked.add(compare); - } - } - // make sure all of our clients were found - assertThat(checked.containsAll(allAuthHolders), is(true)); - } - @Test public void testImportAuthenticationHolders() throws IOException { OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), @@ -1550,116 +775,6 @@ public class TestMITREidDataService_1_2 { assertThat(savedAuthHolders.get(1).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder2.getAuthentication().getOAuth2Request().getClientId())); } - @Test - public void testExportSystemScopes() throws IOException { - SystemScope scope1 = new SystemScope(); - scope1.setId(1L); - scope1.setValue("scope1"); - scope1.setDescription("Scope 1"); - scope1.setRestricted(true); - scope1.setDefaultScope(false); - scope1.setIcon("glass"); - - SystemScope scope2 = new SystemScope(); - scope2.setId(2L); - scope2.setValue("scope2"); - scope2.setDescription("Scope 2"); - scope2.setRestricted(false); - scope2.setDefaultScope(false); - scope2.setIcon("ball"); - - SystemScope scope3 = new SystemScope(); - scope3.setId(3L); - scope3.setValue("scope3"); - scope3.setDescription("Scope 3"); - scope3.setRestricted(false); - scope3.setDefaultScope(true); - scope3.setIcon("road"); - - Set allScopes = ImmutableSet.of(scope1, scope2, scope3); - - Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); - Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); - Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); - Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); - Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); - Mockito.when(sysScopeRepository.getAll()).thenReturn(allScopes); - - // do the data export - StringWriter stringWriter = new StringWriter(); - JsonWriter writer = new JsonWriter(stringWriter); - writer.beginObject(); - dataService.exportData(writer); - writer.endObject(); - writer.close(); - - // parse the output as a JSON object for testing - JsonElement elem = new JsonParser().parse(stringWriter.toString()); - JsonObject root = elem.getAsJsonObject(); - - // make sure the root is there - assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_2), is(true)); - - JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_2).getAsJsonObject(); - - // make sure all the root elements are there - assertThat(config.has(MITREidDataService.CLIENTS), is(true)); - assertThat(config.has(MITREidDataService.GRANTS), is(true)); - assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); - assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); - assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); - assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); - assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); - - // make sure the root elements are all arrays - assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); - assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); - - - // check our scope list (this test) - JsonArray scopes = config.get(MITREidDataService.SYSTEMSCOPES).getAsJsonArray(); - - assertThat(scopes.size(), is(3)); - // check for both of our clients in turn - Set checked = new HashSet<>(); - for (JsonElement e : scopes) { - assertThat(e.isJsonObject(), is(true)); - JsonObject scope = e.getAsJsonObject(); - - SystemScope compare = null; - if (scope.get("value").getAsString().equals(scope1.getValue())) { - compare = scope1; - } else if (scope.get("value").getAsString().equals(scope2.getValue())) { - compare = scope2; - } else if (scope.get("value").getAsString().equals(scope3.getValue())) { - compare = scope3; - } - - if (compare == null) { - fail("Could not find matching scope value: " + scope.get("value").getAsString()); - } else { - assertThat(scope.get("value").getAsString(), equalTo(compare.getValue())); - assertThat(scope.get("description").getAsString(), equalTo(compare.getDescription())); - assertThat(scope.get("icon").getAsString(), equalTo(compare.getIcon())); - assertThat(scope.get("restricted").getAsBoolean(), equalTo(compare.isRestricted())); - assertThat(scope.get("defaultScope").getAsBoolean(), equalTo(compare.isDefaultScope())); - checked.add(compare); - } - } - // make sure all of our clients were found - assertThat(checked.containsAll(allScopes), is(true)); - - } - @Test public void testImportSystemScopes() throws IOException { SystemScope scope1 = new SystemScope(); diff --git a/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_3.java b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_3.java new file mode 100644 index 000000000..8ce49b760 --- /dev/null +++ b/openid-connect-server/src/test/java/org/mitre/openid/connect/service/impl/TestMITREidDataService_1_3.java @@ -0,0 +1,1887 @@ +/******************************************************************************* + * Copyright 2016 The MITRE Corporation + * and the MIT Internet Trust Consortium + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ +package org.mitre.openid.connect.service.impl; + +import java.io.IOException; +import java.io.StringReader; +import java.io.StringWriter; +import java.text.ParseException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mitre.oauth2.model.AuthenticationHolderEntity; +import org.mitre.oauth2.model.ClientDetailsEntity; +import org.mitre.oauth2.model.OAuth2AccessTokenEntity; +import org.mitre.oauth2.model.OAuth2RefreshTokenEntity; +import org.mitre.oauth2.model.PKCEAlgorithm; +import org.mitre.oauth2.model.SystemScope; +import org.mitre.oauth2.repository.AuthenticationHolderRepository; +import org.mitre.oauth2.repository.OAuth2ClientRepository; +import org.mitre.oauth2.repository.OAuth2TokenRepository; +import org.mitre.oauth2.repository.SystemScopeRepository; +import org.mitre.openid.connect.model.ApprovedSite; +import org.mitre.openid.connect.model.BlacklistedSite; +import org.mitre.openid.connect.model.WhitelistedSite; +import org.mitre.openid.connect.repository.ApprovedSiteRepository; +import org.mitre.openid.connect.repository.BlacklistedSiteRepository; +import org.mitre.openid.connect.repository.WhitelistedSiteRepository; +import org.mitre.openid.connect.service.MITREidDataService; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.format.annotation.DateTimeFormat.ISO; +import org.springframework.format.datetime.DateFormatter; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.provider.OAuth2Authentication; +import org.springframework.security.oauth2.provider.OAuth2Request; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; +import com.nimbusds.jwt.JWTParser; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.isA; +import static org.mockito.Matchers.isNull; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(MockitoJUnitRunner.class) +@SuppressWarnings(value = {"rawtypes", "unchecked"}) +public class TestMITREidDataService_1_3 { + + private static Logger logger = LoggerFactory.getLogger(TestMITREidDataService_1_3.class); + + @Mock + private OAuth2ClientRepository clientRepository; + @Mock + private ApprovedSiteRepository approvedSiteRepository; + @Mock + private WhitelistedSiteRepository wlSiteRepository; + @Mock + private BlacklistedSiteRepository blSiteRepository; + @Mock + private AuthenticationHolderRepository authHolderRepository; + @Mock + private OAuth2TokenRepository tokenRepository; + @Mock + private SystemScopeRepository sysScopeRepository; + + @Captor + private ArgumentCaptor capturedRefreshTokens; + @Captor + private ArgumentCaptor capturedAccessTokens; + @Captor + private ArgumentCaptor capturedClients; + @Captor + private ArgumentCaptor capturedBlacklistedSites; + @Captor + private ArgumentCaptor capturedWhitelistedSites; + @Captor + private ArgumentCaptor capturedApprovedSites; + @Captor + private ArgumentCaptor capturedAuthHolders; + @Captor + private ArgumentCaptor capturedScope; + + @InjectMocks + private MITREidDataService_1_3 dataService; + private DateFormatter formatter; + + @Before + public void prepare() { + formatter = new DateFormatter(); + formatter.setIso(ISO.DATE_TIME); + + Mockito.reset(clientRepository, approvedSiteRepository, authHolderRepository, tokenRepository, sysScopeRepository, wlSiteRepository, blSiteRepository); + } + + @Test + public void testExportRefreshTokens() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder1.getId()).thenReturn(1L); + + OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.")); + token1.setAuthenticationHolder(mockedAuthHolder1); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder2.getId()).thenReturn(2L); + + OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.")); + token2.setAuthenticationHolder(mockedAuthHolder2); + + Set allRefreshTokens = ImmutableSet.of(token1, token2); + + Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(allRefreshTokens); + Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + + // check our refresh token list (this test) + JsonArray refreshTokens = config.get(MITREidDataService.REFRESHTOKENS).getAsJsonArray(); + + assertThat(refreshTokens.size(), is(2)); + // check for both of our refresh tokens in turn + Set checked = new HashSet<>(); + for (JsonElement e : refreshTokens) { + assertThat(e.isJsonObject(), is(true)); + JsonObject token = e.getAsJsonObject(); + + OAuth2RefreshTokenEntity compare = null; + if (token.get("id").getAsLong() == token1.getId()) { + compare = token1; + } else if (token.get("id").getAsLong() == token2.getId()) { + compare = token2; + } + + if (compare == null) { + fail("Could not find matching id: " + token.get("id").getAsString()); + } else { + assertThat(token.get("id").getAsLong(), equalTo(compare.getId())); + assertThat(token.get("clientId").getAsString(), equalTo(compare.getClient().getClientId())); + assertThat(token.get("expiration").getAsString(), equalTo(formatter.print(compare.getExpiration(), Locale.ENGLISH))); + assertThat(token.get("value").getAsString(), equalTo(compare.getValue())); + assertThat(token.get("authenticationHolderId").getAsLong(), equalTo(compare.getAuthenticationHolder().getId())); + checked.add(compare); + } + } + // make sure all of our refresh tokens were found + assertThat(checked.containsAll(allRefreshTokens), is(true)); + } + + private class refreshTokenIdComparator implements Comparator { + @Override + public int compare(OAuth2RefreshTokenEntity entity1, OAuth2RefreshTokenEntity entity2) { + return entity1.getId().compareTo(entity2.getId()); + } + } + + + @Test + public void testImportRefreshTokens() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder1.getId()).thenReturn(1L); + + OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.")); + token1.setAuthenticationHolder(mockedAuthHolder1); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder2.getId()).thenReturn(2L); + + OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.")); + token2.setAuthenticationHolder(mockedAuthHolder2); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [" + + + "{\"id\":1,\"clientId\":\"mocked_client_1\",\"expiration\":\"2014-09-10T22:49:44.090+0000\"," + + "\"authenticationHolderId\":1,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.\"}," + + "{\"id\":2,\"clientId\":\"mocked_client_2\",\"expiration\":\"2015-01-07T18:31:50.079+0000\"," + + "\"authenticationHolderId\":2,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.\"}" + + + " ]" + + "}"; + + logger.debug(configJson); + JsonReader reader = new JsonReader(new StringReader(configJson)); + + final Map fakeDb = new HashMap<>(); + when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { + Long id = 332L; + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; + if(_token.getId() == null) { + _token.setId(id++); + } + fakeDb.put(_token.getId(), _token); + return _token; + } + }); + when(tokenRepository.getRefreshTokenById(anyLong())).thenAnswer(new Answer() { + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeDb.get(_id); + } + }); + when(clientRepository.getClientByClientId(anyString())).thenAnswer(new Answer() { + @Override + public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { + String _clientId = (String) invocation.getArguments()[0]; + ClientDetailsEntity _client = mock(ClientDetailsEntity.class); + when(_client.getClientId()).thenReturn(_clientId); + return _client; + } + }); + when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { + Long id = 131L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); + when(_auth.getId()).thenReturn(id); + id++; + return _auth; + } + }); + dataService.importData(reader); + //2 times for token, 2 times to update client, 2 times to update authHolder + verify(tokenRepository, times(6)).saveRefreshToken(capturedRefreshTokens.capture()); + + List savedRefreshTokens = new ArrayList(fakeDb.values()); //capturedRefreshTokens.getAllValues(); + Collections.sort(savedRefreshTokens, new refreshTokenIdComparator()); + + assertThat(savedRefreshTokens.size(), is(2)); + + assertThat(savedRefreshTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); + assertThat(savedRefreshTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); + assertThat(savedRefreshTokens.get(0).getValue(), equalTo(token1.getValue())); + + assertThat(savedRefreshTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); + assertThat(savedRefreshTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); + assertThat(savedRefreshTokens.get(1).getValue(), equalTo(token2.getValue())); + } + + @Test + public void testExportAccessTokens() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder1.getId()).thenReturn(1L); + + OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setJwt(JWTParser.parse("eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3ODk5NjgsInN1YiI6IjkwMzQyLkFTREZKV0ZBIiwiYXRfaGFzaCI6InptTmt1QmNRSmNYQktNaVpFODZqY0EiLCJhdWQiOlsiY2xpZW50Il0sImlzcyI6Imh0dHA6XC9cL2xvY2FsaG9zdDo4MDgwXC9vcGVuaWQtY29ubmVjdC1zZXJ2ZXItd2ViYXBwXC8iLCJpYXQiOjE0MTI3ODkzNjh9.xkEJ9IMXpH7qybWXomfq9WOOlpGYnrvGPgey9UQ4GLzbQx7JC0XgJK83PmrmBZosvFPCmota7FzI_BtwoZLgAZfFiH6w3WIlxuogoH-TxmYbxEpTHoTsszZppkq9mNgOlArV4jrR9y3TPo4MovsH71dDhS_ck-CvAlJunHlqhs0")); + token1.setAuthenticationHolder(mockedAuthHolder1); + token1.setScope(ImmutableSet.of("id-token")); + token1.setTokenType("Bearer"); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder2.getId()).thenReturn(2L); + + OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); + when(mockRefreshToken2.getId()).thenReturn(1L); + + OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setJwt(JWTParser.parse("eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3OTI5NjgsImF1ZCI6WyJjbGllbnQiXSwiaXNzIjoiaHR0cDpcL1wvbG9jYWxob3N0OjgwODBcL29wZW5pZC1jb25uZWN0LXNlcnZlci13ZWJhcHBcLyIsImp0aSI6IjBmZGE5ZmRiLTYyYzItNGIzZS05OTdiLWU0M2VhMDUwMzNiOSIsImlhdCI6MTQxMjc4OTM2OH0.xgaVpRLYE5MzbgXfE0tZt823tjAm6Oh3_kdR1P2I9jRLR6gnTlBQFlYi3Y_0pWNnZSerbAE8Tn6SJHZ9k-curVG0-ByKichV7CNvgsE5X_2wpEaUzejvKf8eZ-BammRY-ie6yxSkAarcUGMvGGOLbkFcz5CtrBpZhfd75J49BIQ")); + token2.setAuthenticationHolder(mockedAuthHolder2); + token2.setIdToken(token1); + token2.setRefreshToken(mockRefreshToken2); + token2.setScope(ImmutableSet.of("openid", "offline_access", "email", "profile")); + token2.setTokenType("Bearer"); + + Set allAccessTokens = ImmutableSet.of(token1, token2); + + Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(allAccessTokens); + Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + + // check our access token list (this test) + JsonArray accessTokens = config.get(MITREidDataService.ACCESSTOKENS).getAsJsonArray(); + + assertThat(accessTokens.size(), is(2)); + // check for both of our access tokens in turn + Set checked = new HashSet<>(); + for (JsonElement e : accessTokens) { + assertTrue(e.isJsonObject()); + JsonObject token = e.getAsJsonObject(); + + OAuth2AccessTokenEntity compare = null; + if (token.get("id").getAsLong() == token1.getId().longValue()) { + compare = token1; + } else if (token.get("id").getAsLong() == token2.getId().longValue()) { + compare = token2; + } + + if (compare == null) { + fail("Could not find matching id: " + token.get("id").getAsString()); + } else { + assertThat(token.get("id").getAsLong(), equalTo(compare.getId())); + assertThat(token.get("clientId").getAsString(), equalTo(compare.getClient().getClientId())); + assertThat(token.get("expiration").getAsString(), equalTo(formatter.print(compare.getExpiration(), Locale.ENGLISH))); + assertThat(token.get("value").getAsString(), equalTo(compare.getValue())); + assertThat(token.get("type").getAsString(), equalTo(compare.getTokenType())); + assertThat(token.get("authenticationHolderId").getAsLong(), equalTo(compare.getAuthenticationHolder().getId())); + assertTrue(token.get("scope").isJsonArray()); + assertThat(jsonArrayToStringSet(token.getAsJsonArray("scope")), equalTo(compare.getScope())); + if(token.get("idTokenId").isJsonNull()) { + assertNull(compare.getIdToken()); + } else { + assertThat(token.get("idTokenId").getAsLong(), equalTo(compare.getIdToken().getId())); + } + if(token.get("refreshTokenId").isJsonNull()) { + assertNull(compare.getIdToken()); + } else { + assertThat(token.get("refreshTokenId").getAsLong(), equalTo(compare.getRefreshToken().getId())); + } + checked.add(compare); + } + } + // make sure all of our access tokens were found + assertThat(checked.containsAll(allAccessTokens), is(true)); + } + + private class accessTokenIdComparator implements Comparator { + @Override + public int compare(OAuth2AccessTokenEntity entity1, OAuth2AccessTokenEntity entity2) { + return entity1.getId().compareTo(entity2.getId()); + } + } + + @Test + public void testImportAccessTokens() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + AuthenticationHolderEntity mockedAuthHolder1 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder1.getId()).thenReturn(1L); + + OAuth2AccessTokenEntity token1 = new OAuth2AccessTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setJwt(JWTParser.parse("eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3ODk5NjgsInN1YiI6IjkwMzQyLkFTREZKV0ZBIiwiYXRfaGFzaCI6InptTmt1QmNRSmNYQktNaVpFODZqY0EiLCJhdWQiOlsiY2xpZW50Il0sImlzcyI6Imh0dHA6XC9cL2xvY2FsaG9zdDo4MDgwXC9vcGVuaWQtY29ubmVjdC1zZXJ2ZXItd2ViYXBwXC8iLCJpYXQiOjE0MTI3ODkzNjh9.xkEJ9IMXpH7qybWXomfq9WOOlpGYnrvGPgey9UQ4GLzbQx7JC0XgJK83PmrmBZosvFPCmota7FzI_BtwoZLgAZfFiH6w3WIlxuogoH-TxmYbxEpTHoTsszZppkq9mNgOlArV4jrR9y3TPo4MovsH71dDhS_ck-CvAlJunHlqhs0")); + token1.setAuthenticationHolder(mockedAuthHolder1); + token1.setScope(ImmutableSet.of("id-token")); + token1.setTokenType("Bearer"); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + AuthenticationHolderEntity mockedAuthHolder2 = mock(AuthenticationHolderEntity.class); + when(mockedAuthHolder2.getId()).thenReturn(2L); + + OAuth2RefreshTokenEntity mockRefreshToken2 = mock(OAuth2RefreshTokenEntity.class); + when(mockRefreshToken2.getId()).thenReturn(1L); + + OAuth2AccessTokenEntity token2 = new OAuth2AccessTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setJwt(JWTParser.parse("eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3OTI5NjgsImF1ZCI6WyJjbGllbnQiXSwiaXNzIjoiaHR0cDpcL1wvbG9jYWxob3N0OjgwODBcL29wZW5pZC1jb25uZWN0LXNlcnZlci13ZWJhcHBcLyIsImp0aSI6IjBmZGE5ZmRiLTYyYzItNGIzZS05OTdiLWU0M2VhMDUwMzNiOSIsImlhdCI6MTQxMjc4OTM2OH0.xgaVpRLYE5MzbgXfE0tZt823tjAm6Oh3_kdR1P2I9jRLR6gnTlBQFlYi3Y_0pWNnZSerbAE8Tn6SJHZ9k-curVG0-ByKichV7CNvgsE5X_2wpEaUzejvKf8eZ-BammRY-ie6yxSkAarcUGMvGGOLbkFcz5CtrBpZhfd75J49BIQ")); + token2.setAuthenticationHolder(mockedAuthHolder2); + token2.setIdToken(token1); + token2.setRefreshToken(mockRefreshToken2); + token2.setScope(ImmutableSet.of("openid", "offline_access", "email", "profile")); + token2.setTokenType("Bearer"); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [" + + + "{\"id\":1,\"clientId\":\"mocked_client_1\",\"expiration\":\"2014-09-10T22:49:44.090+0000\"," + + "\"refreshTokenId\":null,\"idTokenId\":null,\"scope\":[\"id-token\"],\"type\":\"Bearer\"," + + "\"authenticationHolderId\":1,\"value\":\"eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3ODk5NjgsInN1YiI6IjkwMzQyLkFTREZKV0ZBIiwiYXRfaGFzaCI6InptTmt1QmNRSmNYQktNaVpFODZqY0EiLCJhdWQiOlsiY2xpZW50Il0sImlzcyI6Imh0dHA6XC9cL2xvY2FsaG9zdDo4MDgwXC9vcGVuaWQtY29ubmVjdC1zZXJ2ZXItd2ViYXBwXC8iLCJpYXQiOjE0MTI3ODkzNjh9.xkEJ9IMXpH7qybWXomfq9WOOlpGYnrvGPgey9UQ4GLzbQx7JC0XgJK83PmrmBZosvFPCmota7FzI_BtwoZLgAZfFiH6w3WIlxuogoH-TxmYbxEpTHoTsszZppkq9mNgOlArV4jrR9y3TPo4MovsH71dDhS_ck-CvAlJunHlqhs0\"}," + + "{\"id\":2,\"clientId\":\"mocked_client_2\",\"expiration\":\"2015-01-07T18:31:50.079+0000\"," + + "\"refreshTokenId\":1,\"idTokenId\":1,\"scope\":[\"openid\",\"offline_access\",\"email\",\"profile\"],\"type\":\"Bearer\"," + + "\"authenticationHolderId\":2,\"value\":\"eyJhbGciOiJSUzI1NiJ9.eyJleHAiOjE0MTI3OTI5NjgsImF1ZCI6WyJjbGllbnQiXSwiaXNzIjoiaHR0cDpcL1wvbG9jYWxob3N0OjgwODBcL29wZW5pZC1jb25uZWN0LXNlcnZlci13ZWJhcHBcLyIsImp0aSI6IjBmZGE5ZmRiLTYyYzItNGIzZS05OTdiLWU0M2VhMDUwMzNiOSIsImlhdCI6MTQxMjc4OTM2OH0.xgaVpRLYE5MzbgXfE0tZt823tjAm6Oh3_kdR1P2I9jRLR6gnTlBQFlYi3Y_0pWNnZSerbAE8Tn6SJHZ9k-curVG0-ByKichV7CNvgsE5X_2wpEaUzejvKf8eZ-BammRY-ie6yxSkAarcUGMvGGOLbkFcz5CtrBpZhfd75J49BIQ\"}" + + + " ]" + + "}"; + + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + final Map fakeDb = new HashMap<>(); + when(tokenRepository.saveAccessToken(isA(OAuth2AccessTokenEntity.class))).thenAnswer(new Answer() { + Long id = 324L; + @Override + public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2AccessTokenEntity _token = (OAuth2AccessTokenEntity) invocation.getArguments()[0]; + if(_token.getId() == null) { + _token.setId(id++); + } + fakeDb.put(_token.getId(), _token); + return _token; + } + }); + when(tokenRepository.getAccessTokenById(anyLong())).thenAnswer(new Answer() { + @Override + public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeDb.get(_id); + } + }); + when(clientRepository.getClientByClientId(anyString())).thenAnswer(new Answer() { + @Override + public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { + String _clientId = (String) invocation.getArguments()[0]; + ClientDetailsEntity _client = mock(ClientDetailsEntity.class); + when(_client.getClientId()).thenReturn(_clientId); + return _client; + } + }); + when(authHolderRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { + Long id = 133L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _auth = mock(AuthenticationHolderEntity.class); + when(_auth.getId()).thenReturn(id); + id++; + return _auth; + } + }); + dataService.importData(reader); + //2 times for token, 2 times to update client, 2 times to update authHolder, 2 times to update id token, 2 times to update refresh token + verify(tokenRepository, times(8)).saveAccessToken(capturedAccessTokens.capture()); + + List savedAccessTokens = new ArrayList(fakeDb.values()); //capturedAccessTokens.getAllValues(); + Collections.sort(savedAccessTokens, new accessTokenIdComparator()); + + assertThat(savedAccessTokens.size(), is(2)); + + assertThat(savedAccessTokens.get(0).getClient().getClientId(), equalTo(token1.getClient().getClientId())); + assertThat(savedAccessTokens.get(0).getExpiration(), equalTo(token1.getExpiration())); + assertThat(savedAccessTokens.get(0).getValue(), equalTo(token1.getValue())); + + assertThat(savedAccessTokens.get(1).getClient().getClientId(), equalTo(token2.getClient().getClientId())); + assertThat(savedAccessTokens.get(1).getExpiration(), equalTo(token2.getExpiration())); + assertThat(savedAccessTokens.get(1).getValue(), equalTo(token2.getValue())); + } + + @Test + public void testExportClients() throws IOException { + ClientDetailsEntity client1 = new ClientDetailsEntity(); + client1.setId(1L); + client1.setAccessTokenValiditySeconds(3600); + client1.setClientId("client1"); + client1.setClientSecret("clientsecret1"); + client1.setRedirectUris(ImmutableSet.of("http://foo.com/")); + client1.setScope(ImmutableSet.of("foo", "bar", "baz", "dolphin")); + client1.setGrantTypes(ImmutableSet.of("implicit", "authorization_code", "urn:ietf:params:oauth:grant_type:redelegate", "refresh_token")); + client1.setAllowIntrospection(true); + + ClientDetailsEntity client2 = new ClientDetailsEntity(); + client2.setId(2L); + client2.setAccessTokenValiditySeconds(3600); + client2.setClientId("client2"); + client2.setClientSecret("clientsecret2"); + client2.setRedirectUris(ImmutableSet.of("http://bar.baz.com/")); + client2.setScope(ImmutableSet.of("foo", "dolphin", "electric-wombat")); + client2.setGrantTypes(ImmutableSet.of("client_credentials", "urn:ietf:params:oauth:grant_type:redelegate")); + client2.setAllowIntrospection(false); + client2.setCodeChallengeMethod(PKCEAlgorithm.S256); + + Set allClients = ImmutableSet.of(client1, client2); + + Mockito.when(clientRepository.getAllClients()).thenReturn(allClients); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + + // check our client list (this test) + JsonArray clients = config.get(MITREidDataService.CLIENTS).getAsJsonArray(); + + assertThat(clients.size(), is(2)); + // check for both of our clients in turn + Set checked = new HashSet<>(); + for (JsonElement e : clients) { + assertThat(e.isJsonObject(), is(true)); + JsonObject client = e.getAsJsonObject(); + + ClientDetailsEntity compare = null; + if (client.get("clientId").getAsString().equals(client1.getClientId())) { + compare = client1; + } else if (client.get("clientId").getAsString().equals(client2.getClientId())) { + compare = client2; + } + + if (compare == null) { + fail("Could not find matching clientId: " + client.get("clientId").getAsString()); + } else { + assertThat(client.get("clientId").getAsString(), equalTo(compare.getClientId())); + assertThat(client.get("secret").getAsString(), equalTo(compare.getClientSecret())); + assertThat(client.get("accessTokenValiditySeconds").getAsInt(), equalTo(compare.getAccessTokenValiditySeconds())); + assertThat(client.get("allowIntrospection").getAsBoolean(), equalTo(compare.isAllowIntrospection())); + assertThat(jsonArrayToStringSet(client.get("redirectUris").getAsJsonArray()), equalTo(compare.getRedirectUris())); + assertThat(jsonArrayToStringSet(client.get("scope").getAsJsonArray()), equalTo(compare.getScope())); + assertThat(jsonArrayToStringSet(client.get("grantTypes").getAsJsonArray()), equalTo(compare.getGrantTypes())); + assertThat((client.has("codeChallengeMethod") && !client.get("codeChallengeMethod").isJsonNull()) ? PKCEAlgorithm.parse(client.get("codeChallengeMethod").getAsString()) : null, equalTo(compare.getCodeChallengeMethod())); + checked.add(compare); + } + } + // make sure all of our clients were found + assertThat(checked.containsAll(allClients), is(true)); + } + + @Test + public void testImportClients() throws IOException { + ClientDetailsEntity client1 = new ClientDetailsEntity(); + client1.setId(1L); + client1.setAccessTokenValiditySeconds(3600); + client1.setClientId("client1"); + client1.setClientSecret("clientsecret1"); + client1.setRedirectUris(ImmutableSet.of("http://foo.com/")); + client1.setScope(ImmutableSet.of("foo", "bar", "baz", "dolphin")); + client1.setGrantTypes(ImmutableSet.of("implicit", "authorization_code", "urn:ietf:params:oauth:grant_type:redelegate", "refresh_token")); + client1.setAllowIntrospection(true); + + ClientDetailsEntity client2 = new ClientDetailsEntity(); + client2.setId(2L); + client2.setAccessTokenValiditySeconds(3600); + client2.setClientId("client2"); + client2.setClientSecret("clientsecret2"); + client2.setRedirectUris(ImmutableSet.of("http://bar.baz.com/")); + client2.setScope(ImmutableSet.of("foo", "dolphin", "electric-wombat")); + client2.setGrantTypes(ImmutableSet.of("client_credentials", "urn:ietf:params:oauth:grant_type:redelegate")); + client2.setAllowIntrospection(false); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [" + + + "{\"id\":1,\"accessTokenValiditySeconds\":3600,\"clientId\":\"client1\",\"secret\":\"clientsecret1\"," + + "\"redirectUris\":[\"http://foo.com/\"]," + + "\"scope\":[\"foo\",\"bar\",\"baz\",\"dolphin\"]," + + "\"grantTypes\":[\"implicit\",\"authorization_code\",\"urn:ietf:params:oauth:grant_type:redelegate\",\"refresh_token\"]," + + "\"allowIntrospection\":true}," + + "{\"id\":2,\"accessTokenValiditySeconds\":3600,\"clientId\":\"client2\",\"secret\":\"clientsecret2\"," + + "\"redirectUris\":[\"http://bar.baz.com/\"]," + + "\"scope\":[\"foo\",\"dolphin\",\"electric-wombat\"]," + + "\"grantTypes\":[\"client_credentials\",\"urn:ietf:params:oauth:grant_type:redelegate\"]," + + "\"allowIntrospection\":false}" + + + " ]" + + "}"; + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + dataService.importData(reader); + verify(clientRepository, times(2)).saveClient(capturedClients.capture()); + + List savedClients = capturedClients.getAllValues(); + + assertThat(savedClients.size(), is(2)); + + assertThat(savedClients.get(0).getAccessTokenValiditySeconds(), equalTo(client1.getAccessTokenValiditySeconds())); + assertThat(savedClients.get(0).getClientId(), equalTo(client1.getClientId())); + assertThat(savedClients.get(0).getClientSecret(), equalTo(client1.getClientSecret())); + assertThat(savedClients.get(0).getRedirectUris(), equalTo(client1.getRedirectUris())); + assertThat(savedClients.get(0).getScope(), equalTo(client1.getScope())); + assertThat(savedClients.get(0).getGrantTypes(), equalTo(client1.getGrantTypes())); + assertThat(savedClients.get(0).isAllowIntrospection(), equalTo(client1.isAllowIntrospection())); + + assertThat(savedClients.get(1).getAccessTokenValiditySeconds(), equalTo(client2.getAccessTokenValiditySeconds())); + assertThat(savedClients.get(1).getClientId(), equalTo(client2.getClientId())); + assertThat(savedClients.get(1).getClientSecret(), equalTo(client2.getClientSecret())); + assertThat(savedClients.get(1).getRedirectUris(), equalTo(client2.getRedirectUris())); + assertThat(savedClients.get(1).getScope(), equalTo(client2.getScope())); + assertThat(savedClients.get(1).getGrantTypes(), equalTo(client2.getGrantTypes())); + assertThat(savedClients.get(1).isAllowIntrospection(), equalTo(client2.isAllowIntrospection())); + } + + @Test + public void testExportBlacklistedSites() throws IOException { + BlacklistedSite site1 = new BlacklistedSite(); + site1.setId(1L); + site1.setUri("http://foo.com"); + + BlacklistedSite site2 = new BlacklistedSite(); + site2.setId(2L); + site2.setUri("http://bar.com"); + + BlacklistedSite site3 = new BlacklistedSite(); + site3.setId(3L); + site3.setUri("http://baz.com"); + + Set allBlacklistedSites = ImmutableSet.of(site1, site2, site3); + + Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(blSiteRepository.getAll()).thenReturn(allBlacklistedSites); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + // check our scope list (this test) + JsonArray sites = config.get(MITREidDataService.BLACKLISTEDSITES).getAsJsonArray(); + + assertThat(sites.size(), is(3)); + // check for both of our sites in turn + Set checked = new HashSet<>(); + for (JsonElement e : sites) { + assertThat(e.isJsonObject(), is(true)); + JsonObject site = e.getAsJsonObject(); + + BlacklistedSite compare = null; + if (site.get("id").getAsLong() == site1.getId().longValue()) { + compare = site1; + } else if (site.get("id").getAsLong() == site2.getId().longValue()) { + compare = site2; + } else if (site.get("id").getAsLong() == site3.getId().longValue()) { + compare = site3; + } + + if (compare == null) { + fail("Could not find matching blacklisted site id: " + site.get("id").getAsString()); + } else { + assertThat(site.get("uri").getAsString(), equalTo(compare.getUri())); + checked.add(compare); + } + } + // make sure all of our clients were found + assertThat(checked.containsAll(allBlacklistedSites), is(true)); + + } + + @Test + public void testImportBlacklistedSites() throws IOException { + BlacklistedSite site1 = new BlacklistedSite(); + site1.setId(1L); + site1.setUri("http://foo.com"); + + BlacklistedSite site2 = new BlacklistedSite(); + site2.setId(2L); + site2.setUri("http://bar.com"); + + BlacklistedSite site3 = new BlacklistedSite(); + site3.setId(3L); + site3.setUri("http://baz.com"); + + String configJson = "{" + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [" + + + "{\"id\":1,\"uri\":\"http://foo.com\"}," + + "{\"id\":2,\"uri\":\"http://bar.com\"}," + + "{\"id\":3,\"uri\":\"http://baz.com\"}" + + + " ]" + + "}"; + + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + dataService.importData(reader); + verify(blSiteRepository, times(3)).save(capturedBlacklistedSites.capture()); + + List savedSites = capturedBlacklistedSites.getAllValues(); + + assertThat(savedSites.size(), is(3)); + + assertThat(savedSites.get(0).getUri(), equalTo(site1.getUri())); + assertThat(savedSites.get(1).getUri(), equalTo(site2.getUri())); + assertThat(savedSites.get(2).getUri(), equalTo(site3.getUri())); + } + + @Test + public void testExportWhitelistedSites() throws IOException { + WhitelistedSite site1 = new WhitelistedSite(); + site1.setId(1L); + site1.setClientId("foo"); + + WhitelistedSite site2 = new WhitelistedSite(); + site2.setId(2L); + site2.setClientId("bar"); + + WhitelistedSite site3 = new WhitelistedSite(); + site3.setId(3L); + site3.setClientId("baz"); + + Set allWhitelistedSites = ImmutableSet.of(site1, site2, site3); + + Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(allWhitelistedSites); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + // check our scope list (this test) + JsonArray sites = config.get(MITREidDataService.WHITELISTEDSITES).getAsJsonArray(); + + assertThat(sites.size(), is(3)); + // check for both of our sites in turn + Set checked = new HashSet<>(); + for (JsonElement e : sites) { + assertThat(e.isJsonObject(), is(true)); + JsonObject site = e.getAsJsonObject(); + + WhitelistedSite compare = null; + if (site.get("id").getAsLong() == site1.getId().longValue()) { + compare = site1; + } else if (site.get("id").getAsLong() == site2.getId().longValue()) { + compare = site2; + } else if (site.get("id").getAsLong() == site3.getId().longValue()) { + compare = site3; + } + + if (compare == null) { + fail("Could not find matching whitelisted site id: " + site.get("id").getAsString()); + } else { + assertThat(site.get("clientId").getAsString(), equalTo(compare.getClientId())); + checked.add(compare); + } + } + // make sure all of our clients were found + assertThat(checked.containsAll(allWhitelistedSites), is(true)); + + } + + @Test + public void testImportWhitelistedSites() throws IOException { + WhitelistedSite site1 = new WhitelistedSite(); + site1.setId(1L); + site1.setClientId("foo"); + + WhitelistedSite site2 = new WhitelistedSite(); + site2.setId(2L); + site2.setClientId("bar"); + + WhitelistedSite site3 = new WhitelistedSite(); + site3.setId(3L); + site3.setClientId("baz"); + //site3.setAllowedScopes(null); + + String configJson = "{" + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [" + + + "{\"id\":1,\"clientId\":\"foo\"}," + + "{\"id\":2,\"clientId\":\"bar\"}," + + "{\"id\":3,\"clientId\":\"baz\"}" + + + " ]" + + "}"; + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + final Map fakeDb = new HashMap<>(); + when(wlSiteRepository.save(isA(WhitelistedSite.class))).thenAnswer(new Answer() { + Long id = 333L; + @Override + public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { + WhitelistedSite _site = (WhitelistedSite) invocation.getArguments()[0]; + if(_site.getId() == null) { + _site.setId(id++); + } + fakeDb.put(_site.getId(), _site); + return _site; + } + }); + when(wlSiteRepository.getById(anyLong())).thenAnswer(new Answer() { + @Override + public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeDb.get(_id); + } + }); + + dataService.importData(reader); + verify(wlSiteRepository, times(3)).save(capturedWhitelistedSites.capture()); + + List savedSites = capturedWhitelistedSites.getAllValues(); + + assertThat(savedSites.size(), is(3)); + + assertThat(savedSites.get(0).getClientId(), equalTo(site1.getClientId())); + assertThat(savedSites.get(1).getClientId(), equalTo(site2.getClientId())); + assertThat(savedSites.get(2).getClientId(), equalTo(site3.getClientId())); + } + + @Test + public void testExportGrants() throws IOException, ParseException { + Date creationDate1 = formatter.parse("2014-09-10T22:49:44.090+0000", Locale.ENGLISH); + Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); + + OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); + when(mockToken1.getId()).thenReturn(1L); + + ApprovedSite site1 = new ApprovedSite(); + site1.setId(1L); + site1.setClientId("foo"); + site1.setCreationDate(creationDate1); + site1.setAccessDate(accessDate1); + site1.setUserId("user1"); + site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); + site1.setApprovedAccessTokens(ImmutableSet.of(mockToken1)); + + Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); + Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); + Date timeoutDate2 = formatter.parse("2014-10-01T20:49:44.090+0000", Locale.ENGLISH); + + ApprovedSite site2 = new ApprovedSite(); + site2.setId(2L); + site2.setClientId("bar"); + site2.setCreationDate(creationDate2); + site2.setAccessDate(accessDate2); + site2.setUserId("user2"); + site2.setAllowedScopes(ImmutableSet.of("openid", "offline_access", "email", "profile")); + site2.setTimeoutDate(timeoutDate2); + + Set allApprovedSites = ImmutableSet.of(site1, site2); + + Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(allApprovedSites); + Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + Mockito.when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + // check our scope list (this test) + JsonArray sites = config.get(MITREidDataService.GRANTS).getAsJsonArray(); + + assertThat(sites.size(), is(2)); + // check for both of our sites in turn + Set checked = new HashSet<>(); + for (JsonElement e : sites) { + assertThat(e.isJsonObject(), is(true)); + JsonObject site = e.getAsJsonObject(); + + ApprovedSite compare = null; + if (site.get("id").getAsLong() == site1.getId().longValue()) { + compare = site1; + } else if (site.get("id").getAsLong() == site2.getId().longValue()) { + compare = site2; + } + + if (compare == null) { + fail("Could not find matching whitelisted site id: " + site.get("id").getAsString()); + } else { + assertThat(site.get("clientId").getAsString(), equalTo(compare.getClientId())); + assertThat(site.get("creationDate").getAsString(), equalTo(formatter.print(compare.getCreationDate(), Locale.ENGLISH))); + assertThat(site.get("accessDate").getAsString(), equalTo(formatter.print(compare.getAccessDate(), Locale.ENGLISH))); + if(site.get("timeoutDate").isJsonNull()) { + assertNull(compare.getTimeoutDate()); + } else { + assertThat(site.get("timeoutDate").getAsString(), equalTo(formatter.print(compare.getTimeoutDate(), Locale.ENGLISH))); + } + assertThat(site.get("userId").getAsString(), equalTo(compare.getUserId())); + assertThat(jsonArrayToStringSet(site.getAsJsonArray("allowedScopes")), equalTo(compare.getAllowedScopes())); + if (site.get("approvedAccessTokens").isJsonNull() || site.getAsJsonArray("approvedAccessTokens") == null) { + assertTrue(compare.getApprovedAccessTokens() == null || compare.getApprovedAccessTokens().isEmpty()); + } else { + assertNotNull(compare.getApprovedAccessTokens()); + Set tokenIds = new HashSet<>(); + for(OAuth2AccessTokenEntity entity : compare.getApprovedAccessTokens()) { + tokenIds.add(entity.getId().toString()); + } + assertThat(jsonArrayToStringSet(site.getAsJsonArray("approvedAccessTokens")), equalTo(tokenIds)); + } + checked.add(compare); + } + } + // make sure all of our clients were found + assertThat(checked.containsAll(allApprovedSites), is(true)); + } + + @Test + public void testImportGrants() throws IOException, ParseException { + Date creationDate1 = formatter.parse("2014-09-10T22:49:44.090+0000", Locale.ENGLISH); + Date accessDate1 = formatter.parse("2014-09-10T23:49:44.090+0000", Locale.ENGLISH); + + OAuth2AccessTokenEntity mockToken1 = mock(OAuth2AccessTokenEntity.class); + when(mockToken1.getId()).thenReturn(1L); + + ApprovedSite site1 = new ApprovedSite(); + site1.setId(1L); + site1.setClientId("foo"); + site1.setCreationDate(creationDate1); + site1.setAccessDate(accessDate1); + site1.setUserId("user1"); + site1.setAllowedScopes(ImmutableSet.of("openid", "phone")); + site1.setApprovedAccessTokens(ImmutableSet.of(mockToken1)); + + Date creationDate2 = formatter.parse("2014-09-11T18:49:44.090+0000", Locale.ENGLISH); + Date accessDate2 = formatter.parse("2014-09-11T20:49:44.090+0000", Locale.ENGLISH); + Date timeoutDate2 = formatter.parse("2014-10-01T20:49:44.090+0000", Locale.ENGLISH); + + ApprovedSite site2 = new ApprovedSite(); + site2.setId(2L); + site2.setClientId("bar"); + site2.setCreationDate(creationDate2); + site2.setAccessDate(accessDate2); + site2.setUserId("user2"); + site2.setAllowedScopes(ImmutableSet.of("openid", "offline_access", "email", "profile")); + site2.setTimeoutDate(timeoutDate2); + + String configJson = "{" + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [" + + + "{\"id\":1,\"clientId\":\"foo\",\"creationDate\":\"2014-09-10T22:49:44.090+0000\",\"accessDate\":\"2014-09-10T23:49:44.090+0000\"," + + "\"userId\":\"user1\",\"whitelistedSiteId\":null,\"allowedScopes\":[\"openid\",\"phone\"], \"whitelistedSiteId\":1," + + "\"approvedAccessTokens\":[1]}," + + "{\"id\":2,\"clientId\":\"bar\",\"creationDate\":\"2014-09-11T18:49:44.090+0000\",\"accessDate\":\"2014-09-11T20:49:44.090+0000\"," + + "\"timeoutDate\":\"2014-10-01T20:49:44.090+0000\",\"userId\":\"user2\"," + + "\"allowedScopes\":[\"openid\",\"offline_access\",\"email\",\"profile\"]}" + + + " ]" + + "}"; + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + final Map fakeDb = new HashMap<>(); + when(approvedSiteRepository.save(isA(ApprovedSite.class))).thenAnswer(new Answer() { + Long id = 364L; + @Override + public ApprovedSite answer(InvocationOnMock invocation) throws Throwable { + ApprovedSite _site = (ApprovedSite) invocation.getArguments()[0]; + if(_site.getId() == null) { + _site.setId(id++); + } + fakeDb.put(_site.getId(), _site); + return _site; + } + }); + when(approvedSiteRepository.getById(anyLong())).thenAnswer(new Answer() { + @Override + public ApprovedSite answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeDb.get(_id); + } + }); + when(wlSiteRepository.getById(isNull(Long.class))).thenAnswer(new Answer() { + Long id = 432L; + @Override + public WhitelistedSite answer(InvocationOnMock invocation) throws Throwable { + WhitelistedSite _site = mock(WhitelistedSite.class); + when(_site.getId()).thenReturn(id++); + return _site; + } + }); + when(tokenRepository.getAccessTokenById(isNull(Long.class))).thenAnswer(new Answer() { + Long id = 245L; + @Override + public OAuth2AccessTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2AccessTokenEntity _token = mock(OAuth2AccessTokenEntity.class); + when(_token.getId()).thenReturn(id++); + return _token; + } + }); + + dataService.importData(reader); + //2 for sites, 1 for updating access token ref on #1 + verify(approvedSiteRepository, times(3)).save(capturedApprovedSites.capture()); + + List savedSites = new ArrayList(fakeDb.values()); + + assertThat(savedSites.size(), is(2)); + + assertThat(savedSites.get(0).getClientId(), equalTo(site1.getClientId())); + assertThat(savedSites.get(0).getAccessDate(), equalTo(site1.getAccessDate())); + assertThat(savedSites.get(0).getCreationDate(), equalTo(site1.getCreationDate())); + assertThat(savedSites.get(0).getAllowedScopes(), equalTo(site1.getAllowedScopes())); + assertThat(savedSites.get(0).getTimeoutDate(), equalTo(site1.getTimeoutDate())); + assertThat(savedSites.get(0).getApprovedAccessTokens().size(), equalTo(site1.getApprovedAccessTokens().size())); + + assertThat(savedSites.get(1).getClientId(), equalTo(site2.getClientId())); + assertThat(savedSites.get(1).getAccessDate(), equalTo(site2.getAccessDate())); + assertThat(savedSites.get(1).getCreationDate(), equalTo(site2.getCreationDate())); + assertThat(savedSites.get(1).getAllowedScopes(), equalTo(site2.getAllowedScopes())); + assertThat(savedSites.get(1).getTimeoutDate(), equalTo(site2.getTimeoutDate())); + assertThat(savedSites.get(1).getApprovedAccessTokens().size(), equalTo(site2.getApprovedAccessTokens().size())); + } + + @Test + public void testExportAuthenticationHolders() throws IOException { + OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), + true, new HashSet(), new HashSet(), "http://foo.com", + new HashSet(), null); + Authentication mockAuth1 = new UsernamePasswordAuthenticationToken("user1", "pass1", AuthorityUtils.commaSeparatedStringToAuthorityList("ROLE_USER")); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); + + AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); + holder1.setId(1L); + holder1.setAuthentication(auth1); + + OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), + true, new HashSet(), new HashSet(), "http://bar.com", + new HashSet(), null); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, null); + + AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); + holder2.setId(2L); + holder2.setAuthentication(auth2); + + List allAuthHolders = ImmutableList.of(holder1, holder2); + + when(clientRepository.getAllClients()).thenReturn(new HashSet()); + when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + when(blSiteRepository.getAll()).thenReturn(new HashSet()); + when(authHolderRepository.getAll()).thenReturn(allAuthHolders); + when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + when(sysScopeRepository.getAll()).thenReturn(new HashSet()); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + + // check our holder list (this test) + JsonArray holders = config.get(MITREidDataService.AUTHENTICATIONHOLDERS).getAsJsonArray(); + + assertThat(holders.size(), is(2)); + // check for both of our clients in turn + Set checked = new HashSet<>(); + for (JsonElement e : holders) { + assertThat(e.isJsonObject(), is(true)); + JsonObject holder = e.getAsJsonObject(); + + AuthenticationHolderEntity compare = null; + if (holder.get("id").getAsLong() == holder1.getId()) { + compare = holder1; + } else if (holder.get("id").getAsLong() == holder2.getId()) { + compare = holder2; + } + + if (compare == null) { + fail("Could not find matching authentication holder id: " + holder.get("id").getAsString()); + } else { + assertTrue(holder.get("clientId").getAsString().equals(compare.getClientId())); + assertTrue(holder.get("approved").getAsBoolean() == compare.isApproved()); + assertTrue(holder.get("redirectUri").getAsString().equals(compare.getRedirectUri())); + if (compare.getUserAuth() != null) { + assertTrue(holder.get("savedUserAuthentication").isJsonObject()); + JsonObject savedAuth = holder.get("savedUserAuthentication").getAsJsonObject(); + assertTrue(savedAuth.get("name").getAsString().equals(compare.getUserAuth().getName())); + assertTrue(savedAuth.get("authenticated").getAsBoolean() == compare.getUserAuth().isAuthenticated()); + assertTrue(savedAuth.get("sourceClass").getAsString().equals(compare.getUserAuth().getSourceClass())); + } + checked.add(compare); + } + } + // make sure all of our clients were found + assertThat(checked.containsAll(allAuthHolders), is(true)); + } + + @Test + public void testImportAuthenticationHolders() throws IOException { + OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), + true, new HashSet(), new HashSet(), "http://foo.com", + new HashSet(), null); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); + + AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); + holder1.setId(1L); + holder1.setAuthentication(auth1); + + OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), + true, new HashSet(), new HashSet(), "http://bar.com", + new HashSet(), null); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); + + AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); + holder2.setId(2L); + holder2.setAuthentication(auth2); + + String configJson = "{" + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + + + "{\"id\":1,\"clientId\":\"client1\",\"redirectUri\":\"http://foo.com\"," + + "\"savedUserAuthentication\":null}," + + "{\"id\":2,\"clientId\":\"client2\",\"redirectUri\":\"http://bar.com\"," + + "\"savedUserAuthentication\":null}" + + " ]" + + "}"; + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + final Map fakeDb = new HashMap<>(); + when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { + Long id = 243L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _site = (AuthenticationHolderEntity) invocation.getArguments()[0]; + if(_site.getId() == null) { + _site.setId(id++); + } + fakeDb.put(_site.getId(), _site); + return _site; + } + }); + + dataService.importData(reader); + verify(authHolderRepository, times(2)).save(capturedAuthHolders.capture()); + + List savedAuthHolders = capturedAuthHolders.getAllValues(); + + assertThat(savedAuthHolders.size(), is(2)); + assertThat(savedAuthHolders.get(0).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder1.getAuthentication().getOAuth2Request().getClientId())); + assertThat(savedAuthHolders.get(1).getAuthentication().getOAuth2Request().getClientId(), equalTo(holder2.getAuthentication().getOAuth2Request().getClientId())); + } + + @Test + public void testExportSystemScopes() throws IOException { + SystemScope scope1 = new SystemScope(); + scope1.setId(1L); + scope1.setValue("scope1"); + scope1.setDescription("Scope 1"); + scope1.setRestricted(true); + scope1.setDefaultScope(false); + scope1.setIcon("glass"); + + SystemScope scope2 = new SystemScope(); + scope2.setId(2L); + scope2.setValue("scope2"); + scope2.setDescription("Scope 2"); + scope2.setRestricted(false); + scope2.setDefaultScope(false); + scope2.setIcon("ball"); + + SystemScope scope3 = new SystemScope(); + scope3.setId(3L); + scope3.setValue("scope3"); + scope3.setDescription("Scope 3"); + scope3.setRestricted(false); + scope3.setDefaultScope(true); + scope3.setIcon("road"); + + Set allScopes = ImmutableSet.of(scope1, scope2, scope3); + + Mockito.when(clientRepository.getAllClients()).thenReturn(new HashSet()); + Mockito.when(approvedSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(wlSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(blSiteRepository.getAll()).thenReturn(new HashSet()); + Mockito.when(authHolderRepository.getAll()).thenReturn(new ArrayList()); + Mockito.when(tokenRepository.getAllAccessTokens()).thenReturn(new HashSet()); + Mockito.when(tokenRepository.getAllRefreshTokens()).thenReturn(new HashSet()); + Mockito.when(sysScopeRepository.getAll()).thenReturn(allScopes); + + // do the data export + StringWriter stringWriter = new StringWriter(); + JsonWriter writer = new JsonWriter(stringWriter); + writer.beginObject(); + dataService.exportData(writer); + writer.endObject(); + writer.close(); + + // parse the output as a JSON object for testing + JsonElement elem = new JsonParser().parse(stringWriter.toString()); + JsonObject root = elem.getAsJsonObject(); + + // make sure the root is there + assertThat(root.has(MITREidDataService.MITREID_CONNECT_1_3), is(true)); + + JsonObject config = root.get(MITREidDataService.MITREID_CONNECT_1_3).getAsJsonObject(); + + // make sure all the root elements are there + assertThat(config.has(MITREidDataService.CLIENTS), is(true)); + assertThat(config.has(MITREidDataService.GRANTS), is(true)); + assertThat(config.has(MITREidDataService.WHITELISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.BLACKLISTEDSITES), is(true)); + assertThat(config.has(MITREidDataService.REFRESHTOKENS), is(true)); + assertThat(config.has(MITREidDataService.ACCESSTOKENS), is(true)); + assertThat(config.has(MITREidDataService.SYSTEMSCOPES), is(true)); + assertThat(config.has(MITREidDataService.AUTHENTICATIONHOLDERS), is(true)); + + // make sure the root elements are all arrays + assertThat(config.get(MITREidDataService.CLIENTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.GRANTS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.WHITELISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.BLACKLISTEDSITES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.REFRESHTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.ACCESSTOKENS).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.SYSTEMSCOPES).isJsonArray(), is(true)); + assertThat(config.get(MITREidDataService.AUTHENTICATIONHOLDERS).isJsonArray(), is(true)); + + + // check our scope list (this test) + JsonArray scopes = config.get(MITREidDataService.SYSTEMSCOPES).getAsJsonArray(); + + assertThat(scopes.size(), is(3)); + // check for both of our clients in turn + Set checked = new HashSet<>(); + for (JsonElement e : scopes) { + assertThat(e.isJsonObject(), is(true)); + JsonObject scope = e.getAsJsonObject(); + + SystemScope compare = null; + if (scope.get("value").getAsString().equals(scope1.getValue())) { + compare = scope1; + } else if (scope.get("value").getAsString().equals(scope2.getValue())) { + compare = scope2; + } else if (scope.get("value").getAsString().equals(scope3.getValue())) { + compare = scope3; + } + + if (compare == null) { + fail("Could not find matching scope value: " + scope.get("value").getAsString()); + } else { + assertThat(scope.get("value").getAsString(), equalTo(compare.getValue())); + assertThat(scope.get("description").getAsString(), equalTo(compare.getDescription())); + assertThat(scope.get("icon").getAsString(), equalTo(compare.getIcon())); + assertThat(scope.get("restricted").getAsBoolean(), equalTo(compare.isRestricted())); + assertThat(scope.get("defaultScope").getAsBoolean(), equalTo(compare.isDefaultScope())); + checked.add(compare); + } + } + // make sure all of our clients were found + assertThat(checked.containsAll(allScopes), is(true)); + + } + + @Test + public void testImportSystemScopes() throws IOException { + SystemScope scope1 = new SystemScope(); + scope1.setId(1L); + scope1.setValue("scope1"); + scope1.setDescription("Scope 1"); + scope1.setRestricted(true); + scope1.setDefaultScope(false); + scope1.setIcon("glass"); + + SystemScope scope2 = new SystemScope(); + scope2.setId(2L); + scope2.setValue("scope2"); + scope2.setDescription("Scope 2"); + scope2.setRestricted(false); + scope2.setDefaultScope(false); + scope2.setIcon("ball"); + + SystemScope scope3 = new SystemScope(); + scope3.setId(3L); + scope3.setValue("scope3"); + scope3.setDescription("Scope 3"); + scope3.setRestricted(false); + scope3.setDefaultScope(true); + scope3.setIcon("road"); + scope3.setStructured(true); + scope3.setStructuredParamDescription("Structured Parameter"); + + String configJson = "{" + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [], " + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [" + + + "{\"id\":1,\"description\":\"Scope 1\",\"icon\":\"glass\",\"value\":\"scope1\",\"restricted\":true,\"defaultScope\":false}," + + "{\"id\":2,\"description\":\"Scope 2\",\"icon\":\"ball\",\"value\":\"scope2\",\"restricted\":false,\"defaultScope\":false}," + + "{\"id\":3,\"description\":\"Scope 3\",\"icon\":\"road\",\"value\":\"scope3\",\"restricted\":false,\"defaultScope\":true,\"structured\":true,\"structuredParameter\":\"Structured Parameter\"}" + + + " ]" + + "}"; + + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + + dataService.importData(reader); + verify(sysScopeRepository, times(3)).save(capturedScope.capture()); + + List savedScopes = capturedScope.getAllValues(); + + assertThat(savedScopes.size(), is(3)); + assertThat(savedScopes.get(0).getValue(), equalTo(scope1.getValue())); + assertThat(savedScopes.get(0).getDescription(), equalTo(scope1.getDescription())); + assertThat(savedScopes.get(0).getIcon(), equalTo(scope1.getIcon())); + assertThat(savedScopes.get(0).isDefaultScope(), equalTo(scope1.isDefaultScope())); + assertThat(savedScopes.get(0).isRestricted(), equalTo(scope1.isRestricted())); + assertThat(savedScopes.get(0).isStructured(), equalTo(scope1.isStructured())); + assertThat(savedScopes.get(0).getStructuredParamDescription(), equalTo(scope1.getStructuredParamDescription())); + + assertThat(savedScopes.get(1).getValue(), equalTo(scope2.getValue())); + assertThat(savedScopes.get(1).getDescription(), equalTo(scope2.getDescription())); + assertThat(savedScopes.get(1).getIcon(), equalTo(scope2.getIcon())); + assertThat(savedScopes.get(1).isDefaultScope(), equalTo(scope2.isDefaultScope())); + assertThat(savedScopes.get(1).isRestricted(), equalTo(scope2.isRestricted())); + assertThat(savedScopes.get(1).isStructured(), equalTo(scope2.isStructured())); + assertThat(savedScopes.get(1).getStructuredParamDescription(), equalTo(scope2.getStructuredParamDescription())); + + assertThat(savedScopes.get(2).getValue(), equalTo(scope3.getValue())); + assertThat(savedScopes.get(2).getDescription(), equalTo(scope3.getDescription())); + assertThat(savedScopes.get(2).getIcon(), equalTo(scope3.getIcon())); + assertThat(savedScopes.get(2).isDefaultScope(), equalTo(scope3.isDefaultScope())); + assertThat(savedScopes.get(2).isRestricted(), equalTo(scope3.isRestricted())); + assertThat(savedScopes.get(2).isStructured(), equalTo(scope3.isStructured())); + assertThat(savedScopes.get(2).getStructuredParamDescription(), equalTo(scope3.getStructuredParamDescription())); + + } + + @Test + public void testFixRefreshTokenAuthHolderReferencesOnImport() throws IOException, ParseException { + String expiration1 = "2014-09-10T22:49:44.090+0000"; + Date expirationDate1 = formatter.parse(expiration1, Locale.ENGLISH); + + ClientDetailsEntity mockedClient1 = mock(ClientDetailsEntity.class); + when(mockedClient1.getClientId()).thenReturn("mocked_client_1"); + + OAuth2Request req1 = new OAuth2Request(new HashMap(), "client1", new ArrayList(), + true, new HashSet(), new HashSet(), "http://foo.com", + new HashSet(), null); + Authentication mockAuth1 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth1 = new OAuth2Authentication(req1, mockAuth1); + + AuthenticationHolderEntity holder1 = new AuthenticationHolderEntity(); + holder1.setId(1L); + holder1.setAuthentication(auth1); + + OAuth2RefreshTokenEntity token1 = new OAuth2RefreshTokenEntity(); + token1.setId(1L); + token1.setClient(mockedClient1); + token1.setExpiration(expirationDate1); + token1.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.")); + token1.setAuthenticationHolder(holder1); + + String expiration2 = "2015-01-07T18:31:50.079+0000"; + Date expirationDate2 = formatter.parse(expiration2, Locale.ENGLISH); + + ClientDetailsEntity mockedClient2 = mock(ClientDetailsEntity.class); + when(mockedClient2.getClientId()).thenReturn("mocked_client_2"); + + OAuth2Request req2 = new OAuth2Request(new HashMap(), "client2", new ArrayList(), + true, new HashSet(), new HashSet(), "http://bar.com", + new HashSet(), null); + Authentication mockAuth2 = mock(Authentication.class, withSettings().serializable()); + OAuth2Authentication auth2 = new OAuth2Authentication(req2, mockAuth2); + + AuthenticationHolderEntity holder2 = new AuthenticationHolderEntity(); + holder2.setId(2L); + holder2.setAuthentication(auth2); + + OAuth2RefreshTokenEntity token2 = new OAuth2RefreshTokenEntity(); + token2.setId(2L); + token2.setClient(mockedClient2); + token2.setExpiration(expirationDate2); + token2.setJwt(JWTParser.parse("eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.")); + token2.setAuthenticationHolder(holder2); + + String configJson = "{" + + "\"" + MITREidDataService.SYSTEMSCOPES + "\": [], " + + "\"" + MITREidDataService.ACCESSTOKENS + "\": [], " + + "\"" + MITREidDataService.CLIENTS + "\": [], " + + "\"" + MITREidDataService.GRANTS + "\": [], " + + "\"" + MITREidDataService.WHITELISTEDSITES + "\": [], " + + "\"" + MITREidDataService.BLACKLISTEDSITES + "\": [], " + + "\"" + MITREidDataService.AUTHENTICATIONHOLDERS + "\": [" + + + "{\"id\":1,\"authentication\":{\"authorizationRequest\":{\"clientId\":\"client1\",\"redirectUri\":\"http://foo.com\"}," + + "\"userAuthentication\":null}}," + + "{\"id\":2,\"authentication\":{\"authorizationRequest\":{\"clientId\":\"client2\",\"redirectUri\":\"http://bar.com\"}," + + "\"userAuthentication\":null}}" + + " ]," + + "\"" + MITREidDataService.REFRESHTOKENS + "\": [" + + + "{\"id\":1,\"clientId\":\"mocked_client_1\",\"expiration\":\"2014-09-10T22:49:44.090+0000\"," + + "\"authenticationHolderId\":1,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJmOTg4OWQyOS0xMTk1LTQ4ODEtODgwZC1lZjVlYzAwY2Y4NDIifQ.\"}," + + "{\"id\":2,\"clientId\":\"mocked_client_2\",\"expiration\":\"2015-01-07T18:31:50.079+0000\"," + + "\"authenticationHolderId\":2,\"value\":\"eyJhbGciOiJub25lIn0.eyJqdGkiOiJlYmEyYjc3My0xNjAzLTRmNDAtOWQ3MS1hMGIxZDg1OWE2MDAifQ.\"}" + + + " ]" + + "}"; + logger.debug(configJson); + + JsonReader reader = new JsonReader(new StringReader(configJson)); + final Map fakeRefreshTokenTable = new HashMap<>(); + final Map fakeAuthHolderTable = new HashMap<>(); + when(tokenRepository.saveRefreshToken(isA(OAuth2RefreshTokenEntity.class))).thenAnswer(new Answer() { + Long id = 343L; + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + OAuth2RefreshTokenEntity _token = (OAuth2RefreshTokenEntity) invocation.getArguments()[0]; + if(_token.getId() == null) { + _token.setId(id++); + } + fakeRefreshTokenTable.put(_token.getId(), _token); + return _token; + } + }); + when(tokenRepository.getRefreshTokenById(anyLong())).thenAnswer(new Answer() { + @Override + public OAuth2RefreshTokenEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeRefreshTokenTable.get(_id); + } + }); + when(clientRepository.getClientByClientId(anyString())).thenAnswer(new Answer() { + @Override + public ClientDetailsEntity answer(InvocationOnMock invocation) throws Throwable { + String _clientId = (String) invocation.getArguments()[0]; + ClientDetailsEntity _client = mock(ClientDetailsEntity.class); + when(_client.getClientId()).thenReturn(_clientId); + return _client; + } + }); + when(authHolderRepository.save(isA(AuthenticationHolderEntity.class))).thenAnswer(new Answer() { + Long id = 356L; + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + AuthenticationHolderEntity _holder = (AuthenticationHolderEntity) invocation.getArguments()[0]; + if(_holder.getId() == null) { + _holder.setId(id++); + } + fakeAuthHolderTable.put(_holder.getId(), _holder); + return _holder; + } + }); + when(authHolderRepository.getById(anyLong())).thenAnswer(new Answer() { + @Override + public AuthenticationHolderEntity answer(InvocationOnMock invocation) throws Throwable { + Long _id = (Long) invocation.getArguments()[0]; + return fakeAuthHolderTable.get(_id); + } + }); + dataService.importData(reader); + + List savedRefreshTokens = new ArrayList(fakeRefreshTokenTable.values()); //capturedRefreshTokens.getAllValues(); + Collections.sort(savedRefreshTokens, new refreshTokenIdComparator()); + + assertThat(savedRefreshTokens.get(0).getAuthenticationHolder().getId(), equalTo(356L)); + assertThat(savedRefreshTokens.get(1).getAuthenticationHolder().getId(), equalTo(357L)); + } + + private Set jsonArrayToStringSet(JsonArray a) { + Set s = new HashSet<>(); + for (JsonElement jsonElement : a) { + s.add(jsonElement.getAsString()); + } + return s; + } + +}