feat: Add oauth2 refresh token authentication (#1856)

* feat: Add oauth2 refresh token authentication

* feat: Add unit test case
pull/1857/head
guqing 2022-04-19 18:16:14 +08:00 committed by GitHub
parent ad562b4917
commit ed6aea6245
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 909 additions and 122 deletions

View File

@ -28,12 +28,13 @@ import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import run.halo.app.identity.authentication.InMemoryOAuth2AuthorizationService;
import run.halo.app.identity.authentication.JwtDaoAuthenticationProvider;
import run.halo.app.identity.authentication.JwtGenerator;
import run.halo.app.identity.authentication.JwtUsernamePasswordAuthenticationFilter;
import run.halo.app.identity.authentication.OAuth2AuthorizationService;
import run.halo.app.identity.authentication.OAuth2RefreshTokenAuthenticationProvider;
import run.halo.app.identity.authentication.OAuth2TokenEndpointFilter;
import run.halo.app.identity.authentication.ProviderContextFilter;
import run.halo.app.identity.authentication.ProviderSettings;
import run.halo.app.identity.entrypoint.JwtAccessDeniedHandler;
@ -67,13 +68,13 @@ public class WebSecurityConfig {
ProviderContextFilter providerContextFilter = new ProviderContextFilter(providerSettings);
http
.authorizeHttpRequests((authorize) -> authorize
.antMatchers("/api/v1/oauth2/login").permitAll()
.antMatchers("/api/v1/oauth2/token").permitAll()
.antMatchers("/api/**", "/apis/**").authenticated()
)
.csrf(AbstractHttpConfigurer::disable)
.httpBasic(Customizer.withDefaults())
.addFilterBefore(new JwtUsernamePasswordAuthenticationFilter(authenticationManager()),
UsernamePasswordAuthenticationFilter.class)
.addFilterBefore(new OAuth2TokenEndpointFilter(authenticationManager()),
AbstractPreAuthenticatedProcessingFilter.class)
.addFilterAfter(providerContextFilter, SecurityContextPersistenceFilter.class)
.sessionManagement(
(session) -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
@ -86,7 +87,7 @@ public class WebSecurityConfig {
@Bean
AuthenticationManager authenticationManager() throws Exception {
authenticationManagerBuilder.authenticationProvider(jwtDaoAuthenticationProvider());
authenticationManagerBuilder.authenticationProvider(refreshTokenAuthenticationProvider());
return authenticationManagerBuilder.getOrBuild();
}
@ -108,13 +109,14 @@ public class WebSecurityConfig {
}
@Bean
JwtDaoAuthenticationProvider jwtDaoAuthenticationProvider() {
JwtDaoAuthenticationProvider authenticationProvider =
new JwtDaoAuthenticationProvider(jwtGenerator(),
new InMemoryOAuth2AuthorizationService());
authenticationProvider.setUserDetailsService(userDetailsService());
authenticationProvider.setPasswordEncoder(passwordEncoder());
return authenticationProvider;
OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider() {
return new OAuth2RefreshTokenAuthenticationProvider(oauth2AuthorizationService(),
jwtGenerator());
}
@Bean
OAuth2AuthorizationService oauth2AuthorizationService() {
return new InMemoryOAuth2AuthorizationService();
}
@Bean

View File

@ -0,0 +1,49 @@
package run.halo.app.identity.authentication;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.Assert;
/**
* An {@link AuthenticationConverter} that simply delegates to it's
* internal {@code List} of {@link AuthenticationConverter}(s).
* <p>
* Each {@link AuthenticationConverter} is given a chance to
* {@link AuthenticationConverter#convert(HttpServletRequest)}
* with the first {@code non-null} {@link Authentication} being returned.
*
* @author guqing
* @see AuthenticationConverter
* @since 2.0.0
*/
public class DelegatingAuthenticationConverter implements AuthenticationConverter {
private final List<AuthenticationConverter> converters;
/**
* Constructs a {@code DelegatingAuthenticationConverter} using the provided parameters.
*
* @param converters a {@code List} of {@link AuthenticationConverter}(s)
*/
public DelegatingAuthenticationConverter(List<AuthenticationConverter> converters) {
Assert.notEmpty(converters, "converters cannot be empty");
this.converters = Collections.unmodifiableList(new LinkedList<>(converters));
}
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
Assert.notNull(request, "request cannot be null");
for (AuthenticationConverter converter : this.converters) {
Authentication authentication = converter.convert(request);
if (authentication != null) {
return authentication;
}
}
return null;
}
}

View File

@ -1,108 +0,0 @@
package run.halo.app.identity.authentication;
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
/**
* @author guqing
* @since 2.0.0
*/
public class JwtDaoAuthenticationProvider extends DaoAuthenticationProvider {
private static final String ERROR_URI =
"https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private final OAuth2AuthorizationService authorizationService;
public JwtDaoAuthenticationProvider(
OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator,
OAuth2AuthorizationService authorizationService) {
this.tokenGenerator = tokenGenerator;
this.authorizationService = authorizationService;
}
@Override
protected Authentication createSuccessAuthentication(Object principal,
Authentication authentication, UserDetails user) {
UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken =
(UsernamePasswordAuthenticationToken) super.createSuccessAuthentication(principal,
authentication, user);
Set<String> scopes = usernamePasswordAuthenticationToken.getAuthorities().stream()
.map(GrantedAuthority::getAuthority).collect(Collectors.toSet());
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.principal(authentication)
.providerContext(ProviderContextHolder.getProviderContext())
.authorizedScopes(scopes);
OAuth2Authorization.Builder authorizationBuilder = new OAuth2Authorization.Builder()
.principalName(authentication.getName())
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, scopes);
// ----- Access token -----
OAuth2TokenContext tokenContext =
tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build();
OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
if (generatedAccessToken == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the access token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
if (generatedAccessToken instanceof ClaimAccessor) {
authorizationBuilder.token(accessToken, (metadata) -> {
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME,
((ClaimAccessor) generatedAccessToken).getClaims());
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false);
});
} else {
authorizationBuilder.accessToken(accessToken);
}
ProviderSettings providerSettings =
ProviderContextHolder.getProviderContext().providerSettings();
// ----- Refresh token -----
OAuth2RefreshToken currentRefreshToken = null;
if (!providerSettings.isReuseRefreshTokens()) {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (generatedRefreshToken == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
currentRefreshToken = new OAuth2RefreshToken(
generatedRefreshToken.getTokenValue(), generatedRefreshToken.getIssuedAt(),
generatedRefreshToken.getExpiresAt());
authorizationBuilder.refreshToken(currentRefreshToken);
}
this.authorizationService.save(authorizationBuilder.build());
return new OAuth2AccessTokenAuthenticationToken(authentication, accessToken,
currentRefreshToken, Collections.emptyMap());
}
@Override
public boolean supports(Class<?> authentication) {
return UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication);
}
}

View File

@ -0,0 +1,71 @@
package run.halo.app.identity.authentication;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
/**
* Base implementation of an {@link Authentication} representing an OAuth 2.0 Authorization Grant.
*
* @author guqing
* @see AbstractAuthenticationToken
* @see AuthorizationGrantType
* @see
* <a href="https://tools.ietf.org/html/rfc6749#section-1.3">Section 1.3 Authorization Grant</a>
* @since 2.0.0
*/
public class OAuth2AuthorizationGrantAuthenticationToken extends AbstractAuthenticationToken {
private final AuthorizationGrantType authorizationGrantType;
private final Map<String, Object> additionalParameters;
/**
* Sub-class constructor.
*
* @param authorizationGrantType the authorization grant type
* @param additionalParameters the additional parameters
*/
protected OAuth2AuthorizationGrantAuthenticationToken(
AuthorizationGrantType authorizationGrantType,
@Nullable Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
this.authorizationGrantType = authorizationGrantType;
this.additionalParameters = Collections.unmodifiableMap(
additionalParameters != null
? new HashMap<>(additionalParameters) :
Collections.emptyMap());
}
/**
* Returns the authorization grant type.
*
* @return the authorization grant type
*/
public AuthorizationGrantType getGrantType() {
return this.authorizationGrantType;
}
@Override
public Object getCredentials() {
return "";
}
@Override
public Object getPrincipal() {
return "";
}
/**
* Returns the additional parameters.
*
* @return the additional parameters
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
}

View File

@ -0,0 +1,36 @@
package run.halo.app.identity.authentication;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Map;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* @author guqing
* @since 2.0.0
*/
public class OAuth2EndpointUtils {
static final String ERROR_URI =
"https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
Map<String, String[]> parameterMap = request.getParameterMap();
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
parameterMap.forEach((key, values) -> {
if (values.length > 0) {
for (String value : values) {
parameters.add(key, value);
}
}
});
return parameters;
}
static void throwError(String errorCode, String parameterName, String errorUri) {
OAuth2Error
error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
throw new OAuth2AuthenticationException(error);
}
}

View File

@ -0,0 +1,81 @@
package run.halo.app.identity.authentication;
import jakarta.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
/**
* Attempts to extract an Access Token Request from {@link HttpServletRequest} for the OAuth 2.0
* Refresh Token Grant
* and then converts it to an {@link OAuth2RefreshTokenAuthenticationToken} used for
* authenticating the authorization grant.
*
* @author guqing
* @see AuthenticationConverter
* @see OAuth2RefreshTokenAuthenticationToken
* @see OAuth2TokenEndpointFilter
* @since 2.0.0
*/
public class OAuth2RefreshTokenAuthenticationConverter implements AuthenticationConverter {
static final String ACCESS_TOKEN_REQUEST_ERROR_URI =
"https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
return null;
}
MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
// refresh_token (REQUIRED)
String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN);
if (!StringUtils.hasText(refreshToken)
|| parameters.get(OAuth2ParameterNames.REFRESH_TOKEN).size() != 1) {
OAuth2EndpointUtils.throwError(OAuth2ErrorCodes.INVALID_REQUEST,
OAuth2ParameterNames.REFRESH_TOKEN,
ACCESS_TOKEN_REQUEST_ERROR_URI);
}
// scope (OPTIONAL)
String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
if (StringUtils.hasText(scope)
&& parameters.get(OAuth2ParameterNames.SCOPE).size() != 1) {
OAuth2EndpointUtils.throwError(
OAuth2ErrorCodes.INVALID_REQUEST,
OAuth2ParameterNames.SCOPE,
ACCESS_TOKEN_REQUEST_ERROR_URI);
}
Set<String> requestedScopes = null;
if (StringUtils.hasText(scope)) {
requestedScopes = new HashSet<>(
Arrays.asList(StringUtils.delimitedListToStringArray(scope, " ")));
}
Map<String, Object> additionalParameters = new HashMap<>();
parameters.forEach((key, value) -> {
if (!key.equals(OAuth2ParameterNames.GRANT_TYPE)
&& !key.equals(OAuth2ParameterNames.REFRESH_TOKEN)
&& !key.equals(OAuth2ParameterNames.SCOPE)) {
additionalParameters.put(key, value.get(0));
}
});
return new OAuth2RefreshTokenAuthenticationToken(refreshToken, requestedScopes,
additionalParameters);
}
}

