diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java index f2655977b..1d60fc872 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/OAuth2AccessTokenEntity.java @@ -67,7 +67,7 @@ import com.nimbusds.jwt.JWT; @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, query = "select a from OAuth2AccessTokenEntity a where a.expiration <= :" + OAuth2AccessTokenEntity.PARAM_DATE), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_REFRESH_TOKEN, query = "select a from OAuth2AccessTokenEntity a where a.refreshToken = :" + OAuth2AccessTokenEntity.PARAM_REFRESH_TOKEN), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_CLIENT, query = "select a from OAuth2AccessTokenEntity a where a.client = :" + OAuth2AccessTokenEntity.PARAM_CLIENT), - @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, query = "select a from OAuth2AccessTokenEntity a where a.jwt = :" + OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE), + @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE_HASH, query = "select a from OAuth2AccessTokenEntity a where a.tokenValueHash = :" + OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE_HASH), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, query = "select a from OAuth2AccessTokenEntity a where a.approvedSite = :" + OAuth2AccessTokenEntity.PARAM_APPROVED_SITE), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, query = "select a from OAuth2AccessTokenEntity a join a.permissions p where p.resourceSet.id = :" + OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID), @NamedQuery(name = OAuth2AccessTokenEntity.QUERY_BY_NAME, query = "select r from OAuth2AccessTokenEntity r where r.authenticationHolder.userAuth.name = :" + OAuth2AccessTokenEntity.PARAM_NAME), @@ -78,7 +78,7 @@ import com.nimbusds.jwt.JWT; public class OAuth2AccessTokenEntity implements OAuth2AccessToken { public static final String QUERY_BY_APPROVED_SITE = "OAuth2AccessTokenEntity.getByApprovedSite"; - public static final String QUERY_BY_TOKEN_VALUE = "OAuth2AccessTokenEntity.getByTokenValue"; + public static final String QUERY_BY_TOKEN_VALUE_HASH = "OAuth2AccessTokenEntity.getByTokenValue"; public static final String QUERY_BY_CLIENT = "OAuth2AccessTokenEntity.getByClient"; public static final String QUERY_BY_REFRESH_TOKEN = "OAuth2AccessTokenEntity.getByRefreshToken"; public static final String QUERY_EXPIRED_BY_DATE = "OAuth2AccessTokenEntity.getAllExpiredByDate"; @@ -87,7 +87,7 @@ public class OAuth2AccessTokenEntity implements OAuth2AccessToken { public static final String QUERY_BY_NAME = "OAuth2AccessTokenEntity.getByName"; public static final String DELETE_BY_REFRESH_TOKEN = "OAuth2AccessTokenEntity.deleteByRefreshToken"; - public static final String PARAM_TOKEN_VALUE = "tokenValue"; + public static final String PARAM_TOKEN_VALUE_HASH = "tokenValueHash"; public static final String PARAM_CLIENT = "client"; public static final String PARAM_REFRESH_TOKEN = "refreshToken"; public static final String PARAM_DATE = "date"; @@ -105,6 +105,8 @@ public class OAuth2AccessTokenEntity implements OAuth2AccessToken { private JWT jwtValue; // JWT-encoded access token value + private String tokenValueHash; // hash of access token value + private Date expiration; private String tokenType = OAuth2AccessToken.BEARER_TYPE; @@ -274,6 +276,19 @@ public class OAuth2AccessTokenEntity implements OAuth2AccessToken { this.jwtValue = jwt; } + /** + * @return the tokenValueHash + */ + @Basic + @Column(name="token_value_hash") + public String getTokenValueHash() { + return tokenValueHash; + } + + public void setTokenValueHash(String hash) { + this.tokenValueHash = hash; + } + @Override @Transient public int getExpiresIn() { diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java index fc72528c6..72eca59ae 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.java @@ -17,6 +17,9 @@ *******************************************************************************/ package org.mitre.oauth2.repository.impl; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; import java.text.ParseException; import java.util.ArrayList; import java.util.Date; @@ -44,6 +47,7 @@ import org.mitre.uma.model.ResourceSet; import org.mitre.util.jpa.JpaUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.security.crypto.codec.Hex; import org.springframework.stereotype.Repository; import org.springframework.transaction.annotation.Transactional; @@ -55,32 +59,47 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { private static final int MAXEXPIREDRESULTS = 1000; - private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class); + private static final Logger logger = + LoggerFactory.getLogger(JpaOAuth2TokenRepository.class); - @PersistenceContext(unitName="defaultPersistenceUnit") + @PersistenceContext(unitName = "defaultPersistenceUnit") private EntityManager manager; @Override public Set getAllAccessTokens() { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, OAuth2AccessTokenEntity.class); + TypedQuery query = + manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_ALL, + OAuth2AccessTokenEntity.class); return new LinkedHashSet<>(query.getResultList()); } @Override public Set getAllRefreshTokens() { - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, OAuth2RefreshTokenEntity.class); + TypedQuery query = + manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_ALL, + OAuth2RefreshTokenEntity.class); return new LinkedHashSet<>(query.getResultList()); } @Override - public OAuth2AccessTokenEntity getAccessTokenByValue(String accessTokenValue) { + public OAuth2AccessTokenEntity getAccessTokenByValue( + String accessTokenValue) { + MessageDigest md; try { - JWT jwt = JWTParser.parse(accessTokenValue); - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE, jwt); + md = MessageDigest.getInstance("SHA-256"); + byte[] hash = md + .digest(accessTokenValue.getBytes(StandardCharsets.UTF_8)); + String atHash = new String(Hex.encode(hash)); + TypedQuery query = + manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_TOKEN_VALUE_HASH, + OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_TOKEN_VALUE_HASH, + atHash); return JpaUtil.getSingleResult(query.getResultList()); - } catch (ParseException e) { + } catch (NoSuchAlgorithmException e) { + e.printStackTrace(); return null; } } @@ -91,35 +110,44 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { } @Override - @Transactional(value="defaultTransactionManager") - public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity token) { + @Transactional(value = "defaultTransactionManager") + public OAuth2AccessTokenEntity saveAccessToken( + OAuth2AccessTokenEntity token) { return JpaUtil.saveOrUpdate(token.getId(), manager, token); } @Override - @Transactional(value="defaultTransactionManager") + @Transactional(value = "defaultTransactionManager") public void removeAccessToken(OAuth2AccessTokenEntity accessToken) { OAuth2AccessTokenEntity found = getAccessTokenById(accessToken.getId()); if (found != null) { manager.remove(found); } else { - throw new IllegalArgumentException("Access token not found: " + accessToken); + throw new IllegalArgumentException( + "Access token not found: " + accessToken); } } @Override - @Transactional(value="defaultTransactionManager") - public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity refreshToken) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.DELETE_BY_REFRESH_TOKEN, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_REFRESH_TOKEN, refreshToken); + @Transactional(value = "defaultTransactionManager") + public void clearAccessTokensForRefreshToken( + OAuth2RefreshTokenEntity refreshToken) { + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.DELETE_BY_REFRESH_TOKEN, + OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_REFRESH_TOKEN, + refreshToken); query.executeUpdate(); } @Override - public OAuth2RefreshTokenEntity getRefreshTokenByValue(String refreshTokenValue) { + public OAuth2RefreshTokenEntity getRefreshTokenByValue( + String refreshTokenValue) { try { JWT jwt = JWTParser.parse(refreshTokenValue); - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, OAuth2RefreshTokenEntity.class); + TypedQuery query = manager + .createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_TOKEN_VALUE, + OAuth2RefreshTokenEntity.class); query.setParameter(OAuth2RefreshTokenEntity.PARAM_TOKEN_VALUE, jwt); return JpaUtil.getSingleResult(query.getResultList()); } catch (ParseException e) { @@ -133,32 +161,40 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { } @Override - @Transactional(value="defaultTransactionManager") - public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity refreshToken) { - return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, refreshToken); + @Transactional(value = "defaultTransactionManager") + public OAuth2RefreshTokenEntity saveRefreshToken( + OAuth2RefreshTokenEntity refreshToken) { + return JpaUtil.saveOrUpdate(refreshToken.getId(), manager, + refreshToken); } @Override - @Transactional(value="defaultTransactionManager") + @Transactional(value = "defaultTransactionManager") public void removeRefreshToken(OAuth2RefreshTokenEntity refreshToken) { - OAuth2RefreshTokenEntity found = getRefreshTokenById(refreshToken.getId()); + OAuth2RefreshTokenEntity found = + getRefreshTokenById(refreshToken.getId()); if (found != null) { manager.remove(found); } else { - throw new IllegalArgumentException("Refresh token not found: " + refreshToken); + throw new IllegalArgumentException( + "Refresh token not found: " + refreshToken); } } @Override - @Transactional(value="defaultTransactionManager") + @Transactional(value = "defaultTransactionManager") public void clearTokensForClient(ClientDetailsEntity client) { - TypedQuery queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); + TypedQuery queryA = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_CLIENT, + OAuth2AccessTokenEntity.class); queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); List accessTokens = queryA.getResultList(); for (OAuth2AccessTokenEntity accessToken : accessTokens) { removeAccessToken(accessToken); } - TypedQuery queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); + TypedQuery queryR = manager.createNamedQuery( + OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, + OAuth2RefreshTokenEntity.class); queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); List refreshTokens = queryR.getResultList(); for (OAuth2RefreshTokenEntity refreshToken : refreshTokens) { @@ -167,85 +203,112 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { } @Override - public List getAccessTokensForClient(ClientDetailsEntity client) { - TypedQuery queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_CLIENT, OAuth2AccessTokenEntity.class); + public List getAccessTokensForClient( + ClientDetailsEntity client) { + TypedQuery queryA = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_CLIENT, + OAuth2AccessTokenEntity.class); queryA.setParameter(OAuth2AccessTokenEntity.PARAM_CLIENT, client); List accessTokens = queryA.getResultList(); return accessTokens; } @Override - public List getRefreshTokensForClient(ClientDetailsEntity client) { - TypedQuery queryR = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, OAuth2RefreshTokenEntity.class); + public List getRefreshTokensForClient( + ClientDetailsEntity client) { + TypedQuery queryR = manager.createNamedQuery( + OAuth2RefreshTokenEntity.QUERY_BY_CLIENT, + OAuth2RefreshTokenEntity.class); queryR.setParameter(OAuth2RefreshTokenEntity.PARAM_CLIENT, client); List refreshTokens = queryR.getResultList(); return refreshTokens; } - + @Override public Set getAccessTokensByUserName(String name) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name); - List results = query.getResultList(); - return results != null ? new HashSet<>(results) : new HashSet<>(); + TypedQuery query = + manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_NAME, + OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_NAME, name); + List results = query.getResultList(); + return results != null ? new HashSet<>(results) : new HashSet<>(); } - + @Override - public Set getRefreshTokensByUserName(String name) { - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, OAuth2RefreshTokenEntity.class); - query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name); - List results = query.getResultList(); - return results != null ? new HashSet<>(results) : new HashSet<>(); + public Set getRefreshTokensByUserName( + String name) { + TypedQuery query = + manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_BY_NAME, + OAuth2RefreshTokenEntity.class); + query.setParameter(OAuth2RefreshTokenEntity.PARAM_NAME, name); + List results = query.getResultList(); + return results != null ? new HashSet<>(results) : new HashSet<>(); } @Override public Set getAllExpiredAccessTokens() { - DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); + DefaultPageCriteria pageCriteria = + new DefaultPageCriteria(0, MAXEXPIREDRESULTS); return getAllExpiredAccessTokens(pageCriteria); } @Override - public Set getAllExpiredAccessTokens(PageCriteria pageCriteria) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2AccessTokenEntity.class); + public Set getAllExpiredAccessTokens( + PageCriteria pageCriteria) { + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_EXPIRED_BY_DATE, + OAuth2AccessTokenEntity.class); query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); } @Override public Set getAllExpiredRefreshTokens() { - DefaultPageCriteria pageCriteria = new DefaultPageCriteria(0, MAXEXPIREDRESULTS); + DefaultPageCriteria pageCriteria = + new DefaultPageCriteria(0, MAXEXPIREDRESULTS); return getAllExpiredRefreshTokens(pageCriteria); } @Override - public Set getAllExpiredRefreshTokens(PageCriteria pageCriteria) { - TypedQuery query = manager.createNamedQuery(OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, OAuth2RefreshTokenEntity.class); + public Set getAllExpiredRefreshTokens( + PageCriteria pageCriteria) { + TypedQuery query = manager.createNamedQuery( + OAuth2RefreshTokenEntity.QUERY_EXPIRED_BY_DATE, + OAuth2RefreshTokenEntity.class); query.setParameter(OAuth2AccessTokenEntity.PARAM_DATE, new Date()); - return new LinkedHashSet<>(JpaUtil.getResultPage(query,pageCriteria)); + return new LinkedHashSet<>(JpaUtil.getResultPage(query, pageCriteria)); } @Override - public Set getAccessTokensForResourceSet(ResourceSet rs) { - TypedQuery query = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, OAuth2AccessTokenEntity.class); - query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, rs.getId()); + public Set getAccessTokensForResourceSet( + ResourceSet rs) { + TypedQuery query = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_RESOURCE_SET, + OAuth2AccessTokenEntity.class); + query.setParameter(OAuth2AccessTokenEntity.PARAM_RESOURCE_SET_ID, + rs.getId()); return new LinkedHashSet<>(query.getResultList()); } @Override - @Transactional(value="defaultTransactionManager") + @Transactional(value = "defaultTransactionManager") public void clearDuplicateAccessTokens() { - Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); + Query query = manager.createQuery( + "select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); @SuppressWarnings("unchecked") List resultList = query.getResultList(); List values = new ArrayList<>(); for (Object[] r : resultList) { - logger.warn("Found duplicate access tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); + logger.warn("Found duplicate access tokens: {}, {}", + ((JWT) r[0]).serialize(), r[1]); values.add((JWT) r[0]); } if (values.size() > 0) { CriteriaBuilder cb = manager.getCriteriaBuilder(); - CriteriaDelete criteriaDelete = cb.createCriteriaDelete(OAuth2AccessTokenEntity.class); - Root root = criteriaDelete.from(OAuth2AccessTokenEntity.class); + CriteriaDelete criteriaDelete = + cb.createCriteriaDelete(OAuth2AccessTokenEntity.class); + Root root = + criteriaDelete.from(OAuth2AccessTokenEntity.class); criteriaDelete.where(root.get("jwt").in(values)); int result = manager.createQuery(criteriaDelete).executeUpdate(); logger.warn("Deleted {} duplicate access tokens", result); @@ -253,20 +316,24 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { } @Override - @Transactional(value="defaultTransactionManager") + @Transactional(value = "defaultTransactionManager") public void clearDuplicateRefreshTokens() { - Query query = manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); + Query query = manager.createQuery( + "select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1"); @SuppressWarnings("unchecked") List resultList = query.getResultList(); List values = new ArrayList<>(); for (Object[] r : resultList) { - logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT)r[0]).serialize(), r[1]); + logger.warn("Found duplicate refresh tokens: {}, {}", + ((JWT) r[0]).serialize(), r[1]); values.add((JWT) r[0]); } if (values.size() > 0) { CriteriaBuilder cb = manager.getCriteriaBuilder(); - CriteriaDelete criteriaDelete = cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class); - Root root = criteriaDelete.from(OAuth2RefreshTokenEntity.class); + CriteriaDelete criteriaDelete = + cb.createCriteriaDelete(OAuth2RefreshTokenEntity.class); + Root root = + criteriaDelete.from(OAuth2RefreshTokenEntity.class); criteriaDelete.where(root.get("jwt").in(values)); int result = manager.createQuery(criteriaDelete).executeUpdate(); logger.warn("Deleted {} duplicate refresh tokens", result); @@ -275,9 +342,13 @@ public class JpaOAuth2TokenRepository implements OAuth2TokenRepository { } @Override - public List getAccessTokensForApprovedSite(ApprovedSite approvedSite) { - TypedQuery queryA = manager.createNamedQuery(OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, OAuth2AccessTokenEntity.class); - queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, approvedSite); + public List getAccessTokensForApprovedSite( + ApprovedSite approvedSite) { + TypedQuery queryA = manager.createNamedQuery( + OAuth2AccessTokenEntity.QUERY_BY_APPROVED_SITE, + OAuth2AccessTokenEntity.class); + queryA.setParameter(OAuth2AccessTokenEntity.PARAM_APPROVED_SITE, + approvedSite); List accessTokens = queryA.getResultList(); return accessTokens; }