diff --git a/common/protocol/address.go b/common/protocol/address.go new file mode 100644 index 00000000..44c1abb9 --- /dev/null +++ b/common/protocol/address.go @@ -0,0 +1,183 @@ +package protocol + +import ( + "io" + + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" +) + +type AddressOption func(*AddressParser) + +func PortThenAddress() AddressOption { + return func(p *AddressParser) { + p.portFirst = true + } +} + +func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption { + return func(p *AddressParser) { + p.addrTypeMap[b] = f + p.addrByteMap[f] = b + } +} + +type AddressTypeParser func(byte) byte + +func WithAddressTypeParser(atp AddressTypeParser) AddressOption { + return func(p *AddressParser) { + p.typeParser = atp + } +} + +type AddressParser struct { + addrTypeMap map[byte]net.AddressFamily + addrByteMap map[net.AddressFamily]byte + portFirst bool + typeParser AddressTypeParser +} + +func NewAddressParser(options ...AddressOption) *AddressParser { + p := &AddressParser{ + addrTypeMap: make(map[byte]net.AddressFamily, 8), + addrByteMap: make(map[net.AddressFamily]byte, 8), + } + for _, opt := range options { + opt(p) + } + return p +} + +func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) { + if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { + return 0, err + } + return net.PortFromBytes(b.BytesFrom(-2)), nil +} + +func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) { + if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { + return nil, err + } + + addrType := b.Byte(b.Len() - 1) + if p.typeParser != nil { + addrType = p.typeParser(addrType) + } + + addrFamily, valid := p.addrTypeMap[addrType] + if !valid { + return nil, newError("unknown address type: ", addrType) + } + + switch addrFamily { + case net.AddressFamilyIPv4: + if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { + return nil, err + } + return net.IPAddress(b.BytesFrom(-4)), nil + case net.AddressFamilyIPv6: + if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { + return nil, err + } + return net.IPAddress(b.BytesFrom(-16)), nil + case net.AddressFamilyDomain: + if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { + return nil, err + } + domainLength := int(b.Byte(b.Len() - 1)) + if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { + return nil, err + } + return net.DomainAddress(string(b.BytesFrom(-domainLength))), nil + default: + panic("impossible case") + } +} + +func (p *AddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) { + if buffer == nil { + buffer = buf.New() + defer buffer.Release() + } + + if p.portFirst { + port, err := p.readPort(buffer, input) + if err != nil { + return nil, 0, err + } + addr, err := p.readAddress(buffer, input) + if err != nil { + return nil, 0, err + } + return addr, port, nil + } + + addr, err := p.readAddress(buffer, input) + if err != nil { + return nil, 0, err + } + + port, err := p.readPort(buffer, input) + if err != nil { + return nil, 0, err + } + + return addr, port, nil +} + +func (p *AddressParser) writePort(writer io.Writer, port net.Port) error { + if _, err := writer.Write(port.Bytes(nil)); err != nil { + return err + } + return nil +} + +func (p *AddressParser) writeAddress(writer io.Writer, address net.Address) error { + tb, valid := p.addrByteMap[address.Family()] + if !valid { + return newError("unknown address family", address.Family()) + } + + switch address.Family() { + case net.AddressFamilyIPv4, net.AddressFamilyIPv6: + if _, err := writer.Write([]byte{tb}); err != nil { + return err + } + if _, err := writer.Write(address.IP()); err != nil { + return err + } + case net.AddressFamilyDomain: + domain := address.Domain() + if IsDomainTooLong(domain) { + return newError("Super long domain is not supported: ", domain) + } + if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil { + return err + } + if _, err := writer.Write([]byte(domain)); err != nil { + return err + } + } + return nil +} + +func (p *AddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error { + if p.portFirst { + if err := p.writePort(writer, port); err != nil { + return err + } + if err := p.writeAddress(writer, addr); err != nil { + return err + } + return nil + } + + if err := p.writeAddress(writer, addr); err != nil { + return err + } + if err := p.writePort(writer, port); err != nil { + return err + } + return nil +} diff --git a/common/protocol/address_test.go b/common/protocol/address_test.go new file mode 100644 index 00000000..5e724d51 --- /dev/null +++ b/common/protocol/address_test.go @@ -0,0 +1,70 @@ +package protocol_test + +import ( + "bytes" + "testing" + + "v2ray.com/core/common/buf" + "v2ray.com/core/common/net" + . "v2ray.com/core/common/protocol" + . "v2ray.com/ext/assert" +) + +func TestAddressParser(t *testing.T) { + assert := With(t) + + data := []struct { + Options []AddressOption + Input []byte + Address net.Address + Port net.Port + Error bool + }{ + { + Options: []AddressOption{}, + Input: []byte{0, 0, 0, 0, 0}, + Error: true, + }, + { + Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)}, + Input: []byte{1, 0, 0, 0, 0, 0, 53}, + Address: net.IPAddress([]byte{0, 0, 0, 0}), + Port: net.Port(53), + }, + { + Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)}, + Input: []byte{1, 0, 0, 0, 0}, + Error: true, + }, + { + Options: []AddressOption{AddressFamilyByte(0x04, net.AddressFamilyIPv6)}, + Input: []byte{4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80}, + Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}), + Port: net.Port(80), + }, + { + Options: []AddressOption{AddressFamilyByte(0x03, net.AddressFamilyDomain)}, + Input: []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80}, + Address: net.DomainAddress("v2ray.com"), + Port: net.Port(80), + }, + { + Options: []AddressOption{AddressFamilyByte(0x03, net.AddressFamilyDomain)}, + Input: []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0}, + Error: true, + }, + } + + for _, tc := range data { + b := buf.New() + parser := NewAddressParser(tc.Options...) + addr, port, err := parser.ReadAddressPort(b, bytes.NewReader(tc.Input)) + b.Release() + if tc.Error { + assert(err, IsNotNil) + } else { + assert(addr, Equals, tc.Address) + assert(port, Equals, tc.Port) + } + } +} diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 48fbd486..fb3f22b9 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -11,16 +11,20 @@ import ( "v2ray.com/core/common/dice" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" - "v2ray.com/core/proxy/socks" ) const ( Version = 1 RequestOptionOneTimeAuth bitmask.Byte = 0x01 +) - AddrTypeIPv4 = 1 - AddrTypeIPv6 = 4 - AddrTypeDomain = 3 +var addrParser = protocol.NewAddressParser( + protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4), + protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6), + protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain), + protocol.WithAddressTypeParser(func(b byte) byte { + return b & 0x0F + }), ) // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts. @@ -58,10 +62,21 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea Command: protocol.RequestCommandTCP, } - if err := buffer.Reset(buf.ReadFullFrom(br, 1)); err != nil { - return nil, nil, newError("failed to read address type").Base(err) + buffer.Clear() + + addr, port, err := addrParser.ReadAddressPort(buffer, br) + + if err != nil { + // Invalid address. Continue to read some bytes to confuse client. + nBytes := dice.Roll(32) + buffer.Clear() + buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes)) + return nil, nil, newError("failed to read address").Base(err) } + request.Address = addr + request.Port = port + if !account.Cipher.IsAEAD() { if (buffer.Byte(0) & 0x10) == 0x10 { request.Option.Set(RequestOptionOneTimeAuth) @@ -76,20 +91,6 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea } } - addrType := (buffer.Byte(0) & 0x0F) - - addr, port, err := socks.ReadAddress(buffer, addrType, br) - if err != nil { - // Invalid address. Continue to read some bytes to confuse client. - nBytes := dice.Roll(32) - buffer.Clear() - buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes)) - return nil, nil, newError("failed to read address").Base(err) - } - - request.Address = addr - request.Port = port - if request.Option.Has(RequestOptionOneTimeAuth) { actualAuth := make([]byte, AuthSize) authenticator.Authenticate(buffer.Bytes())(actualAuth) @@ -150,7 +151,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri header := buf.NewLocal(512) - if err := socks.AppendAddress(header, request.Address, request.Port); err != nil { + if err := addrParser.WriteAddressPort(header, request.Address, request.Port); err != nil { return nil, newError("failed to write address").Base(err) } @@ -230,7 +231,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff } iv := buffer.Bytes() - if err := socks.AppendAddress(buffer, request.Address, request.Port); err != nil { + if err := addrParser.WriteAddressPort(buffer, request.Address, request.Port); err != nil { return nil, newError("failed to write address").Base(err) } @@ -301,26 +302,15 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques } } - addrType := (payload.Byte(0) & 0x0F) - payload.SliceFrom(1) + payload.SetByte(0, payload.Byte(0)&0x0F) - switch addrType { - case AddrTypeIPv4: - request.Address = net.IPAddress(payload.BytesTo(4)) - payload.SliceFrom(4) - case AddrTypeIPv6: - request.Address = net.IPAddress(payload.BytesTo(16)) - payload.SliceFrom(16) - case AddrTypeDomain: - domainLength := int(payload.Byte(0)) - request.Address = net.DomainAddress(string(payload.BytesRange(1, 1+domainLength))) - payload.SliceFrom(1 + domainLength) - default: - return nil, nil, newError("unknown address type: ", addrType).AtError() + addr, port, err := addrParser.ReadAddressPort(nil, payload) + if err != nil { + return nil, nil, newError("failed to parse address").Base(err) } - request.Port = net.PortFromBytes(payload.BytesTo(2)) - payload.SliceFrom(2) + request.Address = addr + request.Port = port return request, payload, nil } diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index 8ad4235c..2090296a 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -34,6 +34,12 @@ const ( statusCmdNotSupport = 0x07 ) +var addrParser = protocol.NewAddressParser( + protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4), + protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6), + protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain), +) + type ServerSession struct { config *ServerConfig port net.Port @@ -122,7 +128,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol return nil, newError("failed to write auth response").Base(err) } } - if err := buffer.Reset(buf.ReadFullFrom(reader, 4)); err != nil { + if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil { return nil, newError("failed to read request").Base(err) } @@ -139,13 +145,11 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol request.Command = protocol.RequestCommandUDP } - addrType := buffer.Byte(3) - buffer.Clear() request.Version = socks5Version - addr, port, err := ReadAddress(buffer, addrType, reader) + addr, port, err := addrParser.ReadAddressPort(buffer, reader) if err != nil { return nil, newError("failed to read address").Base(err) } @@ -229,30 +233,10 @@ func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte return err } -// AppendAddress appends Socks address into the given buffer. -func AppendAddress(buffer *buf.Buffer, address net.Address, port net.Port) error { - switch address.Family() { - case net.AddressFamilyIPv4: - buffer.AppendBytes(addrTypeIPv4) - buffer.Append(address.IP()) - case net.AddressFamilyIPv6: - buffer.AppendBytes(addrTypeIPv6) - buffer.Append(address.IP()) - case net.AddressFamilyDomain: - if protocol.IsDomainTooLong(address.Domain()) { - return newError("Super long domain is not supported in Socks protocol: ", address.Domain()) - } - buffer.AppendBytes(addrTypeDomain, byte(len(address.Domain()))) - common.Must(buffer.AppendSupplier(serial.WriteString(address.Domain()))) - } - common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value()))) - return nil -} - func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error { buffer := buf.NewLocal(64) buffer.AppendBytes(socks5Version, errCode, 0x00 /* reserved */) - if err := AppendAddress(buffer, address, port); err != nil { + if err := addrParser.WriteAddressPort(buffer, address, port); err != nil { return err } @@ -269,9 +253,9 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po return err } -func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) { - if len(packet) < 5 { - return nil, nil, newError("insufficient length of packet.") +func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) { + if packet.Len() < 5 { + return nil, newError("insufficient length of packet.") } request := &protocol.RequestHeader{ Version: socks5Version, @@ -279,50 +263,25 @@ func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) { } // packet[0] and packet[1] are reserved - if packet[2] != 0 /* fragments */ { - return nil, nil, newError("discarding fragmented payload.") + if packet.Byte(2) != 0 /* fragments */ { + return nil, newError("discarding fragmented payload.") } - addrType := packet[3] - var dataBegin int + packet.SliceFrom(3) - switch addrType { - case addrTypeIPv4: - if len(packet) < 10 { - return nil, nil, newError("insufficient length of packet") - } - ip := packet[4:8] - request.Port = net.PortFromBytes(packet[8:10]) - request.Address = net.IPAddress(ip) - dataBegin = 10 - case addrTypeIPv6: - if len(packet) < 22 { - return nil, nil, newError("insufficient length of packet") - } - ip := packet[4:20] - request.Port = net.PortFromBytes(packet[20:22]) - request.Address = net.IPAddress(ip) - dataBegin = 22 - case addrTypeDomain: - domainLength := int(packet[4]) - if len(packet) < 5+domainLength+2 { - return nil, nil, newError("insufficient length of packet") - } - domain := string(packet[5 : 5+domainLength]) - request.Port = net.PortFromBytes(packet[5+domainLength : 5+domainLength+2]) - request.Address = net.ParseAddress(domain) - dataBegin = 5 + domainLength + 2 - default: - return nil, nil, newError("unknown address type ", addrType) + addr, port, err := addrParser.ReadAddressPort(nil, packet) + if err != nil { + return nil, newError("failed to read UDP header").Base(err) } - - return request, packet[dataBegin:], nil + request.Address = addr + request.Port = port + return request, nil } func EncodeUDPPacket(request *protocol.RequestHeader, data []byte) (*buf.Buffer, error) { b := buf.New() b.AppendBytes(0, 0, 0 /* Fragment */) - if err := AppendAddress(b, request.Address, request.Port); err != nil { + if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil { return nil, err } b.Append(data) @@ -342,12 +301,9 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) { if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil { return nil, err } - _, data, err := DecodeUDPPacket(b.Bytes()) - if err != nil { + if _, err := DecodeUDPPacket(b); err != nil { return nil, err } - b.Clear() - b.Append(data) return buf.NewMultiBufferValue(b), nil } @@ -376,40 +332,6 @@ func (w *UDPWriter) Write(b []byte) (int, error) { return len(b), nil } -func ReadAddress(b *buf.Buffer, addrType byte, reader io.Reader) (net.Address, net.Port, error) { - var address net.Address - switch addrType { - case addrTypeIPv4: - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { - return nil, 0, err - } - address = net.IPAddress(b.BytesFrom(-4)) - case addrTypeIPv6: - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { - return nil, 0, err - } - address = net.IPAddress(b.BytesFrom(-16)) - case addrTypeDomain: - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { - return nil, 0, err - } - domainLength := int(b.Byte(b.Len() - 1)) - if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { - return nil, 0, err - } - address = net.DomainAddress(string(b.BytesFrom(-domainLength))) - default: - return nil, 0, newError("unknown address type: ", addrType) - } - - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil { - return nil, 0, err - } - port := net.PortFromBytes(b.BytesFrom(-2)) - - return address, port, nil -} - func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) { authByte := byte(authNotRequired) if request.User != nil { @@ -462,7 +384,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i command = byte(cmdUDPPort) } b.AppendBytes(socks5Version, command, 0x00 /* reserved */) - if err := AppendAddress(b, request.Address, request.Port); err != nil { + if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil { return nil, err } @@ -471,7 +393,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i } b.Clear() - if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { + if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil { return nil, err } @@ -480,11 +402,9 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i return nil, newError("server rejects request: ", resp) } - addrType := b.Byte(3) - b.Clear() - address, port, err := ReadAddress(b, addrType, reader) + address, port, err := addrParser.ReadAddressPort(b, reader) if err != nil { return nil, err } diff --git a/proxy/socks/protocol_test.go b/proxy/socks/protocol_test.go index bd14dfe5..5857471b 100644 --- a/proxy/socks/protocol_test.go +++ b/proxy/socks/protocol_test.go @@ -1,7 +1,6 @@ package socks_test import ( - "bytes" "testing" "v2ray.com/core/common/buf" @@ -34,56 +33,3 @@ func TestUDPEncoding(t *testing.T) { assert(err, IsNil) assert(decodedPayload[0].Bytes(), Equals, content) } - -func TestReadAddress(t *testing.T) { - assert := With(t) - - data := []struct { - AddrType byte - Input []byte - Address net.Address - Port net.Port - Error bool - }{ - { - AddrType: 0, - Input: []byte{0, 0, 0, 0}, - Error: true, - }, - { - AddrType: 1, - Input: []byte{0, 0, 0, 0, 0, 53}, - Address: net.IPAddress([]byte{0, 0, 0, 0}), - Port: net.Port(53), - }, - { - AddrType: 4, - Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80}, - Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}), - Port: net.Port(80), - }, - { - AddrType: 3, - Input: []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80}, - Address: net.DomainAddress("v2ray.com"), - Port: net.Port(80), - }, - { - AddrType: 3, - Input: []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0}, - Error: true, - }, - } - - for _, tc := range data { - b := buf.New() - addr, port, err := ReadAddress(b, tc.AddrType, bytes.NewBuffer(tc.Input)) - b.Release() - if tc.Error { - assert(err, IsNotNil) - } else { - assert(addr, Equals, tc.Address) - assert(port, Equals, tc.Port) - } - } -} diff --git a/proxy/socks/server.go b/proxy/socks/server.go index cae3208e..e80efbd1 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -185,18 +185,20 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, } for _, payload := range mpayload { - request, data, err := DecodeUDPPacket(payload.Bytes()) + request, err := DecodeUDPPacket(payload) if err != nil { newError("failed to parse UDP request").Base(err).WithContext(ctx).WriteToLog() + payload.Release() continue } - if len(data) == 0 { + if payload.IsEmpty() { + payload.Release() continue } - newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug().WithContext(ctx).WriteToLog() + newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WithContext(ctx).WriteToLog() if source, ok := proxy.SourceFromContext(ctx); ok { log.Record(&log.AccessMessage{ From: source, @@ -206,9 +208,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, }) } - dataBuf := buf.New() - dataBuf.Append(data) - udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) { + udpServer.Dispatch(ctx, request.Destination(), payload, func(payload *buf.Buffer) { defer payload.Release() newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WithContext(ctx).WriteToLog() diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index 6fe18741..688e0e24 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -15,7 +15,6 @@ import ( "v2ray.com/core/common/buf" "v2ray.com/core/common/crypto" "v2ray.com/core/common/dice" - "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" "v2ray.com/core/proxy/vmess" @@ -82,23 +81,8 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ buffer.AppendBytes(security, byte(0), byte(header.Command)) if header.Command != protocol.RequestCommandMux { - common.Must(buffer.AppendSupplier(serial.WriteUint16(header.Port.Value()))) - - switch header.Address.Family() { - case net.AddressFamilyIPv4: - buffer.AppendBytes(byte(protocol.AddressTypeIPv4)) - buffer.Append(header.Address.IP()) - case net.AddressFamilyIPv6: - buffer.AppendBytes(byte(protocol.AddressTypeIPv6)) - buffer.Append(header.Address.IP()) - case net.AddressFamilyDomain: - domain := header.Address.Domain() - if protocol.IsDomainTooLong(domain) { - return newError("long domain not supported: ", domain) - } - nDomain := len(domain) - buffer.AppendBytes(byte(protocol.AddressTypeDomain), byte(nDomain)) - common.Must(buffer.AppendSupplier(serial.WriteString(domain))) + if err := addrParser.WriteAddressPort(buffer, header.Address, header.Port); err != nil { + return newError("failed to writer address and port").Base(err) } } diff --git a/proxy/vmess/encoding/const.go b/proxy/vmess/encoding/const.go deleted file mode 100644 index d9d85878..00000000 --- a/proxy/vmess/encoding/const.go +++ /dev/null @@ -1,5 +0,0 @@ -package encoding - -const ( - Version = byte(1) -) diff --git a/proxy/vmess/encoding/encoding.go b/proxy/vmess/encoding/encoding.go index 3772a39d..b9e36c7f 100644 --- a/proxy/vmess/encoding/encoding.go +++ b/proxy/vmess/encoding/encoding.go @@ -1,3 +1,19 @@ package encoding +import ( + "v2ray.com/core/common/net" + "v2ray.com/core/common/protocol" +) + //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg encoding -path Proxy,VMess,Encoding + +const ( + Version = byte(1) +) + +var addrParser = protocol.NewAddressParser( + protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4), + protocol.AddressFamilyByte(0x02, net.AddressFamilyDomain), + protocol.AddressFamilyByte(0x03, net.AddressFamilyIPv6), + protocol.PortThenAddress(), +) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 10d0e0a3..3d778cc8 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -105,44 +105,6 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH } } -func readAddress(buffer *buf.Buffer, reader io.Reader) (net.Address, net.Port, error) { - var address net.Address - var port net.Port - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil { - return address, port, newError("failed to read port and address type").Base(err) - } - port = net.PortFromBytes(buffer.BytesRange(-3, -1)) - - addressType := protocol.AddressType(buffer.Byte(buffer.Len() - 1)) - switch addressType { - case protocol.AddressTypeIPv4: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil { - return address, port, newError("failed to read IPv4 address").Base(err) - } - address = net.IPAddress(buffer.BytesFrom(-4)) - case protocol.AddressTypeIPv6: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil { - return address, port, newError("failed to read IPv6 address").Base(err) - } - address = net.IPAddress(buffer.BytesFrom(-16)) - case protocol.AddressTypeDomain: - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil { - return address, port, newError("failed to read domain address").Base(err) - } - domainLength := int(buffer.Byte(buffer.Len() - 1)) - if domainLength == 0 { - return address, port, newError("zero length domain") - } - if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil { - return address, port, newError("failed to read domain address").Base(err) - } - address = net.DomainAddress(string(buffer.BytesFrom(-domainLength))) - default: - return address, port, newError("invalid address type", addressType) - } - return address, port, nil -} - func parseSecurityType(b byte) protocol.SecurityType { if _, f := protocol.SecurityType_name[int32(b)]; f { return protocol.SecurityType(b) @@ -221,7 +183,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request request.Address = net.DomainAddress("v1.mux.cool") request.Port = 0 case protocol.RequestCommandTCP, protocol.RequestCommandUDP: - if addr, port, err := readAddress(buffer, decryptor); err == nil { + if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil { request.Address = addr request.Port = port } else {