View File

@ -0,0 +1,148 @@
package run.halo.app.identity.authentication;
import java.security.Principal;
import java.util.Collections;
import java.util.Set;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.util.Assert;
/**
* An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant.
*
* @author guqing
* @see OAuth2RefreshTokenAuthenticationToken
* @see OAuth2AccessTokenAuthenticationToken
* @see OAuth2AuthorizationService
* @see OAuth2TokenGenerator
* @see
* <a href="https://datatracker.ietf.org/doc/html/rfc6749#section-1.5">Section 1.5 Refresh Token Grant</a>
* @see
* <a href="https://datatracker.ietf.org/doc/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
* @since 2.0.0
*/
public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
private final OAuth2AuthorizationService authorizationService;
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
/**
* Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters.
*
* @param authorizationService the authorization service
* @param tokenGenerator the token generator
* @since 0.2.3
*/
public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService,
OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
Assert.notNull(authorizationService, "authorizationService cannot be null");
Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
this.authorizationService = authorizationService;
this.tokenGenerator = tokenGenerator;
}
@Override
public Authentication authenticate(Authentication authentication) throws
AuthenticationException {
OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication =
(OAuth2RefreshTokenAuthenticationToken) authentication;
OAuth2Authorization authorization = this.authorizationService.findByToken(
refreshTokenAuthentication.getRefreshToken(), OAuth2TokenType.REFRESH_TOKEN);
if (authorization == null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
}
OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
authorization.getRefreshToken();
if (refreshToken == null || !refreshToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code,
// resource owner credentials) or refresh token is invalid, expired, revoked [...].
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
}
// As per https://tools.ietf.org/html/rfc6749#section-6
// The requested scope MUST NOT include any scope not originally granted by the resource
// owner,
// and if omitted is treated as equal to the scope originally granted by the resource owner.
Set<String> scopes = refreshTokenAuthentication.getScopes();
Set<String> authorizedScopes =
authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME);
if (!authorizedScopes.containsAll(scopes)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_SCOPE);
}
if (scopes.isEmpty()) {
scopes = authorizedScopes;
}
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.principal(authorization.getAttribute(Principal.class.getName()))
.providerContext(ProviderContextHolder.getProviderContext())
.authorizedScopes(scopes);
OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization);
// ----- Access token -----
OAuth2TokenContext tokenContext =
tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build();
OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
if (generatedAccessToken == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the access token.",
OAuth2EndpointUtils.ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
if (generatedAccessToken instanceof ClaimAccessor) {
authorizationBuilder.token(accessToken, (metadata) -> {
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME,
((ClaimAccessor) generatedAccessToken).getClaims());
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false);
});
} else {
authorizationBuilder.accessToken(accessToken);
}
ProviderSettings providerSettings =
ProviderContextHolder.getProviderContext().providerSettings();
// ----- Refresh token -----
OAuth2RefreshToken currentRefreshToken = null;
if (!providerSettings.isReuseRefreshTokens()) {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (generatedRefreshToken == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.",
OAuth2EndpointUtils.ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
currentRefreshToken = new OAuth2RefreshToken(
generatedRefreshToken.getTokenValue(), generatedRefreshToken.getIssuedAt(),
generatedRefreshToken.getExpiresAt());
authorizationBuilder.refreshToken(currentRefreshToken);
}
authorization = authorizationBuilder.build();
this.authorizationService.save(authorization);
return new OAuth2AccessTokenAuthenticationToken(refreshTokenAuthentication, accessToken,
currentRefreshToken, Collections.emptyMap());
}
@Override
public boolean supports(Class<?> authentication) {
return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
}
}

