diff --git a/src/main/java/run/halo/app/config/WebServerSecurityConfig.java b/src/main/java/run/halo/app/config/WebServerSecurityConfig.java index 82d36e7e7..4c80d8fa7 100644 --- a/src/main/java/run/halo/app/config/WebServerSecurityConfig.java +++ b/src/main/java/run/halo/app/config/WebServerSecurityConfig.java @@ -7,9 +7,12 @@ import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jose.jwk.source.ImmutableJWKSet; +import java.util.Arrays; +import java.util.List; import org.springframework.context.annotation.Bean; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; +import org.springframework.http.HttpHeaders; import org.springframework.http.codec.ServerCodecConfigurer; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; import org.springframework.security.config.web.server.SecurityWebFiltersOrder; @@ -24,6 +27,9 @@ import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; import org.springframework.security.oauth2.jwt.SupplierReactiveJwtDecoder; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.csrf.CookieServerCsrfTokenRepository; +import org.springframework.web.cors.CorsConfiguration; +import org.springframework.web.cors.reactive.CorsConfigurationSource; +import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource; import org.springframework.web.reactive.function.server.ServerResponse; import run.halo.app.core.extension.service.RoleService; import run.halo.app.core.extension.service.UserService; @@ -55,6 +61,7 @@ public class WebServerSecurityConfig { UserService userService, RoleService roleService) { http.csrf().disable() + .cors(corsSpec -> corsSpec.configurationSource(apiCorsConfigurationSource())) .securityMatcher(pathMatchers("/api/**", "/apis/**")) .authorizeExchange(exchanges -> exchanges.anyExchange().access(new RequestInfoAuthorizationManager(roleService))) @@ -94,6 +101,18 @@ public class WebServerSecurityConfig { return http.build(); } + CorsConfigurationSource apiCorsConfigurationSource() { + CorsConfiguration configuration = new CorsConfiguration(); + configuration.setAllowedOriginPatterns(List.of("*")); + configuration.setAllowedHeaders( + List.of(HttpHeaders.AUTHORIZATION, HttpHeaders.CONTENT_TYPE, HttpHeaders.ACCEPT)); + configuration.setAllowedMethods(Arrays.asList("GET", "POST", "PUT", "DELETE", "PATCH")); + UrlBasedCorsConfigurationSource source = new UrlBasedCorsConfigurationSource(); + source.registerCorsConfiguration("/api/**", configuration); + source.registerCorsConfiguration("/apis/**", configuration); + return source; + } + @Bean ReactiveUserDetailsService userDetailsService(UserService userService, RoleService roleService) { diff --git a/src/test/java/run/halo/app/config/CorsTest.java b/src/test/java/run/halo/app/config/CorsTest.java new file mode 100644 index 000000000..efc670121 --- /dev/null +++ b/src/test/java/run/halo/app/config/CorsTest.java @@ -0,0 +1,93 @@ +package run.halo.app.config; + +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.security.test.context.support.WithMockUser; +import org.springframework.test.web.reactive.server.WebTestClient; + +@SpringBootTest +@AutoConfigureWebTestClient +class CorsTest { + + @Autowired + WebTestClient webClient; + + @Nested + class RequestCorsEnabledApi { + + @Test + @WithMockUser + void shouldNotResponseAllowOriginHeaderWithSameOrigin() { + webClient.get().uri("http://localhost:3000/apis/cors-enabled") + .header(HttpHeaders.ORIGIN, "http://localhost:3000") + .header(HttpHeaders.AUTHORIZATION, "fake-authorization") + .header("FakeHeader", "fake-header-value") + .accept(MediaType.APPLICATION_JSON) + .exchange() + .expectHeader() + .doesNotExist(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN); + } + + @Test + @WithMockUser + void shouldResponseAllowOriginHeaderWithDifferentOrigin() { + webClient.get().uri("http://localhost:3000/apis/cors-enabled") + .header(HttpHeaders.ORIGIN, "https://another.website") + .header(HttpHeaders.AUTHORIZATION, "fake-authorization") + // .header("ForbiddenHeader", "fake value") + .accept(MediaType.APPLICATION_JSON) + .exchange() + .expectHeader() + .exists(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN); + } + + @Test + @WithMockUser + void shouldResponseAllowOriginHeaderWithForbiddenHeader() { + webClient.get().uri("http://localhost:3000/apis/cors-enabled") + .header(HttpHeaders.ORIGIN, "https://another.website") + .header(HttpHeaders.AUTHORIZATION, "fake-authorization") + .header("FakeHeader", "fake-header-value") + // .header("ForbiddenHeader", "fake value") + .accept(MediaType.APPLICATION_JSON) + .exchange() + .expectHeader() + .exists(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN); + } + } + + @Nested + class RequestCorsDisabledApi { + + @Test + @WithMockUser + void shouldNotResponseAllowOriginHeaderWithDifferentOrigin() { + webClient.get().uri("http://localhost:3000/cors-disabled") + .header(HttpHeaders.ORIGIN, "https://another.website") + .header(HttpHeaders.AUTHORIZATION, "fake-authorization") + .accept(MediaType.APPLICATION_JSON) + .exchange() + .expectHeader() + .doesNotExist(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN); + } + + @Test + @WithMockUser + void shouldNotResponseAllowOriginHeaderWithSameOrigin() { + webClient.get().uri("http://localhost:3000/cors-disabled") + .header(HttpHeaders.ORIGIN, "http://localhost:3000") + .header(HttpHeaders.AUTHORIZATION, "fake-authorization") + .header("FakeHeader", "fake-header-value") + .accept(MediaType.APPLICATION_JSON) + .exchange() + .expectHeader() + .doesNotExist(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN); + } + } + +}