Add WebSocket support in plugins (#5662)

#### What type of PR is this?

/kind feature
/area core
/area plugin

#### What this PR does / why we need it:

This PR allows plugin developers defining WebSocket endpoints in plugins.

#### Which issue(s) this PR fixes:

Fixes #5285 

#### Does this PR introduce a user-facing change?

```release-note
支持在插件中实现 WebSocket
```
pull/5787/head
John Niang 2024-04-25 16:19:14 +08:00 committed by GitHub
parent 924aad1304
commit a635881d34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 458 additions and 3 deletions

View File

@ -0,0 +1,34 @@
package run.halo.app.core.endpoint;
import org.springframework.web.reactive.socket.WebSocketHandler;
import run.halo.app.extension.GroupVersion;
/**
* Endpoint for WebSocket.
*
* @author johnniang
*/
public interface WebSocketEndpoint {
/**
* Path of the URL after group version.
*
* @return path of the URL.
*/
String urlPath();
/**
* Group and version parts of the endpoint.
*
* @return GroupVersion.
*/
GroupVersion groupVersion();
/**
* Real WebSocket handler for the endpoint.
*
* @return WebSocket handler.
*/
WebSocketHandler handler();
}

View File

@ -38,6 +38,7 @@ import org.springframework.web.reactive.result.view.ViewResolver;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import run.halo.app.console.ProxyFilter; import run.halo.app.console.ProxyFilter;
import run.halo.app.console.WebSocketRequestPredicate; import run.halo.app.console.WebSocketRequestPredicate;
import run.halo.app.core.endpoint.WebSocketHandlerMapping;
import run.halo.app.core.extension.endpoint.CustomEndpoint; import run.halo.app.core.extension.endpoint.CustomEndpoint;
import run.halo.app.core.extension.endpoint.CustomEndpointsBuilder; import run.halo.app.core.extension.endpoint.CustomEndpointsBuilder;
import run.halo.app.infra.properties.HaloProperties; import run.halo.app.infra.properties.HaloProperties;
@ -100,6 +101,13 @@ public class WebFluxConfig implements WebFluxConfigurer {
return builder.build(); return builder.build();
} }
@Bean
public WebSocketHandlerMapping webSocketHandlerMapping() {
WebSocketHandlerMapping handlerMapping = new WebSocketHandlerMapping();
handlerMapping.setOrder(-2);
return handlerMapping;
}
@Bean @Bean
RouterFunction<ServerResponse> consoleIndexRedirection() { RouterFunction<ServerResponse> consoleIndexRedirection() {
var consolePredicate = method(HttpMethod.GET) var consolePredicate = method(HttpMethod.GET)

View File

@ -1,5 +1,6 @@
package run.halo.app.console; package run.halo.app.console;
import java.util.Objects;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
public enum WebSocketUtils { public enum WebSocketUtils {
@ -8,8 +9,11 @@ public enum WebSocketUtils {
public static boolean isWebSocketUpgrade(HttpHeaders headers) { public static boolean isWebSocketUpgrade(HttpHeaders headers) {
// See io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionUtil // See io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionUtil
// .isWebsocketUpgrade for more. // .isWebsocketUpgrade for more.
var upgradeConnection = headers.getConnection().stream().map(String::toLowerCase)
.anyMatch(conn -> Objects.equals(conn, "upgrade"));
return headers.containsKey(HttpHeaders.UPGRADE) return headers.containsKey(HttpHeaders.UPGRADE)
&& headers.getConnection().contains(HttpHeaders.UPGRADE) && upgradeConnection
&& "websocket".equalsIgnoreCase(headers.getUpgrade()); && "websocket".equalsIgnoreCase(headers.getUpgrade());
} }

View File

@ -0,0 +1,16 @@
package run.halo.app.core.endpoint;
import java.util.Collection;
/**
* Interface for managing WebSocket endpoints, including registering and unregistering.
*
* @author johnniang
*/
public interface WebSocketEndpointManager {
void register(Collection<WebSocketEndpoint> endpoints);
void unregister(Collection<WebSocketEndpoint> endpoints);
}

View File

@ -0,0 +1,140 @@
package run.halo.app.core.endpoint;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.observation.ServerRequestObservationContext;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.handler.AbstractHandlerMapping;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.pattern.PathPattern;
import reactor.core.publisher.Mono;
import run.halo.app.console.WebSocketUtils;
public class WebSocketHandlerMapping extends AbstractHandlerMapping
implements WebSocketEndpointManager, InitializingBean {
private final BiMap<PathPattern, WebSocketEndpoint> endpointMap;
private final ReadWriteLock rwLock;
public WebSocketHandlerMapping() {
this.endpointMap = HashBiMap.create();
this.rwLock = new ReentrantReadWriteLock();
}
@Override
@NonNull
public Mono<WebSocketHandler> getHandlerInternal(ServerWebExchange exchange) {
var request = exchange.getRequest();
if (!HttpMethod.GET.equals(request.getMethod())
|| !WebSocketUtils.isWebSocketUpgrade(request.getHeaders())) {
// skip getting handler if the request is not a WebSocket.
return Mono.empty();
}
var lock = rwLock.readLock();
lock.lock();
try {
// Refer to org.springframework.web.reactive.handler.AbstractUrlHandlerMapping
// .lookupHandler
var pathContainer = request.getPath().pathWithinApplication();
List<PathPattern> matches = null;
for (var pattern : this.endpointMap.keySet()) {
if (pattern.matches(pathContainer)) {
if (matches == null) {
matches = new ArrayList<>();
}
matches.add(pattern);
}
}
if (matches == null) {
return Mono.empty();
}
if (matches.size() > 1) {
matches.sort(PathPattern.SPECIFICITY_COMPARATOR);
}
var pattern = matches.get(0);
exchange.getAttributes().put(BEST_MATCHING_PATTERN_ATTRIBUTE, pattern);
var handler = endpointMap.get(pattern).handler();
exchange.getAttributes().put(BEST_MATCHING_HANDLER_ATTRIBUTE, handler);
ServerRequestObservationContext.findCurrent(exchange.getAttributes())
.ifPresent(context -> context.setPathPattern(pattern.toString()));
var pathWithinMapping = pattern.extractPathWithinPattern(pathContainer);
exchange.getAttributes().put(PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, pathWithinMapping);
var matchInfo = pattern.matchAndExtract(pathContainer);
Assert.notNull(matchInfo, "Expect a match");
exchange.getAttributes()
.put(URI_TEMPLATE_VARIABLES_ATTRIBUTE, matchInfo.getUriVariables());
return Mono.just(handler);
} catch (Exception e) {
return Mono.error(e);
} finally {
lock.unlock();
}
}
@Override
public void register(Collection<WebSocketEndpoint> endpoints) {
if (CollectionUtils.isEmpty(endpoints)) {
return;
}
var lock = rwLock.writeLock();
lock.lock();
try {
endpoints.forEach(endpoint -> {
var urlPath = endpoint.urlPath();
urlPath = StringUtils.prependIfMissing(urlPath, "/");
var groupVersion = endpoint.groupVersion();
var parser = getPathPatternParser();
var pattern = parser.parse("/apis/" + groupVersion + urlPath);
endpointMap.put(pattern, endpoint);
});
} finally {
lock.unlock();
}
}
@Override
public void unregister(Collection<WebSocketEndpoint> endpoints) {
if (CollectionUtils.isEmpty(endpoints)) {
return;
}
var lock = rwLock.writeLock();
lock.lock();
try {
BiMap<WebSocketEndpoint, PathPattern> inverseMap = endpointMap.inverse();
endpoints.forEach(inverseMap::remove);
} finally {
lock.unlock();
}
}
@Override
public void afterPropertiesSet() {
var endpoints = obtainApplicationContext().getBeanProvider(WebSocketEndpoint.class)
.orderedStream()
.toList();
register(endpoints);
}
BiMap<PathPattern, WebSocketEndpoint> getEndpointMap() {
return endpointMap;
}
}

View File

@ -29,6 +29,8 @@ import org.springframework.util.StopWatch;
import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.ServerResponse; import org.springframework.web.reactive.function.server.ServerResponse;
import reactor.core.Exceptions; import reactor.core.Exceptions;
import run.halo.app.core.endpoint.WebSocketEndpoint;
import run.halo.app.core.endpoint.WebSocketEndpointManager;
import run.halo.app.extension.ReactiveExtensionClient; import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.infra.properties.HaloProperties; import run.halo.app.infra.properties.HaloProperties;
import run.halo.app.plugin.event.HaloPluginBeforeStopEvent; import run.halo.app.plugin.event.HaloPluginBeforeStopEvent;
@ -125,6 +127,10 @@ public class DefaultPluginApplicationContextFactory implements PluginApplication
beanFactory.registerSingleton("finderManager", finderManager); beanFactory.registerSingleton("finderManager", finderManager);
}); });
rootContext.getBeanProvider(WebSocketEndpointManager.class)
.ifUnique(manager -> beanFactory.registerSingleton("pluginWebSocketEndpointManager",
new PluginWebSocketEndpointManager(manager)));
rootContext.getBeanProvider(PluginRouterFunctionRegistry.class) rootContext.getBeanProvider(PluginRouterFunctionRegistry.class)
.ifUnique(registry -> { .ifUnique(registry -> {
var pluginRouterFunctionManager = new PluginRouterFunctionManager(registry); var pluginRouterFunctionManager = new PluginRouterFunctionManager(registry);
@ -219,6 +225,31 @@ public class DefaultPluginApplicationContextFactory implements PluginApplication
} }
private static class PluginWebSocketEndpointManager {
private final WebSocketEndpointManager manager;
private List<WebSocketEndpoint> endpoints;
private PluginWebSocketEndpointManager(WebSocketEndpointManager manager) {
this.manager = manager;
}
@EventListener
public void onApplicationEvent(ContextRefreshedEvent event) {
var context = event.getApplicationContext();
this.endpoints = context.getBeanProvider(WebSocketEndpoint.class)
.orderedStream()
.toList();
manager.register(this.endpoints);
}
@EventListener
public void onApplicationEvent(ContextClosedEvent ignored) {
manager.unregister(this.endpoints);
}
}
private static class PluginRouterFunctionManager { private static class PluginRouterFunctionManager {
private final PluginRouterFunctionRegistry routerFunctionRegistry; private final PluginRouterFunctionRegistry routerFunctionRegistry;

View File

@ -6,6 +6,7 @@ import java.util.Set;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.http.server.PathContainer; import org.springframework.http.server.PathContainer;
import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpRequest;
import run.halo.app.console.WebSocketUtils;
/** /**
* Creates {@link RequestInfo} from {@link ServerHttpRequest}. * Creates {@link RequestInfo} from {@link ServerHttpRequest}.
@ -215,6 +216,10 @@ public class RequestInfoFactory {
requestInfo.verb = "deletecollection"; requestInfo.verb = "deletecollection";
} }
} }
if ("list".equals(requestInfo.verb)
&& WebSocketUtils.isWebSocketUpgrade(request.getHeaders())) {
requestInfo.verb = "watch";
}
return requestInfo; return requestInfo;
} }

View File

@ -1,21 +1,114 @@
package run.halo.app.config; package run.halo.app.config;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.when;
import java.net.URI;
import java.util.List; import java.util.List;
import java.util.Set;
import org.hamcrest.core.StringStartsWith; import org.hamcrest.core.StringStartsWith;
import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient; import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.test.mock.mockito.SpyBean;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.client.ReactorNettyWebSocketClient;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;
import run.halo.app.core.endpoint.WebSocketEndpoint;
import run.halo.app.core.extension.Role;
import run.halo.app.core.extension.service.RoleService;
import run.halo.app.extension.GroupVersion;
import run.halo.app.extension.Metadata;
@SpringBootTest(properties = "halo.console.location=classpath:/console/") @SpringBootTest(properties = "halo.console.location=classpath:/console/", webEnvironment =
SpringBootTest.WebEnvironment.RANDOM_PORT)
@Import(WebFluxConfigTest.WebSocketSupportTest.TestWebSocketConfiguration.class)
@AutoConfigureWebTestClient @AutoConfigureWebTestClient
class WebFluxConfigTest { class WebFluxConfigTest {
@Autowired @Autowired
WebTestClient webClient; WebTestClient webClient;
@SpyBean
RoleService roleService;
@LocalServerPort
int port;
@Nested
class WebSocketSupportTest {
@Test
void shouldInitializeWebSocketEndpoint() {
var role = new Role();
var metadata = new Metadata();
metadata.setName("fake-role");
role.setMetadata(metadata);
role.setRules(List.of(new Role.PolicyRule.Builder()
.apiGroups("fake.halo.run")
.verbs("watch")
.resources("resources")
.build()));
when(roleService.listDependenciesFlux(Set.of("anonymous"))).thenReturn(Flux.just(role));
var webSocketClient = new ReactorNettyWebSocketClient();
webSocketClient.execute(
URI.create("ws://localhost:" + port + "/apis/fake.halo.run/v1alpha1/resources"),
session -> {
var send = session.send(Flux.just(session.textMessage("halo")));
var receive = session.receive().map(WebSocketMessage::getPayloadAsText)
.next()
.doOnNext(message -> assertEquals("HALO", message));
return send.and(receive);
})
.as(StepVerifier::create)
.verifyComplete();
}
@TestConfiguration
static class TestWebSocketConfiguration {
@Bean
WebSocketEndpoint fakeWebSocketEndpoint() {
return new FakeWebSocketEndpoint();
}
}
static class FakeWebSocketEndpoint implements WebSocketEndpoint {
@Override
public String urlPath() {
return "/resources";
}
@Override
public GroupVersion groupVersion() {
return GroupVersion.parseAPIVersion("fake.halo.run/v1alpha1");
}
@Override
public WebSocketHandler handler() {
return session -> {
var messages = session.receive()
.map(message -> session.textMessage(
message.getPayloadAsText().toUpperCase())
);
return session.send(messages).then(session.close());
};
}
}
}
@Nested @Nested
class ConsoleRequest { class ConsoleRequest {

View File

@ -0,0 +1,55 @@
package run.halo.app.core.endpoint;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketSession;
import run.halo.app.extension.GroupVersion;
@ExtendWith(MockitoExtension.class)
class WebSocketHandlerMappingTest {
@InjectMocks
WebSocketHandlerMapping handlerMapping;
@Test
void shouldRegisterEndpoint() {
var endpoint = new FakeWebSocketEndpoint();
handlerMapping.register(List.of(endpoint));
assertTrue(handlerMapping.getEndpointMap().containsValue(endpoint));
}
@Test
void shouldUnregisterEndpoint() {
var endpoint = new FakeWebSocketEndpoint();
handlerMapping.register(List.of(endpoint));
assertTrue(handlerMapping.getEndpointMap().containsValue(endpoint));
handlerMapping.unregister(List.of(endpoint));
assertFalse(handlerMapping.getEndpointMap().containsValue(endpoint));
}
static class FakeWebSocketEndpoint implements WebSocketEndpoint {
@Override
public String urlPath() {
return "/resources";
}
@Override
public GroupVersion groupVersion() {
return GroupVersion.parseAPIVersion("fake.halo.run/v1alpha1");
}
@Override
public WebSocketHandler handler() {
return WebSocketSession::close;
}
}
}

View File

@ -20,6 +20,27 @@ import org.springframework.http.HttpMethod;
*/ */
public class RequestInfoResolverTest { public class RequestInfoResolverTest {
@Test
void shouldResolveAsWatchRequestWhenRequestIsWebSocket() {
var request = method(HttpMethod.GET, "/apis/fake.halo.run/v1alpha1/fakes")
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.build();
RequestInfo requestInfo = RequestInfoFactory.INSTANCE.newRequestInfo(request);
assertThat(requestInfo).isNotNull();
assertThat(requestInfo.getVerb()).isEqualTo("watch");
}
@Test
void shouldNotResolveAsWatchRequestWhenRequestIsNotWebSocket() {
var request = method(HttpMethod.GET, "/apis/fake.halo.run/v1alpha1/fakes")
.header("Upgrade", "websocket")
.build();
RequestInfo requestInfo = RequestInfoFactory.INSTANCE.newRequestInfo(request);
assertThat(requestInfo).isNotNull();
assertThat(requestInfo.getVerb()).isEqualTo("list");
}
@Test @Test
public void requestInfoTest() { public void requestInfoTest() {
for (SuccessCase successCase : getTestRequestInfos()) { for (SuccessCase successCase : getTestRequestInfos()) {
@ -178,7 +199,6 @@ public class RequestInfoResolverTest {
} }
public record SuccessCase(String method, String url, String expectedVerb, public record SuccessCase(String method, String url, String expectedVerb,
String expectedAPIPrefix, String expectedAPIGroup, String expectedAPIPrefix, String expectedAPIGroup,
String expectedAPIVersion, String expectedNamespace, String expectedAPIVersion, String expectedNamespace,

49
docs/plugin/websocket.md Normal file
View File

@ -0,0 +1,49 @@
# 插件中如何实现 WebSocket
## 背景
> https://github.com/halo-dev/halo/issues/5285
越来越多的开发者在开发插件过程中需要及时高效获取某些资源的最新状态,但是因为在插件中不支持 WebSocket故只能选择定时轮训的方式来解决。
在插件中支持 WebSocket 的功能需要 Halo Core 来适配并制定规则以方便插件实现 WebSocket。
## 实现
插件中实现 WebSocket 的代码样例如下所示:
```java
@Component
public class MyWebSocketEndpoint implements WebSocketEndpoint {
@Override
public GroupVersion groupVersion() {
return GroupVersion.parseApiVersion("my-plugin.halowrite.com/v1alpha1");
}
@Override
public String urlPath() {
return "/resources";
}
@Override
public WebSocketHandler handler() {
return session -> {
var messages = session.receive()
.map(message -> {
var payload = message.getPayloadAsText();
return session.textMessage(payload.toUpperCase());
});
return session.send(messages);
};
}
}
```
插件安装成功后,可以通过 `/apis/my-plugin.halowrite.com/v1alpha1/resources` 进行访问。 示例如下所示:
```bash
websocat --basic-auth admin:admin ws://127.0.0.1:8090/apis/my-plugin.halowrite.com/v1alpha1/resources
```
同样地WebSocket 相关的 API 仍然受当前权限系统管理。