View File

@ -0,0 +1,53 @@
package run.halo.app.identity.authentication;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
/**
* @author guqing
* @since 2.0.0
*/
public class OAuth2RefreshTokenAuthenticationToken
extends OAuth2AuthorizationGrantAuthenticationToken {
private final String refreshToken;
private final Set<String> scopes;
/**
* Constructs an {@code OAuth2RefreshTokenAuthenticationToken} using the provided parameters.
*
* @param refreshToken the refresh token
* @param scopes the requested scope(s)
* @param additionalParameters the additional parameters
*/
public OAuth2RefreshTokenAuthenticationToken(String refreshToken,
@Nullable Set<String> scopes, @Nullable Map<String, Object> additionalParameters) {
super(AuthorizationGrantType.REFRESH_TOKEN, additionalParameters);
Assert.hasText(refreshToken, "refreshToken cannot be empty");
this.refreshToken = refreshToken;
this.scopes = Collections.unmodifiableSet(
scopes != null ? new HashSet<>(scopes) : Collections.emptySet());
}
/**
* Returns the refresh token.
*
* @return the refresh token
*/
public String getRefreshToken() {
return this.refreshToken;
}
/**
* Returns the requested scope(s).
*
* @return the requested scope(s), or an empty {@code Set} if not available
*/
public Set<String> getScopes() {
return this.scopes;
}
}

