pulled checks for expired tokens into utility functions

pull/990/head
Justin Richer 2015-12-18 11:22:50 -05:00
parent 105d5d9e3d
commit aa878cc3cf
1 changed files with 48 additions and 19 deletions

View File

@ -46,6 +46,7 @@ import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.InvalidScopeException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.ClientAlreadyExistsException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.TokenRequest;
@ -84,6 +85,10 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
@Autowired
private SystemScopeService scopeService;
@Autowired
private ApprovedSiteService approvedSiteService;
@Override
public Set<OAuth2AccessTokenEntity> getAllAccessTokensForUser(String id) {
@ -91,7 +96,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
Set<OAuth2AccessTokenEntity> results = Sets.newLinkedHashSet();
for (OAuth2AccessTokenEntity token : all) {
if (token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
if (clearExpiredAccessToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
results.add(token);
}
}
@ -106,7 +111,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
Set<OAuth2RefreshTokenEntity> results = Sets.newLinkedHashSet();
for (OAuth2RefreshTokenEntity token : all) {
if (token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
if (clearExpiredRefreshToken(token) != null && token.getAuthenticationHolder().getAuthentication().getName().equals(id)) {
results.add(token);
}
}
@ -116,18 +121,50 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
@Override
public OAuth2AccessTokenEntity getAccessTokenById(Long id) {
return tokenRepository.getAccessTokenById(id);
return clearExpiredAccessToken(tokenRepository.getAccessTokenById(id));
}
@Override
public OAuth2RefreshTokenEntity getRefreshTokenById(Long id) {
return tokenRepository.getRefreshTokenById(id);
return clearExpiredRefreshToken(tokenRepository.getRefreshTokenById(id));
}
@Autowired
private ApprovedSiteService approvedSiteService;
/**
* Utility function to delete an access token that's expired before returning it.
* @param token the token to check
* @return null if the token is null or expired, the input token (unchanged) if it hasn't
*/
private OAuth2AccessTokenEntity clearExpiredAccessToken(OAuth2AccessTokenEntity token) {
if (token == null) {
return null;
} else if (token.isExpired()) {
// immediately revoke expired token
logger.debug("Clearing expired access token: " + token.getValue());
revokeAccessToken(token);
return null;
} else {
return token;
}
}
/**
* Utility function to delete a refresh token that's expired before returning it.
* @param token the token to check
* @return null if the token is null or expired, the input token (unchanged) if it hasn't
*/
private OAuth2RefreshTokenEntity clearExpiredRefreshToken(OAuth2RefreshTokenEntity token) {
if (token == null) {
return null;
} else if (token.isExpired()) {
// immediately revoke expired token
logger.debug("Clearing expired refresh token: " + token.getValue());
revokeRefreshToken(token);
return null;
} else {
return token;
}
}
@Override
public OAuth2AccessTokenEntity createAccessToken(OAuth2Authentication authentication) throws AuthenticationException, InvalidClientException {
if (authentication != null && authentication.getOAuth2Request() != null) {
@ -238,7 +275,7 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
@Override
public OAuth2AccessTokenEntity refreshAccessToken(String refreshTokenValue, TokenRequest authRequest) throws AuthenticationException {
OAuth2RefreshTokenEntity refreshToken = tokenRepository.getRefreshTokenByValue(refreshTokenValue);
OAuth2RefreshTokenEntity refreshToken = clearExpiredRefreshToken(tokenRepository.getRefreshTokenByValue(refreshTokenValue));
if (refreshToken == null) {
throw new InvalidTokenException("Invalid refresh token: " + refreshTokenValue);
@ -331,14 +368,10 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
@Override
public OAuth2Authentication loadAuthentication(String accessTokenValue) throws AuthenticationException {
OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenByValue(accessTokenValue);
OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));
if (accessToken == null) {
throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
} else if (accessToken.isExpired()) {
//tokenRepository.removeAccessToken(accessToken);
revokeAccessToken(accessToken);
throw new InvalidTokenException("Expired access token: " + accessTokenValue);
} else {
return accessToken.getAuthenticationHolder().getAuthentication();
}
@ -350,13 +383,9 @@ public class DefaultOAuth2ProviderTokenService implements OAuth2TokenEntityServi
*/
@Override
public OAuth2AccessTokenEntity readAccessToken(String accessTokenValue) throws AuthenticationException {
OAuth2AccessTokenEntity accessToken = tokenRepository.getAccessTokenByValue(accessTokenValue);
OAuth2AccessTokenEntity accessToken = clearExpiredAccessToken(tokenRepository.getAccessTokenByValue(accessTokenValue));
if (accessToken == null) {
throw new InvalidTokenException("Access token for value " + accessTokenValue + " was not found");
} else if (accessToken.isExpired()) {
// immediately revoke the expired token
revokeAccessToken(accessToken);
throw new InvalidTokenException("Access token for value " + accessTokenValue + " is expired");
} else {
return accessToken;
}