Fixed token expiration bug by removing jsql queries. Instead expired tokens or approved sites are filtered at the repository level

Moved getExpired to service layers

Used Predicates to filter expired tokens and approved sites;
pull/650/head
Amanda Anganes 2013-08-06 11:28:13 -04:00 committed by Justin Richer
parent 3134c34606
commit f58141e6a7
9 changed files with 43 additions and 47 deletions

View File

@ -59,7 +59,6 @@ import com.nimbusds.jwt.JWTParser;
@NamedQuery(name = "OAuth2AccessTokenEntity.getAll", query = "select a from OAuth2AccessTokenEntity a"),
@NamedQuery(name = "OAuth2AccessTokenEntity.getByRefreshToken", query = "select a from OAuth2AccessTokenEntity a where a.refreshToken = :refreshToken"),
@NamedQuery(name = "OAuth2AccessTokenEntity.getByClient", query = "select a from OAuth2AccessTokenEntity a where a.client = :client"),
@NamedQuery(name = "OAuth2AccessTokenEntity.getExpired", query = "select a from OAuth2AccessTokenEntity a where a.expiration is not null and a.expiration < current_timestamp"),
@NamedQuery(name = "OAuth2AccessTokenEntity.getByAuthentication", query = "select a from OAuth2AccessTokenEntity a where a.authenticationHolder.authentication = :authentication"),
@NamedQuery(name = "OAuth2AccessTokenEntity.getByIdToken", query = "select a from OAuth2AccessTokenEntity a where a.idToken = :idToken"),
@NamedQuery(name = "OAuth2AccessTokenEntity.getByTokenValue", query = "select a from OAuth2AccessTokenEntity a where a.value = :tokenValue")

View File

@ -50,7 +50,6 @@ import com.nimbusds.jwt.JWTParser;
@NamedQueries({
@NamedQuery(name = "OAuth2RefreshTokenEntity.getAll", query = "select r from OAuth2RefreshTokenEntity r"),
@NamedQuery(name = "OAuth2RefreshTokenEntity.getByClient", query = "select r from OAuth2RefreshTokenEntity r where r.client = :client"),
@NamedQuery(name = "OAuth2RefreshTokenEntity.getExpired", query = "select r from OAuth2RefreshTokenEntity r where r.expiration is not null and r.expiration < current_timestamp"),
@NamedQuery(name = "OAuth2RefreshTokenEntity.getByTokenValue", query = "select r from OAuth2RefreshTokenEntity r where r.value = :tokenValue"),
@NamedQuery(name = "OAuth2RefreshTokenEntity.getByAuthentication", query = "select r from OAuth2RefreshTokenEntity r where r.authenticationHolder.authentication = :authentication")
})

View File

@ -49,15 +49,8 @@ public interface OAuth2TokenRepository {
public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity client);
public List<OAuth2AccessTokenEntity> getExpiredAccessTokens();
public List<OAuth2RefreshTokenEntity> getExpiredRefreshTokens();
public OAuth2AccessTokenEntity getByAuthentication(OAuth2Authentication auth);
/**
* @return
*/
public OAuth2AccessTokenEntity getAccessTokenForIdToken(OAuth2AccessTokenEntity idToken);
public Set<OAuth2AccessTokenEntity> getAllAccessTokens();

View File