View File

@ -0,0 +1,229 @@
package run.halo.app.identity.authentication;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationDetailsSource;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.filter.OncePerRequestFilter;
/**
* @author guqing
* @since 2.0.0
*/
public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
/**
* The default endpoint {@code URI} for access token requests.
*/
private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/api/v1/oauth2/token";
private static final String DEFAULT_ERROR_URI =
"https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private final AuthenticationManager authenticationManager;
private final RequestMatcher tokenEndpointMatcher;
private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
new OAuth2AccessTokenResponseHttpMessageConverter();
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
new OAuth2ErrorHttpMessageConverter();
private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
new WebAuthenticationDetailsSource();
private AuthenticationConverter authenticationConverter;
private AuthenticationSuccessHandler authenticationSuccessHandler =
this::sendAccessTokenResponse;
private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
/**
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
*
* @param authenticationManager the authentication manager
*/
public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager) {
this(authenticationManager, DEFAULT_TOKEN_ENDPOINT_URI);
}
/**
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
*
* @param authenticationManager the authentication manager
* @param tokenEndpointUri the endpoint {@code URI} for access token requests
*/
public OAuth2TokenEndpointFilter(AuthenticationManager authenticationManager,
String tokenEndpointUri) {
Assert.notNull(authenticationManager, "authenticationManager cannot be null");
Assert.hasText(tokenEndpointUri, "tokenEndpointUri cannot be empty");
this.authenticationManager = authenticationManager;
this.tokenEndpointMatcher =
new AntPathRequestMatcher(tokenEndpointUri, HttpMethod.POST.name());
this.authenticationConverter = new DelegatingAuthenticationConverter(
List.of(new OAuth2RefreshTokenAuthenticationConverter())
);
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain)
throws ServletException, IOException {
if (!this.tokenEndpointMatcher.matches(request)) {
filterChain.doFilter(request, response);
return;
}
try {
String[] grantTypes = request.getParameterValues(OAuth2ParameterNames.GRANT_TYPE);
if (grantTypes == null || grantTypes.length != 1) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE);
}
Authentication authorizationGrantAuthentication =
this.authenticationConverter.convert(request);
if (authorizationGrantAuthentication == null) {
throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE,
OAuth2ParameterNames.GRANT_TYPE);
}
if (authorizationGrantAuthentication instanceof AbstractAuthenticationToken) {
((AbstractAuthenticationToken) authorizationGrantAuthentication)
.setDetails(this.authenticationDetailsSource.buildDetails(request));
}
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(
authorizationGrantAuthentication);
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response,
accessTokenAuthentication);
} catch (OAuth2AuthenticationException ex) {
SecurityContextHolder.clearContext();
this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
}
}
/**
* Sets the {@link AuthenticationDetailsSource} used for building an authentication details
* instance from {@link HttpServletRequest}.
*
* @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for
* building an authentication details instance from {@link HttpServletRequest}
*/
public void setAuthenticationDetailsSource(
AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null");
this.authenticationDetailsSource = authenticationDetailsSource;
}
/**
* Sets the {@link AuthenticationConverter} used when attempting to extract an Access Token
* Request from {@link HttpServletRequest}
* to an instance of {@link OAuth2AuthorizationGrantAuthenticationToken} used for
* authenticating the authorization grant.
*
* @param authenticationConverter the {@link AuthenticationConverter} used when attempting to
* extract an Access Token Request from {@link HttpServletRequest}
*/
public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) {
Assert.notNull(authenticationConverter, "authenticationConverter cannot be null");
this.authenticationConverter = authenticationConverter;
}
/**
* Sets the {@link AuthenticationSuccessHandler} used for handling an
* {@link OAuth2AccessTokenAuthenticationToken}
* and returning the {@link OAuth2AccessTokenResponse Access Token Response}.
*
* @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for
* handling an {@link OAuth2AccessTokenAuthenticationToken}
*/
public void setAuthenticationSuccessHandler(
AuthenticationSuccessHandler authenticationSuccessHandler) {
Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
this.authenticationSuccessHandler = authenticationSuccessHandler;
}
/**
* Sets the {@link AuthenticationFailureHandler} used for handling an
* {@link OAuth2AuthenticationException}
* and returning the {@link OAuth2Error Error Response}.
*
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for
* handling an {@link OAuth2AuthenticationException}
*/
public void setAuthenticationFailureHandler(
AuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
}
private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws IOException {
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) authentication;
OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
Map<String, Object> additionalParameters =
accessTokenAuthentication.getAdditionalParameters();
OAuth2AccessTokenResponse.Builder builder =
OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
.tokenType(accessToken.getTokenType())
.scopes(accessToken.getScopes());
if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
builder.expiresIn(
ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
}
if (refreshToken != null) {
builder.refreshToken(refreshToken.getTokenValue());
}
if (!CollectionUtils.isEmpty(additionalParameters)) {
builder.additionalParameters(additionalParameters);
}
OAuth2AccessTokenResponse accessTokenResponse = builder.build();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
}
private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException {
OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
this.errorHttpResponseConverter.write(error, null, httpResponse);
}
private static void throwError(String errorCode, String parameterName) {
OAuth2Error error =
new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
}

