diff --git a/common/reflect/marshal_test.go b/common/reflect/marshal_test.go index 359abae0..82194279 100644 --- a/common/reflect/marshal_test.go +++ b/common/reflect/marshal_test.go @@ -204,9 +204,7 @@ func getConfig() string { "security": "none", "wsSettings": { "path": "/?ed=2048", - "headers": { - "Host": "bing.com" - } + "host": "bing.com" } } } diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index c62acee7..c4937ba5 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -163,13 +163,13 @@ func (c *WebSocketConfig) Build() (proto.Message, error) { path = u.String() } } - // If http host is not set in the Host field, but in headers field, we add it to Host Field here. - // If we don't do that, http host will be overwritten as address. - // Host priority: Host field > headers field > address. - if c.Host == "" && c.Headers["host"] != "" { - c.Host = c.Headers["host"] - } else if c.Host == "" && c.Headers["Host"] != "" { - c.Host = c.Headers["Host"] + // Priority (client): host > serverName > address + for k, v := range c.Headers { + errors.PrintDeprecatedFeatureWarning(`"host" in "headers"`, `independent "host"`) + if c.Host == "" { + c.Host = v + } + delete(c.Headers, k) } config := &websocket.Config{ Path: path, @@ -202,15 +202,11 @@ func (c *HttpUpgradeConfig) Build() (proto.Message, error) { path = u.String() } } - // If http host is not set in the Host field, but in headers field, we add it to Host Field here. - // If we don't do that, http host will be overwritten as address. - // Host priority: Host field > headers field > address. - if c.Host == "" && c.Headers["host"] != "" { - c.Host = c.Headers["host"] - delete(c.Headers, "host") - } else if c.Host == "" && c.Headers["Host"] != "" { - c.Host = c.Headers["Host"] - delete(c.Headers, "Host") + // Priority (client): host > serverName > address + for k := range c.Headers { + if strings.ToLower(k) == "host" { + return nil, errors.New(`"headers" can't contain "host"`) + } } config := &httpupgrade.Config{ Path: path, @@ -274,13 +270,11 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) { c = &extra } - // If http host is not set in the Host field, but in headers field, we add it to Host Field here. - // If we don't do that, http host will be overwritten as address. - // Host priority: Host field > headers field > address. - if c.Host == "" && c.Headers["host"] != "" { - c.Host = c.Headers["host"] - } else if c.Host == "" && c.Headers["Host"] != "" { - c.Host = c.Headers["Host"] + // Priority (client): host > serverName > address + for k := range c.Headers { + if strings.ToLower(k) == "host" { + return nil, errors.New(`"headers" can't contain "host"`) + } } if c.Xmux.MaxConnections != nil && c.Xmux.MaxConnections.To > 0 && c.Xmux.MaxConcurrency != nil && c.Xmux.MaxConcurrency.To > 0 { diff --git a/infra/conf/xray_test.go b/infra/conf/xray_test.go index d225dbf9..1c0fff8d 100644 --- a/infra/conf/xray_test.go +++ b/infra/conf/xray_test.go @@ -48,9 +48,7 @@ func TestXrayConfig(t *testing.T) { "streamSettings": { "network": "ws", "wsSettings": { - "headers": { - "host": "example.domain" - }, + "host": "example.domain", "path": "" }, "tlsSettings": { @@ -139,9 +137,6 @@ func TestXrayConfig(t *testing.T) { ProtocolName: "websocket", Settings: serial.ToTypedMessage(&websocket.Config{ Host: "example.domain", - Header: map[string]string{ - "host": "example.domain", - }, }), }, }, diff --git a/transport/internet/httpupgrade/dialer.go b/transport/internet/httpupgrade/dialer.go index 013f4f28..c10bd97e 100644 --- a/transport/internet/httpupgrade/dialer.go +++ b/transport/internet/httpupgrade/dialer.go @@ -53,9 +53,10 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * var conn net.Conn var requestURL url.URL - if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { - tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) - if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil { + tConfig := tls.ConfigFromStreamSettings(streamSettings) + if tConfig != nil { + tlsConfig := tConfig.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + if fingerprint := tls.GetFingerprint(tConfig.Fingerprint); fingerprint != nil { conn = tls.UClient(pconn, tlsConfig, fingerprint) if err := conn.(*tls.UConn).WebsocketHandshakeContext(ctx); err != nil { return nil, err @@ -69,12 +70,17 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * requestURL.Scheme = "http" } - requestURL.Host = dest.NetAddr() + requestURL.Host = transportConfiguration.Host + if requestURL.Host == "" && tConfig != nil { + requestURL.Host = tConfig.ServerName + } + if requestURL.Host == "" { + requestURL.Host = dest.Address.String() + } requestURL.Path = transportConfiguration.GetNormalizedPath() req := &http.Request{ Method: http.MethodGet, URL: &requestURL, - Host: transportConfiguration.Host, Header: make(http.Header), } for key, value := range transportConfiguration.Header { diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index 4036dfc2..8fa87501 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -259,8 +259,14 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me requestURL.Scheme = "http" } requestURL.Host = transportConfiguration.Host + if requestURL.Host == "" && tlsConfig != nil { + requestURL.Host = tlsConfig.ServerName + } + if requestURL.Host == "" && realityConfig != nil { + requestURL.Host = realityConfig.ServerName + } if requestURL.Host == "" { - requestURL.Host = dest.NetAddr() + requestURL.Host = dest.Address.String() } sessionIdUuid := uuid.New() @@ -279,16 +285,25 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } globalDialerAccess.Unlock() memory2 := streamSettings.DownloadSettings - httpClient2, muxRes2 = getHTTPClient(ctx, *memory2.Destination, memory2) // just panic - if tls.ConfigFromStreamSettings(memory2) != nil || reality.ConfigFromStreamSettings(memory2) != nil { + dest2 := *memory2.Destination // just panic + httpClient2, muxRes2 = getHTTPClient(ctx, dest2, memory2) + tlsConfig2 := tls.ConfigFromStreamSettings(memory2) + realityConfig2 := reality.ConfigFromStreamSettings(memory2) + if tlsConfig2 != nil || realityConfig2 != nil { requestURL2.Scheme = "https" } else { requestURL2.Scheme = "http" } config2 := memory2.ProtocolSettings.(*Config) requestURL2.Host = config2.Host + if requestURL2.Host == "" && tlsConfig2 != nil { + requestURL2.Host = tlsConfig2.ServerName + } + if requestURL2.Host == "" && realityConfig2 != nil { + requestURL2.Host = realityConfig2.ServerName + } if requestURL2.Host == "" { - requestURL2.Host = memory2.Destination.NetAddr() + requestURL2.Host = dest2.Address.String() } requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String() requestURL2.RawQuery = config2.GetNormalizedQuery() diff --git a/transport/internet/websocket/config.go b/transport/internet/websocket/config.go index af66ebbe..4f2c0a14 100644 --- a/transport/internet/websocket/config.go +++ b/transport/internet/websocket/config.go @@ -23,7 +23,6 @@ func (c *Config) GetRequestHeader() http.Header { for k, v := range c.Header { header.Add(k, v) } - header.Set("Host", c.Host) return header } diff --git a/transport/internet/websocket/dialer.go b/transport/internet/websocket/dialer.go index 0659c7de..60330fd7 100644 --- a/transport/internet/websocket/dialer.go +++ b/transport/internet/websocket/dialer.go @@ -58,11 +58,12 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in protocol := "ws" - if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { + tConfig := tls.ConfigFromStreamSettings(streamSettings) + if tConfig != nil { protocol = "wss" - tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) + tlsConfig := tConfig.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) dialer.TLSClientConfig = tlsConfig - if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil { + if fingerprint := tls.GetFingerprint(tConfig.Fingerprint); fingerprint != nil { dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) { // Like the NetDial in the dialer pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) @@ -103,6 +104,14 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in } header := wsSettings.GetRequestHeader() + // See dialer.DialContext() + header.Set("Host", wsSettings.Host) + if header.Get("Host") == "" && tConfig != nil { + header.Set("Host", tConfig.ServerName) + } + if header.Get("Host") == "" { + header.Set("Host", dest.Address.String()) + } if ed != nil { // RawURLEncoding is support by both V2Ray/V2Fly and XRay. header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))