mirror of https://github.com/XTLS/Xray-core
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
153 lines
3.8 KiB
153 lines
3.8 KiB
package websocket_test |
|
|
|
import ( |
|
"context" |
|
"runtime" |
|
"testing" |
|
"time" |
|
|
|
"github.com/xtls/xray-core/common" |
|
"github.com/xtls/xray-core/common/net" |
|
"github.com/xtls/xray-core/common/protocol/tls/cert" |
|
"github.com/xtls/xray-core/testing/servers/tcp" |
|
"github.com/xtls/xray-core/transport/internet" |
|
"github.com/xtls/xray-core/transport/internet/stat" |
|
"github.com/xtls/xray-core/transport/internet/tls" |
|
. "github.com/xtls/xray-core/transport/internet/websocket" |
|
) |
|
|
|
func Test_listenWSAndDial(t *testing.T) { |
|
listenPort := tcp.PickPort() |
|
listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ |
|
ProtocolName: "websocket", |
|
ProtocolSettings: &Config{ |
|
Path: "ws", |
|
}, |
|
}, func(conn stat.Connection) { |
|
go func(c stat.Connection) { |
|
defer c.Close() |
|
|
|
var b [1024]byte |
|
c.SetReadDeadline(time.Now().Add(2 * time.Second)) |
|
_, err := c.Read(b[:]) |
|
if err != nil { |
|
return |
|
} |
|
|
|
common.Must2(c.Write([]byte("Response"))) |
|
}(conn) |
|
}) |
|
common.Must(err) |
|
|
|
ctx := context.Background() |
|
streamSettings := &internet.MemoryStreamConfig{ |
|
ProtocolName: "websocket", |
|
ProtocolSettings: &Config{Path: "ws"}, |
|
} |
|
conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) |
|
|
|
common.Must(err) |
|
_, err = conn.Write([]byte("Test connection 1")) |
|
common.Must(err) |
|
|
|
var b [1024]byte |
|
n, err := conn.Read(b[:]) |
|
common.Must(err) |
|
if string(b[:n]) != "Response" { |
|
t.Error("response: ", string(b[:n])) |
|
} |
|
|
|
common.Must(conn.Close()) |
|
conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) |
|
common.Must(err) |
|
_, err = conn.Write([]byte("Test connection 2")) |
|
common.Must(err) |
|
n, err = conn.Read(b[:]) |
|
common.Must(err) |
|
if string(b[:n]) != "Response" { |
|
t.Error("response: ", string(b[:n])) |
|
} |
|
common.Must(conn.Close()) |
|
|
|
common.Must(listen.Close()) |
|
} |
|
|
|
func TestDialWithRemoteAddr(t *testing.T) { |
|
listenPort := tcp.PickPort() |
|
listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{ |
|
ProtocolName: "websocket", |
|
ProtocolSettings: &Config{ |
|
Path: "ws", |
|
}, |
|
}, func(conn stat.Connection) { |
|
go func(c stat.Connection) { |
|
defer c.Close() |
|
|
|
var b [1024]byte |
|
_, err := c.Read(b[:]) |
|
// common.Must(err) |
|
if err != nil { |
|
return |
|
} |
|
|
|
_, err = c.Write([]byte(c.RemoteAddr().String())) |
|
common.Must(err) |
|
}(conn) |
|
}) |
|
common.Must(err) |
|
|
|
conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), &internet.MemoryStreamConfig{ |
|
ProtocolName: "websocket", |
|
ProtocolSettings: &Config{Path: "ws", Header: map[string]string{"X-Forwarded-For": "1.1.1.1"}}, |
|
}) |
|
|
|
common.Must(err) |
|
_, err = conn.Write([]byte("Test connection 1")) |
|
common.Must(err) |
|
|
|
var b [1024]byte |
|
n, err := conn.Read(b[:]) |
|
common.Must(err) |
|
if string(b[:n]) != "1.1.1.1:0" { |
|
t.Error("response: ", string(b[:n])) |
|
} |
|
|
|
common.Must(listen.Close()) |
|
} |
|
|
|
func Test_listenWSAndDial_TLS(t *testing.T) { |
|
listenPort := tcp.PickPort() |
|
if runtime.GOARCH == "arm64" { |
|
return |
|
} |
|
|
|
start := time.Now() |
|
|
|
streamSettings := &internet.MemoryStreamConfig{ |
|
ProtocolName: "websocket", |
|
ProtocolSettings: &Config{ |
|
Path: "wss", |
|
}, |
|
SecurityType: "tls", |
|
SecuritySettings: &tls.Config{ |
|
AllowInsecure: true, |
|
Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))}, |
|
}, |
|
} |
|
listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) { |
|
go func() { |
|
_ = conn.Close() |
|
}() |
|
}) |
|
common.Must(err) |
|
defer listen.Close() |
|
|
|
conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings) |
|
common.Must(err) |
|
_ = conn.Close() |
|
|
|
end := time.Now() |
|
if !end.Before(start.Add(time.Second * 5)) { |
|
t.Error("end: ", end, " start: ", start) |
|
} |
|
}
|
|
|