View File

@ -0,0 +1,166 @@
package run.halo.app.authentication;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.security.Principal;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import run.halo.app.identity.authentication.JwtGenerator;
import run.halo.app.identity.authentication.OAuth2AccessTokenAuthenticationToken;
import run.halo.app.identity.authentication.OAuth2Authorization;
import run.halo.app.identity.authentication.OAuth2AuthorizationService;
import run.halo.app.identity.authentication.OAuth2RefreshTokenAuthenticationProvider;
import run.halo.app.identity.authentication.OAuth2RefreshTokenAuthenticationToken;
import run.halo.app.identity.authentication.OAuth2TokenContext;
import run.halo.app.identity.authentication.OAuth2TokenGenerator;
import run.halo.app.identity.authentication.OAuth2TokenType;
import run.halo.app.identity.authentication.ProviderContext;
import run.halo.app.identity.authentication.ProviderContextHolder;
import run.halo.app.identity.authentication.ProviderSettings;
/**
* Tests for {@link OAuth2RefreshTokenAuthenticationProvider}.
*
* @author guqing
* @since 2.0.0
*/
public class OAuth2RefreshTokenAuthenticationProviderTest {
private OAuth2AuthorizationService authorizationService;
private OAuth2TokenGenerator<?> tokenGenerator;
private OAuth2RefreshTokenAuthenticationProvider authenticationProvider;
@BeforeEach
public void setUp() {
this.authorizationService = mock(OAuth2AuthorizationService.class);
JwtEncoder jwtEncoder = mock(JwtEncoder.class);
when(jwtEncoder.encode(any())).thenReturn(createJwt(Collections.singleton("scope1")));
JwtGenerator jwtGenerator = new JwtGenerator(jwtEncoder);
this.tokenGenerator = spy(new OAuth2TokenGenerator<OAuth2Token>() {
@Override
public OAuth2Token generate(OAuth2TokenContext context) {
return jwtGenerator.generate(context);
}
});
this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider(
this.authorizationService, this.tokenGenerator);
ProviderSettings
providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null));
}
@AfterEach
public void cleanup() {
ProviderContextHolder.resetProviderContext();
}
private static Jwt createJwt(Set<String> scope) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
return Jwt.withTokenValue("refreshed-access-token")
.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
.issuedAt(issuedAt)
.expiresAt(expiresAt)
.claim(OAuth2ParameterNames.SCOPE, scope)
.build();
}
@Test
public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(null, tokenGenerator))
.isInstanceOf(IllegalArgumentException.class)
.extracting(Throwable::getMessage)
.isEqualTo("authorizationService cannot be null");
}
@Test
public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() {
assertThatThrownBy(
() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService,
null))
.isInstanceOf(IllegalArgumentException.class)
.extracting(Throwable::getMessage)
.isEqualTo("tokenGenerator cannot be null");
}
@Test
public void supportsWhenSupportedAuthenticationThenTrue() {
assertThat(this.authenticationProvider.supports(
OAuth2RefreshTokenAuthenticationToken.class)).isTrue();
}
@Test
public void supportsWhenUnsupportedAuthenticationThenFalse() {
assertThat(this.authenticationProvider.supports(
UsernamePasswordAuthenticationToken.class)).isFalse();
}
@Test
public void authenticateWhenValidRefreshTokenThenReturnAccessToken() {
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
when(this.authorizationService.findByToken(
eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(OAuth2TokenType.REFRESH_TOKEN)))
.thenReturn(authorization);
String tokenValue = authorization.getRefreshToken().getToken().getTokenValue();
OAuth2RefreshTokenAuthenticationToken authentication =
new OAuth2RefreshTokenAuthenticationToken(
tokenValue, null, null);
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(
authentication);
ArgumentCaptor<OAuth2TokenContext> tokenContextArgumentCaptor =
ArgumentCaptor.forClass(OAuth2TokenContext.class);
verify(this.tokenGenerator, times(2)).generate(tokenContextArgumentCaptor.capture());
List<OAuth2TokenContext> allValues = tokenContextArgumentCaptor.getAllValues();
assertThat(allValues).isNotNull();
// refresh token generate capture
assertThat(allValues.size()).isEqualTo(2);
OAuth2TokenContext tokenContext = allValues.get(1);
assertThat(tokenContext.<Authentication>getPrincipal()).isEqualTo(
authorization.getAttribute(
Principal.class.getName()));
assertThat(tokenContext.getAuthorizedScopes())
.isEqualTo(
authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
assertThat(tokenContext.getTokenType()).isEqualTo(OAuth2TokenType.REFRESH_TOKEN);
ArgumentCaptor<OAuth2Authorization> authorizationCaptor =
ArgumentCaptor.forClass(OAuth2Authorization.class);
verify(this.authorizationService).save(authorizationCaptor.capture());
OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(
updatedAuthorization.getAccessToken().getToken());
assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(
authorization.getAccessToken());
assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(
updatedAuthorization.getRefreshToken().getToken());
// By default, refresh token is opened
assertThat(updatedAuthorization.getRefreshToken()).isNotEqualTo(
authorization.getRefreshToken());
}
}

