feat: Add UsernamePasswordAuthenticationProvider and filter to authentication client password request (#1846)

* feat: Add auhentication provider and password authentication filter

* refactor: JwtUsernamePasswordAuthenticationFilter and add test case

* chore: delete unused class

* fix: code style

* refactor: web secutity config
pull/1852/head
guqing 2022-04-18 11:40:11 +08:00 committed by GitHub
parent 20e6d4d1eb
commit 66be1d1ba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1143 additions and 217 deletions

View File

@ -11,14 +11,15 @@ import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.config.Customizer;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.jwt.JwtDecoder;
@ -27,6 +28,11 @@ import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.security.provisioning.InMemoryUserDetailsManager;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import run.halo.app.identity.authentication.JwtDaoAuthenticationProvider;
import run.halo.app.identity.authentication.JwtGenerator;
import run.halo.app.identity.authentication.JwtUsernamePasswordAuthenticationFilter;
import run.halo.app.identity.authentication.OAuth2AuthorizationService;
import run.halo.app.identity.entrypoint.JwtAccessDeniedHandler;
import run.halo.app.identity.entrypoint.JwtAuthenticationEntryPoint;
import run.halo.app.infra.properties.JwtProperties;
@ -43,19 +49,26 @@ public class WebSecurityConfig {
private final RSAPrivateKey priv;
public WebSecurityConfig(JwtProperties jwtProperties) throws IOException {
private final AuthenticationManagerBuilder authenticationManagerBuilder;
public WebSecurityConfig(JwtProperties jwtProperties,
AuthenticationManagerBuilder authenticationManagerBuilder) throws IOException {
this.key = jwtProperties.readPublicKey();
this.priv = jwtProperties.readPrivateKey();
this.authenticationManagerBuilder = authenticationManagerBuilder;
}
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http
.authorizeHttpRequests((authorize) -> authorize
.antMatchers("/api/v1/oauth2/login").permitAll()
.antMatchers("/api/**", "/apis/**").authenticated()
)
.csrf(AbstractHttpConfigurer::disable)
.httpBasic(Customizer.withDefaults())
.addFilterBefore(new JwtUsernamePasswordAuthenticationFilter(authenticationManager()),
UsernamePasswordAuthenticationFilter.class)
.sessionManagement(
(session) -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
.exceptionHandling((exceptions) -> exceptions
@ -66,13 +79,9 @@ public class WebSecurityConfig {
}
@Bean
UserDetailsService users() {
return new InMemoryUserDetailsManager(
User.withUsername("user")
.password("{noop}password")
.authorities("app")
.build()
);
AuthenticationManager authenticationManager() throws Exception {
authenticationManagerBuilder.authenticationProvider(jwtDaoAuthenticationProvider());
return authenticationManagerBuilder.getOrBuild();
}
@Bean
@ -92,10 +101,24 @@ public class WebSecurityConfig {
return new BCryptPasswordEncoder();
}
@Bean
JwtDaoAuthenticationProvider jwtDaoAuthenticationProvider() {
JwtDaoAuthenticationProvider authenticationProvider =
new JwtDaoAuthenticationProvider(jwtGenerator(), new OAuth2AuthorizationService());
authenticationProvider.setUserDetailsService(userDetailsService());
authenticationProvider.setPasswordEncoder(passwordEncoder());
return authenticationProvider;
}
@Bean
JwtGenerator jwtGenerator() {
return new JwtGenerator(jwtEncoder());
}
@Bean
public InMemoryUserDetailsManager userDetailsService() {
UserDetails user = User.withUsername("user")
.password("password")
.password(passwordEncoder().encode("123456"))
.roles("USER")
.build();
return new InMemoryUserDetailsManager(user);

View File

@ -1,35 +0,0 @@
package run.halo.app.identity.authentication;
import com.fasterxml.jackson.annotation.JsonIgnore;
import java.io.Serializable;
import java.util.Collections;
import java.util.Map;
import lombok.Data;
import org.springframework.security.oauth2.jwt.Jwt;
/**
* @author guqing
* @date 2022-04-12
*/
@Data
public class AccessToken implements Serializable {
private String tokenType;
private Jwt accessToken;
private Jwt refreshToken;
private Map<String, Object> additionalInformation;
@JsonIgnore
private long expiration;
public AccessToken(Jwt accessToken) {
this.tokenType = "Bearer".toLowerCase();
this.additionalInformation = Collections.emptyMap();
this.accessToken = accessToken;
}
public long getExpiresIn() {
return this.expiration;
}
}

View File

@ -0,0 +1,120 @@
package run.halo.app.identity.authentication;
import java.util.Collections;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
/**
* @author guqing
* @since 2.0.0
*/
public class JwtDaoAuthenticationProvider extends DaoAuthenticationProvider {
private static final OAuth2TokenType PASSWORD_TOKEN = new OAuth2TokenType("password");
private static final String ERROR_URI =
"https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private final OAuth2AuthorizationService authorizationService;
/**
* TODO from token settings
*/
private static final boolean isReuseRefreshTokens = false;
public JwtDaoAuthenticationProvider(
OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator,
OAuth2AuthorizationService authorizationService) {
this.tokenGenerator = tokenGenerator;
this.authorizationService = authorizationService;
}
@Override
protected Authentication createSuccessAuthentication(Object principal,
Authentication authentication, UserDetails user) {
UsernamePasswordAuthenticationToken usernamePasswordAuthenticationToken =
(UsernamePasswordAuthenticationToken) super.createSuccessAuthentication(principal,
authentication, user);
OAuth2Authorization authorization = this.authorizationService.findByUsername(
usernamePasswordAuthenticationToken.getName(), PASSWORD_TOKEN);
OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
authorization.getRefreshToken();
if (refreshToken == null || !refreshToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code,
// resource owner credentials) or refresh token is invalid, expired, revoked [...].
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
}
Set<String> scopes = usernamePasswordAuthenticationToken.getAuthorities().stream()
.map(GrantedAuthority::getAuthority).collect(Collectors.toSet());
ProviderContext providerContext =
new ProviderContext(ProviderSettings.builder().build(), () -> "/issuer");
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.principal(authentication)
.providerContext(providerContext)
.authorizedScopes(scopes);
OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization);
// ----- Access token -----
OAuth2TokenContext tokenContext =
tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build();
OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
if (generatedAccessToken == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the access token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
if (generatedAccessToken instanceof ClaimAccessor) {
authorizationBuilder.token(accessToken, (metadata) -> {
metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME,
((ClaimAccessor) generatedAccessToken).getClaims());
metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false);
});
} else {
authorizationBuilder.accessToken(accessToken);
}
// ----- Refresh token -----
OAuth2RefreshToken currentRefreshToken = refreshToken.getToken();
if (!isReuseRefreshTokens) {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (generatedRefreshToken == null) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
currentRefreshToken = new OAuth2RefreshToken(
generatedRefreshToken.getTokenValue(), generatedRefreshToken.getIssuedAt(),
generatedRefreshToken.getExpiresAt());
authorizationBuilder.refreshToken(currentRefreshToken);
}
this.authorizationService.save(authorizationBuilder.build());
return new OAuth2AccessTokenAuthenticationToken(authentication, accessToken,
currentRefreshToken, Collections.emptyMap());
}
@Override
public boolean supports(Class<?> authentication) {
return UsernamePasswordAuthenticationToken.class.isAssignableFrom(authentication);
}
}

View File

@ -41,7 +41,8 @@ public record JwtGenerator(JwtEncoder jwtEncoder) implements OAuth2TokenGenerato
@Override
public Jwt generate(OAuth2TokenContext context) {
if (context.getTokenType() == null
|| !OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
|| (!OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())
&& !OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType()))) {
return null;
}
Instant issuedAt = Instant.now();

View File

@ -1,86 +0,0 @@
package run.halo.app.identity.authentication;
import java.io.Serializable;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.stereotype.Component;
import run.halo.app.infra.properties.JwtProperties;
import run.halo.app.infra.utils.HaloUtils;
/**
* Jwt token utils.
*
* @author guqing
* @date 2022-04-12
*/
@Component
@EnableConfigurationProperties(JwtProperties.class)
public class JwtTokenProvider implements Serializable {
private final JwtEncoder jwtEncoder;
private final JwtDecoder jwtDecoder;
private final JwtProperties jwtProperties;
public JwtTokenProvider(JwtEncoder jwtEncoder, JwtDecoder jwtDecoder,
JwtProperties jwtProperties) {
this.jwtEncoder = jwtEncoder;
this.jwtDecoder = jwtDecoder;
this.jwtProperties = jwtProperties;
}
private JwtClaimsSet createJwt(Authentication authentication, Instant expireAt) {
String scope = authentication.getAuthorities().stream()
.map(GrantedAuthority::getAuthority)
.collect(Collectors.joining(" "));
return JwtClaimsSet.builder()
// JWT ID (jti)
.id(HaloUtils.simpleUUID())
// 签发者
.issuer(StringUtils.defaultIfBlank(jwtProperties.getIssuerUri(),
"https://halo.run"))
.issuedAt(Instant.now())
// Authentication#getName maps to the JWTs sub property, if one is present.
.subject(authentication.getName())
// expiration time (exp) claim
.expiresAt(expireAt)
.claim("scope", scope)
.build();
}
public AccessToken getToken(Authentication authentication) {
Instant expireAt = Instant.now().plusMillis(SecurityConstant.EXPIRATION_TIME);
JwtClaimsSet tokenClaimsSet = createJwt(authentication, expireAt);
Jwt token = jwtEncoder.encode(JwtEncoderParameters.from(tokenClaimsSet));
JwtClaimsSet refreshTokenClaimsSet =
createJwt(authentication, expireAt.plus(30, ChronoUnit.MINUTES));
Jwt refreshToken = jwtEncoder.encode(JwtEncoderParameters.from(refreshTokenClaimsSet));
AccessToken accessToken = new AccessToken(token);
accessToken.setRefreshToken(refreshToken);
accessToken.setExpiration(expireAt.toEpochMilli());
accessToken.setTokenType(SecurityConstant.TOKEN_PREFIX);
return accessToken;
}
public Jwt verify(String token) {
try {
return jwtDecoder.decode(token);
} catch (JwtException e) {
return null;
}
}
}

View File

@ -0,0 +1,147 @@
package run.halo.app.identity.authentication;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.temporal.ChronoUnit;
import java.util.Map;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* Processes an authentication client request. Called
* {@code AuthenticationProcessingFilter} prior to Spring Security 3.0.
* <p>
* Request parameter must present two parameters to this filter: a username and password. The
* default parameter names to use are contained in the static fields
* {@link #SPRING_SECURITY_FORM_USERNAME_KEY} and
* {@link #SPRING_SECURITY_FORM_PASSWORD_KEY}. The parameter names can also be changed by
* setting the {@code usernameParameter} and {@code passwordParameter} properties.
* <p>
* This filter by default responds to the URL {@code /api/v1/oauth2/login}.
*
* @author guqing
* @see UsernamePasswordAuthenticationFilter
* @see AbstractAuthenticationProcessingFilter
* @see OAuth2AccessTokenAuthenticationToken
* @since 2.0.0
*/
public class JwtUsernamePasswordAuthenticationFilter extends UsernamePasswordAuthenticationFilter {
/**
* The default endpoint {@code URI} for access token requests.
*/
private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/api/v1/oauth2/login";
private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
new OAuth2AccessTokenResponseHttpMessageConverter();
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
new OAuth2ErrorHttpMessageConverter();
private boolean postOnly = true;
public JwtUsernamePasswordAuthenticationFilter() {
this(DEFAULT_TOKEN_ENDPOINT_URI, null);
}
public JwtUsernamePasswordAuthenticationFilter(String defaultFilterProcessesUrl) {
this(defaultFilterProcessesUrl, null);
}
public JwtUsernamePasswordAuthenticationFilter(AuthenticationManager authenticationManager) {
this(DEFAULT_TOKEN_ENDPOINT_URI, authenticationManager);
}
public JwtUsernamePasswordAuthenticationFilter(String defaultFilterProcessesUrl,
AuthenticationManager authenticationManager) {
super(authenticationManager);
if (!StringUtils.hasText(defaultFilterProcessesUrl)) {
throw new IllegalArgumentException("tokenEndpointUri cannot be empty.");
}
setFilterProcessesUrl(defaultFilterProcessesUrl);
setAuthenticationSuccessHandler(this::sendAccessTokenResponse);
setAuthenticationFailureHandler(this::sendErrorResponse);
}
@Override
public Authentication attemptAuthentication(HttpServletRequest request,
HttpServletResponse response)
throws AuthenticationException {
if (this.postOnly && !HttpMethod.POST.name().equals(request.getMethod())) {
throw new AuthenticationServiceException(
"Authentication method not supported: " + request.getMethod());
}
String username = obtainUsername(request);
username = (username != null) ? username : "";
username = username.trim();
String password = obtainPassword(request);
password = (password != null) ? password : "";
UsernamePasswordAuthenticationToken authRequest =
new UsernamePasswordAuthenticationToken(username, password);
// Allow subclasses to set the "details" property
setDetails(request, authRequest);
return this.getAuthenticationManager().authenticate(authRequest);
}
private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
Authentication authentication) throws IOException {
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) authentication;
OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
Map<String, Object> additionalParameters =
accessTokenAuthentication.getAdditionalParameters();
OAuth2AccessTokenResponse.Builder builder =
OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
.tokenType(accessToken.getTokenType())
.scopes(accessToken.getScopes());
if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
builder.expiresIn(
ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
}
if (refreshToken != null) {
builder.refreshToken(refreshToken.getTokenValue());
}
if (!CollectionUtils.isEmpty(additionalParameters)) {
builder.additionalParameters(additionalParameters);
}
OAuth2AccessTokenResponse accessTokenResponse = builder.build();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
}
private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
AuthenticationException exception) throws IOException {
//OAuth2Error error = ((AuthenticationServiceException) exception).getMessage();
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
this.errorHttpResponseConverter.write(new OAuth2Error(exception.getMessage()), null,
httpResponse);
}
@Override
public void setPostOnly(boolean postOnly) {
this.postOnly = postOnly;
super.setPostOnly(postOnly);
}
}

View File

@ -0,0 +1,103 @@
package run.halo.app.identity.authentication;
import java.util.Collections;
import java.util.Map;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.util.Assert;
/**
* @author guqing
* @since 2.0
*/
public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
private final Authentication principal;
private final OAuth2AccessToken accessToken;
private final OAuth2RefreshToken refreshToken;
private final Map<String, Object> additionalParameters;
/**
* Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
*
* @param clientPrincipal the authenticated client principal
* @param accessToken the access token
*/
public OAuth2AccessTokenAuthenticationToken(Authentication clientPrincipal,
OAuth2AccessToken accessToken) {
this(clientPrincipal, accessToken, null);
}
/**
* Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
*
* @param clientPrincipal the authenticated client principal
* @param accessToken the access token
* @param refreshToken the refresh token
*/
public OAuth2AccessTokenAuthenticationToken(Authentication clientPrincipal,
OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken) {
this(clientPrincipal, accessToken, refreshToken, Collections.emptyMap());
}
/**
* Constructs an {@code OAuth2AccessTokenAuthenticationToken} using the provided parameters.
*
* @param principal the authenticated principal
* @param accessToken the access token
* @param refreshToken the refresh token
* @param additionalParameters the additional parameters
*/
public OAuth2AccessTokenAuthenticationToken(Authentication principal,
OAuth2AccessToken accessToken, @Nullable OAuth2RefreshToken refreshToken,
Map<String, Object> additionalParameters) {
super(Collections.emptyList());
Assert.notNull(principal, "principal cannot be null");
Assert.notNull(accessToken, "accessToken cannot be null");
Assert.notNull(additionalParameters, "additionalParameters cannot be null");
this.principal = principal;
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.additionalParameters = additionalParameters;
}
@Override
public Object getPrincipal() {
return this.principal;
}
@Override
public Object getCredentials() {
return "";
}
/**
* Returns the {@link OAuth2AccessToken access token}.
*
* @return the {@link OAuth2AccessToken}
*/
public OAuth2AccessToken getAccessToken() {
return this.accessToken;
}
/**
* Returns the {@link OAuth2RefreshToken refresh token}.
*
* @return the {@link OAuth2RefreshToken} or {@code null} if not available
*/
@Nullable
public OAuth2RefreshToken getRefreshToken() {
return this.refreshToken;
}
/**
* Returns the additional parameters.
*
* @return a {@code Map} of the additional parameters, may be empty
*/
public Map<String, Object> getAdditionalParameters() {
return this.additionalParameters;
}
}

View File

@ -0,0 +1,452 @@
package run.halo.app.identity.authentication;
import java.io.Serializable;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.function.Consumer;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/**
* @author guqing
* @since 2.0
*/
public class OAuth2Authorization implements Serializable {
/**
* The name of the {@link #getAttribute(String) attribute} used for the authorized scope(s).
* The value of the attribute is of type {@code Set<String>}.
*/
public static final String AUTHORIZED_SCOPE_ATTRIBUTE_NAME =
OAuth2Authorization.class.getName().concat(".AUTHORIZED_SCOPE");
private String id;
private String principalName;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens;
private Map<String, Object> attributes;
protected OAuth2Authorization() {
}
/**
* Returns the identifier for the authorization.
*
* @return the identifier for the authorization
*/
public String getId() {
return this.id;
}
/**
* Returns the {@code Principal} name of the resource owner (or client).
*
* @return the {@code Principal} name of the resource owner (or client)
*/
public String getPrincipalName() {
return this.principalName;
}
/**
* Returns the {@link Token} of type {@link OAuth2AccessToken}.
*
* @return the {@link Token} of type {@link OAuth2AccessToken}
*/
public Token<OAuth2AccessToken> getAccessToken() {
return getToken(OAuth2AccessToken.class);
}
/**
* Returns the {@link Token} of type {@link OAuth2RefreshToken}.
*
* @return the {@link Token} of type {@link OAuth2RefreshToken}, or {@code null} if not
* available
*/
@Nullable
public Token<OAuth2RefreshToken> getRefreshToken() {
return getToken(OAuth2RefreshToken.class);
}
/**
* Returns the {@link Token} of type {@code tokenType}.
*
* @param tokenType the token type
* @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends OAuth2Token> Token<T> getToken(Class<T> tokenType) {
Assert.notNull(tokenType, "tokenType cannot be null");
Token<?> token = this.tokens.get(tokenType);
return token != null ? (Token<T>) token : null;
}
/**
* Returns the {@link Token} matching the {@code tokenValue}.
*
* @param tokenValue the token value
* @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T extends OAuth2Token> Token<T> getToken(String tokenValue) {
Assert.hasText(tokenValue, "tokenValue cannot be empty");
for (Token<?> token : this.tokens.values()) {
if (token.getToken().getTokenValue().equals(tokenValue)) {
return (Token<T>) token;
}
}
return null;
}
/**
* Returns the attribute(s) associated to the authorization.
*
* @return a {@code Map} of the attribute(s)
*/
public Map<String, Object> getAttributes() {
return this.attributes;
}
/**
* Returns the value of an attribute associated to the authorization.
*
* @param name the name of the attribute
* @param <T> the type of the attribute
* @return the value of an attribute associated to the authorization, or {@code null} if not
* available
*/
@Nullable
@SuppressWarnings("unchecked")
public <T> T getAttribute(String name) {
Assert.hasText(name, "name cannot be empty");
return (T) this.attributes.get(name);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
OAuth2Authorization that = (OAuth2Authorization) obj;
return Objects.equals(this.id, that.id)
&& Objects.equals(this.principalName, that.principalName)
&& Objects.equals(this.tokens, that.tokens)
&& Objects.equals(this.attributes, that.attributes);
}
@Override
public int hashCode() {
return Objects.hash(this.id, this.principalName, this.tokens, this.attributes);
}
/**
* Returns a new {@link Builder}, initialized with the values from the provided {@code
* OAuth2Authorization}.
*
* @param authorization the {@code OAuth2Authorization} used for initializing the
* {@link Builder}
* @return the {@link Builder}
*/
public static Builder from(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
return new Builder()
.id(authorization.getId())
.principalName(authorization.getPrincipalName())
.tokens(authorization.tokens)
.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
}
/**
* A holder of an OAuth 2.0 Token and it's associated metadata.
*
* @author guqing
* @since 2.0
*/
public static class Token<T extends OAuth2Token> implements Serializable {
protected static final String TOKEN_METADATA_NAMESPACE = "metadata.token.";
/**
* The name of the metadata that indicates if the token has been invalidated.
*/
public static final String INVALIDATED_METADATA_NAME =
TOKEN_METADATA_NAMESPACE.concat("invalidated");
/**
* The name of the metadata used for the claims of the token.
*/
public static final String CLAIMS_METADATA_NAME = TOKEN_METADATA_NAMESPACE.concat("claims");
private final T token;
private final Map<String, Object> metadata;
protected Token(T token) {
this(token, defaultMetadata());
}
protected Token(T token, Map<String, Object> metadata) {
this.token = token;
this.metadata = Collections.unmodifiableMap(metadata);
}
/**
* Returns the token of type {@link OAuth2Token}.
*
* @return the token of type {@link OAuth2Token}
*/
public T getToken() {
return this.token;
}
/**
* Returns {@code true} if the token has been invalidated (e.g. revoked).
* The default is {@code false}.
*
* @return {@code true} if the token has been invalidated, {@code false} otherwise
*/
public boolean isInvalidated() {
return Boolean.TRUE.equals(getMetadata(INVALIDATED_METADATA_NAME));
}
/**
* Returns {@code true} if the token has expired.
*
* @return {@code true} if the token has expired, {@code false} otherwise
*/
public boolean isExpired() {
return getToken().getExpiresAt() != null && Instant.now()
.isAfter(getToken().getExpiresAt());
}
/**
* Returns {@code true} if the token is before the time it can be used.
*
* @return {@code true} if the token is before the time it can be used, {@code false}
* otherwise
*/
public boolean isBeforeUse() {
Instant notBefore = null;
if (!CollectionUtils.isEmpty(getClaims())) {
notBefore = (Instant) getClaims().get("nbf");
}
return notBefore != null && Instant.now().isBefore(notBefore);
}
/**
* Returns {@code true} if the token is currently active.
*
* @return {@code true} if the token is currently active, {@code false} otherwise
*/
public boolean isActive() {
return !isInvalidated() && !isExpired() && !isBeforeUse();
}
/**
* Returns the claims associated to the token.
*
* @return a {@code Map} of the claims, or {@code null} if not available
*/
@Nullable
public Map<String, Object> getClaims() {
return getMetadata(CLAIMS_METADATA_NAME);
}
/**
* Returns the value of the metadata associated to the token.
*
* @param name the name of the metadata
* @param <V> the value type of the metadata
* @return the value of the metadata, or {@code null} if not available
*/
@Nullable
@SuppressWarnings("unchecked")
public <V> V getMetadata(String name) {
Assert.hasText(name, "name cannot be empty");
return (V) this.metadata.get(name);
}
/**
* Returns the metadata associated to the token.
*
* @return a {@code Map} of the metadata
*/
public Map<String, Object> getMetadata() {
return this.metadata;
}
protected static Map<String, Object> defaultMetadata() {
Map<String, Object> metadata = new HashMap<>();
metadata.put(INVALIDATED_METADATA_NAME, false);
return metadata;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Token<?> that = (Token<?>) obj;
return Objects.equals(this.token, that.token)
&& Objects.equals(this.metadata, that.metadata);
}
@Override
public int hashCode() {
return Objects.hash(this.token, this.metadata);
}
}
/**
* A builder for {@link OAuth2Authorization}.
*/
public static class Builder implements Serializable {
private String id;
private String principalName;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens = new HashMap<>();
private final Map<String, Object> attributes = new HashMap<>();
/**
* Sets the identifier for the authorization.
*
* @param id the identifier for the authorization
* @return the {@link Builder}
*/
public Builder id(String id) {
this.id = id;
return this;
}
/**
* Sets the {@code Principal} name of the resource owner (or client).
*
* @param principalName the {@code Principal} name of the resource owner (or client)
* @return the {@link Builder}
*/
public Builder principalName(String principalName) {
this.principalName = principalName;
return this;
}
/**
* Sets the {@link OAuth2AccessToken access token}.
*
* @param accessToken the {@link OAuth2AccessToken}
* @return the {@link Builder}
*/
public Builder accessToken(OAuth2AccessToken accessToken) {
return token(accessToken);
}
/**
* Sets the {@link OAuth2RefreshToken refresh token}.
*
* @param refreshToken the {@link OAuth2RefreshToken}
* @return the {@link Builder}
*/
public Builder refreshToken(OAuth2RefreshToken refreshToken) {
return token(refreshToken);
}
/**
* Sets the {@link OAuth2Token token}.
*
* @param token the token
* @param <T> the type of the token
* @return the {@link Builder}
*/
public <T extends OAuth2Token> Builder token(T token) {
return token(token, (metadata) -> {
});
}
/**
* Sets the {@link OAuth2Token token} and associated metadata.
*
* @param token the token
* @param metadataConsumer a {@code Consumer} of the metadata {@code Map}
* @param <T> the type of the token
* @return the {@link Builder}
*/
public <T extends OAuth2Token> Builder token(T token,
Consumer<Map<String, Object>> metadataConsumer) {
Assert.notNull(token, "token cannot be null");
Map<String, Object> metadata = Token.defaultMetadata();
Token<?> existingToken = this.tokens.get(token.getClass());
if (existingToken != null) {
metadata.putAll(existingToken.getMetadata());
}
metadataConsumer.accept(metadata);
Class<? extends OAuth2Token> tokenClass = token.getClass();
this.tokens.put(tokenClass, new Token<>(token, metadata));
return this;
}
protected final Builder tokens(Map<Class<? extends OAuth2Token>, Token<?>> tokens) {
this.tokens = new HashMap<>(tokens);
return this;
}
/**
* Adds an attribute associated to the authorization.
*
* @param name the name of the attribute
* @param value the value of the attribute
* @return the {@link Builder}
*/
public Builder attribute(String name, Object value) {
Assert.hasText(name, "name cannot be empty");
Assert.notNull(value, "value cannot be null");
this.attributes.put(name, value);
return this;
}
/**
* A {@code Consumer} of the attributes {@code Map}
* allowing the ability to add, replace, or remove.
*
* @param attributesConsumer a {@link Consumer} of the attributes {@code Map}
* @return the {@link Builder}
*/
public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
attributesConsumer.accept(this.attributes);
return this;
}
/**
* Builds a new {@link OAuth2Authorization}.
*
* @return the {@link OAuth2Authorization}
*/
public OAuth2Authorization build() {
Assert.hasText(this.principalName, "principalName cannot be empty");
OAuth2Authorization authorization = new OAuth2Authorization();
if (!StringUtils.hasText(this.id)) {
this.id = UUID.randomUUID().toString();
}
authorization.id = this.id;
authorization.principalName = this.principalName;
authorization.tokens = Collections.unmodifiableMap(this.tokens);
authorization.attributes = Collections.unmodifiableMap(this.attributes);
return authorization;
}
}
}

View File

@ -0,0 +1,27 @@
package run.halo.app.identity.authentication;
import java.time.Instant;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
/**
* @author guqing
* @since 2.0.0
*/
public class OAuth2AuthorizationService {
OAuth2Authorization findByUsername(String username, OAuth2TokenType oauth2TokenType) {
// TODO to be implementation
return new OAuth2Authorization.Builder().id("id")
.accessToken(new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(),
Instant.now().plusMillis(123)))
.refreshToken(
new OAuth2RefreshToken("refresh_token", Instant.now()))
.principalName("guqing")
.build();
}
void save(OAuth2Authorization authorization) {
// TODO to be implementation
}
}

