From b72baa335fd0959552e0a2e2f711a8f19b88a536 Mon Sep 17 00:00:00 2001 From: rmiccoli Date: Fri, 16 Jun 2023 17:34:59 +0200 Subject: [PATCH] Change query in order to search by AT value hash instead of AT value --- .../oauth2/model/OAuth2AccessTokenEntity.java | 3 +- .../impl/JpaOAuth2TokenRepository.java | 437 ++++++++++-------- 2 files changed, 234 insertions(+), 206 deletions(-) diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java index 7d27eef86..578e701a4 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java @@ -67,7 +67,7 @@ import com.nimbusds.jwt.JWT; @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, query = "select a from OAuth2AccessTokenEntity a where a.expiration <= :" + OAuth2AccessTokenEntity.PARAM_DATE), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, query = "select a from OAuth2AccessTokenEntity a where a.refreshToken = :" + OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_CLIENT, query = "select a from OAuth2AccessTokenEntity a where a.client = :" + OAuth2AccessTokenEntity.PARAM_CLIENT), - @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, query = "select a from OAuth2AccessTokenEntity a where a.jwt = :" + OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE), + @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, query = "select a from OAuth2AccessTokenEntity a where a.tokenValueHash = :" + OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE_HASH), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, query = "select a from OAuth2AccessTokenEntity a where a.approvedSite = :" + OAuth2AccessTokenEntity.PARAM_APPROVED_SITE), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, query = "select a from OAuth2AccessTokenEntity a join a.permissions p where p.resourceSet.id = :" + OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_NAME, query = "select r from OAuth2AccessTokenEntity r where r.authenticationHolder.userAuth.name = :" + OAuth2AccessTokenEntity.PARAM_NAME) @@ -86,6 +86,7 @@ public class OAuth2AccessTokenEntity implements OAuth2AccessToken { public static final String QUERY_BY_NAME = "OAuth2AccessTokenEntity.getByName"; public static final String PARAM_TOKEN_VALUE = "tokenValue"; + public static final String PARAM_TOKEN_VALUE_HASH = "tokenValueHash"; public static final String PARAM_CLIENT = "client"; public static final String PARAM_REFERSH_TOKEN = "refreshToken"; public static final String PARAM_DATE = "date"; diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java index 718a23357..053cde511 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java @@ -17,6 +17,9 @@ *******************************************************************************/ package org.mitre.oauth2.repository.impl; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.text.ParseException; import java.util.ArrayList; import java.util.Date; @@ -44,6 +47,7 @@ import org.mitre.uma.model.ResourceSet; import org.mitre.util.jpa.JpaUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.security.crypto.codec.Hex; import org.springframework.stereotype.Repository; import org.springframework.transaction.annotation.Transactional; @@ -53,236 +57,259 @@ import com.nimbusds.jwt.JWTParser; @Repository public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { - private static final int MAXEXPIREDRESULTS = 1000; + private static final int MAXEXPIREDRESULTS = 1000; - private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class); + private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class); - @PersistenceContext(unitName="defaultPersistenceUnit") - private EntityManager manager; + @PersistenceContext(unitName = "defaultPersistenceUnit") + private EntityManager manager; - @Override - public Set getAllAccessTokens() { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class); - return new LinkedHashSet<>(query.getResultList()); - } + @Override + public Set getAllAccessTokens() { + TypedQuery query = + manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class); + return new LinkedHashSet<>(query.getResultList()); + } - @Override - public Set getAllRefreshTokens() { - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class); - return new LinkedHashSet<>(query.getResultList()); - } + @Override + public Set getAllRefreshTokens() { + TypedQuery query = manager + .createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class); + return new LinkedHashSet<>(query.getResultList()); + } - @Override - public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) { - try { - JWT jwt = JWTParser.parse(accessTokenValue); - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE, jwt); - return JpaUtil.getSingleResult(query.getResultList()); - } catch (ParseException e) { - return null; - } - } + @Override + public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) { + MessageDigest md; + try { + md = MessageDigest.getInstance("SHA-256"); + byte[] hash = md.digest(accessTokenValue.getBytes(StandardCharsets.UTF_8)); + String atHash = new String(Hex.encode(hash)); + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE_HASH, atHash); + return JpaUtil.getSingleResult(query.getResultList()); + } catch (NoSuchAlgorithmException e) { + e.printStackTrace(); + return null; + } + } - @Override - public OAuth2AccessTokenEntity getAccessTokenById(Long id) { - return manager.find(OAuth2AccessTokenEntity.class, id); - } + @Override + public OAuth2AccessTokenEntity getAccessTokenById(Long id) { + return manager.find(OAuth2AccessTokenEntity.class, id); + } - @Override - @Transactional(value="defaultTransactionManager") - public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) { - return JpaUtil.saveOrUpdate(token.getId(), manager, token); - } + @Override + @Transactional(value = "defaultTransactionManager") + public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) { + return JpaUtil.saveOrUpdate(token.getId(), manager, token); + } - @Override - @Transactional(value="defaultTransactionManager") - public void removeAccessToken(OAuth2AccessTokenEntity accessToken) { - OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId()); - if (found != null) { - manager.remove(found); - } else { - throw new IllegalArgumentException("Access token not found: " + accessToken); - } - } + @Override + @Transactional(value = "defaultTransactionManager") + public void removeAccessToken(OAuth2AccessTokenEntity accessToken) { + OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId()); + if (found != null) { + manager.remove(found); + } else { + throw new IllegalArgumentException("Access token not found: " + accessToken); + } + } - @Override - @Transactional(value="defaultTransactionManager") - public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN, refreshToken); - List accessTokens = query.getResultList(); - for (OAuth2AccessTokenEntity accessToken : accessTokens) { - removeAccessToken(accessToken); - } - } + @Override + @Transactional(value = "defaultTransactionManager") + public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) { + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_REFERSH_TOKEN, refreshToken); + List accessTokens = query.getResultList(); + for (OAuth2AccessTokenEntity accessToken : accessTokens) { + removeAccessToken(accessToken); + } + } - @Override - public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) { - try { - JWT jwt = JWTParser.parse(refreshTokenValue); - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class); - query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt); - return JpaUtil.getSingleResult(query.getResultList()); - } catch (ParseException e) { - return null; - } - } + @Override + public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) { + try { + JWT jwt = JWTParser.parse(refreshTokenValue); + TypedQuery query = manager.createNamedQuery( + OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class); + query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt); + return JpaUtil.getSingleResult(query.getResultList()); + } catch (ParseException e) { + return null; + } + } - @Override - public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) { - return manager.find(OAuth2RefreshTokenEntity.class, id); - } + @Override + public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) { + return manager.find(OAuth2RefreshTokenEntity.class, id); + } - @Override - @Transactional(value="defaultTransactionManager") - public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) { - return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken); - } + @Override + @Transactional(value = "defaultTransactionManager") + public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) { + return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken); + } - @Override - @Transactional(value="defaultTransactionManager") - public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) { - OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId()); - if (found != null) { - manager.remove(found); - } else { - throw new IllegalArgumentException("Refresh token not found: " + refreshToken); - } - } + @Override + @Transactional(value = "defaultTransactionManager") + public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) { + OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId()); + if (found != null) { + manager.remove(found); + } else { + throw new IllegalArgumentException("Refresh token not found: " + refreshToken); + } + } - @Override - @Transactional(value="defaultTransactionManager") - public void clearTokensForClient(ClientDetailsEntity client) { - TypedQuery queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); - queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); - List accessTokens = queryA.getResultList(); - for (OAuth2AccessTokenEntity accessToken : accessTokens) { - removeAccessToken(accessToken); - } - TypedQuery queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); - queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); - List refreshTokens = queryR.getResultList(); - for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) { - removeRefreshToken(refreshToken); - } - } + @Override + @Transactional(value = "defaultTransactionManager") + public void clearTokensForClient(ClientDetailsEntity client) { + TypedQuery queryA = manager + .createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); + queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); + List accessTokens = queryA.getResultList(); + for (OAuth2AccessTokenEntity accessToken : accessTokens) { + removeAccessToken(accessToken); + } + TypedQuery queryR = manager + .createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); + queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); + List refreshTokens = queryR.getResultList(); + for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) { + removeRefreshToken(refreshToken); + } + } - @Override - public List getAccessTokensForClient(ClientDetailsEntity client) { - TypedQuery queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); - queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); - List accessTokens = queryA.getResultList(); - return accessTokens; - } + @Override + public List getAccessTokensForClient(ClientDetailsEntity client) { + TypedQuery queryA = manager + .createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); + queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); + List accessTokens = queryA.getResultList(); + return accessTokens; + } - @Override - public List getRefreshTokensForClient(ClientDetailsEntity client) { - TypedQuery queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); - queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); - List refreshTokens = queryR.getResultList(); - return refreshTokens; - } - - @Override - public Set getAccessTokensByUserName(String name) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name); - List results = query.getResultList(); - return results != null ? new HashSet<>(results) : new HashSet<>(); - } - - @Override - public Set getRefreshTokensByUserName(String name) { - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, OAuth2RefreshTokenEntity.class); - query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name); - List results = query.getResultList(); - return results != null ? new HashSet<>(results) : new HashSet<>(); - } + @Override + public List getRefreshTokensForClient(ClientDetailsEntity client) { + TypedQuery queryR = manager + .createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); + queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); + List refreshTokens = queryR.getResultList(); + return refreshTokens; + } - @Override - public Set getAllExpiredAccessTokens() { - DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); - return getAllExpiredAccessTokens(pageCriteria); - } + @Override + public Set getAccessTokensByUserName(String name) { + TypedQuery query = manager + .createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name); + List results = query.getResultList(); + return results != null ? new HashSet<>(results) : new HashSet<>(); + } - @Override - public Set getAllExpiredAccessTokens(PageCriteria pageCriteria) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); - return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); - } + @Override + public Set getRefreshTokensByUserName(String name) { + TypedQuery query = manager + .createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, OAuth2RefreshTokenEntity.class); + query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name); + List results = query.getResultList(); + return results != null ? new HashSet<>(results) : new HashSet<>(); + } - @Override - public Set getAllExpiredRefreshTokens() { - DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); - return getAllExpiredRefreshTokens(pageCriteria); - } + @Override + public Set getAllExpiredAccessTokens() { + DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); + return getAllExpiredAccessTokens(pageCriteria); + } - @Override - public Set getAllExpiredRefreshTokens(PageCriteria pageCriteria) { - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); - return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria)); - } + @Override + public Set getAllExpiredAccessTokens(PageCriteria pageCriteria) { + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); + return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); + } - @Override - public Set getAccessTokensForResourceSet(ResourceSet rs) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId()); - return new LinkedHashSet<>(query.getResultList()); - } + @Override + public Set getAllExpiredRefreshTokens() { + DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); + return getAllExpiredRefreshTokens(pageCriteria); + } - @Override - @Transactional(value="defaultTransactionManager") - public void clearDuplicateAccessTokens() { - Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); - @SuppressWarnings("unchecked") - List resultList = query.getResultList(); - List values = new ArrayList<>(); - for (Object[] r : resultList) { - logger.warn("Found duplicate access tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); - values.add((JWT) r[0]); - } - if (values.size() > 0) { - CriteriaBuilder cb = manager.getCriteriaBuilder(); - CriteriaDelete criteriaDelete = cb.createCriteriaDelete(OAuth2AccessTokenEntity.class); - Root root = criteriaDelete.from(OAuth2AccessTokenEntity.class); - criteriaDelete.where(root.get("jwt").in(values)); - int result = manager.createQuery(criteriaDelete).executeUpdate(); - logger.warn("Deleted {} duplicate access tokens", result); - } - } + @Override + public Set getAllExpiredRefreshTokens(PageCriteria pageCriteria) { + TypedQuery query = manager.createNamedQuery( + OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); + return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); + } - @Override - @Transactional(value="defaultTransactionManager") - public void clearDuplicateRefreshTokens() { - Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); - @SuppressWarnings("unchecked") - List resultList = query.getResultList(); - List values = new ArrayList<>(); - for (Object[] r : resultList) { - logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); - values.add((JWT) r[0]); - } - if (values.size() > 0) { - CriteriaBuilder cb = manager.getCriteriaBuilder(); - CriteriaDelete criteriaDelete = cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class); - Root root = criteriaDelete.from(OAuth2RefreshTokenEntity.class); - criteriaDelete.where(root.get("jwt").in(values)); - int result = manager.createQuery(criteriaDelete).executeUpdate(); - logger.warn("Deleted {} duplicate refresh tokens", result); - } + @Override + public Set getAccessTokensForResourceSet(ResourceSet rs) { + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId()); + return new LinkedHashSet<>(query.getResultList()); + } - } + @Override + @Transactional(value = "defaultTransactionManager") + public void clearDuplicateAccessTokens() { + Query query = manager.createQuery( + "select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); + @SuppressWarnings("unchecked") + List resultList = query.getResultList(); + List values = new ArrayList<>(); + for (Object[] r : resultList) { + logger.warn("Found duplicate access tokens: {}, {}", ((JWT) r[0]).serialize(), r[1]); + values.add((JWT) r[0]); + } + if (values.size() > 0) { + CriteriaBuilder cb = manager.getCriteriaBuilder(); + CriteriaDelete criteriaDelete = + cb.createCriteriaDelete(OAuth2AccessTokenEntity.class); + Root root = criteriaDelete.from(OAuth2AccessTokenEntity.class); + criteriaDelete.where(root.get("jwt").in(values)); + int result = manager.createQuery(criteriaDelete).executeUpdate(); + logger.warn("Deleted {} duplicate access tokens", result); + } + } - @Override - public List getAccessTokensForApprovedSite(ApprovedSite approvedSite) { - TypedQuery queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class); - queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite); - List accessTokens = queryA.getResultList(); - return accessTokens; - } + @Override + @Transactional(value = "defaultTransactionManager") + public void clearDuplicateRefreshTokens() { + Query query = manager.createQuery( + "select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); + @SuppressWarnings("unchecked") + List resultList = query.getResultList(); + List values = new ArrayList<>(); + for (Object[] r : resultList) { + logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT) r[0]).serialize(), r[1]); + values.add((JWT) r[0]); + } + if (values.size() > 0) { + CriteriaBuilder cb = manager.getCriteriaBuilder(); + CriteriaDelete criteriaDelete = + cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class); + Root root = criteriaDelete.from(OAuth2RefreshTokenEntity.class); + criteriaDelete.where(root.get("jwt").in(values)); + int result = manager.createQuery(criteriaDelete).executeUpdate(); + logger.warn("Deleted {} duplicate refresh tokens", result); + } + + } + + @Override + public List getAccessTokensForApprovedSite(ApprovedSite approvedSite) { + TypedQuery queryA = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class); + queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite); + List accessTokens = queryA.getResultList(); + return accessTokens; + } }