Refactor ExtensionGetter for enabling or disabling extensions (#6134)

#### What type of PR is this?

/kind improvement
/kind api-change
/area core

#### What this PR does / why we need it:

This PR refactors ExtensionGetter implementation to add a support of enabling extension point(s). Here is an example of data field of `system` config map:

```json
{
  "data": {
    "extensionPointEnabled": "{  \"search-engine\": [\"search-engine-algolia\"]}"
  },
```

> 1. The `search-engine` is a name of extension point definition.
> 2. The `search-engine-algolia` is a name of extension definition.

#### Does this PR introduce a user-facing change?

```release-note
None
```
pull/6122/head^2
John Niang 2024-06-25 15:46:45 +08:00 committed by GitHub
parent 705bd235c3
commit e4cce918f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 310 additions and 96 deletions

View File

@ -1,6 +1,7 @@
package run.halo.app.infra;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Set;
import lombok.Data;
import org.springframework.boot.convert.ApplicationConversionService;
@ -115,12 +116,10 @@ public class SystemSetting {
}
/**
* ExtensionPointEnabled key is full qualified name of extension point and value is a list of
* full qualified name of implementation.
* ExtensionPointEnabled key is metadata name of extension point and value is a list of
* extension definition names.
*/
public static class ExtensionPointEnabled extends LinkedHashMap<String, Set<String>> {
public static final ExtensionPointEnabled EMPTY = new ExtensionPointEnabled();
public static class ExtensionPointEnabled extends LinkedHashMap<String, LinkedHashSet<String>> {
public static final String GROUP = "extensionPointEnabled";

View File

@ -66,7 +66,7 @@ public class DefaultNotificationSender
Mono<ReactiveNotifier> selectNotifier(String notifierExtensionName) {
return client.fetch(ExtensionDefinition.class, notifierExtensionName)
.flatMap(extDefinition -> extensionGetter.getEnabledExtensionByDefinition(
.flatMap(extDefinition -> extensionGetter.getEnabledExtensions(
ReactiveNotifier.class)
.filter(notifier -> notifier.getClass().getName()
.equals(extDefinition.getSpec().getClassName())

View File

@ -2,17 +2,13 @@ package run.halo.app.plugin.extensionpoint;
import static run.halo.app.extension.index.query.QueryFactory.equal;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import java.util.Objects;
import lombok.RequiredArgsConstructor;
import org.pf4j.ExtensionPoint;
import org.pf4j.PluginManager;
import org.springframework.context.ApplicationContext;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.annotation.AnnotationAwareOrderComparator;
import org.springframework.data.domain.Sort;
import org.springframework.lang.NonNull;
import org.springframework.stereotype.Component;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
@ -32,88 +28,65 @@ public class DefaultExtensionGetter implements ExtensionGetter {
private final PluginManager pluginManager;
private final ApplicationContext applicationContext;
private final BeanFactory beanFactory;
private final ReactiveExtensionClient client;
@Override
public <T extends ExtensionPoint> Flux<T> getExtensions(Class<T> extensionPoint) {
return Flux.fromIterable(pluginManager.getExtensions(extensionPoint))
.concatWith(
Flux.fromStream(() -> beanFactory.getBeanProvider(extensionPoint).orderedStream())
)
.sort(new AnnotationAwareOrderComparator());
}
@Override
public <T extends ExtensionPoint> Mono<T> getEnabledExtension(Class<T> extensionPoint) {
return systemConfigFetcher.fetch(ExtensionPointEnabled.GROUP, ExtensionPointEnabled.class)
.switchIfEmpty(Mono.just(ExtensionPointEnabled.EMPTY))
.mapNotNull(enabled -> {
var implClassNames = enabled.getOrDefault(extensionPoint.getName(), Set.of());
List<T> allExtensions = getAllExtensions(extensionPoint);
if (allExtensions.isEmpty()) {
return null;
}
return allExtensions
.stream()
.filter(impl -> implClassNames.contains(impl.getClass().getName()))
.findFirst()
// Fallback to local implementation of the extension point.
// This will happen when no proper configuration is found.
.orElseGet(() -> allExtensions.get(0));
});
return getEnabledExtensions(extensionPoint).next();
}
@Override
public <T extends ExtensionPoint> Flux<T> getEnabledExtensions(Class<T> extensionPoint) {
return systemConfigFetcher.fetch(ExtensionPointEnabled.GROUP, ExtensionPointEnabled.class)
.switchIfEmpty(Mono.just(ExtensionPointEnabled.EMPTY))
.flatMapMany(enabled -> {
var implClassNames = enabled.getOrDefault(extensionPoint.getName(), Set.of());
var extensions = pluginManager.getExtensions(extensionPoint)
.stream()
.filter(impl -> implClassNames.contains(impl.getClass().getName()))
.toList();
if (extensions.isEmpty()) {
extensions = applicationContext.getBeanProvider(extensionPoint)
.orderedStream()
// we only fetch one implementation here
.limit(1)
.toList();
}
return Flux.fromIterable(extensions);
});
}
@Override
public <T extends ExtensionPoint> Flux<T> getEnabledExtensionByDefinition(
public <T extends ExtensionPoint> Flux<T> getEnabledExtensions(
Class<T> extensionPoint) {
return fetchExtensionPointDefinition(extensionPoint)
.flatMapMany(extensionPointDefinition -> {
ExtensionPointDefinition.ExtensionPointType type =
extensionPointDefinition.getSpec().getType();
.flatMapMany(epd -> {
var epdName = epd.getMetadata().getName();
var type = epd.getSpec().getType();
if (type == ExtensionPointDefinition.ExtensionPointType.SINGLETON) {
return getEnabledExtension(extensionPoint).flux();
return getEnabledExtensions(epdName, extensionPoint).take(1);
}
// TODO If the type is sortable, may need to process the returned order.
return Flux.fromIterable(getAllExtensions(extensionPoint));
return getEnabledExtensions(epdName, extensionPoint);
});
}
@Override
public <T extends ExtensionPoint> Flux<T> getExtensions(Class<T> extensionPointClass) {
var extensions = new ArrayList<>(pluginManager.getExtensions(extensionPointClass));
applicationContext.getBeanProvider(extensionPointClass)
.orderedStream()
.forEach(extensions::add);
return Flux.fromIterable(extensions);
private <T extends ExtensionPoint> Flux<T> getEnabledExtensions(String epdName,
Class<T> extensionPoint) {
return systemConfigFetcher.fetch(ExtensionPointEnabled.GROUP, ExtensionPointEnabled.class)
.switchIfEmpty(Mono.fromSupplier(ExtensionPointEnabled::new))
.flatMapMany(enabled -> {
var extensionDefNames = enabled.getOrDefault(epdName, null);
if (extensionDefNames == null) {
// get all extensions if not specified
return Flux.defer(() -> getExtensions(extensionPoint));
}
var extensions = getExtensions(extensionPoint).cache();
return Flux.fromIterable(extensionDefNames)
.concatMap(extensionDefName ->
client.fetch(ExtensionDefinition.class, extensionDefName)
)
.concatMap(extensionDef -> {
var className = extensionDef.getSpec().getClassName();
return extensions.filter(
extension -> Objects.equals(extension.getClass().getName(),
className)
);
});
});
}
@NonNull
<T extends ExtensionPoint> List<T> getAllExtensions(Class<T> extensionPoint) {
Stream<T> pluginExtsStream = pluginManager.getExtensions(extensionPoint)
.stream();
Stream<T> systemExtsStream = applicationContext.getBeanProvider(extensionPoint)
.orderedStream();
return Stream.concat(systemExtsStream, pluginExtsStream)
.sorted(new AnnotationAwareOrderComparator())
.toList();
}
Mono<ExtensionPointDefinition> fetchExtensionPointDefinition(
private Mono<ExtensionPointDefinition> fetchExtensionPointDefinition(
Class<? extends ExtensionPoint> extensionPoint) {
var listOptions = new ListOptions();
listOptions.setFieldSelector(FieldSelector.of(
@ -125,4 +98,5 @@ public class DefaultExtensionGetter implements ExtensionGetter {
)
.flatMap(list -> Mono.justOrEmpty(ListResult.first(list)));
}
}

View File

@ -15,15 +15,6 @@ public interface ExtensionGetter {
*/
<T extends ExtensionPoint> Mono<T> getEnabledExtension(Class<T> extensionPoint);
/**
* Get enabled extension list from system configuration.
*
* @param extensionPoint is extension point class.
* @return implementations of the corresponding extension point. If no configuration is found,
* we will use the default implementation from application context instead.
*/
<T extends ExtensionPoint> Flux<T> getEnabledExtensions(Class<T> extensionPoint);
/**
* Get the extension(s) according to the {@link ExtensionPointDefinition} queried
* by incoming extension point class.
@ -33,7 +24,7 @@ public interface ExtensionGetter {
* @throws IllegalArgumentException if the incoming extension point class does not have
* the {@link ExtensionPointDefinition}.
*/
<T extends ExtensionPoint> Flux<T> getEnabledExtensionByDefinition(Class<T> extensionPoint);
<T extends ExtensionPoint> Flux<T> getEnabledExtensions(Class<T> extensionPoint);
/**
* Get all extensions according to extension point class.

View File

@ -26,7 +26,7 @@ public class UsernamePasswordDelegatingAuthenticationManager
@Override
public Mono<Authentication> authenticate(Authentication authentication) {
return extensionGetter
.getEnabledExtensionByDefinition(UsernamePasswordAuthenticationManager.class)
.getEnabledExtensions(UsernamePasswordAuthenticationManager.class)
.next()
.flatMap(authenticationManager -> authenticationManager.authenticate(authentication)
.doOnError(t -> log.error(

View File

@ -84,7 +84,7 @@ public class CommentEnabledVariableProcessor extends AbstractTemplateBoundariesP
}
ExtensionGetter extensionGetter = appCtx.getBean(ExtensionGetter.class);
return extensionGetter.getEnabledExtensionByDefinition(CommentWidget.class)
return extensionGetter.getEnabledExtensions(CommentWidget.class)
.next()
.blockOptional();
}

View File

@ -158,7 +158,7 @@ public class PostPublicQueryServiceImpl implements PostPublicQueryService {
ContentWrapper wrapper) {
Assert.notNull(post, "Post name must not be null");
Assert.notNull(wrapper, "Post content must not be null");
return extensionGetter.getEnabledExtensionByDefinition(ReactivePostContentHandler.class)
return extensionGetter.getEnabledExtensions(ReactivePostContentHandler.class)
.reduce(Mono.fromSupplier(() -> ReactivePostContentHandler.PostContentContext.builder()
.post(post)
.content(wrapper.getContent())

View File

@ -58,7 +58,7 @@ public class SinglePageConversionServiceImpl implements SinglePageConversionServ
ContentWrapper wrapper) {
Assert.notNull(singlePage, "SinglePage must not be null");
Assert.notNull(wrapper, "SinglePage content must not be null");
return extensionGetter.getEnabledExtensionByDefinition(
return extensionGetter.getEnabledExtensions(
ReactiveSinglePageContentHandler.class)
.reduce(Mono.fromSupplier(() -> SinglePageContentContext.builder()
.singlePage(singlePage)

View File

@ -24,7 +24,7 @@ public class AdditionalWebFilterChainProxy implements WebFilter {
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
return extensionGetter.getEnabledExtensionByDefinition(AdditionalWebFilter.class)
return extensionGetter.getEnabledExtensions(AdditionalWebFilter.class)
.sort(AnnotationAwareOrderComparator.INSTANCE)
.cast(WebFilter.class)
.collectList()

View File

@ -0,0 +1,250 @@
package run.halo.app.plugin.extensionpoint;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static run.halo.app.infra.SystemSetting.ExtensionPointEnabled.GROUP;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.pf4j.ExtensionPoint;
import org.pf4j.PluginManager;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.core.annotation.Order;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
import run.halo.app.extension.ListOptions;
import run.halo.app.extension.ListResult;
import run.halo.app.extension.Metadata;
import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.infra.SystemConfigurableEnvironmentFetcher;
import run.halo.app.infra.SystemSetting.ExtensionPointEnabled;
import run.halo.app.plugin.extensionpoint.ExtensionPointDefinition.ExtensionPointType;
@ExtendWith(MockitoExtension.class)
class DefaultExtensionGetterTest {
@Mock
ReactiveExtensionClient client;
@Mock
PluginManager pluginManager;
@Mock
SystemConfigurableEnvironmentFetcher configFetcher;
@Mock
BeanFactory beanFactory;
@InjectMocks
DefaultExtensionGetter getter;
@Test
void shouldGetExtensionBySingletonDefinitionWhenExtensionPointEnabledSet() {
// prepare extension point definition
when(client.listBy(same(ExtensionPointDefinition.class), any(ListOptions.class), any()))
.thenReturn(Mono.fromSupplier(() -> {
var epd = createExtensionPointDefinition("fake-extension-point",
FakeExtensionPoint.class,
ExtensionPointType.SINGLETON);
return new ListResult<>(List.of(epd));
}));
when(client.fetch(ExtensionDefinition.class, "fake-extension"))
.thenReturn(Mono.fromSupplier(() -> createExtensionDefinition(
"fake-extension",
FakeExtensionPointImpl.class,
"fake-extension-point")));
when(configFetcher.fetch(GROUP, ExtensionPointEnabled.class))
.thenReturn(Mono.fromSupplier(() -> {
var extensionPointEnabled = new ExtensionPointEnabled();
extensionPointEnabled.put("fake-extension-point",
new LinkedHashSet<>(List.of("fake-extension")));
return extensionPointEnabled;
}));
@SuppressWarnings("unchecked")
ObjectProvider<FakeExtensionPoint> objectProvider = mock(ObjectProvider.class);
when(objectProvider.orderedStream())
.thenReturn(Stream.of(new FakeExtensionPointDefaultImpl()));
when(beanFactory.getBeanProvider(FakeExtensionPoint.class)).thenReturn(objectProvider);
var extensionImpl = new FakeExtensionPointImpl();
when(pluginManager.getExtensions(FakeExtensionPoint.class))
.thenReturn(List.of(extensionImpl));
getter.getEnabledExtensions(FakeExtensionPoint.class)
.as(StepVerifier::create)
.expectNext(extensionImpl)
.verifyComplete();
}
@Test
void shouldGetDefaultSingletonDefinitionWhileExtensionPointEnabledNotSet() {
when(client.listBy(same(ExtensionPointDefinition.class), any(ListOptions.class), any()))
.thenReturn(Mono.fromSupplier(() -> {
var epd = createExtensionPointDefinition("fake-extension-point",
FakeExtensionPoint.class,
ExtensionPointType.SINGLETON);
return new ListResult<>(List.of(epd));
}));
when(configFetcher.fetch(GROUP, ExtensionPointEnabled.class))
.thenReturn(Mono.empty());
@SuppressWarnings("unchecked")
ObjectProvider<FakeExtensionPoint> objectProvider = mock(ObjectProvider.class);
var extensionDefaultImpl = new FakeExtensionPointDefaultImpl();
when(objectProvider.orderedStream())
.thenReturn(Stream.of(extensionDefaultImpl));
when(beanFactory.getBeanProvider(FakeExtensionPoint.class)).thenReturn(objectProvider);
when(pluginManager.getExtensions(FakeExtensionPoint.class))
.thenReturn(List.of());
getter.getEnabledExtensions(FakeExtensionPoint.class)
.as(StepVerifier::create)
.expectNext(extensionDefaultImpl)
.verifyComplete();
}
@Test
void shouldGetMultiInstanceExtensionWhileExtensionPointEnabledSet() {
// prepare extension point definition
when(client.listBy(same(ExtensionPointDefinition.class), any(ListOptions.class), any()))
.thenReturn(Mono.fromSupplier(() -> {
var epd = createExtensionPointDefinition("fake-extension-point",
FakeExtensionPoint.class,
ExtensionPointType.MULTI_INSTANCE);
return new ListResult<>(List.of(epd));
}));
when(client.fetch(ExtensionDefinition.class, "fake-extension"))
.thenReturn(Mono.fromSupplier(() -> createExtensionDefinition(
"fake-extension",
FakeExtensionPointImpl.class,
"fake-extension-point")));
when(client.fetch(ExtensionDefinition.class, "default-fake-extension"))
.thenReturn(Mono.fromSupplier(() -> createExtensionDefinition(
"default-fake-extension",
FakeExtensionPointDefaultImpl.class,
"fake-extension-point")));
when(configFetcher.fetch(GROUP, ExtensionPointEnabled.class))
.thenReturn(Mono.fromSupplier(() -> {
var extensionPointEnabled = new ExtensionPointEnabled();
extensionPointEnabled.put("fake-extension-point",
new LinkedHashSet<>(List.of("default-fake-extension", "fake-extension")));
return extensionPointEnabled;
}));
@SuppressWarnings("unchecked")
ObjectProvider<FakeExtensionPoint> objectProvider = mock(ObjectProvider.class);
var extensionDefaultImpl = new FakeExtensionPointDefaultImpl();
when(objectProvider.orderedStream())
.thenReturn(Stream.of(extensionDefaultImpl));
when(beanFactory.getBeanProvider(FakeExtensionPoint.class)).thenReturn(objectProvider);
var extensionImpl = new FakeExtensionPointImpl();
var anotherExtensionImpl = new FakeExtensionPoint() {
};
when(pluginManager.getExtensions(FakeExtensionPoint.class))
.thenReturn(List.of(extensionImpl, anotherExtensionImpl));
getter.getEnabledExtensions(FakeExtensionPoint.class)
.as(StepVerifier::create)
// should keep the order of enabled extensions
.expectNext(extensionDefaultImpl)
.expectNext(extensionImpl)
.verifyComplete();
}
@Test
void shouldGetMultiInstanceExtensionWhileExtensionPointEnabledNotSet() {
// prepare extension point definition
when(client.listBy(same(ExtensionPointDefinition.class), any(ListOptions.class), any()))
.thenReturn(Mono.fromSupplier(() -> {
var epd = createExtensionPointDefinition("fake-extension-point",
FakeExtensionPoint.class,
ExtensionPointType.MULTI_INSTANCE);
return new ListResult<>(List.of(epd));
}));
when(configFetcher.fetch(GROUP, ExtensionPointEnabled.class))
.thenReturn(Mono.empty());
@SuppressWarnings("unchecked")
ObjectProvider<FakeExtensionPoint> objectProvider = mock(ObjectProvider.class);
var extensionDefaultImpl = new FakeExtensionPointDefaultImpl();
when(objectProvider.orderedStream())
.thenReturn(Stream.of(extensionDefaultImpl));
when(beanFactory.getBeanProvider(FakeExtensionPoint.class)).thenReturn(objectProvider);
var extensionImpl = new FakeExtensionPointImpl();
var anotherExtensionImpl = new FakeExtensionPoint() {
};
when(pluginManager.getExtensions(FakeExtensionPoint.class))
.thenReturn(List.of(extensionImpl, anotherExtensionImpl));
getter.getEnabledExtensions(FakeExtensionPoint.class)
.as(StepVerifier::create)
// should keep the order according to @Order annotation
// order is 1
.expectNext(extensionImpl)
// order is 2
.expectNext(extensionDefaultImpl)
// order is not set
.expectNext(anotherExtensionImpl)
.verifyComplete();
}
interface FakeExtensionPoint extends ExtensionPoint {
}
@Order(1)
static class FakeExtensionPointImpl implements FakeExtensionPoint {
}
@Order(2)
static class FakeExtensionPointDefaultImpl implements FakeExtensionPoint {
}
ExtensionDefinition createExtensionDefinition(String name, Class<?> clazz, String epdName) {
var ed = new ExtensionDefinition();
var metadata = new Metadata();
metadata.setName(name);
ed.setMetadata(metadata);
var spec = new ExtensionDefinition.ExtensionSpec();
spec.setClassName(clazz.getName());
spec.setExtensionPointName(epdName);
ed.setSpec(spec);
return ed;
}
ExtensionPointDefinition createExtensionPointDefinition(String name,
Class<?> clazz,
ExtensionPointType type) {
var epd = new ExtensionPointDefinition();
var metadata = new Metadata();
metadata.setName(name);
epd.setMetadata(metadata);
var spec = new ExtensionPointDefinition.ExtensionPointSpec();
spec.setClassName(clazz.getName());
spec.setType(type);
epd.setSpec(spec);
return epd;
}
}

View File

@ -75,7 +75,7 @@ class CommentElementTagProcessorTest {
.thenReturn(Mono.just(commentSetting));
when(commentSetting.getEnable()).thenReturn(true);
when(extensionGetter.getEnabledExtensionByDefinition(eq(CommentWidget.class)))
when(extensionGetter.getEnabledExtensions(eq(CommentWidget.class)))
.thenReturn(Flux.empty());
String result = templateEngine.process("commentWidget", context);
assertThat(result).isEqualTo("""
@ -88,7 +88,7 @@ class CommentElementTagProcessorTest {
</html>
""");
when(extensionGetter.getEnabledExtensionByDefinition(eq(CommentWidget.class)))
when(extensionGetter.getEnabledExtensions(eq(CommentWidget.class)))
.thenReturn(Flux.just(new DefaultCommentWidget()));
result = templateEngine.process("commentWidget", context);
assertThat(result).isEqualTo("""

View File

@ -54,7 +54,7 @@ class CommentEnabledVariableProcessorTest {
.thenReturn(Mono.just(commentSetting));
CommentWidget commentWidget = mock(CommentWidget.class);
when(extensionGetter.getEnabledExtensionByDefinition(CommentWidget.class))
when(extensionGetter.getEnabledExtensions(CommentWidget.class))
.thenReturn(Flux.just(commentWidget));
WebEngineContext webContext = mock(WebEngineContext.class);
var evaluationContext = mock(ThymeleafEvaluationContext.class);

View File

@ -35,7 +35,7 @@ class PostPublicQueryServiceImplTest {
@Test
void extendPostContent() {
when(extensionGetter.getEnabledExtensionByDefinition(
when(extensionGetter.getEnabledExtensions(
eq(ReactivePostContentHandler.class))).thenReturn(
Flux.just(new PostContentHandlerB(), new PostContentHandlerA(),
new PostContentHandlerC()));

View File

@ -36,7 +36,7 @@ class SinglePageConversionServiceImplTest {
@Test
void extendPageContent() {
when(extensionGetter.getEnabledExtensionByDefinition(
when(extensionGetter.getEnabledExtensions(
eq(ReactiveSinglePageContentHandler.class)))
.thenReturn(
Flux.just(new PageContentHandlerB(),