/oauth2/rest_token impl/ test

pull/4/head
shengzhaoli.shengz 2023-10-31 19:37:11 +08:00
parent 96e4829866
commit 0165e92efc
1 changed files with 123 additions and 27 deletions

View File

@ -26,22 +26,32 @@ import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.*; import org.springframework.security.oauth2.core.*;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext;
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.security.web.authentication.WebAuthenticationDetails;
import org.springframework.stereotype.Controller; import org.springframework.stereotype.Controller;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.bind.annotation.ResponseBody;
import java.io.IOException; import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.Arrays; import java.util.Arrays;
import java.util.Map; import java.util.Map;
@ -63,6 +73,8 @@ public class OAuthRestController {
private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private static final String CLIENT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
private final AuthenticationRestConverter authenticationConverter; private final AuthenticationRestConverter authenticationConverter;
@ -74,9 +86,19 @@ public class OAuthRestController {
private AuthenticationManager authenticationManager; private AuthenticationManager authenticationManager;
@Autowired @Autowired
private ApplicationContext applicationContext; private ApplicationContext applicationContext;
@Autowired
private RegisteredClientRepository registeredClientRepository;
@Autowired
private PasswordEncoder passwordEncoder;
@Autowired
private AuthorizationServerSettings authorizationServerSettings;
public OAuthRestController() { public OAuthRestController() {
@ -92,41 +114,107 @@ public class OAuthRestController {
* Replace OAuth2TokenEndpointFilter flow use restful API * Replace OAuth2TokenEndpointFilter flow use restful API
* *
* @param parameters request params * @param parameters request params
* @see org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider
*/ */
@PostMapping("/oauth2/rest_token") @PostMapping("/oauth2/rest_token")
@ResponseBody @ResponseBody
public void postAccessToken(@RequestBody Map<String, String> parameters, HttpServletResponse response) throws IOException { public void postAccessToken(@RequestBody Map<String, String> parameters, HttpServletResponse response)
throws OAuth2AuthenticationException, IOException {
try { //init OAuth2 contexts
String grantType = parameters.get(OAuth2ParameterNames.GRANT_TYPE); initialOAuth2Contexts(parameters);
if (grantType == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE);
}
Authentication authorizationGrantAuthentication = this.authenticationConverter.convert(parameters); // oauth2 flow start...
if (authorizationGrantAuthentication == null) { String grantType = parameters.get(OAuth2ParameterNames.GRANT_TYPE);
throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); if (grantType == null) {
} throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE);
if (authorizationGrantAuthentication instanceof AbstractAuthenticationToken) {
((AbstractAuthenticationToken) authorizationGrantAuthentication)
// .setDetails(this.authenticationDetailsSource.buildDetails(request));
.setDetails(new WebAuthenticationDetails(WebUtils.getIp(), null));
}
checkAndInitialAuthenticationManager();
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
this.sendAccessTokenResponse(response, accessTokenAuthentication);
} catch (OAuth2AuthenticationException ex) {
SecurityContextHolder.clearContext();
if (LOG.isTraceEnabled()) {
LOG.trace("Token request failed: {}", ex.getError(), ex);
}
this.sendErrorResponse(response, ex);
} }
Authentication authorizationGrantAuthentication = this.authenticationConverter.convert(parameters);
if (authorizationGrantAuthentication == null) {
throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE);
}
if (authorizationGrantAuthentication instanceof AbstractAuthenticationToken) {
((AbstractAuthenticationToken) authorizationGrantAuthentication)
// .setDetails(this.authenticationDetailsSource.buildDetails(request));
.setDetails(new WebAuthenticationDetails(WebUtils.getIp(), null));
}
checkAndInitialAuthenticationManager();
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
this.sendAccessTokenResponse(response, accessTokenAuthentication);
} }
private void initialOAuth2Contexts(Map<String, String> parameters) {
String clientId = parameters.get(OAuth2ParameterNames.CLIENT_ID);
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
}
if (LOG.isTraceEnabled()) {
LOG.trace("Retrieved registered client");
}
if (!registeredClient.getClientAuthenticationMethods().contains(
ClientAuthenticationMethod.CLIENT_SECRET_POST)) {
throwInvalidClient("authentication_method");
}
String clientSecret = parameters.get(OAuth2ParameterNames.CLIENT_SECRET);
if (clientSecret == null) {
throwInvalidClient("credentials");
}
// String clientSecret = clientAuthentication.getCredentials().toString();
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET);
}
if (registeredClient.getClientSecretExpiresAt() != null &&
Instant.now().isAfter(registeredClient.getClientSecretExpiresAt())) {
throwInvalidClient("client_secret_expires_at");
}
if (LOG.isTraceEnabled()) {
LOG.trace("Authenticated client secret");
}
OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient,
ClientAuthenticationMethod.CLIENT_SECRET_POST, clientSecret);
SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(authentication);
SecurityContextHolder.setContext(securityContext);
// init AuthorizationServerContext
AuthorizationServerContext authorizationServerContext = new AuthorizationServerContext() {
@Override
public String getIssuer() {
return authorizationServerSettings.getIssuer();
}
@Override
public AuthorizationServerSettings getAuthorizationServerSettings() {
return authorizationServerSettings;
}
};
AuthorizationServerContextHolder.setContext(authorizationServerContext);
}
/**
*
*/
@ExceptionHandler(OAuth2AuthenticationException.class)
public void handleOAuth2AuthenticationException(OAuth2AuthenticationException ex, HttpServletResponse response) throws IOException {
SecurityContextHolder.clearContext();
if (LOG.isTraceEnabled()) {
LOG.trace("Token request failed: {}", ex.getError(), ex);
}
this.sendErrorResponse(response, ex);
}
private void checkAndInitialAuthenticationManager() { private void checkAndInitialAuthenticationManager() {
if (this.authenticationManager == null) { if (this.authenticationManager == null) {
OAuth2ServerConfiguration serverConfiguration = applicationContext.getBean(OAuth2ServerConfiguration.class); OAuth2ServerConfiguration serverConfiguration = applicationContext.getBean(OAuth2ServerConfiguration.class);
@ -178,5 +266,13 @@ public class OAuthRestController {
throw new OAuth2AuthenticationException(error); throw new OAuth2AuthenticationException(error);
} }
private static void throwInvalidClient(String parameterName) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName,
CLIENT_ERROR_URI
);
throw new OAuth2AuthenticationException(error);
}
} }