diff --git a/common/net/testing/assert/address.go b/common/net/testing/assert/address.go index 4783c152..57c5efbe 100644 --- a/common/net/testing/assert/address.go +++ b/common/net/testing/assert/address.go @@ -25,13 +25,7 @@ func (subject *AddressSubject) DisplayString() string { } func (subject *AddressSubject) Equals(another v2net.Address) { - if subject.value.IsIPv4() && another.IsIPv4() { - IP(subject.value.IP()).Equals(another.IP()) - } else if subject.value.IsIPv6() && another.IsIPv6() { - IP(subject.value.IP()).Equals(another.IP()) - } else if subject.value.IsDomain() && another.IsDomain() { - assert.StringLiteral(subject.value.Domain()).Equals(another.Domain()) - } else { + if !subject.value.Equals(another) { subject.Fail(subject.DisplayString(), "equals to", another) } } diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 067cc4f8..079edf84 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -18,6 +18,7 @@ const ( type Request struct { Address v2net.Address Port v2net.Port + OTA bool } func ReadRequest(reader io.Reader) (*Request, error) { @@ -32,7 +33,10 @@ func ReadRequest(reader io.Reader) (*Request, error) { request := new(Request) - addrType := buffer.Value[0] + addrType := (buffer.Value[0] & 0x0F) + if (buffer.Value[0] & 0x10) == 0x10 { + request.OTA = true + } switch addrType { case AddrTypeIPv4: _, err := io.ReadFull(reader, buffer.Value[:4]) diff --git a/proxy/shadowsocks/protocol_test.go b/proxy/shadowsocks/protocol_test.go new file mode 100644 index 00000000..b230b22d --- /dev/null +++ b/proxy/shadowsocks/protocol_test.go @@ -0,0 +1,37 @@ +package shadowsocks_test + +import ( + "testing" + + "github.com/v2ray/v2ray-core/common/alloc" + v2net "github.com/v2ray/v2ray-core/common/net" + netassert "github.com/v2ray/v2ray-core/common/net/testing/assert" + . "github.com/v2ray/v2ray-core/proxy/shadowsocks" + v2testing "github.com/v2ray/v2ray-core/testing" + "github.com/v2ray/v2ray-core/testing/assert" +) + +func TestNormalRequestParsing(t *testing.T) { + v2testing.Current(t) + + buffer := alloc.NewSmallBuffer().Clear() + buffer.AppendBytes(1, 127, 0, 0, 1, 0, 80) + + request, err := ReadRequest(buffer) + assert.Error(err).IsNil() + netassert.Address(request.Address).Equals(v2net.IPAddress([]byte{127, 0, 0, 1})) + netassert.Port(request.Port).Equals(v2net.Port(80)) + assert.Bool(request.OTA).IsFalse() +} + +func TestOTARequest(t *testing.T) { + v2testing.Current(t) + + buffer := alloc.NewSmallBuffer().Clear() + buffer.AppendBytes(0x13, 13, 119, 119, 119, 46, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 0) + + request, err := ReadRequest(buffer) + assert.Error(err).IsNil() + netassert.Address(request.Address).Equals(v2net.DomainAddress("www.v2ray.com")) + assert.Bool(request.OTA).IsTrue() +}