mirror of https://github.com/v2ray/v2ray-core
unify all address reading and writing
parent
a059ee2c00
commit
af1abf687c
|
@ -0,0 +1,183 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
)
|
||||
|
||||
type AddressOption func(*AddressParser)
|
||||
|
||||
func PortThenAddress() AddressOption {
|
||||
return func(p *AddressParser) {
|
||||
p.portFirst = true
|
||||
}
|
||||
}
|
||||
|
||||
func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption {
|
||||
return func(p *AddressParser) {
|
||||
p.addrTypeMap[b] = f
|
||||
p.addrByteMap[f] = b
|
||||
}
|
||||
}
|
||||
|
||||
type AddressTypeParser func(byte) byte
|
||||
|
||||
func WithAddressTypeParser(atp AddressTypeParser) AddressOption {
|
||||
return func(p *AddressParser) {
|
||||
p.typeParser = atp
|
||||
}
|
||||
}
|
||||
|
||||
type AddressParser struct {
|
||||
addrTypeMap map[byte]net.AddressFamily
|
||||
addrByteMap map[net.AddressFamily]byte
|
||||
portFirst bool
|
||||
typeParser AddressTypeParser
|
||||
}
|
||||
|
||||
func NewAddressParser(options ...AddressOption) *AddressParser {
|
||||
p := &AddressParser{
|
||||
addrTypeMap: make(map[byte]net.AddressFamily, 8),
|
||||
addrByteMap: make(map[net.AddressFamily]byte, 8),
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(p)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return net.PortFromBytes(b.BytesFrom(-2)), nil
|
||||
}
|
||||
|
||||
func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addrType := b.Byte(b.Len() - 1)
|
||||
if p.typeParser != nil {
|
||||
addrType = p.typeParser(addrType)
|
||||
}
|
||||
|
||||
addrFamily, valid := p.addrTypeMap[addrType]
|
||||
if !valid {
|
||||
return nil, newError("unknown address type: ", addrType)
|
||||
}
|
||||
|
||||
switch addrFamily {
|
||||
case net.AddressFamilyIPv4:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.IPAddress(b.BytesFrom(-4)), nil
|
||||
case net.AddressFamilyIPv6:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.IPAddress(b.BytesFrom(-16)), nil
|
||||
case net.AddressFamilyDomain:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domainLength := int(b.Byte(b.Len() - 1))
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.DomainAddress(string(b.BytesFrom(-domainLength))), nil
|
||||
default:
|
||||
panic("impossible case")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *AddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
|
||||
if buffer == nil {
|
||||
buffer = buf.New()
|
||||
defer buffer.Release()
|
||||
}
|
||||
|
||||
if p.portFirst {
|
||||
port, err := p.readPort(buffer, input)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
addr, err := p.readAddress(buffer, input)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
addr, err := p.readAddress(buffer, input)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
port, err := p.readPort(buffer, input)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return addr, port, nil
|
||||
}
|
||||
|
||||
func (p *AddressParser) writePort(writer io.Writer, port net.Port) error {
|
||||
if _, err := writer.Write(port.Bytes(nil)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AddressParser) writeAddress(writer io.Writer, address net.Address) error {
|
||||
tb, valid := p.addrByteMap[address.Family()]
|
||||
if !valid {
|
||||
return newError("unknown address family", address.Family())
|
||||
}
|
||||
|
||||
switch address.Family() {
|
||||
case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
|
||||
if _, err := writer.Write([]byte{tb}); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write(address.IP()); err != nil {
|
||||
return err
|
||||
}
|
||||
case net.AddressFamilyDomain:
|
||||
domain := address.Domain()
|
||||
if IsDomainTooLong(domain) {
|
||||
return newError("Super long domain is not supported: ", domain)
|
||||
}
|
||||
if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := writer.Write([]byte(domain)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *AddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
|
||||
if p.portFirst {
|
||||
if err := p.writePort(writer, port); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.writeAddress(writer, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := p.writeAddress(writer, addr); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.writePort(writer, port); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,70 @@
|
|||
package protocol_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/net"
|
||||
. "v2ray.com/core/common/protocol"
|
||||
. "v2ray.com/ext/assert"
|
||||
)
|
||||
|
||||
func TestAddressParser(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
data := []struct {
|
||||
Options []AddressOption
|
||||
Input []byte
|
||||
Address net.Address
|
||||
Port net.Port
|
||||
Error bool
|
||||
}{
|
||||
{
|
||||
Options: []AddressOption{},
|
||||
Input: []byte{0, 0, 0, 0, 0},
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)},
|
||||
Input: []byte{1, 0, 0, 0, 0, 0, 53},
|
||||
Address: net.IPAddress([]byte{0, 0, 0, 0}),
|
||||
Port: net.Port(53),
|
||||
},
|
||||
{
|
||||
Options: []AddressOption{AddressFamilyByte(0x01, net.AddressFamilyIPv4)},
|
||||
Input: []byte{1, 0, 0, 0, 0},
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
Options: []AddressOption{AddressFamilyByte(0x04, net.AddressFamilyIPv6)},
|
||||
Input: []byte{4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80},
|
||||
Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
|
||||
Port: net.Port(80),
|
||||
},
|
||||
{
|
||||
Options: []AddressOption{AddressFamilyByte(0x03, net.AddressFamilyDomain)},
|
||||
Input: []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80},
|
||||
Address: net.DomainAddress("v2ray.com"),
|
||||
Port: net.Port(80),
|
||||
},
|
||||
{
|
||||
Options: []AddressOption{AddressFamilyByte(0x03, net.AddressFamilyDomain)},
|
||||
Input: []byte{3, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0},
|
||||
Error: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range data {
|
||||
b := buf.New()
|
||||
parser := NewAddressParser(tc.Options...)
|
||||
addr, port, err := parser.ReadAddressPort(b, bytes.NewReader(tc.Input))
|
||||
b.Release()
|
||||
if tc.Error {
|
||||
assert(err, IsNotNil)
|
||||
} else {
|
||||
assert(addr, Equals, tc.Address)
|
||||
assert(port, Equals, tc.Port)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,16 +11,20 @@ import (
|
|||
"v2ray.com/core/common/dice"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/proxy/socks"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = 1
|
||||
RequestOptionOneTimeAuth bitmask.Byte = 0x01
|
||||
)
|
||||
|
||||
AddrTypeIPv4 = 1
|
||||
AddrTypeIPv6 = 4
|
||||
AddrTypeDomain = 3
|
||||
var addrParser = protocol.NewAddressParser(
|
||||
protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
|
||||
protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
|
||||
protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
|
||||
protocol.WithAddressTypeParser(func(b byte) byte {
|
||||
return b & 0x0F
|
||||
}),
|
||||
)
|
||||
|
||||
// ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
|
||||
|
@ -58,10 +62,21 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
|||
Command: protocol.RequestCommandTCP,
|
||||
}
|
||||
|
||||
if err := buffer.Reset(buf.ReadFullFrom(br, 1)); err != nil {
|
||||
return nil, nil, newError("failed to read address type").Base(err)
|
||||
buffer.Clear()
|
||||
|
||||
addr, port, err := addrParser.ReadAddressPort(buffer, br)
|
||||
|
||||
if err != nil {
|
||||
// Invalid address. Continue to read some bytes to confuse client.
|
||||
nBytes := dice.Roll(32)
|
||||
buffer.Clear()
|
||||
buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes))
|
||||
return nil, nil, newError("failed to read address").Base(err)
|
||||
}
|
||||
|
||||
request.Address = addr
|
||||
request.Port = port
|
||||
|
||||
if !account.Cipher.IsAEAD() {
|
||||
if (buffer.Byte(0) & 0x10) == 0x10 {
|
||||
request.Option.Set(RequestOptionOneTimeAuth)
|
||||
|
@ -76,20 +91,6 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
|
|||
}
|
||||
}
|
||||
|
||||
addrType := (buffer.Byte(0) & 0x0F)
|
||||
|
||||
addr, port, err := socks.ReadAddress(buffer, addrType, br)
|
||||
if err != nil {
|
||||
// Invalid address. Continue to read some bytes to confuse client.
|
||||
nBytes := dice.Roll(32)
|
||||
buffer.Clear()
|
||||
buffer.AppendSupplier(buf.ReadFullFrom(br, nBytes))
|
||||
return nil, nil, newError("failed to read address").Base(err)
|
||||
}
|
||||
|
||||
request.Address = addr
|
||||
request.Port = port
|
||||
|
||||
if request.Option.Has(RequestOptionOneTimeAuth) {
|
||||
actualAuth := make([]byte, AuthSize)
|
||||
authenticator.Authenticate(buffer.Bytes())(actualAuth)
|
||||
|
@ -150,7 +151,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
|
|||
|
||||
header := buf.NewLocal(512)
|
||||
|
||||
if err := socks.AppendAddress(header, request.Address, request.Port); err != nil {
|
||||
if err := addrParser.WriteAddressPort(header, request.Address, request.Port); err != nil {
|
||||
return nil, newError("failed to write address").Base(err)
|
||||
}
|
||||
|
||||
|
@ -230,7 +231,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
|
|||
}
|
||||
iv := buffer.Bytes()
|
||||
|
||||
if err := socks.AppendAddress(buffer, request.Address, request.Port); err != nil {
|
||||
if err := addrParser.WriteAddressPort(buffer, request.Address, request.Port); err != nil {
|
||||
return nil, newError("failed to write address").Base(err)
|
||||
}
|
||||
|
||||
|
@ -301,26 +302,15 @@ func DecodeUDPPacket(user *protocol.User, payload *buf.Buffer) (*protocol.Reques
|
|||
}
|
||||
}
|
||||
|
||||
addrType := (payload.Byte(0) & 0x0F)
|
||||
payload.SliceFrom(1)
|
||||
payload.SetByte(0, payload.Byte(0)&0x0F)
|
||||
|
||||
switch addrType {
|
||||
case AddrTypeIPv4:
|
||||
request.Address = net.IPAddress(payload.BytesTo(4))
|
||||
payload.SliceFrom(4)
|
||||
case AddrTypeIPv6:
|
||||
request.Address = net.IPAddress(payload.BytesTo(16))
|
||||
payload.SliceFrom(16)
|
||||
case AddrTypeDomain:
|
||||
domainLength := int(payload.Byte(0))
|
||||
request.Address = net.DomainAddress(string(payload.BytesRange(1, 1+domainLength)))
|
||||
payload.SliceFrom(1 + domainLength)
|
||||
default:
|
||||
return nil, nil, newError("unknown address type: ", addrType).AtError()
|
||||
addr, port, err := addrParser.ReadAddressPort(nil, payload)
|
||||
if err != nil {
|
||||
return nil, nil, newError("failed to parse address").Base(err)
|
||||
}
|
||||
|
||||
request.Port = net.PortFromBytes(payload.BytesTo(2))
|
||||
payload.SliceFrom(2)
|
||||
request.Address = addr
|
||||
request.Port = port
|
||||
|
||||
return request, payload, nil
|
||||
}
|
||||
|
|
|
@ -34,6 +34,12 @@ const (
|
|||
statusCmdNotSupport = 0x07
|
||||
)
|
||||
|
||||
var addrParser = protocol.NewAddressParser(
|
||||
protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
|
||||
protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
|
||||
protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
|
||||
)
|
||||
|
||||
type ServerSession struct {
|
||||
config *ServerConfig
|
||||
port net.Port
|
||||
|
@ -122,7 +128,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
|||
return nil, newError("failed to write auth response").Base(err)
|
||||
}
|
||||
}
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, 4)); err != nil {
|
||||
if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil {
|
||||
return nil, newError("failed to read request").Base(err)
|
||||
}
|
||||
|
||||
|
@ -139,13 +145,11 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
|
|||
request.Command = protocol.RequestCommandUDP
|
||||
}
|
||||
|
||||
addrType := buffer.Byte(3)
|
||||
|
||||
buffer.Clear()
|
||||
|
||||
request.Version = socks5Version
|
||||
|
||||
addr, port, err := ReadAddress(buffer, addrType, reader)
|
||||
addr, port, err := addrParser.ReadAddressPort(buffer, reader)
|
||||
if err != nil {
|
||||
return nil, newError("failed to read address").Base(err)
|
||||
}
|
||||
|
@ -229,30 +233,10 @@ func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte
|
|||
return err
|
||||
}
|
||||
|
||||
// AppendAddress appends Socks address into the given buffer.
|
||||
func AppendAddress(buffer *buf.Buffer, address net.Address, port net.Port) error {
|
||||
switch address.Family() {
|
||||
case net.AddressFamilyIPv4:
|
||||
buffer.AppendBytes(addrTypeIPv4)
|
||||
buffer.Append(address.IP())
|
||||
case net.AddressFamilyIPv6:
|
||||
buffer.AppendBytes(addrTypeIPv6)
|
||||
buffer.Append(address.IP())
|
||||
case net.AddressFamilyDomain:
|
||||
if protocol.IsDomainTooLong(address.Domain()) {
|
||||
return newError("Super long domain is not supported in Socks protocol: ", address.Domain())
|
||||
}
|
||||
buffer.AppendBytes(addrTypeDomain, byte(len(address.Domain())))
|
||||
common.Must(buffer.AppendSupplier(serial.WriteString(address.Domain())))
|
||||
}
|
||||
common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
|
||||
buffer := buf.NewLocal(64)
|
||||
buffer.AppendBytes(socks5Version, errCode, 0x00 /* reserved */)
|
||||
if err := AppendAddress(buffer, address, port); err != nil {
|
||||
if err := addrParser.WriteAddressPort(buffer, address, port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -269,9 +253,9 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
|
|||
return err
|
||||
}
|
||||
|
||||
func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) {
|
||||
if len(packet) < 5 {
|
||||
return nil, nil, newError("insufficient length of packet.")
|
||||
func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) {
|
||||
if packet.Len() < 5 {
|
||||
return nil, newError("insufficient length of packet.")
|
||||
}
|
||||
request := &protocol.RequestHeader{
|
||||
Version: socks5Version,
|
||||
|
@ -279,50 +263,25 @@ func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) {
|
|||
}
|
||||
|
||||
// packet[0] and packet[1] are reserved
|
||||
if packet[2] != 0 /* fragments */ {
|
||||
return nil, nil, newError("discarding fragmented payload.")
|
||||
if packet.Byte(2) != 0 /* fragments */ {
|
||||
return nil, newError("discarding fragmented payload.")
|
||||
}
|
||||
|
||||
addrType := packet[3]
|
||||
var dataBegin int
|
||||
packet.SliceFrom(3)
|
||||
|
||||
switch addrType {
|
||||
case addrTypeIPv4:
|
||||
if len(packet) < 10 {
|
||||
return nil, nil, newError("insufficient length of packet")
|
||||
}
|
||||
ip := packet[4:8]
|
||||
request.Port = net.PortFromBytes(packet[8:10])
|
||||
request.Address = net.IPAddress(ip)
|
||||
dataBegin = 10
|
||||
case addrTypeIPv6:
|
||||
if len(packet) < 22 {
|
||||
return nil, nil, newError("insufficient length of packet")
|
||||
}
|
||||
ip := packet[4:20]
|
||||
request.Port = net.PortFromBytes(packet[20:22])
|
||||
request.Address = net.IPAddress(ip)
|
||||
dataBegin = 22
|
||||
case addrTypeDomain:
|
||||
domainLength := int(packet[4])
|
||||
if len(packet) < 5+domainLength+2 {
|
||||
return nil, nil, newError("insufficient length of packet")
|
||||
}
|
||||
domain := string(packet[5 : 5+domainLength])
|
||||
request.Port = net.PortFromBytes(packet[5+domainLength : 5+domainLength+2])
|
||||
request.Address = net.ParseAddress(domain)
|
||||
dataBegin = 5 + domainLength + 2
|
||||
default:
|
||||
return nil, nil, newError("unknown address type ", addrType)
|
||||
addr, port, err := addrParser.ReadAddressPort(nil, packet)
|
||||
if err != nil {
|
||||
return nil, newError("failed to read UDP header").Base(err)
|
||||
}
|
||||
|
||||
return request, packet[dataBegin:], nil
|
||||
request.Address = addr
|
||||
request.Port = port
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func EncodeUDPPacket(request *protocol.RequestHeader, data []byte) (*buf.Buffer, error) {
|
||||
b := buf.New()
|
||||
b.AppendBytes(0, 0, 0 /* Fragment */)
|
||||
if err := AppendAddress(b, request.Address, request.Port); err != nil {
|
||||
if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Append(data)
|
||||
|
@ -342,12 +301,9 @@ func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
|
|||
if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, data, err := DecodeUDPPacket(b.Bytes())
|
||||
if err != nil {
|
||||
if _, err := DecodeUDPPacket(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Clear()
|
||||
b.Append(data)
|
||||
return buf.NewMultiBufferValue(b), nil
|
||||
}
|
||||
|
||||
|
@ -376,40 +332,6 @@ func (w *UDPWriter) Write(b []byte) (int, error) {
|
|||
return len(b), nil
|
||||
}
|
||||
|
||||
func ReadAddress(b *buf.Buffer, addrType byte, reader io.Reader) (net.Address, net.Port, error) {
|
||||
var address net.Address
|
||||
switch addrType {
|
||||
case addrTypeIPv4:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
address = net.IPAddress(b.BytesFrom(-4))
|
||||
case addrTypeIPv6:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
address = net.IPAddress(b.BytesFrom(-16))
|
||||
case addrTypeDomain:
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
domainLength := int(b.Byte(b.Len() - 1))
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
address = net.DomainAddress(string(b.BytesFrom(-domainLength)))
|
||||
default:
|
||||
return nil, 0, newError("unknown address type: ", addrType)
|
||||
}
|
||||
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
port := net.PortFromBytes(b.BytesFrom(-2))
|
||||
|
||||
return address, port, nil
|
||||
}
|
||||
|
||||
func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer io.Writer) (*protocol.RequestHeader, error) {
|
||||
authByte := byte(authNotRequired)
|
||||
if request.User != nil {
|
||||
|
@ -462,7 +384,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
|
|||
command = byte(cmdUDPPort)
|
||||
}
|
||||
b.AppendBytes(socks5Version, command, 0x00 /* reserved */)
|
||||
if err := AppendAddress(b, request.Address, request.Port); err != nil {
|
||||
if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -471,7 +393,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
|
|||
}
|
||||
|
||||
b.Clear()
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
|
||||
if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -480,11 +402,9 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
|
|||
return nil, newError("server rejects request: ", resp)
|
||||
}
|
||||
|
||||
addrType := b.Byte(3)
|
||||
|
||||
b.Clear()
|
||||
|
||||
address, port, err := ReadAddress(b, addrType, reader)
|
||||
address, port, err := addrParser.ReadAddressPort(b, reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package socks_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"v2ray.com/core/common/buf"
|
||||
|
@ -34,56 +33,3 @@ func TestUDPEncoding(t *testing.T) {
|
|||
assert(err, IsNil)
|
||||
assert(decodedPayload[0].Bytes(), Equals, content)
|
||||
}
|
||||
|
||||
func TestReadAddress(t *testing.T) {
|
||||
assert := With(t)
|
||||
|
||||
data := []struct {
|
||||
AddrType byte
|
||||
Input []byte
|
||||
Address net.Address
|
||||
Port net.Port
|
||||
Error bool
|
||||
}{
|
||||
{
|
||||
AddrType: 0,
|
||||
Input: []byte{0, 0, 0, 0},
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
AddrType: 1,
|
||||
Input: []byte{0, 0, 0, 0, 0, 53},
|
||||
Address: net.IPAddress([]byte{0, 0, 0, 0}),
|
||||
Port: net.Port(53),
|
||||
},
|
||||
{
|
||||
AddrType: 4,
|
||||
Input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 80},
|
||||
Address: net.IPAddress([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}),
|
||||
Port: net.Port(80),
|
||||
},
|
||||
{
|
||||
AddrType: 3,
|
||||
Input: []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 80},
|
||||
Address: net.DomainAddress("v2ray.com"),
|
||||
Port: net.Port(80),
|
||||
},
|
||||
{
|
||||
AddrType: 3,
|
||||
Input: []byte{9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0},
|
||||
Error: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range data {
|
||||
b := buf.New()
|
||||
addr, port, err := ReadAddress(b, tc.AddrType, bytes.NewBuffer(tc.Input))
|
||||
b.Release()
|
||||
if tc.Error {
|
||||
assert(err, IsNotNil)
|
||||
} else {
|
||||
assert(addr, Equals, tc.Address)
|
||||
assert(port, Equals, tc.Port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -185,18 +185,20 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
|
|||
}
|
||||
|
||||
for _, payload := range mpayload {
|
||||
request, data, err := DecodeUDPPacket(payload.Bytes())
|
||||
request, err := DecodeUDPPacket(payload)
|
||||
|
||||
if err != nil {
|
||||
newError("failed to parse UDP request").Base(err).WithContext(ctx).WriteToLog()
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
if payload.IsEmpty() {
|
||||
payload.Release()
|
||||
continue
|
||||
}
|
||||
|
||||
newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug().WithContext(ctx).WriteToLog()
|
||||
newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WithContext(ctx).WriteToLog()
|
||||
if source, ok := proxy.SourceFromContext(ctx); ok {
|
||||
log.Record(&log.AccessMessage{
|
||||
From: source,
|
||||
|
@ -206,9 +208,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
|
|||
})
|
||||
}
|
||||
|
||||
dataBuf := buf.New()
|
||||
dataBuf.Append(data)
|
||||
udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
|
||||
udpServer.Dispatch(ctx, request.Destination(), payload, func(payload *buf.Buffer) {
|
||||
defer payload.Release()
|
||||
|
||||
newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WithContext(ctx).WriteToLog()
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"v2ray.com/core/common/buf"
|
||||
"v2ray.com/core/common/crypto"
|
||||
"v2ray.com/core/common/dice"
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
"v2ray.com/core/common/serial"
|
||||
"v2ray.com/core/proxy/vmess"
|
||||
|
@ -82,23 +81,8 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
|
|||
buffer.AppendBytes(security, byte(0), byte(header.Command))
|
||||
|
||||
if header.Command != protocol.RequestCommandMux {
|
||||
common.Must(buffer.AppendSupplier(serial.WriteUint16(header.Port.Value())))
|
||||
|
||||
switch header.Address.Family() {
|
||||
case net.AddressFamilyIPv4:
|
||||
buffer.AppendBytes(byte(protocol.AddressTypeIPv4))
|
||||
buffer.Append(header.Address.IP())
|
||||
case net.AddressFamilyIPv6:
|
||||
buffer.AppendBytes(byte(protocol.AddressTypeIPv6))
|
||||
buffer.Append(header.Address.IP())
|
||||
case net.AddressFamilyDomain:
|
||||
domain := header.Address.Domain()
|
||||
if protocol.IsDomainTooLong(domain) {
|
||||
return newError("long domain not supported: ", domain)
|
||||
}
|
||||
nDomain := len(domain)
|
||||
buffer.AppendBytes(byte(protocol.AddressTypeDomain), byte(nDomain))
|
||||
common.Must(buffer.AppendSupplier(serial.WriteString(domain)))
|
||||
if err := addrParser.WriteAddressPort(buffer, header.Address, header.Port); err != nil {
|
||||
return newError("failed to writer address and port").Base(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
package encoding
|
||||
|
||||
const (
|
||||
Version = byte(1)
|
||||
)
|
|
@ -1,3 +1,19 @@
|
|||
package encoding
|
||||
|
||||
import (
|
||||
"v2ray.com/core/common/net"
|
||||
"v2ray.com/core/common/protocol"
|
||||
)
|
||||
|
||||
//go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg encoding -path Proxy,VMess,Encoding
|
||||
|
||||
const (
|
||||
Version = byte(1)
|
||||
)
|
||||
|
||||
var addrParser = protocol.NewAddressParser(
|
||||
protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
|
||||
protocol.AddressFamilyByte(0x02, net.AddressFamilyDomain),
|
||||
protocol.AddressFamilyByte(0x03, net.AddressFamilyIPv6),
|
||||
protocol.PortThenAddress(),
|
||||
)
|
||||
|
|
|
@ -105,44 +105,6 @@ func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionH
|
|||
}
|
||||
}
|
||||
|
||||
func readAddress(buffer *buf.Buffer, reader io.Reader) (net.Address, net.Port, error) {
|
||||
var address net.Address
|
||||
var port net.Port
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
|
||||
return address, port, newError("failed to read port and address type").Base(err)
|
||||
}
|
||||
port = net.PortFromBytes(buffer.BytesRange(-3, -1))
|
||||
|
||||
addressType := protocol.AddressType(buffer.Byte(buffer.Len() - 1))
|
||||
switch addressType {
|
||||
case protocol.AddressTypeIPv4:
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
|
||||
return address, port, newError("failed to read IPv4 address").Base(err)
|
||||
}
|
||||
address = net.IPAddress(buffer.BytesFrom(-4))
|
||||
case protocol.AddressTypeIPv6:
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
|
||||
return address, port, newError("failed to read IPv6 address").Base(err)
|
||||
}
|
||||
address = net.IPAddress(buffer.BytesFrom(-16))
|
||||
case protocol.AddressTypeDomain:
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
|
||||
return address, port, newError("failed to read domain address").Base(err)
|
||||
}
|
||||
domainLength := int(buffer.Byte(buffer.Len() - 1))
|
||||
if domainLength == 0 {
|
||||
return address, port, newError("zero length domain")
|
||||
}
|
||||
if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
|
||||
return address, port, newError("failed to read domain address").Base(err)
|
||||
}
|
||||
address = net.DomainAddress(string(buffer.BytesFrom(-domainLength)))
|
||||
default:
|
||||
return address, port, newError("invalid address type", addressType)
|
||||
}
|
||||
return address, port, nil
|
||||
}
|
||||
|
||||
func parseSecurityType(b byte) protocol.SecurityType {
|
||||
if _, f := protocol.SecurityType_name[int32(b)]; f {
|
||||
return protocol.SecurityType(b)
|
||||
|
@ -221,7 +183,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
|||
request.Address = net.DomainAddress("v1.mux.cool")
|
||||
request.Port = 0
|
||||
case protocol.RequestCommandTCP, protocol.RequestCommandUDP:
|
||||
if addr, port, err := readAddress(buffer, decryptor); err == nil {
|
||||
if addr, port, err := addrParser.ReadAddressPort(buffer, decryptor); err == nil {
|
||||
request.Address = addr
|
||||
request.Port = port
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue