diff --git a/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java b/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java index 7b5560e..f42aa77 100644 --- a/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java +++ b/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java @@ -26,22 +26,32 @@ import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; +import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.*; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext; +import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder; +import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.stereotype.Controller; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.ResponseBody; import java.io.IOException; +import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Map; @@ -63,6 +73,8 @@ public class OAuthRestController { private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2"; + private static final String CLIENT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1"; + private final AuthenticationRestConverter authenticationConverter; @@ -74,9 +86,19 @@ public class OAuthRestController { private AuthenticationManager authenticationManager; + @Autowired private ApplicationContext applicationContext; + @Autowired + private RegisteredClientRepository registeredClientRepository; + + @Autowired + private PasswordEncoder passwordEncoder; + + @Autowired + private AuthorizationServerSettings authorizationServerSettings; + public OAuthRestController() { @@ -92,41 +114,107 @@ public class OAuthRestController { * Replace OAuth2TokenEndpointFilter flow use restful API * * @param parameters request params + * @see org.springframework.security.oauth2.server.authorization.authentication.ClientSecretAuthenticationProvider */ @PostMapping("/oauth2/rest_token") @ResponseBody - public void postAccessToken(@RequestBody Map parameters, HttpServletResponse response) throws IOException { + public void postAccessToken(@RequestBody Map parameters, HttpServletResponse response) + throws OAuth2AuthenticationException, IOException { - try { - String grantType = parameters.get(OAuth2ParameterNames.GRANT_TYPE); - if (grantType == null) { - throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE); - } + //init OAuth2 contexts + initialOAuth2Contexts(parameters); - Authentication authorizationGrantAuthentication = this.authenticationConverter.convert(parameters); - if (authorizationGrantAuthentication == null) { - throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); - } - if (authorizationGrantAuthentication instanceof AbstractAuthenticationToken) { - ((AbstractAuthenticationToken) authorizationGrantAuthentication) + // oauth2 flow start... + String grantType = parameters.get(OAuth2ParameterNames.GRANT_TYPE); + if (grantType == null) { + throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.GRANT_TYPE); + } + + Authentication authorizationGrantAuthentication = this.authenticationConverter.convert(parameters); + if (authorizationGrantAuthentication == null) { + throwError(OAuth2ErrorCodes.UNSUPPORTED_GRANT_TYPE, OAuth2ParameterNames.GRANT_TYPE); + } + if (authorizationGrantAuthentication instanceof AbstractAuthenticationToken) { + ((AbstractAuthenticationToken) authorizationGrantAuthentication) // .setDetails(this.authenticationDetailsSource.buildDetails(request)); - .setDetails(new WebAuthenticationDetails(WebUtils.getIp(), null)); - } + .setDetails(new WebAuthenticationDetails(WebUtils.getIp(), null)); + } + + checkAndInitialAuthenticationManager(); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication); + this.sendAccessTokenResponse(response, accessTokenAuthentication); + } + + private void initialOAuth2Contexts(Map parameters) { + String clientId = parameters.get(OAuth2ParameterNames.CLIENT_ID); + RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); + if (registeredClient == null) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); + } + + if (LOG.isTraceEnabled()) { + LOG.trace("Retrieved registered client"); + } + + if (!registeredClient.getClientAuthenticationMethods().contains( + ClientAuthenticationMethod.CLIENT_SECRET_POST)) { + throwInvalidClient("authentication_method"); + } + + String clientSecret = parameters.get(OAuth2ParameterNames.CLIENT_SECRET); + if (clientSecret == null) { + throwInvalidClient("credentials"); + } + +// String clientSecret = clientAuthentication.getCredentials().toString(); + if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) { + throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET); + } + + if (registeredClient.getClientSecretExpiresAt() != null && + Instant.now().isAfter(registeredClient.getClientSecretExpiresAt())) { + throwInvalidClient("client_secret_expires_at"); + } - checkAndInitialAuthenticationManager(); + if (LOG.isTraceEnabled()) { + LOG.trace("Authenticated client secret"); + } + OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(registeredClient, + ClientAuthenticationMethod.CLIENT_SECRET_POST, clientSecret); + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(authentication); + SecurityContextHolder.setContext(securityContext); + + // init AuthorizationServerContext + AuthorizationServerContext authorizationServerContext = new AuthorizationServerContext() { + @Override + public String getIssuer() { + return authorizationServerSettings.getIssuer(); + } - OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = - (OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication); - this.sendAccessTokenResponse(response, accessTokenAuthentication); - } catch (OAuth2AuthenticationException ex) { - SecurityContextHolder.clearContext(); - if (LOG.isTraceEnabled()) { - LOG.trace("Token request failed: {}", ex.getError(), ex); + @Override + public AuthorizationServerSettings getAuthorizationServerSettings() { + return authorizationServerSettings; } - this.sendErrorResponse(response, ex); + }; + AuthorizationServerContextHolder.setContext(authorizationServerContext); + } + + /** + * 异常处理 + */ + @ExceptionHandler(OAuth2AuthenticationException.class) + public void handleOAuth2AuthenticationException(OAuth2AuthenticationException ex, HttpServletResponse response) throws IOException { + SecurityContextHolder.clearContext(); + if (LOG.isTraceEnabled()) { + LOG.trace("Token request failed: {}", ex.getError(), ex); } + this.sendErrorResponse(response, ex); } + private void checkAndInitialAuthenticationManager() { if (this.authenticationManager == null) { OAuth2ServerConfiguration serverConfiguration = applicationContext.getBean(OAuth2ServerConfiguration.class); @@ -178,5 +266,13 @@ public class OAuthRestController { throw new OAuth2AuthenticationException(error); } + private static void throwInvalidClient(String parameterName) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_CLIENT, + "Client authentication failed: " + parameterName, + CLIENT_ERROR_URI + ); + throw new OAuth2AuthenticationException(error); + } }