@ -41,7 +41,6 @@ import javax.persistence.Transient;
@NamedQuery(name = "ApprovedSite.getAll", query = "select a from ApprovedSite a"),
@NamedQuery(name = "ApprovedSite.getByUserId", query = "select a from ApprovedSite a where a.userId = :userId"),
@NamedQuery(name = "ApprovedSite.getByClientId", query = "select a from ApprovedSite a where a.clientId = :clientId"),
@NamedQuery(name = "ApprovedSite.getExpired", query = "select a from ApprovedSite a where a.timeoutDate is not null and a.timeoutDate < current_timestamp"),
@NamedQuery(name = "ApprovedSite.getByClientIdAndUserId", query = "select a from ApprovedSite a where a.clientId = :clientId and a.userId = :userId")
})
public class ApprovedSite {

View File

@ -84,9 +84,4 @@ public interface ApprovedSiteRepository {
*/
public Collection<ApprovedSite> getByClientId(String clientId);
/**
* Get all expired sites
* @return
*/
public Collection<ApprovedSite> getExpired();
}

View File

@ -159,26 +159,6 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository {
return refreshTokens;
}
/* (non-Javadoc)
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#getExpiredAccessTokens()
*/
@Override
public List<OAuth2AccessTokenEntity> getExpiredAccessTokens() {
TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery("OAuth2AccessTokenEntity.getExpired", OAuth2AccessTokenEntity.class);
List<OAuth2AccessTokenEntity> accessTokens = queryA.getResultList();
return accessTokens;
}
/* (non-Javadoc)
* @see org.mitre.oauth2.repository.OAuth2TokenRepository#getExpiredRefreshTokens()
*/
@Override
public List<OAuth2RefreshTokenEntity> getExpiredRefreshTokens() {
TypedQuery<OAuth2RefreshTokenEntity> queryR = manager.createNamedQuery("OAuth2RefreshTokenEntity.getExpired", OAuth2RefreshTokenEntity.class);
List<OAuth2RefreshTokenEntity> refreshTokens = queryR.getResultList();
return refreshTokens;
}
@Override
public OAuth2AccessTokenEntity getByAuthentication(OAuth2Authentication auth) {
TypedQuery<OAuth2AccessTokenEntity> queryA = manager.createNamedQuery("OAuth2AccessTokenEntity.getByAuthentication", OAuth2AccessTokenEntity.class);

View File

@ -18,6 +18,7 @@
*/
package org.mitre.oauth2.service.impl;
import java.util.Collection;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
@ -45,6 +46,7 @@ import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.TokenEnhancer;
import org.springframework.stereotype.Service;
import com.google.common.base.Predicate;
import com.google.common.collect.Sets;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.PlainJWT;
@ -361,18 +363,40 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
public void clearExpiredTokens() {
logger.info("Cleaning out all expired tokens");
List<OAuth2AccessTokenEntity> accessTokens = tokenRepository.getExpiredAccessTokens();
Collection<OAuth2AccessTokenEntity> accessTokens = getExpiredAccessTokens();
logger.info("Found " + accessTokens.size() + " expired access tokens");
for (OAuth2AccessTokenEntity oAuth2AccessTokenEntity : accessTokens) {
revokeAccessToken(oAuth2AccessTokenEntity);
}
List<OAuth2RefreshTokenEntity> refreshTokens = tokenRepository.getExpiredRefreshTokens();
Collection<OAuth2RefreshTokenEntity> refreshTokens = getExpiredRefreshTokens();
logger.info("Found " + refreshTokens.size() + " expired refresh tokens");
for (OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity : refreshTokens) {
revokeRefreshToken(oAuth2RefreshTokenEntity);
}
}
private Predicate<OAuth2AccessTokenEntity> isAccessTokenExpired = new Predicate<OAuth2AccessTokenEntity>() {
@Override
public boolean apply(OAuth2AccessTokenEntity input) {
return (input != null && input.isExpired());
}
};
private Predicate<OAuth2RefreshTokenEntity> isRefreshTokenExpired = new Predicate<OAuth2RefreshTokenEntity>() {
@Override
public boolean apply(OAuth2RefreshTokenEntity input) {
return (input != null && input.isExpired());
}
};
private Collection<OAuth2AccessTokenEntity> getExpiredAccessTokens() {
return Sets.filter(Sets.newHashSet(tokenRepository.getAllAccessTokens()), isAccessTokenExpired);
}
private Collection<OAuth2RefreshTokenEntity> getExpiredRefreshTokens() {
return Sets.filter(Sets.newHashSet(tokenRepository.getAllRefreshTokens()), isRefreshTokenExpired);
}
/* (non-Javadoc)
* @see org.mitre.oauth2.service.OAuth2TokenEntityService#saveAccessToken(org.mitre.oauth2.model.OAuth2AccessTokenEntity)

View File

@ -15,6 +15,8 @@
******************************************************************************/
package org.mitre.openid.connect.repository.impl;
import static org.mitre.util.jpa.JpaUtil.saveOrUpdate;
import java.util.Collection;
import javax.persistence.EntityManager;
@ -26,8 +28,6 @@ import org.mitre.openid.connect.repository.ApprovedSiteRepository;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;
import static org.mitre.util.jpa.JpaUtil.saveOrUpdate;
/**
* JPA ApprovedSite repository implementation
*
@ -100,11 +100,4 @@ public class JpaApprovedSiteRepository implements ApprovedSiteRepository {
return query.getResultList();
}
@Override
@Transactional
public Collection<ApprovedSite> getExpired() {
TypedQuery<ApprovedSite> query = manager.createNamedQuery("ApprovedSite.getExpired", ApprovedSite.class);
return query.getResultList();
}
}

View File

@ -30,6 +30,9 @@ import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import com.google.common.base.Predicate;
import com.google.common.collect.Sets;
/**
* Implementation of the ApprovedSiteService
*
@ -130,12 +133,23 @@ public class DefaultApprovedSiteService implements ApprovedSiteService {
logger.info("Clearing expired approved sites");
Collection<ApprovedSite> expiredSites = approvedSiteRepository.getExpired();
Collection<ApprovedSite> expiredSites = getExpired();
if (expiredSites != null) {
for (ApprovedSite expired : expiredSites) {
approvedSiteRepository.remove(expired);
}
}
}
private Predicate<ApprovedSite> isExpired = new Predicate<ApprovedSite>() {
@Override
public boolean apply(ApprovedSite input) {
return (input != null && input.isExpired());
}
};
private Collection<ApprovedSite> getExpired() {
return Sets.filter(Sets.newHashSet(approvedSiteRepository.getAll()), isExpired);
}
}