diff --git a/api/src/main/java/run/halo/app/core/endpoint/WebSocketEndpoint.java b/api/src/main/java/run/halo/app/core/endpoint/WebSocketEndpoint.java new file mode 100644 index 000000000..33d141da3 --- /dev/null +++ b/api/src/main/java/run/halo/app/core/endpoint/WebSocketEndpoint.java @@ -0,0 +1,34 @@ +package run.halo.app.core.endpoint; + +import org.springframework.web.reactive.socket.WebSocketHandler; +import run.halo.app.extension.GroupVersion; + +/** + * Endpoint for WebSocket. + * + * @author johnniang + */ +public interface WebSocketEndpoint { + + /** + * Path of the URL after group version. + * + * @return path of the URL. + */ + String urlPath(); + + /** + * Group and version parts of the endpoint. + * + * @return GroupVersion. + */ + GroupVersion groupVersion(); + + /** + * Real WebSocket handler for the endpoint. + * + * @return WebSocket handler. + */ + WebSocketHandler handler(); + +} diff --git a/application/src/main/java/run/halo/app/config/WebFluxConfig.java b/application/src/main/java/run/halo/app/config/WebFluxConfig.java index a0ad4bd61..9ef009f64 100644 --- a/application/src/main/java/run/halo/app/config/WebFluxConfig.java +++ b/application/src/main/java/run/halo/app/config/WebFluxConfig.java @@ -38,6 +38,7 @@ import org.springframework.web.reactive.result.view.ViewResolver; import reactor.core.publisher.Mono; import run.halo.app.console.ProxyFilter; import run.halo.app.console.WebSocketRequestPredicate; +import run.halo.app.core.endpoint.WebSocketHandlerMapping; import run.halo.app.core.extension.endpoint.CustomEndpoint; import run.halo.app.core.extension.endpoint.CustomEndpointsBuilder; import run.halo.app.infra.properties.HaloProperties; @@ -100,6 +101,13 @@ public class WebFluxConfig implements WebFluxConfigurer { return builder.build(); } + @Bean + public WebSocketHandlerMapping webSocketHandlerMapping() { + WebSocketHandlerMapping handlerMapping = new WebSocketHandlerMapping(); + handlerMapping.setOrder(-2); + return handlerMapping; + } + @Bean RouterFunction consoleIndexRedirection() { var consolePredicate = method(HttpMethod.GET) diff --git a/application/src/main/java/run/halo/app/console/WebSocketUtils.java b/application/src/main/java/run/halo/app/console/WebSocketUtils.java index 933e8c268..29ba096ae 100644 --- a/application/src/main/java/run/halo/app/console/WebSocketUtils.java +++ b/application/src/main/java/run/halo/app/console/WebSocketUtils.java @@ -1,5 +1,6 @@ package run.halo.app.console; +import java.util.Objects; import org.springframework.http.HttpHeaders; public enum WebSocketUtils { @@ -8,8 +9,11 @@ public enum WebSocketUtils { public static boolean isWebSocketUpgrade(HttpHeaders headers) { // See io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionUtil // .isWebsocketUpgrade for more. + var upgradeConnection = headers.getConnection().stream().map(String::toLowerCase) + .anyMatch(conn -> Objects.equals(conn, "upgrade")); + return headers.containsKey(HttpHeaders.UPGRADE) - && headers.getConnection().contains(HttpHeaders.UPGRADE) + && upgradeConnection && "websocket".equalsIgnoreCase(headers.getUpgrade()); } diff --git a/application/src/main/java/run/halo/app/core/endpoint/WebSocketEndpointManager.java b/application/src/main/java/run/halo/app/core/endpoint/WebSocketEndpointManager.java new file mode 100644 index 000000000..bcdf99ecf --- /dev/null +++ b/application/src/main/java/run/halo/app/core/endpoint/WebSocketEndpointManager.java @@ -0,0 +1,16 @@ +package run.halo.app.core.endpoint; + +import java.util.Collection; + +/** + * Interface for managing WebSocket endpoints, including registering and unregistering. + * + * @author johnniang + */ +public interface WebSocketEndpointManager { + + void register(Collection endpoints); + + void unregister(Collection endpoints); + +} diff --git a/application/src/main/java/run/halo/app/core/endpoint/WebSocketHandlerMapping.java b/application/src/main/java/run/halo/app/core/endpoint/WebSocketHandlerMapping.java new file mode 100644 index 000000000..0105c03b3 --- /dev/null +++ b/application/src/main/java/run/halo/app/core/endpoint/WebSocketHandlerMapping.java @@ -0,0 +1,140 @@ +package run.halo.app.core.endpoint; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.observation.ServerRequestObservationContext; +import org.springframework.lang.NonNull; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.reactive.handler.AbstractHandlerMapping; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.util.pattern.PathPattern; +import reactor.core.publisher.Mono; +import run.halo.app.console.WebSocketUtils; + +public class WebSocketHandlerMapping extends AbstractHandlerMapping + implements WebSocketEndpointManager, InitializingBean { + + private final BiMap endpointMap; + + private final ReadWriteLock rwLock; + + public WebSocketHandlerMapping() { + this.endpointMap = HashBiMap.create(); + this.rwLock = new ReentrantReadWriteLock(); + } + + @Override + @NonNull + public Mono getHandlerInternal(ServerWebExchange exchange) { + var request = exchange.getRequest(); + if (!HttpMethod.GET.equals(request.getMethod()) + || !WebSocketUtils.isWebSocketUpgrade(request.getHeaders())) { + // skip getting handler if the request is not a WebSocket. + return Mono.empty(); + } + + var lock = rwLock.readLock(); + lock.lock(); + try { + // Refer to org.springframework.web.reactive.handler.AbstractUrlHandlerMapping + // .lookupHandler + var pathContainer = request.getPath().pathWithinApplication(); + List matches = null; + for (var pattern : this.endpointMap.keySet()) { + if (pattern.matches(pathContainer)) { + if (matches == null) { + matches = new ArrayList<>(); + } + matches.add(pattern); + } + } + if (matches == null) { + return Mono.empty(); + } + + if (matches.size() > 1) { + matches.sort(PathPattern.SPECIFICITY_COMPARATOR); + } + + var pattern = matches.get(0); + exchange.getAttributes().put(BEST_MATCHING_PATTERN_ATTRIBUTE, pattern); + + var handler = endpointMap.get(pattern).handler(); + exchange.getAttributes().put(BEST_MATCHING_HANDLER_ATTRIBUTE, handler); + + ServerRequestObservationContext.findCurrent(exchange.getAttributes()) + .ifPresent(context -> context.setPathPattern(pattern.toString())); + + var pathWithinMapping = pattern.extractPathWithinPattern(pathContainer); + exchange.getAttributes().put(PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, pathWithinMapping); + + var matchInfo = pattern.matchAndExtract(pathContainer); + Assert.notNull(matchInfo, "Expect a match"); + exchange.getAttributes() + .put(URI_TEMPLATE_VARIABLES_ATTRIBUTE, matchInfo.getUriVariables()); + return Mono.just(handler); + } catch (Exception e) { + return Mono.error(e); + } finally { + lock.unlock(); + } + } + + @Override + public void register(Collection endpoints) { + if (CollectionUtils.isEmpty(endpoints)) { + return; + } + var lock = rwLock.writeLock(); + lock.lock(); + try { + endpoints.forEach(endpoint -> { + var urlPath = endpoint.urlPath(); + urlPath = StringUtils.prependIfMissing(urlPath, "/"); + var groupVersion = endpoint.groupVersion(); + var parser = getPathPatternParser(); + var pattern = parser.parse("/apis/" + groupVersion + urlPath); + endpointMap.put(pattern, endpoint); + }); + } finally { + lock.unlock(); + } + } + + @Override + public void unregister(Collection endpoints) { + if (CollectionUtils.isEmpty(endpoints)) { + return; + } + var lock = rwLock.writeLock(); + lock.lock(); + try { + BiMap inverseMap = endpointMap.inverse(); + endpoints.forEach(inverseMap::remove); + } finally { + lock.unlock(); + } + } + + @Override + public void afterPropertiesSet() { + var endpoints = obtainApplicationContext().getBeanProvider(WebSocketEndpoint.class) + .orderedStream() + .toList(); + register(endpoints); + } + + BiMap getEndpointMap() { + return endpointMap; + } +} diff --git a/application/src/main/java/run/halo/app/plugin/DefaultPluginApplicationContextFactory.java b/application/src/main/java/run/halo/app/plugin/DefaultPluginApplicationContextFactory.java index 7142016a6..6ec14f8de 100644 --- a/application/src/main/java/run/halo/app/plugin/DefaultPluginApplicationContextFactory.java +++ b/application/src/main/java/run/halo/app/plugin/DefaultPluginApplicationContextFactory.java @@ -29,6 +29,8 @@ import org.springframework.util.StopWatch; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.ServerResponse; import reactor.core.Exceptions; +import run.halo.app.core.endpoint.WebSocketEndpoint; +import run.halo.app.core.endpoint.WebSocketEndpointManager; import run.halo.app.extension.ReactiveExtensionClient; import run.halo.app.infra.properties.HaloProperties; import run.halo.app.plugin.event.HaloPluginBeforeStopEvent; @@ -125,6 +127,10 @@ public class DefaultPluginApplicationContextFactory implements PluginApplication beanFactory.registerSingleton("finderManager", finderManager); }); + rootContext.getBeanProvider(WebSocketEndpointManager.class) + .ifUnique(manager -> beanFactory.registerSingleton("pluginWebSocketEndpointManager", + new PluginWebSocketEndpointManager(manager))); + rootContext.getBeanProvider(PluginRouterFunctionRegistry.class) .ifUnique(registry -> { var pluginRouterFunctionManager = new PluginRouterFunctionManager(registry); @@ -219,6 +225,31 @@ public class DefaultPluginApplicationContextFactory implements PluginApplication } + private static class PluginWebSocketEndpointManager { + + private final WebSocketEndpointManager manager; + + private List endpoints; + + private PluginWebSocketEndpointManager(WebSocketEndpointManager manager) { + this.manager = manager; + } + + @EventListener + public void onApplicationEvent(ContextRefreshedEvent event) { + var context = event.getApplicationContext(); + this.endpoints = context.getBeanProvider(WebSocketEndpoint.class) + .orderedStream() + .toList(); + manager.register(this.endpoints); + } + + @EventListener + public void onApplicationEvent(ContextClosedEvent ignored) { + manager.unregister(this.endpoints); + } + } + private static class PluginRouterFunctionManager { private final PluginRouterFunctionRegistry routerFunctionRegistry; diff --git a/application/src/main/java/run/halo/app/security/authorization/RequestInfoFactory.java b/application/src/main/java/run/halo/app/security/authorization/RequestInfoFactory.java index 1fb898e5d..4f629f6a7 100644 --- a/application/src/main/java/run/halo/app/security/authorization/RequestInfoFactory.java +++ b/application/src/main/java/run/halo/app/security/authorization/RequestInfoFactory.java @@ -6,6 +6,7 @@ import java.util.Set; import org.apache.commons.lang3.StringUtils; import org.springframework.http.server.PathContainer; import org.springframework.http.server.reactive.ServerHttpRequest; +import run.halo.app.console.WebSocketUtils; /** * Creates {@link RequestInfo} from {@link ServerHttpRequest}. @@ -215,6 +216,10 @@ public class RequestInfoFactory { requestInfo.verb = "deletecollection"; } } + if ("list".equals(requestInfo.verb) + && WebSocketUtils.isWebSocketUpgrade(request.getHeaders())) { + requestInfo.verb = "watch"; + } return requestInfo; } diff --git a/application/src/test/java/run/halo/app/config/WebFluxConfigTest.java b/application/src/test/java/run/halo/app/config/WebFluxConfigTest.java index 8e433318b..2aeec5cab 100644 --- a/application/src/test/java/run/halo/app/config/WebFluxConfigTest.java +++ b/application/src/test/java/run/halo/app/config/WebFluxConfigTest.java @@ -1,21 +1,114 @@ package run.halo.app.config; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +import java.net.URI; import java.util.List; +import java.util.Set; import org.hamcrest.core.StringStartsWith; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Import; import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; +import run.halo.app.core.endpoint.WebSocketEndpoint; +import run.halo.app.core.extension.Role; +import run.halo.app.core.extension.service.RoleService; +import run.halo.app.extension.GroupVersion; +import run.halo.app.extension.Metadata; -@SpringBootTest(properties = "halo.console.location=classpath:/console/") +@SpringBootTest(properties = "halo.console.location=classpath:/console/", webEnvironment = + SpringBootTest.WebEnvironment.RANDOM_PORT) +@Import(WebFluxConfigTest.WebSocketSupportTest.TestWebSocketConfiguration.class) @AutoConfigureWebTestClient class WebFluxConfigTest { @Autowired WebTestClient webClient; + @SpyBean + RoleService roleService; + + @LocalServerPort + int port; + + @Nested + class WebSocketSupportTest { + + @Test + void shouldInitializeWebSocketEndpoint() { + var role = new Role(); + var metadata = new Metadata(); + metadata.setName("fake-role"); + role.setMetadata(metadata); + role.setRules(List.of(new Role.PolicyRule.Builder() + .apiGroups("fake.halo.run") + .verbs("watch") + .resources("resources") + .build())); + when(roleService.listDependenciesFlux(Set.of("anonymous"))).thenReturn(Flux.just(role)); + var webSocketClient = new ReactorNettyWebSocketClient(); + webSocketClient.execute( + URI.create("ws://localhost:" + port + "/apis/fake.halo.run/v1alpha1/resources"), + session -> { + var send = session.send(Flux.just(session.textMessage("halo"))); + var receive = session.receive().map(WebSocketMessage::getPayloadAsText) + .next() + .doOnNext(message -> assertEquals("HALO", message)); + return send.and(receive); + }) + .as(StepVerifier::create) + .verifyComplete(); + } + + @TestConfiguration + static class TestWebSocketConfiguration { + + @Bean + WebSocketEndpoint fakeWebSocketEndpoint() { + return new FakeWebSocketEndpoint(); + } + + } + + static class FakeWebSocketEndpoint implements WebSocketEndpoint { + + @Override + public String urlPath() { + return "/resources"; + } + + @Override + public GroupVersion groupVersion() { + return GroupVersion.parseAPIVersion("fake.halo.run/v1alpha1"); + } + + @Override + public WebSocketHandler handler() { + return session -> { + var messages = session.receive() + .map(message -> session.textMessage( + message.getPayloadAsText().toUpperCase()) + ); + return session.send(messages).then(session.close()); + }; + } + } + + } + @Nested class ConsoleRequest { diff --git a/application/src/test/java/run/halo/app/core/endpoint/WebSocketHandlerMappingTest.java b/application/src/test/java/run/halo/app/core/endpoint/WebSocketHandlerMappingTest.java new file mode 100644 index 000000000..3fb4e43ff --- /dev/null +++ b/application/src/test/java/run/halo/app/core/endpoint/WebSocketHandlerMappingTest.java @@ -0,0 +1,55 @@ +package run.halo.app.core.endpoint; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketSession; +import run.halo.app.extension.GroupVersion; + +@ExtendWith(MockitoExtension.class) +class WebSocketHandlerMappingTest { + + @InjectMocks + WebSocketHandlerMapping handlerMapping; + + @Test + void shouldRegisterEndpoint() { + var endpoint = new FakeWebSocketEndpoint(); + handlerMapping.register(List.of(endpoint)); + assertTrue(handlerMapping.getEndpointMap().containsValue(endpoint)); + } + + @Test + void shouldUnregisterEndpoint() { + var endpoint = new FakeWebSocketEndpoint(); + handlerMapping.register(List.of(endpoint)); + assertTrue(handlerMapping.getEndpointMap().containsValue(endpoint)); + handlerMapping.unregister(List.of(endpoint)); + assertFalse(handlerMapping.getEndpointMap().containsValue(endpoint)); + } + + static class FakeWebSocketEndpoint implements WebSocketEndpoint { + + @Override + public String urlPath() { + return "/resources"; + } + + @Override + public GroupVersion groupVersion() { + return GroupVersion.parseAPIVersion("fake.halo.run/v1alpha1"); + } + + @Override + public WebSocketHandler handler() { + return WebSocketSession::close; + } + } + +} \ No newline at end of file diff --git a/application/src/test/java/run/halo/app/security/authorization/RequestInfoResolverTest.java b/application/src/test/java/run/halo/app/security/authorization/RequestInfoResolverTest.java index af0fad243..3e67d7565 100644 --- a/application/src/test/java/run/halo/app/security/authorization/RequestInfoResolverTest.java +++ b/application/src/test/java/run/halo/app/security/authorization/RequestInfoResolverTest.java @@ -20,6 +20,27 @@ import org.springframework.http.HttpMethod; */ public class RequestInfoResolverTest { + @Test + void shouldResolveAsWatchRequestWhenRequestIsWebSocket() { + var request = method(HttpMethod.GET, "/apis/fake.halo.run/v1alpha1/fakes") + .header("Upgrade", "websocket") + .header("Connection", "Upgrade") + .build(); + RequestInfo requestInfo = RequestInfoFactory.INSTANCE.newRequestInfo(request); + assertThat(requestInfo).isNotNull(); + assertThat(requestInfo.getVerb()).isEqualTo("watch"); + } + + @Test + void shouldNotResolveAsWatchRequestWhenRequestIsNotWebSocket() { + var request = method(HttpMethod.GET, "/apis/fake.halo.run/v1alpha1/fakes") + .header("Upgrade", "websocket") + .build(); + RequestInfo requestInfo = RequestInfoFactory.INSTANCE.newRequestInfo(request); + assertThat(requestInfo).isNotNull(); + assertThat(requestInfo.getVerb()).isEqualTo("list"); + } + @Test public void requestInfoTest() { for (SuccessCase successCase : getTestRequestInfos()) { @@ -178,7 +199,6 @@ public class RequestInfoResolverTest { } - public record SuccessCase(String method, String url, String expectedVerb, String expectedAPIPrefix, String expectedAPIGroup, String expectedAPIVersion, String expectedNamespace, diff --git a/docs/plugin/websocket.md b/docs/plugin/websocket.md new file mode 100644 index 000000000..bd63e1946 --- /dev/null +++ b/docs/plugin/websocket.md @@ -0,0 +1,49 @@ +# 插件中如何实现 WebSocket + +## 背景 + +> https://github.com/halo-dev/halo/issues/5285 + +越来越多的开发者在开发插件过程中需要及时高效获取某些资源的最新状态,但是因为在插件中不支持 WebSocket,故只能选择定时轮训的方式来解决。 + +在插件中支持 WebSocket 的功能需要 Halo Core 来适配并制定规则以方便插件实现 WebSocket。 + +## 实现 + +插件中实现 WebSocket 的代码样例如下所示: + +```java +@Component +public class MyWebSocketEndpoint implements WebSocketEndpoint { + + @Override + public GroupVersion groupVersion() { + return GroupVersion.parseApiVersion("my-plugin.halowrite.com/v1alpha1"); + } + + @Override + public String urlPath() { + return "/resources"; + } + + @Override + public WebSocketHandler handler() { + return session -> { + var messages = session.receive() + .map(message -> { + var payload = message.getPayloadAsText(); + return session.textMessage(payload.toUpperCase()); + }); + return session.send(messages); + }; + } +} +``` + +插件安装成功后,可以通过 `/apis/my-plugin.halowrite.com/v1alpha1/resources` 进行访问。 示例如下所示: + +```bash +websocat --basic-auth admin:admin ws://127.0.0.1:8090/apis/my-plugin.halowrite.com/v1alpha1/resources +``` + +同样地,WebSocket 相关的 API 仍然受当前权限系统管理。