diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java index 961b92420..a5c570d68 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java @@ -66,7 +66,7 @@ public class AuthenticationHolderEntity { private SavedUserAuthentication userAuth; - private Collection authorities; + private Collection authorities; private Set resourceIds; @@ -116,14 +116,14 @@ public class AuthenticationHolderEntity { // pull apart the request and save its bits OAuth2Request o2Request = authentication.getOAuth2Request(); - setAuthorities(o2Request.getAuthorities()); + setAuthorities(o2Request.getAuthorities() == null ? null : new HashSet<>(o2Request.getAuthorities())); setClientId(o2Request.getClientId()); - setExtensions(o2Request.getExtensions()); + setExtensions(o2Request.getExtensions() == null ? null : new HashMap<>(o2Request.getExtensions())); setRedirectUri(o2Request.getRedirectUri()); - setRequestParameters(o2Request.getRequestParameters()); - setResourceIds(o2Request.getResourceIds()); - setResponseTypes(o2Request.getResponseTypes()); - setScope(o2Request.getScope()); + setRequestParameters(o2Request.getRequestParameters() == null ? null : new HashMap<>(o2Request.getRequestParameters())); + setResourceIds(o2Request.getResourceIds() == null ? null : new HashSet<>(o2Request.getResourceIds())); + setResponseTypes(o2Request.getResponseTypes() == null ? null : new HashSet<>(o2Request.getResponseTypes())); + setScope(o2Request.getScope() == null ? null : new HashSet<>(o2Request.getScope())); setApproved(o2Request.isApproved()); if (authentication.getUserAuthentication() != null) { @@ -159,19 +159,15 @@ public class AuthenticationHolderEntity { ) @Convert(converter = SimpleGrantedAuthorityStringConverter.class) @Column(name="authority") - public Collection getAuthorities() { + public Collection getAuthorities() { return authorities; } /** * @param authorities the authorities to set */ - public void setAuthorities(Collection authorities) { - if (authorities != null) { - this.authorities = new HashSet<>(authorities); - } else { - this.authorities = null; - } + public void setAuthorities(Collection authorities) { + this.authorities = authorities; } /** @@ -191,11 +187,7 @@ public class AuthenticationHolderEntity { * @param resourceIds the resourceIds to set */ public void setResourceIds(Set resourceIds) { - if (resourceIds != null) { - this.resourceIds = new HashSet<>(resourceIds); - } else { - this.resourceIds = null; - } + this.resourceIds = resourceIds; } /** @@ -247,11 +239,7 @@ public class AuthenticationHolderEntity { * @param responseTypes the responseTypes to set */ public void setResponseTypes(Set responseTypes) { - if (responseTypes != null) { - this.responseTypes = new HashSet<>(responseTypes); - } else { - this.responseTypes = null; - } + this.responseTypes = responseTypes; } /** @@ -264,7 +252,7 @@ public class AuthenticationHolderEntity { ) @Column(name="val") @MapKeyColumn(name="extension") - @Convert(converter=SerializableStringConverter.class) + @Convert(attributeName="value", converter=SerializableStringConverter.class) public Map getExtensions() { return extensions; } @@ -273,11 +261,7 @@ public class AuthenticationHolderEntity { * @param extensions the extensions to set */ public void setExtensions(Map extensions) { - if (extensions != null) { - this.extensions = new HashMap<>(extensions); - } else { - this.extensions = null; - } + this.extensions = extensions; } /** @@ -313,11 +297,7 @@ public class AuthenticationHolderEntity { * @param scope the scope to set */ public void setScope(Set scope) { - if (scope != null) { - this.scope = new HashSet<>(scope); - } else { - this.scope = null; - } + this.scope = scope; } /** @@ -338,11 +318,7 @@ public class AuthenticationHolderEntity { * @param requestParameters the requestParameters to set */ public void setRequestParameters(Map requestParameters) { - if (requestParameters != null) { - this.requestParameters = new HashMap<>(requestParameters); - } else { - this.requestParameters = null; - } + this.requestParameters = requestParameters; } diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java index 3bff6639c..9ba5a03e2 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java @@ -265,7 +265,7 @@ public class OAuth2AccessTokenEntity implements OAuth2AccessToken { /** * @return the idToken */ - @OneToOne(cascade=CascadeType.ALL) // one-to-one mapping for now + @OneToOne(cascade=CascadeType.ALL, orphanRemoval=true) // one-to-one mapping for now @JoinColumn(name = "id_token_id") public OAuth2AccessTokenEntity getIdToken() { return idToken; diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/SavedUserAuthentication.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/SavedUserAuthentication.java index c83859fc5..bf242bb1b 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/SavedUserAuthentication.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/SavedUserAuthentication.java @@ -54,7 +54,7 @@ public class SavedUserAuthentication implements Authentication { private String name; - private Collection authorities; + private Collection authorities; private boolean authenticated; @@ -65,7 +65,7 @@ public class SavedUserAuthentication implements Authentication { */ public SavedUserAuthentication(Authentication src) { setName(src.getName()); - setAuthorities(src.getAuthorities()); + setAuthorities(new HashSet<>(src.getAuthorities())); setAuthenticated(src.isAuthenticated()); if (src instanceof SavedUserAuthentication) { @@ -115,7 +115,7 @@ public class SavedUserAuthentication implements Authentication { ) @Convert(converter = SimpleGrantedAuthorityStringConverter.class) @Column(name="authority") - public Collection getAuthorities() { + public Collection getAuthorities() { return authorities; } @@ -175,12 +175,8 @@ public class SavedUserAuthentication implements Authentication { /** * @param authorities the authorities to set */ - public void setAuthorities(Collection authorities) { - if (authorities != null) { - this.authorities = new HashSet<>(authorities); - } else { - this.authorities = null; - } + public void setAuthorities(Collection authorities) { + this.authorities = authorities; } diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java index 7d2bdb5a5..5317a3194 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java @@ -97,7 +97,13 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { public void removeAccessToken(OAuth2AccessTokenEntity accessToken) { OAuth2AccessTokenEntity found = getAccessTokenByValue(accessToken.getValue()); if (found != null) { - manager.remove(found); + OAuth2AccessTokenEntity accessTokenForIdToken = getAccessTokenForIdToken(found); + if (accessTokenForIdToken != null) { + accessTokenForIdToken.setIdToken(null); + JpaUtil.saveOrUpdate(accessTokenForIdToken.getId(), manager, accessTokenForIdToken); + } else { + manager.remove(found); + } } else { throw new IllegalArgumentException("Access token not found: " + accessToken); } @@ -231,7 +237,7 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { @Transactional(value="defaultTransactionManager") public void clearDuplicateAccessTokens() { - Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING c > 1"); + Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); @SuppressWarnings("unchecked") List resultList = query.getResultList(); List values = new ArrayList<>(); @@ -255,7 +261,7 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { @Override @Transactional(value="defaultTransactionManager") public void clearDuplicateRefreshTokens() { - Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING c > 1"); + Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); @SuppressWarnings("unchecked") List resultList = query.getResultList(); List values = new ArrayList<>();