View File

@ -0,0 +1,60 @@
package run.halo.app.authentication;
import java.security.Principal;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.util.CollectionUtils;
import run.halo.app.identity.authentication.OAuth2Authorization;
/**
* @author guqing
* @since 2.0.0
*/
public class TestOAuth2Authorizations {
public static OAuth2Authorization.Builder authorization() {
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(),
Instant.now().plusSeconds(300));
return authorization(accessToken, Collections.emptyMap());
}
private static OAuth2Authorization.Builder authorization(
OAuth2AccessToken accessToken, Map<String, Object> accessTokenClaims) {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
"refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS));
return new OAuth2Authorization.Builder()
.id("id")
.principalName("principal")
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.token(accessToken, (metadata) -> metadata.putAll(tokenMetadata(accessTokenClaims)))
.refreshToken(refreshToken)
.attribute(Principal.class.getName(),
new TestingAuthenticationToken("principal", "123456", "ROLE_A", "ROLE_B"))
.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, Collections.emptySet());
}
private static Map<String, Object> tokenMetadata(Map<String, Object> tokenClaims) {
Map<String, Object> tokenMetadata = new HashMap<>();
tokenMetadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false);
if (CollectionUtils.isEmpty(tokenClaims)) {
tokenClaims = defaultTokenClaims();
}
tokenMetadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, tokenClaims);
return tokenMetadata;
}
private static Map<String, Object> defaultTokenClaims() {
Map<String, Object> claims = new HashMap<>();
claims.put("claim1", "value1");
claims.put("claim2", "value2");
claims.put("claim3", "value3");
return claims;
}
}