diff --git a/src/main/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistry.java b/src/main/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistry.java index e94a0a4a0..19399f2ee 100644 --- a/src/main/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistry.java +++ b/src/main/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistry.java @@ -1,13 +1,13 @@ package run.halo.app.plugin.resources; +import com.google.common.collect.LinkedHashMultimap; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.locks.StampedLock; import org.springframework.stereotype.Component; import org.springframework.util.Assert; -import org.springframework.util.LinkedMultiValueMap; -import org.springframework.util.MultiValueMap; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.ServerResponse; import reactor.core.publisher.Mono; @@ -27,8 +27,8 @@ public class ReverseProxyRouterFunctionRegistry { private final StampedLock lock = new StampedLock(); private final Map> proxyNameRouterFunctionRegistry = new HashMap<>(); - private final MultiValueMap pluginIdReverseProxyMap = - new LinkedMultiValueMap<>(); + private final LinkedHashMultimap pluginIdReverseProxyMap = + LinkedHashMultimap.create(); public ReverseProxyRouterFunctionRegistry( ReverseProxyRouterFunctionFactory reverseProxyRouterFunctionFactory) { @@ -47,11 +47,7 @@ public class ReverseProxyRouterFunctionRegistry { final String proxyName = reverseProxy.getMetadata().getName(); long stamp = lock.writeLock(); try { - List reverseProxyNames = pluginIdReverseProxyMap.get(pluginId); - if (reverseProxyNames != null && reverseProxyNames.contains(proxyName)) { - return Mono.empty(); - } - pluginIdReverseProxyMap.add(pluginId, proxyName); + pluginIdReverseProxyMap.put(pluginId, proxyName); // Obtain plugin application context PluginApplicationContext pluginApplicationContext = @@ -71,8 +67,8 @@ public class ReverseProxyRouterFunctionRegistry { * Only for test. */ protected int reverseProxySize(String pluginId) { - List names = pluginIdReverseProxyMap.get(pluginId); - return names == null ? 0 : names.size(); + Set names = pluginIdReverseProxyMap.get(pluginId); + return names.size(); } /** @@ -84,10 +80,7 @@ public class ReverseProxyRouterFunctionRegistry { Assert.notNull(pluginId, "The plugin id must not be null."); long stamp = lock.writeLock(); try { - List proxyNames = pluginIdReverseProxyMap.remove(pluginId); - if (proxyNames == null) { - return Mono.empty(); - } + Set proxyNames = pluginIdReverseProxyMap.removeAll(pluginId); for (String proxyName : proxyNames) { proxyNameRouterFunctionRegistry.remove(proxyName); } @@ -103,11 +96,7 @@ public class ReverseProxyRouterFunctionRegistry { public Mono remove(String pluginId, String reverseProxyName) { long stamp = lock.writeLock(); try { - List proxyNames = pluginIdReverseProxyMap.get(pluginId); - if (proxyNames == null) { - return Mono.empty(); - } - proxyNames.remove(reverseProxyName); + pluginIdReverseProxyMap.remove(pluginId, reverseProxyName); proxyNameRouterFunctionRegistry.remove(reverseProxyName); return Mono.empty(); } finally { diff --git a/src/test/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistryTest.java b/src/test/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistryTest.java index 165df92d6..0b688c570 100644 --- a/src/test/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistryTest.java +++ b/src/test/java/run/halo/app/plugin/resources/ReverseProxyRouterFunctionRegistryTest.java @@ -51,14 +51,7 @@ class ReverseProxyRouterFunctionRegistryTest { @Test void register() { - ReverseProxy mock = Mockito.mock(ReverseProxy.class); - Metadata metadata = new Metadata(); - metadata.setName("test-reverse-proxy"); - when(mock.getMetadata()).thenReturn(metadata); - RouterFunction routerFunction = request -> Mono.empty(); - - when(reverseProxyRouterFunctionFactory.create(any(), any())) - .thenReturn(Mono.just(routerFunction)); + ReverseProxy mock = getMockReverseProxy(); registry.register("fake-plugin", mock) .as(StepVerifier::create) .verifyComplete(); @@ -72,6 +65,42 @@ class ReverseProxyRouterFunctionRegistryTest { assertThat(registry.reverseProxySize("fake-plugin")).isEqualTo(1); - verify(reverseProxyRouterFunctionFactory, times(1)).create(any(), any()); + verify(reverseProxyRouterFunctionFactory, times(2)).create(any(), any()); + } + + @Test + void remove() { + ReverseProxy mock = getMockReverseProxy(); + registry.register("fake-plugin", mock) + .as(StepVerifier::create) + .verifyComplete(); + + registry.remove("fake-plugin").block(); + + assertThat(registry.reverseProxySize("fake-plugin")).isEqualTo(0); + } + + @Test + void removeByKeyValue() { + ReverseProxy mock = getMockReverseProxy(); + registry.register("fake-plugin", mock) + .as(StepVerifier::create) + .verifyComplete(); + + registry.remove("fake-plugin", "test-reverse-proxy").block(); + + assertThat(registry.reverseProxySize("fake-plugin")).isEqualTo(0); + } + + private ReverseProxy getMockReverseProxy() { + ReverseProxy mock = Mockito.mock(ReverseProxy.class); + Metadata metadata = new Metadata(); + metadata.setName("test-reverse-proxy"); + when(mock.getMetadata()).thenReturn(metadata); + RouterFunction routerFunction = request -> Mono.empty(); + + when(reverseProxyRouterFunctionFactory.create(any(), any())) + .thenReturn(Mono.just(routerFunction)); + return mock; } } \ No newline at end of file