From 4f12fab56bf424cf9a176836c50bda524e5bfea1 Mon Sep 17 00:00:00 2001 From: Justin Richer Date: Fri, 13 Mar 2015 13:45:49 -0400 Subject: [PATCH] made unused auth codes expired (they're still single-use), refactored auth code service layer --- .../model/AuthenticationHolderEntity.java | 5 +- .../oauth2/model/AuthorizationCodeEntity.java | 48 +++++++++---- .../AuthorizationCodeRepository.java | 21 ++++-- .../db/tables/hsql_database_tables.sql | 3 +- .../db/tables/mysql_database_tables.sql | 3 +- .../src/main/webapp/WEB-INF/task-config.xml | 1 + .../impl/JpaAuthorizationCodeRepository.java | 45 +++++++----- ...DefaultOAuth2AuthorizationCodeService.java | 69 ++++++++++++++++++- 8 files changed, 153 insertions(+), 42 deletions(-) diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java index 5193e88f9..3d3151bd8 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthenticationHolderEntity.java @@ -34,7 +34,10 @@ import org.springframework.security.oauth2.provider.OAuth2Authentication; @Table(name = "authentication_holder") @NamedQueries ({ @NamedQuery(name = AuthenticationHolderEntity.QUERY_ALL, query = "select a from AuthenticationHolderEntity a"), - @NamedQuery(name = AuthenticationHolderEntity.QUERY_GET_UNUSED, query = "select a from AuthenticationHolderEntity a where a.id not in (select t.authenticationHolder.id from OAuth2AccessTokenEntity t) and a.id not in (select r.authenticationHolder.id from OAuth2RefreshTokenEntity r)") + @NamedQuery(name = AuthenticationHolderEntity.QUERY_GET_UNUSED, query = "select a from AuthenticationHolderEntity a where " + + "a.id not in (select t.authenticationHolder.id from OAuth2AccessTokenEntity t) and " + + "a.id not in (select r.authenticationHolder.id from OAuth2RefreshTokenEntity r) and " + + "a.id not in (select c.authenticationHolder.id from AuthorizationCodeEntity c)") }) public class AuthenticationHolderEntity { diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthorizationCodeEntity.java b/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthorizationCodeEntity.java index f93b03475..2e636369e 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthorizationCodeEntity.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/model/AuthorizationCodeEntity.java @@ -16,19 +16,20 @@ *******************************************************************************/ package org.mitre.oauth2.model; +import java.util.Date; + import javax.persistence.Basic; import javax.persistence.Column; import javax.persistence.Entity; -import javax.persistence.FetchType; import javax.persistence.GeneratedValue; import javax.persistence.GenerationType; import javax.persistence.Id; -import javax.persistence.Lob; +import javax.persistence.JoinColumn; +import javax.persistence.ManyToOne; import javax.persistence.NamedQueries; import javax.persistence.NamedQuery; import javax.persistence.Table; - -import org.springframework.security.oauth2.provider.OAuth2Authentication; +import javax.persistence.Temporal; /** * Entity class for authorization codes @@ -39,17 +40,23 @@ import org.springframework.security.oauth2.provider.OAuth2Authentication; @Entity @Table(name = "authorization_code") @NamedQueries({ - @NamedQuery(name = AuthorizationCodeEntity.QUERY_BY_VALUE, query = "select a from AuthorizationCodeEntity a where a.code = :code") + @NamedQuery(name = AuthorizationCodeEntity.QUERY_BY_VALUE, query = "select a from AuthorizationCodeEntity a where a.code = :code"), + @NamedQuery(name = AuthorizationCodeEntity.QUERY_EXPIRATION_BY_DATE, query = "select a from AuthorizationCodeEntity a where a.expiration <= :" + AuthorizationCodeEntity.PARAM_DATE) }) public class AuthorizationCodeEntity { public static final String QUERY_BY_VALUE = "AuthorizationCodeEntity.getByValue"; + public static final String QUERY_EXPIRATION_BY_DATE = "AuthorizationCodeEntity.expirationByDate"; + + public static final String PARAM_DATE = "date"; private Long id; private String code; - private OAuth2Authentication authentication; + private AuthenticationHolderEntity authenticationHolder; + + private Date expiration; /** * Default constructor. @@ -64,9 +71,10 @@ public class AuthorizationCodeEntity { * @param code the authorization code * @param authRequest the AuthoriztionRequestHolder associated with the original code request */ - public AuthorizationCodeEntity(String code, OAuth2Authentication authRequest) { + public AuthorizationCodeEntity(String code, AuthenticationHolderEntity authenticationHolder, Date expiration) { this.code = code; - this.authentication = authRequest; + this.authenticationHolder = authenticationHolder; + this.expiration = expiration; } /** @@ -103,20 +111,30 @@ public class AuthorizationCodeEntity { } /** + * The authentication in place when this token was created. * @return the authentication */ - @Lob - @Basic(fetch=FetchType.EAGER) - @Column(name="authentication") - public OAuth2Authentication getAuthentication() { - return authentication; + @ManyToOne + @JoinColumn(name = "auth_holder_id") + public AuthenticationHolderEntity getAuthenticationHolder() { + return authenticationHolder; } /** * @param authentication the authentication to set */ - public void setAuthentication(OAuth2Authentication authentication) { - this.authentication = authentication; + public void setAuthenticationHolder(AuthenticationHolderEntity authenticationHolder) { + this.authenticationHolder = authenticationHolder; } + @Basic + @Temporal(javax.persistence.TemporalType.TIMESTAMP) + @Column(name = "expiration") + public Date getExpiration() { + return expiration; + } + + public void setExpiration(Date expiration) { + this.expiration = expiration; + } } diff --git a/openid-connect-common/src/main/java/org/mitre/oauth2/repository/AuthorizationCodeRepository.java b/openid-connect-common/src/main/java/org/mitre/oauth2/repository/AuthorizationCodeRepository.java index b5e98a123..dbdaa4e04 100644 --- a/openid-connect-common/src/main/java/org/mitre/oauth2/repository/AuthorizationCodeRepository.java +++ b/openid-connect-common/src/main/java/org/mitre/oauth2/repository/AuthorizationCodeRepository.java @@ -16,9 +16,9 @@ *******************************************************************************/ package org.mitre.oauth2.repository; +import java.util.Collection; + import org.mitre.oauth2.model.AuthorizationCodeEntity; -import org.springframework.security.oauth2.common.exceptions.InvalidGrantException; -import org.springframework.security.oauth2.provider.OAuth2Authentication; /** * Interface for saving and consuming OAuth2 authorization codes as AuthorizationCodeEntitys. @@ -37,12 +37,23 @@ public interface AuthorizationCodeRepository { public AuthorizationCodeEntity save(AuthorizationCodeEntity authorizationCode); /** - * Consume an authorization code. + * Get an authorization code from the repository by value. * * @param code the authorization code value * @return the authentication associated with the code - * @throws InvalidGrantException if no AuthorizationCodeEntity is found with the given value */ - public OAuth2Authentication consume(String code) throws InvalidGrantException; + public AuthorizationCodeEntity getByCode(String code); + /** + * Remove an authorization code from the repository + * + * @param authorizationCodeEntity + */ + public void remove(AuthorizationCodeEntity authorizationCodeEntity); + + /** + * @return A collection of all expired codes. + */ + public Collection getExpiredCodes(); + } diff --git a/openid-connect-server-webapp/src/main/resources/db/tables/hsql_database_tables.sql b/openid-connect-server-webapp/src/main/resources/db/tables/hsql_database_tables.sql index ea932fec3..0cbc68867 100644 --- a/openid-connect-server-webapp/src/main/resources/db/tables/hsql_database_tables.sql +++ b/openid-connect-server-webapp/src/main/resources/db/tables/hsql_database_tables.sql @@ -52,7 +52,8 @@ CREATE TABLE IF NOT EXISTS client_authority ( CREATE TABLE IF NOT EXISTS authorization_code ( id BIGINT GENERATED BY DEFAULT AS IDENTITY(START WITH 1) PRIMARY KEY, code VARCHAR(256), - authentication LONGVARBINARY + auth_holder_id BIGINT, + expiration TIMESTAMP ); CREATE TABLE IF NOT EXISTS client_grant_type ( diff --git a/openid-connect-server-webapp/src/main/resources/db/tables/mysql_database_tables.sql b/openid-connect-server-webapp/src/main/resources/db/tables/mysql_database_tables.sql index 4adddd772..46948049f 100644 --- a/openid-connect-server-webapp/src/main/resources/db/tables/mysql_database_tables.sql +++ b/openid-connect-server-webapp/src/main/resources/db/tables/mysql_database_tables.sql @@ -52,7 +52,8 @@ CREATE TABLE IF NOT EXISTS client_authority ( CREATE TABLE IF NOT EXISTS authorization_code ( id BIGINT AUTO_INCREMENT PRIMARY KEY, code VARCHAR(256), - authentication LONGBLOB + auth_holder_id BIGINT, + expiration TIMESTAMP NULL ); CREATE TABLE IF NOT EXISTS client_grant_type ( diff --git a/openid-connect-server-webapp/src/main/webapp/WEB-INF/task-config.xml b/openid-connect-server-webapp/src/main/webapp/WEB-INF/task-config.xml index ea500d290..031b2133d 100644 --- a/openid-connect-server-webapp/src/main/webapp/WEB-INF/task-config.xml +++ b/openid-connect-server-webapp/src/main/webapp/WEB-INF/task-config.xml @@ -30,6 +30,7 @@ + diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaAuthorizationCodeRepository.java b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaAuthorizationCodeRepository.java index 06955ceac..b7dd22e9c 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaAuthorizationCodeRepository.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/repository/impl/JpaAuthorizationCodeRepository.java @@ -19,6 +19,9 @@ */ package org.mitre.oauth2.repository.impl; +import java.util.Collection; +import java.util.Date; + import javax.persistence.EntityManager; import javax.persistence.PersistenceContext; import javax.persistence.TypedQuery; @@ -26,8 +29,6 @@ import javax.persistence.TypedQuery; import org.mitre.oauth2.model.AuthorizationCodeEntity; import org.mitre.oauth2.repository.AuthorizationCodeRepository; import org.mitre.util.jpa.JpaUtil; -import org.springframework.security.oauth2.common.exceptions.InvalidGrantException; -import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.stereotype.Repository; import org.springframework.transaction.annotation.Transactional; @@ -56,27 +57,39 @@ public class JpaAuthorizationCodeRepository implements AuthorizationCodeReposito } /* (non-Javadoc) - * @see org.mitre.oauth2.repository.AuthorizationCodeRepository#consume(java.lang.String) + * @see org.mitre.oauth2.repository.AuthorizationCodeRepository#getByCode(java.lang.String) */ @Override @Transactional - public OAuth2Authentication consume(String code) throws InvalidGrantException { - + public AuthorizationCodeEntity getByCode(String code) { TypedQuery query = manager.createNamedQuery(AuthorizationCodeEntity.QUERY_BY_VALUE, AuthorizationCodeEntity.class); query.setParameter("code", code); AuthorizationCodeEntity result = JpaUtil.getSingleResult(query.getResultList()); - - if (result == null) { - throw new InvalidGrantException("JpaAuthorizationCodeRepository: no authorization code found for value " + code); - } - - OAuth2Authentication authRequest = result.getAuthentication(); - - manager.remove(result); - - return authRequest; - + return result; } + /* (non-Javadoc) + * @see org.mitre.oauth2.repository.AuthorizationCodeRepository#remove(org.mitre.oauth2.model.AuthorizationCodeEntity) + */ + @Override + public void remove(AuthorizationCodeEntity authorizationCodeEntity) { + AuthorizationCodeEntity found = manager.find(AuthorizationCodeEntity.class, authorizationCodeEntity.getId()); + if (found != null) { + manager.remove(found); + } + } + + /* (non-Javadoc) + * @see org.mitre.oauth2.repository.AuthorizationCodeRepository#getExpiredCodes() + */ + @Override + public Collection getExpiredCodes() { + TypedQuery query = manager.createNamedQuery(AuthorizationCodeEntity.QUERY_EXPIRATION_BY_DATE, AuthorizationCodeEntity.class); + query.setParameter(AuthorizationCodeEntity.PARAM_DATE, new Date()); // this gets anything that's already expired + return query.getResultList(); + } + + + } diff --git a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2AuthorizationCodeService.java b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2AuthorizationCodeService.java index ee60baffb..14a9f711e 100644 --- a/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2AuthorizationCodeService.java +++ b/openid-connect-server/src/main/java/org/mitre/oauth2/service/impl/DefaultOAuth2AuthorizationCodeService.java @@ -19,7 +19,15 @@ */ package org.mitre.oauth2.service.impl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Date; + +import org.mitre.oauth2.model.AuthenticationHolderEntity; import org.mitre.oauth2.model.AuthorizationCodeEntity; +import org.mitre.oauth2.repository.AuthenticationHolderRepository; import org.mitre.oauth2.repository.AuthorizationCodeRepository; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.oauth2.common.exceptions.InvalidGrantException; @@ -27,6 +35,7 @@ import org.springframework.security.oauth2.common.util.RandomValueStringGenerato import org.springframework.security.oauth2.provider.OAuth2Authentication; import org.springframework.security.oauth2.provider.code.AuthorizationCodeServices; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; /** * Database-backed, random-value authorization code service implementation. @@ -34,11 +43,18 @@ import org.springframework.stereotype.Service; * @author aanganes * */ -@Service +@Service("defaultOAuth2AuthorizationCodeService") public class DefaultOAuth2AuthorizationCodeService implements AuthorizationCodeServices { + // Logger for this class + private static final Logger logger = LoggerFactory.getLogger(DefaultOAuth2AuthorizationCodeService.class); @Autowired private AuthorizationCodeRepository repository; + + @Autowired + private AuthenticationHolderRepository authenticationHolderRepository; + + private int authCodeExpirationSeconds = 60 * 5; // expire in 5 minutes by default private RandomValueStringGenerator generator = new RandomValueStringGenerator(); @@ -54,7 +70,15 @@ public class DefaultOAuth2AuthorizationCodeService implements AuthorizationCodeS public String createAuthorizationCode(OAuth2Authentication authentication) { String code = generator.generate(); - AuthorizationCodeEntity entity = new AuthorizationCodeEntity(code, authentication); + // attach the authorization so that we can look it up later + AuthenticationHolderEntity authHolder = new AuthenticationHolderEntity(); + authHolder.setAuthentication(authentication); + authHolder = authenticationHolderRepository.save(authHolder); + + // set the auth code to expire + Date expiration = new Date(System.currentTimeMillis() + (getAuthCodeExpirationSeconds() * 1000L)); + + AuthorizationCodeEntity entity = new AuthorizationCodeEntity(code, authHolder, expiration); repository.save(entity); return code; @@ -73,9 +97,34 @@ public class DefaultOAuth2AuthorizationCodeService implements AuthorizationCodeS @Override public OAuth2Authentication consumeAuthorizationCode(String code) throws InvalidGrantException { - OAuth2Authentication auth = repository.consume(code); + AuthorizationCodeEntity result = repository.getByCode(code); + + if (result == null) { + throw new InvalidGrantException("JpaAuthorizationCodeRepository: no authorization code found for value " + code); + } + + OAuth2Authentication auth = result.getAuthenticationHolder().getAuthentication(); + + repository.remove(result); + return auth; } + + /** + * Find and remove all expired auth codes. + */ + @Transactional + public void clearExpiredAuthorizationCodes() { + + Collection codes = repository.getExpiredCodes(); + + for (AuthorizationCodeEntity code : codes) { + repository.remove(code); + } + + logger.info("Removed " + codes.size() + " expired authorization codes."); + + } /** * @return the repository @@ -91,4 +140,18 @@ public class DefaultOAuth2AuthorizationCodeService implements AuthorizationCodeS this.repository = repository; } + /** + * @return the authCodeExpirationSeconds + */ + public int getAuthCodeExpirationSeconds() { + return authCodeExpirationSeconds; + } + + /** + * @param authCodeExpirationSeconds the authCodeExpirationSeconds to set + */ + public void setAuthCodeExpirationSeconds(int authCodeExpirationSeconds) { + this.authCodeExpirationSeconds = authCodeExpirationSeconds; + } + }