From f58141e6a7f7a014b848a099770eb4534f5721dc Mon Sep 17 00:00:00 2001 From: Amanda Anganes Date: Tue, 6 Aug 2013 11:28:13 -0400 Subject: [PATCH] Fixed token expiration bug by removing jsql queries. Instead expired tokens or approved sites are filtered at the repository level Moved getExpired to service layers Used Predicates to filter expired tokens and approved sites; --- .../oauth2/model/OAuth2AccessTokenEntity.java | 1 - .../model/OAuth2RefreshTokenEntity.java | 1 - .../repository/OAuth2TokenRepository.java | 7 ----- .../openid/connect/model/ApprovedSite.java | 1 - .../repository/ApprovedSiteRepository.java | 5 ---- .../impl/JpaOAuth2TokenRepository.java | 20 ------------- .../DefaultOAuth2ProviderTokenService.java | 28 +++++++++++++++++-- .../impl/JpaApprovedSiteRepository.java | 11 ++------ .../impl/DefaultApprovedSiteService.java | 16 ++++++++++- 9 files changed, 43 insertions(+), 47 deletions(-) diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java index 3c5cb003b..977b97279 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java @@ -59,7 +59,6 @@ import com.nimbusds.jwt.JWTParser; @NamedQuery(name = "OAuth2AccessTokenEntity.getAll", query = "select a from OAuth2AccessTokenEntity a"), @NamedQuery(name = "OAuth2AccessTokenEntity.getByRefreshToken", query = "select a from OAuth2AccessTokenEntity a where a.refreshToken = :refreshToken"), @NamedQuery(name = "OAuth2AccessTokenEntity.getByClient", query = "select a from OAuth2AccessTokenEntity a where a.client = :client"), - @NamedQuery(name = "OAuth2AccessTokenEntity.getExpired", query = "select a from OAuth2AccessTokenEntity a where a.expiration is not null and a.expiration < current_timestamp"), @NamedQuery(name = "OAuth2AccessTokenEntity.getByAuthentication", query = "select a from OAuth2AccessTokenEntity a where a.authenticationHolder.authentication = :authentication"), @NamedQuery(name = "OAuth2AccessTokenEntity.getByIdToken", query = "select a from OAuth2AccessTokenEntity a where a.idToken = :idToken"), @NamedQuery(name = "OAuth2AccessTokenEntity.getByTokenValue", query = "select a from OAuth2AccessTokenEntity a where a.value = :tokenValue") diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2RefreshTokenEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2RefreshTokenEntity.java index c6b4b5864..902ac913e 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2RefreshTokenEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2RefreshTokenEntity.java @@ -50,7 +50,6 @@ import com.nimbusds.jwt.JWTParser; @NamedQueries({ @NamedQuery(name = "OAuth2RefreshTokenEntity.getAll", query = "select r from OAuth2RefreshTokenEntity r"), @NamedQuery(name = "OAuth2RefreshTokenEntity.getByClient", query = "select r from OAuth2RefreshTokenEntity r where r.client = :client"), - @NamedQuery(name = "OAuth2RefreshTokenEntity.getExpired", query = "select r from OAuth2RefreshTokenEntity r where r.expiration is not null and r.expiration < current_timestamp"), @NamedQuery(name = "OAuth2RefreshTokenEntity.getByTokenValue", query = "select r from OAuth2RefreshTokenEntity r where r.value = :tokenValue"), @NamedQuery(name = "OAuth2RefreshTokenEntity.getByAuthentication", query = "select r from OAuth2RefreshTokenEntity r where r.authenticationHolder.authentication = :authentication") }) diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/repository/OAuth2TokenRepository.java b/openid-connect-common/src/main/java/org/mitre/oauth2/repository/OAuth2TokenRepository.java index 496c4935f..35cc2a6ee 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/repository/OAuth2TokenRepository.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/repository/OAuth2TokenRepository.java @@ -49,15 +49,8 @@ public interface OAuth2TokenRepository { public List getRefreshTokensForClient(ClientDetailsEntity client); - public List getExpiredAccessTokens(); - - public List getExpiredRefreshTokens(); - public OAuth2AccessTokenEntity getByAuthentication(OAuth2Authentication auth); - /** - * @return - */ public OAuth2AccessTokenEntity getAccessTokenForIdToken(OAuth2AccessTokenEntity idToken); public Set getAllAccessTokens(); diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/model/ApprovedSite.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/model/ApprovedSite.java index 195f87e34..bab029b3b 100644 --- a/openid-connect-common/src/main/java/org/mitre/openid/connect/model/ApprovedSite.java +++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/model/ApprovedSite.java @@ -41,7 +41,6 @@ import javax.persistence.Transient; @NamedQuery(name = "ApprovedSite.getAll", query = "select a from ApprovedSite a"), @NamedQuery(name = "ApprovedSite.getByUserId", query = "select a from ApprovedSite a where a.userId = :userId"), @NamedQuery(name = "ApprovedSite.getByClientId", query = "select a from ApprovedSite a where a.clientId = :clientId"), - @NamedQuery(name = "ApprovedSite.getExpired", query = "select a from ApprovedSite a where a.timeoutDate is not null and a.timeoutDate < current_timestamp"), @NamedQuery(name = "ApprovedSite.getByClientIdAndUserId", query = "select a from ApprovedSite a where a.clientId = :clientId and a.userId = :userId") }) public class ApprovedSite { diff --git a/openid-connect-common/src/main/java/org/mitre/openid/connect/repository/ApprovedSiteRepository.java b/openid-connect-common/src/main/java/org/mitre/openid/connect/repository/ApprovedSiteRepository.java index 28267f0a0..2a426bb38 100644 --- a/openid-connect-common/src/main/java/org/mitre/openid/connect/repository/ApprovedSiteRepository.java +++ b/openid-connect-common/src/main/java/org/mitre/openid/connect/repository/ApprovedSiteRepository.java @@ -84,9 +84,4 @@ public interface ApprovedSiteRepository { */ public Collection getByClientId(String clientId); - /** - * Get all expired sites - * @return - */ - public Collection getExpired(); } diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java index 535f5ff67..cdfe89987 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java @@ -159,26 +159,6 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { return refreshTokens; } - /* (non-Javadoc) - * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getExpiredAccessTokens() - */ - @Override - public List getExpiredAccessTokens() { - TypedQuery queryA = manager.createNamedQuery("OAuth2AccessTokenEntity.getExpired", OAuth2AccessTokenEntity.class); - List accessTokens = queryA.getResultList(); - return accessTokens; - } - - /* (non-Javadoc) - * @see org.mitre.oauth2.repository.OAuth2TokenRepository#getExpiredRefreshTokens() - */ - @Override - public List getExpiredRefreshTokens() { - TypedQuery queryR = manager.createNamedQuery("OAuth2RefreshTokenEntity.getExpired", OAuth2RefreshTokenEntity.class); - List refreshTokens = queryR.getResultList(); - return refreshTokens; - } - @Override public OAuth2AccessTokenEntity getByAuthentication(OAuth2Authentication auth) { TypedQuery queryA = manager.createNamedQuery("OAuth2AccessTokenEntity.getByAuthentication", OAuth2AccessTokenEntity.class); diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java index c3b25f190..0c69ab0aa 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2ProviderTokenService.java @@ -18,6 +18,7 @@ */ package org.mitre.oauth2.service.impl; +import java.util.Collection; import java.util.Date; import java.util.HashSet; import java.util.List; @@ -45,6 +46,7 @@ import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.token.TokenEnhancer; import org.springframework.stereotype.Service; +import com.google.common.base.Predicate; import com.google.common.collect.Sets; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.PlainJWT; @@ -361,18 +363,40 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi public void clearExpiredTokens() { logger.info("Cleaning out all expired tokens"); - List accessTokens = tokenRepository.getExpiredAccessTokens(); + Collection accessTokens = getExpiredAccessTokens(); logger.info("Found " + accessTokens.size() + " expired access tokens"); for (OAuth2AccessTokenEntity oAuth2AccessTokenEntity : accessTokens) { revokeAccessToken(oAuth2AccessTokenEntity); } - List refreshTokens = tokenRepository.getExpiredRefreshTokens(); + Collection refreshTokens = getExpiredRefreshTokens(); logger.info("Found " + refreshTokens.size() + " expired refresh tokens"); for (OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity : refreshTokens) { revokeRefreshToken(oAuth2RefreshTokenEntity); } } + + private Predicate isAccessTokenExpired = new Predicate() { + @Override + public boolean apply(OAuth2AccessTokenEntity input) { + return (input != null && input.isExpired()); + } + }; + + private Predicate isRefreshTokenExpired = new Predicate() { + @Override + public boolean apply(OAuth2RefreshTokenEntity input) { + return (input != null && input.isExpired()); + } + }; + + private Collection getExpiredAccessTokens() { + return Sets.filter(Sets.newHashSet(tokenRepository.getAllAccessTokens()), isAccessTokenExpired); + } + + private Collection getExpiredRefreshTokens() { + return Sets.filter(Sets.newHashSet(tokenRepository.getAllRefreshTokens()), isRefreshTokenExpired); + } /* (non-Javadoc) * @see org.mitre.oauth2.service.OAuth2TokenEntityService#saveAccessToken(org.mitre.oauth2.model.OAuth2AccessTokenEntity) diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/repository/impl/JpaApprovedSiteRepository.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/repository/impl/JpaApprovedSiteRepository.java index 3d73663e1..4a36cff12 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/repository/impl/JpaApprovedSiteRepository.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/repository/impl/JpaApprovedSiteRepository.java @@ -15,6 +15,8 @@ ******************************************************************************/ package org.mitre.openid.connect.repository.impl; +import static org.mitre.util.jpa.JpaUtil.saveOrUpdate; + import java.util.Collection; import javax.persistence.EntityManager; @@ -26,8 +28,6 @@ import org.mitre.openid.connect.repository.ApprovedSiteRepository; import org.springframework.stereotype.Repository; import org.springframework.transaction.annotation.Transactional; -import static org.mitre.util.jpa.JpaUtil.saveOrUpdate; - /** * JPA ApprovedSite repository implementation * @@ -100,11 +100,4 @@ public class JpaApprovedSiteRepository implements ApprovedSiteRepository { return query.getResultList(); } - - @Override - @Transactional - public Collection getExpired() { - TypedQuery query = manager.createNamedQuery("ApprovedSite.getExpired", ApprovedSite.class); - return query.getResultList(); - } } diff --git a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java index c1a6ea425..22d17ba4e 100644 --- a/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java +++ b/openid-connect-server/src/main/java/org/mitre/openid/connect/service/impl/DefaultApprovedSiteService.java @@ -30,6 +30,9 @@ import org.springframework.security.oauth2.provider.ClientDetails; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import com.google.common.base.Predicate; +import com.google.common.collect.Sets; + /** * Implementation of the ApprovedSiteService * @@ -130,12 +133,23 @@ public class DefaultApprovedSiteService implements ApprovedSiteService { logger.info("Clearing expired approved sites"); - Collection expiredSites = approvedSiteRepository.getExpired(); + Collection expiredSites = getExpired(); if (expiredSites != null) { for (ApprovedSite expired : expiredSites) { approvedSiteRepository.remove(expired); } } } + + private Predicate isExpired = new Predicate() { + @Override + public boolean apply(ApprovedSite input) { + return (input != null && input.isExpired()); + } + }; + + private Collection getExpired() { + return Sets.filter(Sets.newHashSet(approvedSiteRepository.getAll()), isExpired); + } }