From 8460d016ab4db845c49743076b66cc67bbd5cd38 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Fri, 9 Feb 2018 17:48:09 +0100 Subject: [PATCH] fix address parsing for mux --- proxy/vmess/encoding/encoding_test.go | 87 +++++++++++++++++++++++++ proxy/vmess/encoding/server.go | 93 ++++++++++++++++++--------- proxy/vmess/inbound/inbound.go | 5 -- proxy/vmess/outbound/outbound.go | 2 +- 4 files changed, 150 insertions(+), 37 deletions(-) diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index 79417392..2777f8cf 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -66,3 +66,90 @@ func TestRequestSerialization(t *testing.T) { // anti replay attack assert(err, IsNotNil) } + +func TestInvalidRequest(t *testing.T) { + assert := With(t) + + user := &protocol.User{ + Level: 0, + Email: "test@v2ray.com", + } + id := uuid.New() + account := &vmess.Account{ + Id: id.String(), + AlterId: 0, + } + user.Account = serial.ToTypedMessage(account) + + expectedRequest := &protocol.RequestHeader{ + Version: 1, + User: user, + Command: protocol.RequestCommand(100), + Address: net.DomainAddress("www.v2ray.com"), + Port: net.Port(443), + Security: protocol.Security(protocol.SecurityType_AES128_GCM), + } + + buffer := buf.New() + client := NewClientSession(protocol.DefaultIDHash) + common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) + + buffer2 := buf.New() + buffer2.Append(buffer.Bytes()) + + sessionHistory := NewSessionHistory() + defer common.Close(sessionHistory) + + userValidator := vmess.NewTimedUserValidator(protocol.DefaultIDHash) + userValidator.Add(user) + defer common.Close(userValidator) + + server := NewServerSession(userValidator, sessionHistory) + _, err := server.DecodeRequestHeader(buffer) + assert(err, IsNotNil) +} + +func TestMuxRequest(t *testing.T) { + assert := With(t) + + user := &protocol.User{ + Level: 0, + Email: "test@v2ray.com", + } + id := uuid.New() + account := &vmess.Account{ + Id: id.String(), + AlterId: 0, + } + user.Account = serial.ToTypedMessage(account) + + expectedRequest := &protocol.RequestHeader{ + Version: 1, + User: user, + Command: protocol.RequestCommandMux, + Security: protocol.Security(protocol.SecurityType_AES128_GCM), + } + + buffer := buf.New() + client := NewClientSession(protocol.DefaultIDHash) + common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) + + buffer2 := buf.New() + buffer2.Append(buffer.Bytes()) + + sessionHistory := NewSessionHistory() + defer common.Close(sessionHistory) + + userValidator := vmess.NewTimedUserValidator(protocol.DefaultIDHash) + userValidator.Add(user) + defer common.Close(userValidator) + + server := NewServerSession(userValidator, sessionHistory) + actualRequest, err := server.DecodeRequestHeader(buffer) + assert(err, IsNil) + + assert(expectedRequest.Version, Equals, actualRequest.Version) + assert(byte(expectedRequest.Command), Equals, byte(actualRequest.Command)) + assert(byte(expectedRequest.Option), Equals, byte(actualRequest.Option)) + assert(byte(expectedRequest.Security), Equals, byte(actualRequest.Security)) +} diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 8806540f..d58914c1 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "v2ray.com/core/common/dice" + "golang.org/x/crypto/chacha20poly1305" "v2ray.com/core/common" "v2ray.com/core/common/bitmask" @@ -103,6 +105,44 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH } } +func readAddress(buffer *buf.Buffer, reader io.Reader) (net.Address, net.Port, error) { + var address net.Address + var port net.Port + if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil { + return address, port, newError("failed to read port and address type").Base(err) + } + port = net.PortFromBytes(buffer.BytesRange(-3, -1)) + + addressType := protocol.AddressType(buffer.Byte(buffer.Len() - 1)) + switch addressType { + case protocol.AddressTypeIPv4: + if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { + return address, port, newError("failed to read IPv4 address").Base(err) + } + address = net.IPAddress(buffer.BytesFrom(-4)) + case protocol.AddressTypeIPv6: + if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { + return address, port, newError("failed to read IPv6 address").Base(err) + } + address = net.IPAddress(buffer.BytesFrom(-16)) + case protocol.AddressTypeDomain: + if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { + return address, port, newError("failed to read domain address").Base(err) + } + domainLength := int(buffer.Byte(buffer.Len() - 1)) + if domainLength == 0 { + return address, port, newError("zero length domain") + } + if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { + return address, port, newError("failed to read domain address").Base(err) + } + address = net.DomainAddress(string(buffer.BytesFrom(-domainLength))) + default: + return address, port, newError("invalid address type", addressType) + } + return address, port, nil +} + func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.RequestHeader, error) { buffer := buf.New() defer buffer.Release() @@ -128,7 +168,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv) decryptor := crypto.NewCryptionReader(aesStream, reader) - if err := buffer.Reset(buf.ReadFullFrom(decryptor, 41)); err != nil { + if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil { return nil, newError("failed to read request header").Base(err) } @@ -137,10 +177,6 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request Version: buffer.Byte(0), } - if request.Version != Version { - return nil, newError("invalid protocol version ", request.Version) - } - s.requestBodyIV = append([]byte(nil), buffer.BytesRange(1, 17)...) // 16 bytes s.requestBodyKey = append([]byte(nil), buffer.BytesRange(17, 33)...) // 16 bytes var sid sessionId @@ -159,33 +195,28 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request // 1 bytes reserved request.Command = protocol.RequestCommand(buffer.Byte(37)) - if request.Command != protocol.RequestCommandMux { - request.Port = net.PortFromBytes(buffer.BytesRange(38, 40)) - - switch protocol.AddressType(buffer.Byte(40)) { - case protocol.AddressTypeIPv4: - if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil { - return nil, newError("failed to read IPv4 address").Base(err) - } - request.Address = net.IPAddress(buffer.BytesFrom(-4)) - case protocol.AddressTypeIPv6: - if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 16)); err != nil { - return nil, newError("failed to read IPv6 address").Base(err) - } - request.Address = net.IPAddress(buffer.BytesFrom(-16)) - case protocol.AddressTypeDomain: - if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 1)); err != nil { - return nil, newError("failed to read domain address").Base(err) - } - domainLength := int(buffer.Byte(buffer.Len() - 1)) - if domainLength == 0 { - return nil, newError("zero length domain").Base(err) - } - if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, domainLength)); err != nil { - return nil, newError("failed to read domain address").Base(err) - } - request.Address = net.DomainAddress(string(buffer.BytesFrom(-domainLength))) + invalidRequest := false + switch request.Command { + case protocol.RequestCommandMux: + request.Address = net.DomainAddress("v1.mux.cool") + request.Port = 0 + case protocol.RequestCommandTCP, protocol.RequestCommandUDP: + if addr, port, err := readAddress(buffer, decryptor); err == nil { + request.Address = addr + request.Port = port + } else { + invalidRequest = true + newError("failed to read address").Base(err).WriteToLog() } + default: + invalidRequest = true + } + + if invalidRequest { + randomLen := dice.Roll(32) + // Read random number of bytes for prevent detection. + buffer.AppendSupplier(buf.ReadFullFrom(decryptor, randomLen)) + return nil, newError("invalid request") } if padingLen > 0 { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 68a3c93d..145a60ba 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -239,11 +239,6 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i return err } - if request.Command == protocol.RequestCommandMux { - request.Address = net.DomainAddress("v1.mux.com") - request.Port = net.Port(0) - } - log.Record(&log.AccessMessage{ From: connection.RemoteAddr(), To: request.Destination(), diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index d195c8cc..2d78184d 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -75,7 +75,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial if target.Network == net.Network_UDP { command = protocol.RequestCommandUDP } - if target.Address.Family().IsDomain() && target.Address.Domain() == "v1.mux.com" { + if target.Address.Family().IsDomain() && target.Address.Domain() == "v1.mux.cool" { command = protocol.RequestCommandMux } request := &protocol.RequestHeader{