diff --git a/src/main/java/com/monkeyk/sos/config/OAuth2ServerConfiguration.java b/src/main/java/com/monkeyk/sos/config/OAuth2ServerConfiguration.java index c4366dd..df1708b 100644 --- a/src/main/java/com/monkeyk/sos/config/OAuth2ServerConfiguration.java +++ b/src/main/java/com/monkeyk/sos/config/OAuth2ServerConfiguration.java @@ -13,6 +13,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.http.MediaType; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.http.SessionCreationPolicy; @@ -30,6 +31,7 @@ import org.springframework.security.oauth2.server.authorization.config.annotatio import org.springframework.security.oauth2.server.authorization.oidc.OidcProviderConfiguration; import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer; +import org.springframework.security.web.DefaultSecurityFilterChain; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint; import org.springframework.security.web.util.matcher.MediaTypeRequestMatcher; @@ -74,6 +76,12 @@ public class OAuth2ServerConfiguration { private JdbcTemplate jdbcTemplate; + /** + * @since 3.0.0 + */ + private AuthenticationManager authenticationManagerOAuth2; + + /** * authorizationServerSecurityFilterChain * @@ -125,10 +133,22 @@ public class OAuth2ServerConfiguration { //ext jwt oauth2ResourceServer.jwt(Customizer.withDefaults())); - return http.build(); + DefaultSecurityFilterChain filterChain = http.build(); + this.authenticationManagerOAuth2 = http.getSharedObject(AuthenticationManager.class); + return filterChain; } + /** + * 获取 OAuth2流程中的 AuthenticationManager + * + * @return AuthenticationManager + * @since 3.0.0 + */ + public AuthenticationManager authenticationManagerOAuth2() { + return authenticationManagerOAuth2; + } + /** * 扩展 oidc 的默认能力配置项 * diff --git a/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java b/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java index 59c8eee..7b5560e 100644 --- a/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java +++ b/src/main/java/com/monkeyk/sos/web/controller/OAuthRestController.java @@ -11,22 +11,19 @@ */ package com.monkeyk.sos.web.controller; +import com.monkeyk.sos.config.OAuth2ServerConfiguration; import com.monkeyk.sos.web.WebUtils; import com.monkeyk.sos.web.authentication.*; import jakarta.servlet.http.HttpServletResponse; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.beans.BeansException; -import org.springframework.beans.factory.BeanInitializationException; -import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AuthenticationManager; -import org.springframework.security.config.annotation.authentication.configuration.AuthenticationConfiguration; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; @@ -59,7 +56,7 @@ import java.util.Map; * @since 2.0.0 */ @Controller -public class OAuthRestController implements InitializingBean, ApplicationContextAware { +public class OAuthRestController { private static final Logger LOG = LoggerFactory.getLogger(OAuthRestController.class); @@ -77,6 +74,10 @@ public class OAuthRestController implements InitializingBean, ApplicationContext private AuthenticationManager authenticationManager; + @Autowired + private ApplicationContext applicationContext; + + public OAuthRestController() { this.authenticationConverter = new DelegatingAuthenticationRestConverter( @@ -112,6 +113,8 @@ public class OAuthRestController implements InitializingBean, ApplicationContext .setDetails(new WebAuthenticationDetails(WebUtils.getIp(), null)); } + checkAndInitialAuthenticationManager(); + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication); this.sendAccessTokenResponse(response, accessTokenAuthentication); @@ -124,6 +127,14 @@ public class OAuthRestController implements InitializingBean, ApplicationContext } } + private void checkAndInitialAuthenticationManager() { + if (this.authenticationManager == null) { + OAuth2ServerConfiguration serverConfiguration = applicationContext.getBean(OAuth2ServerConfiguration.class); + this.authenticationManager = serverConfiguration.authenticationManagerOAuth2(); + Assert.notNull(this.authenticationManager, "authenticationManager cannot be null"); + } + } + private void sendErrorResponse(HttpServletResponse response, AuthenticationException exception) throws IOException { @@ -168,23 +179,4 @@ public class OAuthRestController implements InitializingBean, ApplicationContext } - @Override - public void afterPropertiesSet() throws Exception { - Assert.state(authenticationManager != null, "AuthenticationManager must be provided"); - } - - - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - if (this.authenticationManager == null) { - AuthenticationConfiguration configuration = applicationContext.getBean(AuthenticationConfiguration.class); - Assert.notNull(configuration, "AuthenticationManagerBuilder is null"); - try { - this.authenticationManager = configuration.getAuthenticationManager(); - } catch (Exception e) { - throw new BeanInitializationException("Call 'getAuthenticationManager' error", e); - } - } - } - }