View File

@ -1,19 +0,0 @@
package run.halo.app.identity.authentication;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
/**
* @author guqing
* @date 2022-04-12
*/
@RestController
@RequestMapping("/oauth")
public class OauthController {
@GetMapping("login")
public String login() {
return "hello";
}
}

View File

@ -1,64 +0,0 @@
package run.halo.app;
import static org.assertj.core.api.Assertions.assertThat;
import java.net.MalformedURLException;
import java.net.URL;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.test.context.TestPropertySource;
import run.halo.app.identity.authentication.AccessToken;
import run.halo.app.identity.authentication.JwtTokenProvider;
/**
* @author guqing
* @date 2022-04-13
*/
@WithMockUser(username = "test", password = "test")
@TestPropertySource(properties = {"halo.security.oauth2.jwt.public-key-location=classpath:app.pub",
"halo.security.oauth2.jwt.private-key-location=classpath:app.key"})
@SpringBootTest
public class JwtTokenProviderTest {
@Autowired
private JwtTokenProvider jwtTokenProvider;
private AccessToken accessToken;
@BeforeEach
public void setUp() {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
accessToken = jwtTokenProvider.getToken(authentication);
}
@Test
public void createJwt() throws MalformedURLException {
Jwt token = accessToken.getAccessToken();
Jwt refreshToken = accessToken.getRefreshToken();
System.out.println(token.getTokenValue());
assertThat(token.getClaims()).isNotNull()
.containsEntry("sub", "test")
.containsEntry("scope", "ROLE_USER")
.containsEntry("iss", new URL("https://halo.run"))
.containsKey("iss")
.containsKey("jti");
assertThat(refreshToken.getClaims())
.isNotNull()
.containsEntry("sub", "test")
.containsKey("iss")
.containsKey("jti");
}
@Test
public void verifyBadToken() {
Jwt badToken = jwtTokenProvider.verify("badToken");
assertThat(badToken).isNull();
}
}

