undertow下WebSocket与mcp不兼容问题(临时处理)

pull/8757/head
chenrui 2025-08-22 16:32:09 +08:00
parent 3d4e6ba940
commit 016ab6e5ba
4 changed files with 84 additions and 19 deletions

View File

@ -16,10 +16,10 @@ public class WebSocketConfig {
* ServerEndpointExporter * ServerEndpointExporter
* bean使@ServerEndpointWebsocket endpoint * bean使@ServerEndpointWebsocket endpoint
*/ */
@Bean // @Bean
public ServerEndpointExporter serverEndpointExporter() { // public ServerEndpointExporter serverEndpointExporter() {
return new ServerEndpointExporter(); // return new ServerEndpointExporter();
} // }
@Bean @Bean
public WebsocketFilter websocketFilter(){ public WebsocketFilter websocketFilter(){

View File

@ -1,10 +1,9 @@
package org.jeecg.config.security; package org.jeecg.config.security;
import io.undertow.servlet.spec.HttpServletRequestImpl;
import io.undertow.util.HttpString;
import jakarta.servlet.FilterChain; import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -12,34 +11,100 @@ import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.filter.OncePerRequestFilter;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Set;
/** /**
* querytoken * querytoken
* @author eightmonth * Undertow Tomcat ClassCastException
* @date 2024/7/3 14:04 *
*
* 1. Authorization Bearer <token>
* 2. token
* 3. X-Access-Token
*
* token Authorization /
*/ */
@Component @Component
@Order(value = Integer.MIN_VALUE) @Order(value = Integer.MIN_VALUE)
public class CopyTokenFilter extends OncePerRequestFilter { public class CopyTokenFilter extends OncePerRequestFilter {
@Override @Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
// 以下为undertow定制代码如切换其它servlet容器需要同步更换 // 容器无关实现:根据 header/参数提取 token并以 Authorization 注入
HttpServletRequestImpl undertowRequest = (HttpServletRequestImpl) request; String tokenHeader = request.getHeader("Authorization");
String token = request.getHeader("Authorization"); String candidate = null;
if (StringUtils.hasText(token)) { if (StringUtils.hasText(tokenHeader)) {
undertowRequest.getExchange().getRequestHeaders().remove("Authorization"); String trimmed = tokenHeader.trim();
undertowRequest.getExchange().getRequestHeaders().add(new HttpString("Authorization"), "bearer " + token); if (startsWithIgnoreCase(trimmed, "Bearer ")) {
candidate = trimmed;
} else if (!trimmed.contains(" ")) { // 纯 token无空格视为需要规范化
candidate = trimmed;
} // 其他认证方案(如 Basic ...)保持不处理
} else { } else {
String bearerToken = request.getParameter("token"); String bearerToken = request.getParameter("token");
String headerBearerToken = request.getHeader("X-Access-Token"); String headerBearerToken = request.getHeader("X-Access-Token");
if (StringUtils.hasText(bearerToken)) { if (StringUtils.hasText(bearerToken)) {
undertowRequest.getExchange().getRequestHeaders().add(new HttpString("Authorization"), "bearer " + bearerToken); candidate = bearerToken.trim();
} else if (StringUtils.hasText(headerBearerToken)) { } else if (StringUtils.hasText(headerBearerToken)) {
undertowRequest.getExchange().getRequestHeaders().add(new HttpString("Authorization"), "bearer " + headerBearerToken); candidate = headerBearerToken.trim();
} }
} }
filterChain.doFilter(undertowRequest, response);
}
if (StringUtils.hasText(candidate)) {
final String authValue = startsWithIgnoreCase(candidate, "Bearer ") ? candidate : ("Bearer " + candidate);
HttpServletRequest wrapped = new AuthorizationHeaderRequestWrapper(request, authValue);
filterChain.doFilter(wrapped, response);
return;
}
filterChain.doFilter(request, response);
}
private boolean startsWithIgnoreCase(String str, String prefix) {
if (str == null || prefix == null) {
return false;
}
if (prefix.length() > str.length()) {
return false;
}
return str.regionMatches(true, 0, prefix, 0, prefix.length());
}
private static class AuthorizationHeaderRequestWrapper extends HttpServletRequestWrapper {
private final String authorization;
AuthorizationHeaderRequestWrapper(HttpServletRequest request, String authorization) {
super(request);
this.authorization = authorization;
}
@Override
public String getHeader(String name) {
if ("Authorization".equalsIgnoreCase(name)) {
return authorization;
}
return super.getHeader(name);
}
@Override
public Enumeration<String> getHeaders(String name) {
if ("Authorization".equalsIgnoreCase(name)) {
return Collections.enumeration(Collections.singletonList(authorization));
}
return super.getHeaders(name);
}
@Override
public Enumeration<String> getHeaderNames() {
Set<String> names = new LinkedHashSet<>();
Enumeration<String> e = super.getHeaderNames();
while (e.hasMoreElements()) {
names.add(e.nextElement());
}
names.add("Authorization");
return Collections.enumeration(names);
}
}
} }

View File

@ -22,7 +22,7 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
@Component @Component
@ServerEndpoint("/vxeSocket/{userId}/{pageId}") //@ServerEndpoint("/vxeSocket/{userId}/{pageId}")
public class VxeSocket { public class VxeSocket {
/** /**

View File

@ -21,7 +21,7 @@ import lombok.extern.slf4j.Slf4j;
*/ */
@Component @Component
@Slf4j @Slf4j
@ServerEndpoint("/websocket/{userId}") //@ServerEndpoint("/websocket/{userId}")
public class WebSocket { public class WebSocket {
/**线程安全Map*/ /**线程安全Map*/