Refactor authentication filter again

pull/146/head
johnniang 2019-05-06 19:14:51 +08:00
parent 30f21c50ab
commit 22a7311fef
3 changed files with 50 additions and 37 deletions

View File

@ -8,6 +8,7 @@ import org.springframework.web.filter.OncePerRequestFilter;
import run.halo.app.config.properties.HaloProperties; import run.halo.app.config.properties.HaloProperties;
import run.halo.app.exception.NotInstallException; import run.halo.app.exception.NotInstallException;
import run.halo.app.model.properties.PrimaryProperties; import run.halo.app.model.properties.PrimaryProperties;
import run.halo.app.security.context.SecurityContextHolder;
import run.halo.app.security.handler.AuthenticationFailureHandler; import run.halo.app.security.handler.AuthenticationFailureHandler;
import run.halo.app.security.handler.DefaultAuthenticationFailureHandler; import run.halo.app.security.handler.DefaultAuthenticationFailureHandler;
import run.halo.app.service.OptionService; import run.halo.app.service.OptionService;
@ -75,6 +76,7 @@ public abstract class AbstractAuthenticationFilter extends OncePerRequestFilter
* @param request http servlet request must not be null. * @param request http servlet request must not be null.
* @return true if the request should skip authentication failure; false otherwise * @return true if the request should skip authentication failure; false otherwise
*/ */
@Deprecated
protected boolean shouldSkipAuthenticateFailure(@NonNull HttpServletRequest request) { protected boolean shouldSkipAuthenticateFailure(@NonNull HttpServletRequest request) {
Assert.notNull(request, "Http servlet request must not be null"); Assert.notNull(request, "Http servlet request must not be null");
@ -126,6 +128,7 @@ public abstract class AbstractAuthenticationFilter extends OncePerRequestFilter
* @param url url must not be blank * @param url url must not be blank
* @param method method must not be blank * @param method method must not be blank
*/ */
@Deprecated
public void addTryAuthUrlMethodPattern(@NonNull String url, @NonNull String method) { public void addTryAuthUrlMethodPattern(@NonNull String url, @NonNull String method) {
Assert.hasText(url, "Try authenticating url must not be blank"); Assert.hasText(url, "Try authenticating url must not be blank");
Assert.hasText(method, "Try authenticating method must not be blank"); Assert.hasText(method, "Try authenticating method must not be blank");
@ -176,5 +179,19 @@ public abstract class AbstractAuthenticationFilter extends OncePerRequestFilter
getFailureHandler().onFailure(request, response, new NotInstallException("The blog has not been initialized yet!")); getFailureHandler().onFailure(request, response, new NotInstallException("The blog has not been initialized yet!"));
return; return;
} }
if (shouldNotFilter(request)) {
filterChain.doFilter(request, response);
return;
}
try {
// Do authenticate
doAuthenticate(request, response, filterChain);
} finally {
SecurityContextHolder.clearContext();
}
} }
protected abstract void doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException;
} }

View File

@ -77,47 +77,45 @@ public class AdminAuthenticationFilter extends AbstractAuthenticationFilter {
} }
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { protected void doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
super.doFilterInternal(request, response, filterChain); if (!haloProperties.isAuthEnabled()) {
if (haloProperties.isAuthEnabled()) {
// Get token from request
String token = getTokenFromRequest(request);
if (StringUtils.isBlank(token)) {
if (!shouldSkipAuthenticateFailure(request)) {
getFailureHandler().onFailure(request, response, new AuthenticationException("You have to login before accessing admin api"));
return;
}
} else {
// Get user id from cache
Optional<Integer> optionalUserId = cacheStore.getAny(SecurityUtils.buildTokenAccessKey(token), Integer.class);
if (!optionalUserId.isPresent()) {
getFailureHandler().onFailure(request, response, new AuthenticationException("The token has been expired or not exist").setErrorData(token));
return;
}
// Get the user
User user = userService.getById(optionalUserId.get());
// Build user detail
UserDetail userDetail = new UserDetail(user);
// Set security
SecurityContextHolder.setContext(new SecurityContextImpl(new AuthenticationImpl(userDetail)));
}
} else {
// Set security // Set security
userService.getCurrentUser().ifPresent(user -> userService.getCurrentUser().ifPresent(user ->
SecurityContextHolder.setContext(new SecurityContextImpl(new AuthenticationImpl(new UserDetail(user))))); SecurityContextHolder.setContext(new SecurityContextImpl(new AuthenticationImpl(new UserDetail(user)))));
// Do filter
filterChain.doFilter(request, response);
return;
} }
filterChain.doFilter(request, response); // Get token from request
String token = getTokenFromRequest(request);
// Clear context if (StringUtils.isBlank(token)) {
SecurityContextHolder.clearContext(); getFailureHandler().onFailure(request, response, new AuthenticationException("You have to login before accessing admin api"));
return;
}
// Get user id from cache
Optional<Integer> optionalUserId = cacheStore.getAny(SecurityUtils.buildTokenAccessKey(token), Integer.class);
if (!optionalUserId.isPresent()) {
getFailureHandler().onFailure(request, response, new AuthenticationException("The token has been expired or not exist").setErrorData(token));
return;
}
// Get the user
User user = userService.getById(optionalUserId.get());
// Build user detail
UserDetail userDetail = new UserDetail(user);
// Set security
SecurityContextHolder.setContext(new SecurityContextImpl(new AuthenticationImpl(userDetail)));
// Do filter
filterChain.doFilter(request, response);
} }
@Override @Override

View File

@ -39,9 +39,7 @@ public class ApiAuthenticationFilter extends AbstractAuthenticationFilter {
} }
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { protected void doAuthenticate(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
super.doFilterInternal(request, response, filterChain);
// Get token // Get token
String token = getTokenFromRequest(request); String token = getTokenFromRequest(request);