Support redirecting to page according to query after authenticated (#6736)

#### What type of PR is this?

/kind improvement
/area core
/milestone 2.20.0

#### What this PR does / why we need it:

This PR supports query `redirect_uri` to control where to redirect after authenticated.

#### Which issue(s) this PR fixes:

Fixes https://github.com/halo-dev/halo/issues/6720

#### Special notes for your reviewer:

Every step below needs you logging out.

1. Try to request <http://localhost:8090/console/login?redirect_uri=/xxx
2. Try to request <http://localhost:8090/login?redirect_uri=/xxx
3. Try to request <http://localhost:8090/console/posts

#### Does this PR introduce a user-facing change?

```release-note
None
```
pull/6740/head
John Niang 2024-09-30 18:37:52 +08:00 committed by GitHub
parent 8a9b954969
commit db65dd3b3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 241 additions and 27 deletions

View File

@ -22,6 +22,7 @@ import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.SecurityWebFilterChain;
import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.NegatedServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.NegatedServerWebExchangeMatcher;
@ -36,6 +37,7 @@ import run.halo.app.extension.ReactiveExtensionClient;
import run.halo.app.infra.AnonymousUserConst; import run.halo.app.infra.AnonymousUserConst;
import run.halo.app.infra.properties.HaloProperties; import run.halo.app.infra.properties.HaloProperties;
import run.halo.app.security.DefaultUserDetailService; import run.halo.app.security.DefaultUserDetailService;
import run.halo.app.security.HaloServerRequestCache;
import run.halo.app.security.authentication.CryptoService; import run.halo.app.security.authentication.CryptoService;
import run.halo.app.security.authentication.SecurityConfigurer; import run.halo.app.security.authentication.SecurityConfigurer;
import run.halo.app.security.authentication.impl.RsaKeyService; import run.halo.app.security.authentication.impl.RsaKeyService;
@ -64,7 +66,8 @@ public class WebServerSecurityConfig {
ServerSecurityContextRepository securityContextRepository, ServerSecurityContextRepository securityContextRepository,
ReactiveExtensionClient client, ReactiveExtensionClient client,
CryptoService cryptoService, CryptoService cryptoService,
HaloProperties haloProperties) { HaloProperties haloProperties,
ServerRequestCache serverRequestCache) {
var pathMatcher = pathMatchers("/**"); var pathMatcher = pathMatchers("/**");
var staticResourcesMatcher = pathMatchers(HttpMethod.GET, var staticResourcesMatcher = pathMatchers(HttpMethod.GET,
@ -134,7 +137,8 @@ public class WebServerSecurityConfig {
haloProperties.getSecurity().getReferrerOptions().getPolicy()) haloProperties.getSecurity().getReferrerOptions().getPolicy())
) )
.hsts(hstsSpec -> hstsSpec.includeSubdomains(false)) .hsts(hstsSpec -> hstsSpec.includeSubdomains(false))
); )
.requestCache(spec -> spec.requestCache(serverRequestCache));
// Integrate with other configurers separately // Integrate with other configurers separately
securityConfigurers.orderedStream() securityConfigurers.orderedStream()
@ -142,6 +146,11 @@ public class WebServerSecurityConfig {
return http.build(); return http.build();
} }
@Bean
ServerRequestCache serverRequestCache() {
return new HaloServerRequestCache();
}
@Bean @Bean
ServerSecurityContextRepository securityContextRepository() { ServerSecurityContextRepository securityContextRepository() {
return new WebSessionServerSecurityContextRepository(); return new WebSessionServerSecurityContextRepository();

View File

@ -5,6 +5,7 @@ import org.springframework.cache.CacheManager;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext; import org.springframework.context.support.GenericApplicationContext;
import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import run.halo.app.content.PostContentService; import run.halo.app.content.PostContentService;
import run.halo.app.core.extension.service.AttachmentService; import run.halo.app.core.extension.service.AttachmentService;
import run.halo.app.extension.DefaultSchemeManager; import run.halo.app.extension.DefaultSchemeManager;
@ -79,6 +80,12 @@ public enum SharedApplicationContextFactory {
.ifUnique(rateLimiterRegistry -> .ifUnique(rateLimiterRegistry ->
beanFactory.registerSingleton("rateLimiterRegistry", rateLimiterRegistry) beanFactory.registerSingleton("rateLimiterRegistry", rateLimiterRegistry)
); );
// Authentication plugins may need this RequestCache to handle successful login redirect
rootContext.getBeanProvider(ServerRequestCache.class)
.ifUnique(serverRequestCache ->
beanFactory.registerSingleton("serverRequestCache", serverRequestCache)
);
// TODO add more shared instance here // TODO add more shared instance here
sharedContext.refresh(); sharedContext.refresh();

View File

@ -5,6 +5,7 @@ import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint; import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
@ -30,9 +31,11 @@ public class DefaultServerAuthenticationEntryPoint implements ServerAuthenticati
private final RedirectServerAuthenticationEntryPoint redirectEntryPoint; private final RedirectServerAuthenticationEntryPoint redirectEntryPoint;
public DefaultServerAuthenticationEntryPoint() { public DefaultServerAuthenticationEntryPoint(ServerRequestCache serverRequestCache) {
this.redirectEntryPoint = var entryPoint =
new RedirectServerAuthenticationEntryPoint("/login?authentication_required"); new RedirectServerAuthenticationEntryPoint("/login?authentication_required");
entryPoint.setRequestCache(serverRequestCache);
this.redirectEntryPoint = entryPoint;
} }
@Override @Override
@ -40,7 +43,7 @@ public class DefaultServerAuthenticationEntryPoint implements ServerAuthenticati
return xhrMatcher.matches(exchange) return xhrMatcher.matches(exchange)
.filter(MatchResult::isMatch) .filter(MatchResult::isMatch)
.switchIfEmpty( .switchIfEmpty(
Mono.defer(() -> this.redirectEntryPoint.commence(exchange, ex)).then(Mono.empty()) Mono.defer(() -> this.redirectEntryPoint.commence(exchange, ex).then(Mono.empty()))
) )
.flatMap(match -> Mono.defer( .flatMap(match -> Mono.defer(
() -> { () -> {

View File

@ -10,6 +10,7 @@ import org.springframework.security.web.server.DelegatingServerAuthenticationEnt
import org.springframework.security.web.server.authentication.AuthenticationConverterServerWebExchangeMatcher; import org.springframework.security.web.server.authentication.AuthenticationConverterServerWebExchangeMatcher;
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
import org.springframework.security.web.server.authorization.ServerWebExchangeDelegatingServerAccessDeniedHandler; import org.springframework.security.web.server.authorization.ServerWebExchangeDelegatingServerAccessDeniedHandler;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -24,10 +25,14 @@ public class ExceptionSecurityConfigurer implements SecurityConfigurer {
private final ServerResponse.Context context; private final ServerResponse.Context context;
private final ServerRequestCache serverRequestCache;
public ExceptionSecurityConfigurer(MessageSource messageSource, public ExceptionSecurityConfigurer(MessageSource messageSource,
ServerResponse.Context context) { ServerResponse.Context context,
ServerRequestCache serverRequestCache) {
this.messageSource = messageSource; this.messageSource = messageSource;
this.context = context; this.context = context;
this.serverRequestCache = serverRequestCache;
} }
@Override @Override
@ -59,7 +64,7 @@ public class ExceptionSecurityConfigurer implements SecurityConfigurer {
)); ));
entryPoints.add(new DelegatingServerAuthenticationEntryPoint.DelegateEntry( entryPoints.add(new DelegatingServerAuthenticationEntryPoint.DelegateEntry(
exchange -> ServerWebExchangeMatcher.MatchResult.match(), exchange -> ServerWebExchangeMatcher.MatchResult.match(),
new DefaultServerAuthenticationEntryPoint() new DefaultServerAuthenticationEntryPoint(serverRequestCache)
)); ));
exception.authenticationEntryPoint( exception.authenticationEntryPoint(

View File

@ -0,0 +1,86 @@
package run.halo.app.security;
import static org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers.pathMatchers;
import java.net.URI;
import java.util.Collections;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.server.RequestPath;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
import org.springframework.security.web.server.util.matcher.AndServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.NegatedServerWebExchangeMatcher;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebSession;
import reactor.core.publisher.Mono;
/**
* Halo server request cache implementation for saving redirect URI from query.
*
* @author johnniang
*/
public class HaloServerRequestCache extends WebSessionServerRequestCache {
/**
* Currently, we have no idea to customize the sessionAttributeName in
* WebSessionServerRequestCache, so we have to copy the attr into here.
*/
private static final String DEFAULT_SAVED_REQUEST_ATTR = "SPRING_SECURITY_SAVED_REQUEST";
private static final String REDIRECT_URI_QUERY = "redirect_uri";
private final String sessionAttrName = DEFAULT_SAVED_REQUEST_ATTR;
public HaloServerRequestCache() {
super();
setSaveRequestMatcher(createDefaultRequestMatcher());
}
@Override
public Mono<Void> saveRequest(ServerWebExchange exchange) {
var redirectUriQuery = exchange.getRequest().getQueryParams().getFirst(REDIRECT_URI_QUERY);
if (StringUtils.isNotBlank(redirectUriQuery)) {
var redirectUri = URI.create(redirectUriQuery);
return saveRedirectUri(exchange, redirectUri);
}
return super.saveRequest(exchange);
}
@Override
public Mono<URI> getRedirectUri(ServerWebExchange exchange) {
return super.getRedirectUri(exchange);
}
@Override
public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) {
return super.removeMatchingRequest(exchange);
}
private Mono<Void> saveRedirectUri(ServerWebExchange exchange, URI redirectUri) {
var requestPath = exchange.getRequest().getPath();
var redirectPath = RequestPath.parse(redirectUri, requestPath.contextPath().value());
var query = redirectUri.getRawQuery();
var finalRedirect =
redirectPath.pathWithinApplication() + (query == null ? "" : "?" + query);
return exchange.getSession()
.map(WebSession::getAttributes)
.doOnNext(attributes -> attributes.put(this.sessionAttrName, finalRedirect))
.then();
}
private static ServerWebExchangeMatcher createDefaultRequestMatcher() {
var get = pathMatchers(HttpMethod.GET, "/**");
var notFavicon = new NegatedServerWebExchangeMatcher(
pathMatchers(
"/favicon.*", "/login/**", "/signup/**", "/password-reset/**", "/challenges/**"
));
var html = new MediaTypeServerWebExchangeMatcher(MediaType.TEXT_HTML);
html.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
return new AndServerWebExchangeMatcher(get, notFavicon, html);
}
}

View File

@ -5,6 +5,7 @@ import org.springframework.security.core.Authentication;
import org.springframework.security.web.server.WebFilterExchange; import org.springframework.security.web.server.WebFilterExchange;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import run.halo.app.security.LoginHandlerEnhancer; import run.halo.app.security.LoginHandlerEnhancer;
@ -13,11 +14,14 @@ public class TotpAuthenticationSuccessHandler implements ServerAuthenticationSuc
private final LoginHandlerEnhancer loginEnhancer; private final LoginHandlerEnhancer loginEnhancer;
private final ServerAuthenticationSuccessHandler successHandler = private final ServerAuthenticationSuccessHandler successHandler;
new RedirectServerAuthenticationSuccessHandler("/uc");
public TotpAuthenticationSuccessHandler(LoginHandlerEnhancer loginEnhancer) { public TotpAuthenticationSuccessHandler(LoginHandlerEnhancer loginEnhancer,
ServerRequestCache serverRequestCache) {
this.loginEnhancer = loginEnhancer; this.loginEnhancer = loginEnhancer;
var successHandler = new RedirectServerAuthenticationSuccessHandler("/uc");
successHandler.setRequestCache(serverRequestCache);
this.successHandler = successHandler;
} }
@Override @Override

View File

@ -8,6 +8,7 @@ import org.springframework.security.config.web.server.ServerHttpSecurity;
import org.springframework.security.web.server.authentication.AuthenticationWebFilter; import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationFailureHandler;
import org.springframework.security.web.server.context.ServerSecurityContextRepository; import org.springframework.security.web.server.context.ServerSecurityContextRepository;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import run.halo.app.security.LoginHandlerEnhancer; import run.halo.app.security.LoginHandlerEnhancer;
import run.halo.app.security.authentication.SecurityConfigurer; import run.halo.app.security.authentication.SecurityConfigurer;
@ -24,13 +25,17 @@ public class TwoFactorAuthSecurityConfigurer implements SecurityConfigurer {
private final LoginHandlerEnhancer loginHandlerEnhancer; private final LoginHandlerEnhancer loginHandlerEnhancer;
private final ServerRequestCache serverRequestCache;
public TwoFactorAuthSecurityConfigurer( public TwoFactorAuthSecurityConfigurer(
ServerSecurityContextRepository securityContextRepository, ServerSecurityContextRepository securityContextRepository,
TotpAuthService totpAuthService, LoginHandlerEnhancer loginHandlerEnhancer TotpAuthService totpAuthService, LoginHandlerEnhancer loginHandlerEnhancer,
ServerRequestCache serverRequestCache
) { ) {
this.securityContextRepository = securityContextRepository; this.securityContextRepository = securityContextRepository;
this.totpAuthService = totpAuthService; this.totpAuthService = totpAuthService;
this.loginHandlerEnhancer = loginHandlerEnhancer; this.loginHandlerEnhancer = loginHandlerEnhancer;
this.serverRequestCache = serverRequestCache;
} }
@Override @Override
@ -43,7 +48,7 @@ public class TwoFactorAuthSecurityConfigurer implements SecurityConfigurer {
filter.setSecurityContextRepository(securityContextRepository); filter.setSecurityContextRepository(securityContextRepository);
filter.setServerAuthenticationConverter(new TotpCodeAuthenticationConverter()); filter.setServerAuthenticationConverter(new TotpCodeAuthenticationConverter());
filter.setAuthenticationSuccessHandler( filter.setAuthenticationSuccessHandler(
new TotpAuthenticationSuccessHandler(loginHandlerEnhancer) new TotpAuthenticationSuccessHandler(loginHandlerEnhancer, serverRequestCache)
); );
filter.setAuthenticationFailureHandler( filter.setAuthenticationFailureHandler(
new RedirectServerAuthenticationFailureHandler("/challenges/two-factor/totp?error") new RedirectServerAuthenticationFailureHandler("/challenges/two-factor/totp?error")

View File

@ -3,8 +3,9 @@ package run.halo.app.security.authentication.twofactor;
import java.net.URI; import java.net.URI;
import org.springframework.context.MessageSource; import org.springframework.context.MessageSource;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.server.DefaultServerRedirectStrategy;
import org.springframework.security.web.server.ServerAuthenticationEntryPoint; import org.springframework.security.web.server.ServerAuthenticationEntryPoint;
import org.springframework.security.web.server.authentication.RedirectServerAuthenticationEntryPoint; import org.springframework.security.web.server.ServerRedirectStrategy;
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
import org.springframework.web.reactive.function.server.ServerResponse; import org.springframework.web.reactive.function.server.ServerResponse;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
@ -18,10 +19,13 @@ public class TwoFactorAuthenticationEntryPoint implements ServerAuthenticationEn
.flatMap(a -> ServerWebExchangeMatcher.MatchResult.match()) .flatMap(a -> ServerWebExchangeMatcher.MatchResult.match())
.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch()); .switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());
private static final String REDIRECT_LOCATION = "/challenges/two-factor/totp"; private static final URI REDIRECT_LOCATION = URI.create("/challenges/two-factor/totp");
private final RedirectServerAuthenticationEntryPoint redirectEntryPoint = /**
new RedirectServerAuthenticationEntryPoint(REDIRECT_LOCATION); * Because we don't want to cache the request before redirecting to the 2FA page,
* ServerRedirectStrategy is used to redirect the request.
*/
private final ServerRedirectStrategy redirectStrategy = new DefaultServerRedirectStrategy();
private final MessageSource messageSource; private final MessageSource messageSource;
@ -45,10 +49,12 @@ public class TwoFactorAuthenticationEntryPoint implements ServerAuthenticationEn
public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException ex) { public Mono<Void> commence(ServerWebExchange exchange, AuthenticationException ex) {
return XHR_MATCHER.matches(exchange) return XHR_MATCHER.matches(exchange)
.filter(ServerWebExchangeMatcher.MatchResult::isMatch) .filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.switchIfEmpty(redirectEntryPoint.commence(exchange, ex).then(Mono.empty())) .switchIfEmpty(
redirectStrategy.sendRedirect(exchange, REDIRECT_LOCATION).then(Mono.empty())
)
.flatMap(isXhr -> { .flatMap(isXhr -> {
var errorResponse = Exceptions.createErrorResponse( var errorResponse = Exceptions.createErrorResponse(
new TwoFactorAuthRequiredException(URI.create(REDIRECT_LOCATION)), new TwoFactorAuthRequiredException(REDIRECT_LOCATION),
null, exchange, messageSource); null, exchange, messageSource);
return ServerResponse.status(errorResponse.getStatusCode()) return ServerResponse.status(errorResponse.getStatusCode())
.bodyValue(errorResponse.getBody()) .bodyValue(errorResponse.getBody())

View File

@ -8,6 +8,7 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.RouterFunction;
import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.RouterFunctions;
@ -18,6 +19,7 @@ import run.halo.app.core.extension.AuthProvider;
import run.halo.app.infra.actuator.GlobalInfoService; import run.halo.app.infra.actuator.GlobalInfoService;
import run.halo.app.plugin.PluginConst; import run.halo.app.plugin.PluginConst;
import run.halo.app.security.AuthProviderService; import run.halo.app.security.AuthProviderService;
import run.halo.app.security.HaloServerRequestCache;
import run.halo.app.security.authentication.CryptoService; import run.halo.app.security.authentication.CryptoService;
/** /**
@ -35,6 +37,8 @@ class PreAuthLoginEndpoint {
private final AuthProviderService authProviderService; private final AuthProviderService authProviderService;
private final ServerRequestCache serverRequestCache = new HaloServerRequestCache();
PreAuthLoginEndpoint(CryptoService cryptoService, GlobalInfoService globalInfoService, PreAuthLoginEndpoint(CryptoService cryptoService, GlobalInfoService globalInfoService,
AuthProviderService authProviderService) { AuthProviderService authProviderService) {
this.cryptoService = cryptoService; this.cryptoService = cryptoService;
@ -46,6 +50,7 @@ class PreAuthLoginEndpoint {
RouterFunction<ServerResponse> preAuthLoginEndpoints() { RouterFunction<ServerResponse> preAuthLoginEndpoints() {
return RouterFunctions.nest(path("/login"), RouterFunctions.route() return RouterFunctions.nest(path("/login"), RouterFunctions.route()
.GET("", request -> { .GET("", request -> {
// TODO get redirect URI and cache it
var exchange = request.exchange(); var exchange = request.exchange();
var contextPath = exchange.getRequest().getPath().contextPath().value(); var contextPath = exchange.getRequest().getPath().contextPath().value();
var publicKey = cryptoService.readPublicKey() var publicKey = cryptoService.readPublicKey()
@ -78,7 +83,8 @@ class PreAuthLoginEndpoint {
.filter(ap -> !Objects.equals(loginMethod, ap.getMetadata().getName())) .filter(ap -> !Objects.equals(loginMethod, ap.getMetadata().getName()))
.cache(); .cache();
return ServerResponse.ok().render("login", Map.of( return serverRequestCache.saveRequest(exchange).then(Mono.defer(() ->
ServerResponse.ok().render("login", Map.of(
"action", contextPath + "/login", "action", contextPath + "/login",
"publicKey", publicKey, "publicKey", publicKey,
"globalInfo", globalInfo, "globalInfo", globalInfo,
@ -87,6 +93,7 @@ class PreAuthLoginEndpoint {
"socialAuthProviders", socialAuthProviders, "socialAuthProviders", socialAuthProviders,
"formAuthProviders", formAuthProviders "formAuthProviders", formAuthProviders
// TODO Add more models here // TODO Add more models here
))
)); ));
}) })
.build()); .build());

View File

@ -7,15 +7,22 @@ import java.net.URI;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.InjectMocks; import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException; import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier; import reactor.test.StepVerifier;
@ExtendWith(MockitoExtension.class) @ExtendWith(MockitoExtension.class)
class DefaultServerAuthenticationEntryPointTest { class DefaultServerAuthenticationEntryPointTest {
@Mock
ServerRequestCache requestCache;
@InjectMocks @InjectMocks
DefaultServerAuthenticationEntryPoint entryPoint; DefaultServerAuthenticationEntryPoint entryPoint;
@ -40,6 +47,7 @@ class DefaultServerAuthenticationEntryPointTest {
.build(); .build();
var mockExchange = MockServerWebExchange.builder(mockReq) var mockExchange = MockServerWebExchange.builder(mockReq)
.build(); .build();
Mockito.when(requestCache.saveRequest(mockExchange)).thenReturn(Mono.empty());
var commenceMono = entryPoint.commence(mockExchange, var commenceMono = entryPoint.commence(mockExchange,
new AuthenticationCredentialsNotFoundException("Not Found")); new AuthenticationCredentialsNotFoundException("Not Found"));
StepVerifier.create(commenceMono) StepVerifier.create(commenceMono)

View File

@ -0,0 +1,74 @@
package run.halo.app.security;
import java.net.URI;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.http.MediaType;
import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
import org.springframework.mock.web.server.MockServerWebExchange;
import org.springframework.web.server.session.DefaultWebSessionManager;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;
class HaloServerRequestCacheTest {
HaloServerRequestCache requestCache;
@BeforeEach
void setUp() {
requestCache = new HaloServerRequestCache();
}
@Test
void shouldNotSaveIfPageNotCacheable() {
var mockExchange =
MockServerWebExchange.from(MockServerHttpRequest.get("/login"));
requestCache.saveRequest(mockExchange)
.then(requestCache.getRedirectUri(mockExchange))
.as(StepVerifier::create)
.verifyComplete();
}
@Test
void shouldSaveIfPageCacheable() {
var mockExchange = MockServerWebExchange.from(
MockServerHttpRequest.get("/archives").accept(MediaType.TEXT_HTML)
);
requestCache.saveRequest(mockExchange)
.then(requestCache.getRedirectUri(mockExchange))
.as(StepVerifier::create)
.expectNext(URI.create("/archives"))
.verifyComplete();
}
@Test
void shouldSaveIfQueryPresent() {
var mockExchange =
MockServerWebExchange.from(MockServerHttpRequest.get("/login?redirect_uri=/halo?q=v"));
requestCache.saveRequest(mockExchange)
.then(requestCache.getRedirectUri(mockExchange))
.as(StepVerifier::create)
.expectNext(URI.create("/halo?q=v"));
}
@Test
void shouldRemoveIfRedirectUriFound() {
var sessionManager = new DefaultWebSessionManager();
var mockExchange =
MockServerWebExchange.builder(MockServerHttpRequest.get("/login?redirect_uri=/halo"))
.sessionManager(sessionManager)
.build();
var removeExchange = mockExchange.mutate()
.request(builder -> builder.uri(URI.create("/halo")))
.build();
requestCache.saveRequest(mockExchange)
.then(Mono.defer(() -> requestCache.removeMatchingRequest(removeExchange)))
.as(StepVerifier::create)
.assertNext(request -> {
Assertions.assertEquals(URI.create("/halo"), request.getURI());
})
.verifyComplete();
}
}