From 016ab6e5ba9931cb89ff2034ab66dfc18b2d6b34 Mon Sep 17 00:00:00 2001 From: chenrui Date: Fri, 22 Aug 2025 16:32:09 +0800 Subject: [PATCH] =?UTF-8?q?undertow=E4=B8=8BWebSocket=E4=B8=8Emcp=E4=B8=8D?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E9=97=AE=E9=A2=98(=E4=B8=B4=E6=97=B6?= =?UTF-8?q?=E5=A4=84=E7=90=86)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/jeecg/config/WebSocketConfig.java | 8 +- .../config/security/CopyTokenFilter.java | 91 ++++++++++++++++--- .../demo/mock/vxe/websocket/VxeSocket.java | 2 +- .../modules/message/websocket/WebSocket.java | 2 +- 4 files changed, 84 insertions(+), 19 deletions(-) diff --git a/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/WebSocketConfig.java b/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/WebSocketConfig.java index 801544b2c..013e62264 100644 --- a/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/WebSocketConfig.java +++ b/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/WebSocketConfig.java @@ -16,10 +16,10 @@ public class WebSocketConfig { * 注入ServerEndpointExporter, * 这个bean会自动注册使用了@ServerEndpoint注解声明的Websocket endpoint */ - @Bean - public ServerEndpointExporter serverEndpointExporter() { - return new ServerEndpointExporter(); - } +// @Bean +// public ServerEndpointExporter serverEndpointExporter() { +// return new ServerEndpointExporter(); +// } @Bean public WebsocketFilter websocketFilter(){ diff --git a/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/security/CopyTokenFilter.java b/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/security/CopyTokenFilter.java index de18fbeb0..c8f6b9614 100644 --- a/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/security/CopyTokenFilter.java +++ b/jeecg-boot/jeecg-boot-base-core/src/main/java/org/jeecg/config/security/CopyTokenFilter.java @@ -1,10 +1,9 @@ package org.jeecg.config.security; -import io.undertow.servlet.spec.HttpServletRequestImpl; -import io.undertow.util.HttpString; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; import jakarta.servlet.http.HttpServletResponse; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; @@ -12,34 +11,100 @@ import org.springframework.util.StringUtils; import org.springframework.web.filter.OncePerRequestFilter; import java.io.IOException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashSet; +import java.util.Set; /** * 复制仪盘表请求query体携带的token - * @author eightmonth - * @date 2024/7/3 14:04 + * 注意:改为容器无关实现,避免 Undertow 专有类型转换导致在 Tomcat 下 ClassCastException。 + * + * 来源优先级: + * 1. Authorization 头(若存在则规范为 Bearer 格式) + * 2. 查询参数 token + * 3. 自定义头 X-Access-Token + * + * 若最终获得 token,则通过请求包装器注入 Authorization 头,保持对下游过滤器/安全链可见。 */ @Component @Order(value = Integer.MIN_VALUE) public class CopyTokenFilter extends OncePerRequestFilter { @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - // 以下为undertow定制代码,如切换其它servlet容器,需要同步更换 - HttpServletRequestImpl undertowRequest = (HttpServletRequestImpl) request; - String token = request.getHeader("Authorization"); - if (StringUtils.hasText(token)) { - undertowRequest.getExchange().getRequestHeaders().remove("Authorization"); - undertowRequest.getExchange().getRequestHeaders().add(new HttpString("Authorization"), "bearer " + token); + // 容器无关实现:根据 header/参数提取 token,并以 Authorization 注入 + String tokenHeader = request.getHeader("Authorization"); + String candidate = null; + if (StringUtils.hasText(tokenHeader)) { + String trimmed = tokenHeader.trim(); + if (startsWithIgnoreCase(trimmed, "Bearer ")) { + candidate = trimmed; + } else if (!trimmed.contains(" ")) { // 纯 token,无空格,视为需要规范化 + candidate = trimmed; + } // 其他认证方案(如 Basic ...)保持不处理 } else { String bearerToken = request.getParameter("token"); String headerBearerToken = request.getHeader("X-Access-Token"); if (StringUtils.hasText(bearerToken)) { - undertowRequest.getExchange().getRequestHeaders().add(new HttpString("Authorization"), "bearer " + bearerToken); + candidate = bearerToken.trim(); } else if (StringUtils.hasText(headerBearerToken)) { - undertowRequest.getExchange().getRequestHeaders().add(new HttpString("Authorization"), "bearer " + headerBearerToken); + candidate = headerBearerToken.trim(); } } - filterChain.doFilter(undertowRequest, response); + + if (StringUtils.hasText(candidate)) { + final String authValue = startsWithIgnoreCase(candidate, "Bearer ") ? candidate : ("Bearer " + candidate); + HttpServletRequest wrapped = new AuthorizationHeaderRequestWrapper(request, authValue); + filterChain.doFilter(wrapped, response); + return; + } + + filterChain.doFilter(request, response); } + private boolean startsWithIgnoreCase(String str, String prefix) { + if (str == null || prefix == null) { + return false; + } + if (prefix.length() > str.length()) { + return false; + } + return str.regionMatches(true, 0, prefix, 0, prefix.length()); + } + private static class AuthorizationHeaderRequestWrapper extends HttpServletRequestWrapper { + private final String authorization; + + AuthorizationHeaderRequestWrapper(HttpServletRequest request, String authorization) { + super(request); + this.authorization = authorization; + } + + @Override + public String getHeader(String name) { + if ("Authorization".equalsIgnoreCase(name)) { + return authorization; + } + return super.getHeader(name); + } + + @Override + public Enumeration getHeaders(String name) { + if ("Authorization".equalsIgnoreCase(name)) { + return Collections.enumeration(Collections.singletonList(authorization)); + } + return super.getHeaders(name); + } + + @Override + public Enumeration getHeaderNames() { + Set names = new LinkedHashSet<>(); + Enumeration e = super.getHeaderNames(); + while (e.hasMoreElements()) { + names.add(e.nextElement()); + } + names.add("Authorization"); + return Collections.enumeration(names); + } + } } diff --git a/jeecg-boot/jeecg-boot-module/jeecg-module-demo/src/main/java/org/jeecg/modules/demo/mock/vxe/websocket/VxeSocket.java b/jeecg-boot/jeecg-boot-module/jeecg-module-demo/src/main/java/org/jeecg/modules/demo/mock/vxe/websocket/VxeSocket.java index 017e8b66e..6643ca9c5 100644 --- a/jeecg-boot/jeecg-boot-module/jeecg-module-demo/src/main/java/org/jeecg/modules/demo/mock/vxe/websocket/VxeSocket.java +++ b/jeecg-boot/jeecg-boot-module/jeecg-module-demo/src/main/java/org/jeecg/modules/demo/mock/vxe/websocket/VxeSocket.java @@ -22,7 +22,7 @@ import java.util.Map; */ @Slf4j @Component -@ServerEndpoint("/vxeSocket/{userId}/{pageId}") +//@ServerEndpoint("/vxeSocket/{userId}/{pageId}") public class VxeSocket { /** diff --git a/jeecg-boot/jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/message/websocket/WebSocket.java b/jeecg-boot/jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/message/websocket/WebSocket.java index a6cc51d65..6848a6f11 100644 --- a/jeecg-boot/jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/message/websocket/WebSocket.java +++ b/jeecg-boot/jeecg-module-system/jeecg-system-biz/src/main/java/org/jeecg/modules/message/websocket/WebSocket.java @@ -21,7 +21,7 @@ import lombok.extern.slf4j.Slf4j; */ @Component @Slf4j -@ServerEndpoint("/websocket/{userId}") +//@ServerEndpoint("/websocket/{userId}") public class WebSocket { /**线程安全Map*/