View File

@ -3,9 +3,9 @@ package run.halo.app;
import static org.assertj.core.api.Assertions.assertThat;
import org.junit.jupiter.api.Test;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.method.HandlerTypePredicate;
import run.halo.app.identity.authentication.OauthController;
/**
* Test case for api path prefix predicate.
@ -24,7 +24,13 @@ public class PathPrefixPredicateTest {
boolean result = HandlerTypePredicate.forAnnotation(RestController.class)
.and(HandlerTypePredicate.forBasePackage(Application.class.getPackageName()))
.test(OauthController.class);
.test(TestController.class);
assertThat(result).isTrue();
}
@RestController
@RequestMapping("/test-prefix")
public static class TestController {
}
}

View File

@ -0,0 +1,251 @@
package run.halo.app.authentication;
import static java.util.Map.entry;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.InstanceOfAssertFactories.type;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import jakarta.servlet.FilterChain;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
import org.springframework.security.web.authentication.WebAuthenticationDetails;
import run.halo.app.identity.authentication.JwtUsernamePasswordAuthenticationFilter;
import run.halo.app.identity.authentication.OAuth2AccessTokenAuthenticationToken;
/**
* Tests for {@link JwtUsernamePasswordAuthenticationFilter}.
*
* @author guqing
* @since 2.0.0
*/
public class JwtUsernamePasswordAuthenticationFilterTests {
private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/api/v1/oauth2/login";
private static final String REMOTE_ADDRESS = "remote-address";
private AuthenticationManager authenticationManager;
private JwtUsernamePasswordAuthenticationFilter filter;
private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
new OAuth2AccessTokenResponseHttpMessageConverter();
@BeforeEach
public void setUp() {
this.authenticationManager = mock(AuthenticationManager.class);
this.filter = new JwtUsernamePasswordAuthenticationFilter(this.authenticationManager);
}
@AfterEach
public void cleanup() {
SecurityContextHolder.clearContext();
}
@Test
public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() {
assertThatThrownBy(
() -> new JwtUsernamePasswordAuthenticationFilter(null, this.authenticationManager))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("tokenEndpointUri cannot be empty.");
}
@Test
public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("AuthenticationDetailsSource required");
}
@Test
public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("successHandler cannot be null");
}
@Test
public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("failureHandler cannot be null");
}
@Test
public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception {
String requestUri = "/path";
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
@Test
public void doFilterWhenTokenRequestGetThenNotProcessed() throws Exception {
String requestUri = DEFAULT_TOKEN_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
AuthenticationFailureHandler authenticationFailureHandler =
mock(AuthenticationFailureHandler.class);
filter.setAuthenticationFailureHandler(authenticationFailureHandler);
this.filter.doFilter(request, response, filterChain);
verify(authenticationFailureHandler).onAuthenticationFailure(any(HttpServletRequest.class),
any(HttpServletResponse.class), any(AuthenticationException.class));
verifyNoInteractions(filterChain);
}
@Test
public void doFilterWhenAuthorizationCodeTokenRequestThenAccessTokenResponse()
throws Exception {
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "token",
Instant.now(), Instant.now().plus(Duration.ofHours(1)),
new HashSet<>(Arrays.asList("scope1", "scope2")));
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
"refresh-token", Instant.now(), Instant.now().plus(Duration.ofDays(1)));
Authentication clientPrincipal =
new UsernamePasswordAuthenticationToken("guqing", "123456");
Map<String, Object>
additionalParameters = Collections.singletonMap("custom-param", "custom-value");
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
new OAuth2AccessTokenAuthenticationToken(clientPrincipal, accessToken, refreshToken,
additionalParameters);
when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createClientTokenRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
ArgumentCaptor<UsernamePasswordAuthenticationToken>
authorizationCodeAuthenticationCaptor =
ArgumentCaptor.forClass(UsernamePasswordAuthenticationToken.class);
verify(this.authenticationManager).authenticate(
authorizationCodeAuthenticationCaptor.capture());
UsernamePasswordAuthenticationToken accessTokenAuthenticationToken =
authorizationCodeAuthenticationCaptor.getValue();
assertThat(accessTokenAuthenticationToken.getName()).isEqualTo("guqing");
assertThat(accessTokenAuthenticationToken.getDetails())
.asInstanceOf(type(WebAuthenticationDetails.class))
.extracting(WebAuthenticationDetails::getRemoteAddress)
.isEqualTo(REMOTE_ADDRESS);
assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse(response);
OAuth2AccessToken accessTokenResult = accessTokenResponse.getAccessToken();
assertThat(accessTokenResult.getTokenType()).isEqualTo(accessToken.getTokenType());
assertThat(accessTokenResult.getTokenValue()).isEqualTo(accessToken.getTokenValue());
assertThat(accessTokenResult.getIssuedAt()).isBetween(
accessToken.getIssuedAt().minusSeconds(1), accessToken.getIssuedAt().plusSeconds(1));
assertThat(accessTokenResult.getExpiresAt()).isBetween(
accessToken.getExpiresAt().minusSeconds(1), accessToken.getExpiresAt().plusSeconds(1));
assertThat(accessTokenResult.getScopes()).isEqualTo(accessToken.getScopes());
assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(
refreshToken.getTokenValue());
assertThat(accessTokenResponse.getAdditionalParameters()).containsExactly(
entry("custom-param", "custom-value"));
}
@Test
public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception {
AuthenticationSuccessHandler authenticationSuccessHandler =
mock(AuthenticationSuccessHandler.class);
this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler);
Authentication clientPrincipal =
new UsernamePasswordAuthenticationToken("guqing", "123456");
OAuth2AccessToken accessToken = new OAuth2AccessToken(
OAuth2AccessToken.TokenType.BEARER, "token",
Instant.now(), Instant.now().plus(Duration.ofHours(1)),
new HashSet<>(Arrays.asList("scope1", "scope2")));
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
new OAuth2AccessTokenAuthenticationToken(clientPrincipal, accessToken);
when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(clientPrincipal);
SecurityContextHolder.setContext(securityContext);
MockHttpServletRequest request = createClientTokenRequest();
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
}
private OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response)
throws Exception {
MockClientHttpResponse httpResponse = new MockClientHttpResponse(
response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));
return this.accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class,
httpResponse);
}
private static MockHttpServletRequest createClientTokenRequest() {
String requestUri = DEFAULT_TOKEN_ENDPOINT_URI;
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
request.setServletPath(requestUri);
request.setRemoteAddr(REMOTE_ADDRESS);
request.addParameter("username", "guqing");
request.addParameter("password", "123456");
request.addParameter(OAuth2ParameterNames.GRANT_TYPE,
AuthorizationGrantType.PASSWORD.getValue());
request.addParameter("custom-param-1", "custom-value-1");
return request;
}
}