From b00ee6736974047aa2d1b994ab683808ffcce0cb Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 7 Jan 2017 21:57:24 +0100 Subject: [PATCH] refine socks udp handling --- proxy/socks/protocol.go | 69 ++++++ proxy/socks/protocol/socks.go | 311 ---------------------------- proxy/socks/protocol/socks4.go | 38 ---- proxy/socks/protocol/socks4_test.go | 39 ---- proxy/socks/protocol/socks_test.go | 177 ---------------- proxy/socks/protocol/udp.go | 91 -------- proxy/socks/protocol/udp_test.go | 36 ---- proxy/socks/server.go | 6 +- proxy/socks/server_udp.go | 46 ++-- testing/scenarios/socks_end_test.go | 6 +- 10 files changed, 94 insertions(+), 725 deletions(-) delete mode 100644 proxy/socks/protocol/socks.go delete mode 100644 proxy/socks/protocol/socks4.go delete mode 100644 proxy/socks/protocol/socks4_test.go delete mode 100644 proxy/socks/protocol/socks_test.go delete mode 100644 proxy/socks/protocol/udp.go delete mode 100644 proxy/socks/protocol/udp_test.go diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index e9526561..dc34edc0 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -272,3 +272,72 @@ func writeSocks4Response(writer io.Writer, errCode byte, address v2net.Address, _, err := writer.Write(buffer.Bytes()) return err } + +func DecodeUDPPacket(packet []byte) (*protocol.RequestHeader, []byte, error) { + if len(packet) < 5 { + return nil, nil, errors.New("Socks|UDP: Insufficient length of packet.") + } + request := &protocol.RequestHeader{ + Version: socks5Version, + Command: protocol.RequestCommandUDP, + } + + // packet[0] and packet[1] are reserved + if packet[2] != 0 /* fragments */ { + return nil, nil, errors.New("Socks|UDP: Fragmented payload.") + } + + addrType := packet[3] + var dataBegin int + + switch addrType { + case addrTypeIPv4: + if len(packet) < 10 { + return nil, nil, errors.New("Socks|UDP: Insufficient length of packet.") + } + ip := packet[4:8] + request.Port = v2net.PortFromBytes(packet[8:10]) + request.Address = v2net.IPAddress(ip) + dataBegin = 10 + case addrTypeIPv6: + if len(packet) < 22 { + return nil, nil, errors.New("Socks|UDP: Insufficient length of packet.") + } + ip := packet[4:20] + request.Port = v2net.PortFromBytes(packet[20:22]) + request.Address = v2net.IPAddress(ip) + dataBegin = 22 + case addrTypeDomain: + domainLength := int(packet[4]) + if len(packet) < 5+domainLength+2 { + return nil, nil, errors.New("Socks|UDP: Insufficient length of packet.") + } + domain := string(packet[5 : 5+domainLength]) + request.Port = v2net.PortFromBytes(packet[5+domainLength : 5+domainLength+2]) + request.Address = v2net.ParseAddress(domain) + dataBegin = 5 + domainLength + 2 + default: + return nil, nil, errors.New("Socks|UDP: Unknown address type ", addrType) + } + + return request, packet[dataBegin:], nil +} + +func EncodeUDPPacket(request *protocol.RequestHeader, data []byte) *buf.Buffer { + b := buf.NewSmall() + b.AppendBytes(0, 0, 0 /* Fragment */) + switch request.Address.Family() { + case v2net.AddressFamilyIPv4: + b.AppendBytes(addrTypeIPv4) + b.Append(request.Address.IP()) + case v2net.AddressFamilyIPv6: + b.AppendBytes(addrTypeIPv6) + b.Append(request.Address.IP()) + case v2net.AddressFamilyDomain: + b.AppendBytes(addrTypeDomain, byte(len(request.Address.Domain()))) + b.AppendSupplier(serial.WriteString(request.Address.Domain())) + } + b.AppendSupplier(serial.WriteUint16(request.Port.Value())) + b.Append(data) + return b +} diff --git a/proxy/socks/protocol/socks.go b/proxy/socks/protocol/socks.go deleted file mode 100644 index 3e5ec250..00000000 --- a/proxy/socks/protocol/socks.go +++ /dev/null @@ -1,311 +0,0 @@ -package protocol - -import ( - "io" - - "v2ray.com/core/common/buf" - "v2ray.com/core/common/crypto" - "v2ray.com/core/common/errors" - "v2ray.com/core/common/log" - v2net "v2ray.com/core/common/net" - "v2ray.com/core/proxy" -) - -const ( - socksVersion = byte(0x05) - socks4Version = byte(0x04) - - AuthNotRequired = byte(0x00) - AuthGssApi = byte(0x01) - AuthUserPass = byte(0x02) - AuthNoMatchingMethod = byte(0xFF) - - Socks4RequestGranted = byte(90) - Socks4RequestRejected = byte(91) -) - -// Authentication request header of Socks5 protocol -type Socks5AuthenticationRequest struct { - version byte - nMethods byte - authMethods [256]byte -} - -func (request *Socks5AuthenticationRequest) HasAuthMethod(method byte) bool { - for i := 0; i < int(request.nMethods); i++ { - if request.authMethods[i] == method { - return true - } - } - return false -} - -func ReadAuthentication(reader io.Reader) (auth Socks5AuthenticationRequest, auth4 Socks4AuthenticationRequest, err error) { - buffer := make([]byte, 256) - - nBytes, err := reader.Read(buffer) - if err != nil { - return - } - if nBytes < 2 { - err = errors.New("Socks: Insufficient header.") - return - } - - if buffer[0] == socks4Version { - auth4.Version = buffer[0] - auth4.Command = buffer[1] - auth4.Port = v2net.PortFromBytes(buffer[2:4]) - copy(auth4.IP[:], buffer[4:8]) - err = Socks4Downgrade - return - } - - auth.version = buffer[0] - if auth.version != socksVersion { - log.Warning("Socks: Unknown protocol version ", auth.version) - err = proxy.ErrInvalidProtocolVersion - return - } - - auth.nMethods = buffer[1] - if auth.nMethods <= 0 { - log.Warning("Socks: Zero length of authentication methods") - err = crypto.ErrAuthenticationFailed - return - } - - if nBytes-2 != int(auth.nMethods) { - log.Warning("Socks: Unmatching number of auth methods, expecting ", auth.nMethods, ", but got ", nBytes) - err = crypto.ErrAuthenticationFailed - return - } - copy(auth.authMethods[:], buffer[2:nBytes]) - return -} - -type Socks5AuthenticationResponse struct { - version byte - authMethod byte -} - -func NewAuthenticationResponse(authMethod byte) *Socks5AuthenticationResponse { - return &Socks5AuthenticationResponse{ - version: socksVersion, - authMethod: authMethod, - } -} - -func WriteAuthentication(writer io.Writer, r *Socks5AuthenticationResponse) error { - _, err := writer.Write([]byte{r.version, r.authMethod}) - return err -} - -type Socks5UserPassRequest struct { - version byte - username string - password string -} - -func (request Socks5UserPassRequest) Username() string { - return request.username -} - -func (request Socks5UserPassRequest) Password() string { - return request.password -} - -func (request Socks5UserPassRequest) AuthDetail() string { - return request.username + ":" + request.password -} - -func ReadUserPassRequest(reader io.Reader) (request Socks5UserPassRequest, err error) { - buffer := buf.NewLocal(512) - defer buffer.Release() - - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)) - if err != nil { - return - } - request.version = buffer.Byte(0) - nUsername := int(buffer.Byte(1)) - - buffer.Clear() - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, nUsername)) - if err != nil { - return - } - request.username = buffer.String() - - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)) - if err != nil { - return - } - nPassword := int(buffer.Byte(0)) - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, nPassword)) - if err != nil { - return - } - request.password = buffer.String() - return -} - -type Socks5UserPassResponse struct { - version byte - status byte -} - -func NewSocks5UserPassResponse(status byte) Socks5UserPassResponse { - return Socks5UserPassResponse{ - version: socksVersion, - status: status, - } -} - -func WriteUserPassResponse(writer io.Writer, response Socks5UserPassResponse) error { - _, err := writer.Write([]byte{response.version, response.status}) - return err -} - -const ( - AddrTypeIPv4 = byte(0x01) - AddrTypeIPv6 = byte(0x04) - AddrTypeDomain = byte(0x03) - - CmdConnect = byte(0x01) - CmdBind = byte(0x02) - CmdUdpAssociate = byte(0x03) -) - -type Socks5Request struct { - Version byte - Command byte - AddrType byte - IPv4 [4]byte - Domain string - IPv6 [16]byte - Port v2net.Port -} - -func ReadRequest(reader io.Reader) (request *Socks5Request, err error) { - buffer := buf.NewLocal(512) - defer buffer.Release() - - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)) - if err != nil { - return - } - - request = &Socks5Request{ - Version: buffer.Byte(0), - Command: buffer.Byte(1), - // buffer[2] is a reserved field - AddrType: buffer.Byte(3), - } - switch request.AddrType { - case AddrTypeIPv4: - _, err = io.ReadFull(reader, request.IPv4[:]) - if err != nil { - return - } - case AddrTypeDomain: - buffer.Clear() - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)) - if err != nil { - return - } - domainLength := int(buffer.Byte(0)) - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, domainLength)) - if err != nil { - return - } - - request.Domain = string(buffer.BytesFrom(-domainLength)) - case AddrTypeIPv6: - _, err = io.ReadFull(reader, request.IPv6[:]) - if err != nil { - return - } - default: - err = errors.Format("Socks: Unexpected address type %d", request.AddrType) - return - } - - err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)) - if err != nil { - return - } - - request.Port = v2net.PortFromBytes(buffer.BytesFrom(-2)) - return -} - -func (request *Socks5Request) Destination() v2net.Destination { - switch request.AddrType { - case AddrTypeIPv4: - return v2net.TCPDestination(v2net.IPAddress(request.IPv4[:]), request.Port) - case AddrTypeIPv6: - return v2net.TCPDestination(v2net.IPAddress(request.IPv6[:]), request.Port) - case AddrTypeDomain: - return v2net.TCPDestination(v2net.ParseAddress(request.Domain), request.Port) - default: - panic("Unknown address type") - } -} - -const ( - ErrorSuccess = byte(0x00) - ErrorGeneralFailure = byte(0x01) - ErrorConnectionNotAllowed = byte(0x02) - ErrorNetworkUnreachable = byte(0x03) - ErrorHostUnUnreachable = byte(0x04) - ErrorConnectionRefused = byte(0x05) - ErrorTTLExpired = byte(0x06) - ErrorCommandNotSupported = byte(0x07) - ErrorAddressTypeNotSupported = byte(0x08) -) - -type Socks5Response struct { - Version byte - Error byte - AddrType byte - IPv4 [4]byte - Domain string - IPv6 [16]byte - Port v2net.Port -} - -func NewSocks5Response() *Socks5Response { - return &Socks5Response{ - Version: socksVersion, - } -} - -func (r *Socks5Response) SetIPv4(ipv4 []byte) { - r.AddrType = AddrTypeIPv4 - copy(r.IPv4[:], ipv4) -} - -func (r *Socks5Response) SetIPv6(ipv6 []byte) { - r.AddrType = AddrTypeIPv6 - copy(r.IPv6[:], ipv6) -} - -func (r *Socks5Response) SetDomain(domain string) { - r.AddrType = AddrTypeDomain - r.Domain = domain -} - -func (r *Socks5Response) Write(writer io.Writer) { - writer.Write([]byte{r.Version, r.Error, 0x00 /* reserved */, r.AddrType}) - switch r.AddrType { - case 0x01: - writer.Write(r.IPv4[:]) - case 0x03: - writer.Write([]byte{byte(len(r.Domain))}) - writer.Write([]byte(r.Domain)) - case 0x04: - writer.Write(r.IPv6[:]) - } - writer.Write(r.Port.Bytes(nil)) -} diff --git a/proxy/socks/protocol/socks4.go b/proxy/socks/protocol/socks4.go deleted file mode 100644 index 7bb0ed4e..00000000 --- a/proxy/socks/protocol/socks4.go +++ /dev/null @@ -1,38 +0,0 @@ -package protocol - -import ( - "io" - "v2ray.com/core/common/errors" - v2net "v2ray.com/core/common/net" -) - -var ( - Socks4Downgrade = errors.New("Downgraded to Socks 4.") -) - -type Socks4AuthenticationRequest struct { - Version byte - Command byte - Port v2net.Port - IP [4]byte -} - -type Socks4AuthenticationResponse struct { - result byte - port uint16 - ip []byte -} - -func NewSocks4AuthenticationResponse(result byte, port v2net.Port, ip []byte) *Socks4AuthenticationResponse { - return &Socks4AuthenticationResponse{ - result: result, - port: port.Value(), - ip: ip, - } -} - -func (r *Socks4AuthenticationResponse) Write(writer io.Writer) { - writer.Write([]byte{ - byte(0x00), r.result, byte(r.port >> 8), byte(r.port), - r.ip[0], r.ip[1], r.ip[2], r.ip[3]}) -} diff --git a/proxy/socks/protocol/socks4_test.go b/proxy/socks/protocol/socks4_test.go deleted file mode 100644 index 75b9b898..00000000 --- a/proxy/socks/protocol/socks4_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package protocol - -import ( - "bytes" - "testing" - - "v2ray.com/core/common/buf" - v2net "v2ray.com/core/common/net" - "v2ray.com/core/testing/assert" -) - -func TestSocks4AuthenticationRequestRead(t *testing.T) { - assert := assert.On(t) - - rawRequest := []byte{ - 0x04, // version - 0x01, // command - 0x00, 0x35, - 0x72, 0x72, 0x72, 0x72, - } - _, request4, err := ReadAuthentication(bytes.NewReader(rawRequest)) - assert.Error(err).Equals(Socks4Downgrade) - assert.Byte(request4.Version).Equals(0x04) - assert.Byte(request4.Command).Equals(0x01) - assert.Port(request4.Port).Equals(v2net.Port(53)) - assert.Bytes(request4.IP[:]).Equals([]byte{0x72, 0x72, 0x72, 0x72}) -} - -func TestSocks4AuthenticationResponseToBytes(t *testing.T) { - assert := assert.On(t) - - response := NewSocks4AuthenticationResponse(byte(0x10), 443, []byte{1, 2, 3, 4}) - - buffer := buf.NewLocal(2048) - defer buffer.Release() - - response.Write(buffer) - assert.Bytes(buffer.Bytes()).Equals([]byte{0x00, 0x10, 0x01, 0xBB, 0x01, 0x02, 0x03, 0x04}) -} diff --git a/proxy/socks/protocol/socks_test.go b/proxy/socks/protocol/socks_test.go deleted file mode 100644 index 45f11930..00000000 --- a/proxy/socks/protocol/socks_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package protocol - -import ( - "bytes" - "io" - "testing" - - "v2ray.com/core/common/buf" - "v2ray.com/core/common/crypto" - v2net "v2ray.com/core/common/net" - "v2ray.com/core/proxy" - "v2ray.com/core/testing/assert" -) - -func TestHasAuthenticationMethod(t *testing.T) { - assert := assert.On(t) - - request := Socks5AuthenticationRequest{ - version: socksVersion, - nMethods: byte(0x02), - authMethods: [256]byte{0x01, 0x02}, - } - - assert.Bool(request.HasAuthMethod(byte(0x01))).IsTrue() - - request.authMethods[0] = byte(0x03) - assert.Bool(request.HasAuthMethod(byte(0x01))).IsFalse() -} - -func TestAuthenticationRequestRead(t *testing.T) { - assert := assert.On(t) - - buffer := buf.New() - buffer.AppendBytes( - 0x05, // version - 0x01, // nMethods - 0x02, // methods - ) - request, _, err := ReadAuthentication(buffer) - assert.Error(err).IsNil() - assert.Byte(request.version).Equals(0x05) - assert.Byte(request.nMethods).Equals(0x01) - assert.Byte(request.authMethods[0]).Equals(0x02) -} - -func TestAuthenticationResponseWrite(t *testing.T) { - assert := assert.On(t) - - response := NewAuthenticationResponse(byte(0x05)) - - buffer := bytes.NewBuffer(make([]byte, 0, 10)) - WriteAuthentication(buffer, response) - assert.Bytes(buffer.Bytes()).Equals([]byte{socksVersion, byte(0x05)}) -} - -func TestRequestRead(t *testing.T) { - assert := assert.On(t) - - rawRequest := []byte{ - 0x05, // version - 0x01, // cmd connect - 0x00, // reserved - 0x01, // ipv4 type - 0x72, 0x72, 0x72, 0x72, // 114.114.114.114 - 0x00, 0x35, // port 53 - } - request, err := ReadRequest(bytes.NewReader(rawRequest)) - assert.Error(err).IsNil() - assert.Byte(request.Version).Equals(0x05) - assert.Byte(request.Command).Equals(0x01) - assert.Byte(request.AddrType).Equals(0x01) - assert.Bytes(request.IPv4[:]).Equals([]byte{0x72, 0x72, 0x72, 0x72}) - assert.Port(request.Port).Equals(v2net.Port(53)) -} - -func TestResponseWrite(t *testing.T) { - assert := assert.On(t) - - response := Socks5Response{ - socksVersion, - ErrorSuccess, - AddrTypeIPv4, - [4]byte{0x72, 0x72, 0x72, 0x72}, - "", - [16]byte{}, - v2net.Port(53), - } - buffer := buf.NewLocal(2048) - defer buffer.Release() - - response.Write(buffer) - expectedBytes := []byte{ - socksVersion, - ErrorSuccess, - byte(0x00), - AddrTypeIPv4, - 0x72, 0x72, 0x72, 0x72, - byte(0x00), byte(0x035), - } - assert.Bytes(buffer.Bytes()).Equals(expectedBytes) -} - -func TestSetIPv6(t *testing.T) { - assert := assert.On(t) - - response := NewSocks5Response() - response.SetIPv6([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}) - - buffer := buf.NewLocal(2048) - defer buffer.Release() - response.Write(buffer) - assert.Bytes(buffer.Bytes()).Equals([]byte{ - socksVersion, 0, 0, AddrTypeIPv6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0}) -} - -func TestSetDomain(t *testing.T) { - assert := assert.On(t) - - response := NewSocks5Response() - response.SetDomain("v2ray.com") - - buffer := buf.NewLocal(2048) - defer buffer.Release() - response.Write(buffer) - assert.Bytes(buffer.Bytes()).Equals([]byte{ - socksVersion, 0, 0, AddrTypeDomain, 9, 118, 50, 114, 97, 121, 46, 99, 111, 109, 0, 0}) -} - -func TestEmptyAuthRequest(t *testing.T) { - assert := assert.On(t) - - _, _, err := ReadAuthentication(buf.New()) - assert.Error(err).Equals(io.EOF) -} - -func TestSingleByteAuthRequest(t *testing.T) { - assert := assert.On(t) - - _, _, err := ReadAuthentication(bytes.NewReader(make([]byte, 1))) - assert.Error(err).IsNotNil() -} - -func TestZeroAuthenticationMethod(t *testing.T) { - assert := assert.On(t) - - buffer := buf.New() - buffer.AppendBytes(5, 0) - _, _, err := ReadAuthentication(buffer) - assert.Error(err).Equals(crypto.ErrAuthenticationFailed) -} -func TestWrongProtocolVersion(t *testing.T) { - assert := assert.On(t) - - buffer := buf.New() - buffer.AppendBytes(6, 1, 0) - _, _, err := ReadAuthentication(buffer) - assert.Error(err).Equals(proxy.ErrInvalidProtocolVersion) -} - -func TestEmptyRequest(t *testing.T) { - assert := assert.On(t) - - _, err := ReadRequest(buf.New()) - assert.Error(err).Equals(io.EOF) -} - -func TestIPv6Request(t *testing.T) { - assert := assert.On(t) - - b := buf.New() - b.AppendBytes(5, 1, 0, 4, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 0, 8) - request, err := ReadRequest(b) - assert.Error(err).IsNil() - assert.Byte(request.Command).Equals(1) - assert.Bytes(request.IPv6[:]).Equals([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}) - assert.Port(request.Port).Equals(8) -} diff --git a/proxy/socks/protocol/udp.go b/proxy/socks/protocol/udp.go deleted file mode 100644 index f4e142a6..00000000 --- a/proxy/socks/protocol/udp.go +++ /dev/null @@ -1,91 +0,0 @@ -package protocol - -import ( - "v2ray.com/core/common/buf" - "v2ray.com/core/common/errors" - v2net "v2ray.com/core/common/net" - "v2ray.com/core/common/serial" -) - -var ( - ErrorUnknownAddressType = errors.New("Unknown Address Type.") -) - -type Socks5UDPRequest struct { - Fragment byte - Address v2net.Address - Port v2net.Port - Data *buf.Buffer -} - -func (request *Socks5UDPRequest) Destination() v2net.Destination { - return v2net.UDPDestination(request.Address, request.Port) -} - -func (request *Socks5UDPRequest) Write(buffer *buf.Buffer) { - buffer.AppendBytes(0, 0, request.Fragment) - switch request.Address.Family() { - case v2net.AddressFamilyIPv4: - buffer.AppendBytes(AddrTypeIPv4) - buffer.Append(request.Address.IP()) - case v2net.AddressFamilyIPv6: - buffer.AppendBytes(AddrTypeIPv6) - buffer.Append(request.Address.IP()) - case v2net.AddressFamilyDomain: - buffer.AppendBytes(AddrTypeDomain, byte(len(request.Address.Domain()))) - buffer.Append([]byte(request.Address.Domain())) - } - buffer.AppendSupplier(serial.WriteUint16(request.Port.Value())) - buffer.Append(request.Data.Bytes()) -} - -func ReadUDPRequest(packet []byte) (*Socks5UDPRequest, error) { - if len(packet) < 5 { - return nil, errors.New("Socks|UDP: Insufficient length of packet.") - } - request := new(Socks5UDPRequest) - - // packet[0] and packet[1] are reserved - request.Fragment = packet[2] - - addrType := packet[3] - var dataBegin int - - switch addrType { - case AddrTypeIPv4: - if len(packet) < 10 { - return nil, errors.New("Socks|UDP: Insufficient length of packet.") - } - ip := packet[4:8] - request.Port = v2net.PortFromBytes(packet[8:10]) - request.Address = v2net.IPAddress(ip) - dataBegin = 10 - case AddrTypeIPv6: - if len(packet) < 22 { - return nil, errors.New("Socks|UDP: Insufficient length of packet.") - } - ip := packet[4:20] - request.Port = v2net.PortFromBytes(packet[20:22]) - request.Address = v2net.IPAddress(ip) - dataBegin = 22 - case AddrTypeDomain: - domainLength := int(packet[4]) - if len(packet) < 5+domainLength+2 { - return nil, errors.New("Socks|UDP: Insufficient length of packet.") - } - domain := string(packet[5 : 5+domainLength]) - request.Port = v2net.PortFromBytes(packet[5+domainLength : 5+domainLength+2]) - request.Address = v2net.ParseAddress(domain) - dataBegin = 5 + domainLength + 2 - default: - return nil, errors.Format("Socks|UDP: Unknown address type %d", addrType) - } - - if len(packet) > dataBegin { - b := buf.NewSmall() - b.Append(packet[dataBegin:]) - request.Data = b - } - - return request, nil -} diff --git a/proxy/socks/protocol/udp_test.go b/proxy/socks/protocol/udp_test.go deleted file mode 100644 index 576a29aa..00000000 --- a/proxy/socks/protocol/udp_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package protocol - -import ( - "testing" - - v2net "v2ray.com/core/common/net" - "v2ray.com/core/testing/assert" -) - -func TestSingleByteUDPRequest(t *testing.T) { - assert := assert.On(t) - - request, err := ReadUDPRequest(make([]byte, 1)) - if request != nil { - t.Fail() - } - assert.Error(err).IsNotNil() -} - -func TestDomainAddressRequest(t *testing.T) { - assert := assert.On(t) - - payload := make([]byte, 0, 1024) - payload = append(payload, 0, 0, 1, AddrTypeDomain, byte(len("v2ray.com"))) - payload = append(payload, []byte("v2ray.com")...) - payload = append(payload, 0, 80) - payload = append(payload, []byte("Actual payload")...) - - request, err := ReadUDPRequest(payload) - assert.Error(err).IsNil() - - assert.Byte(request.Fragment).Equals(1) - assert.Address(request.Address).EqualsString("v2ray.com") - assert.Port(request.Port).Equals(v2net.Port(80)) - assert.String(request.Data.String()).Equals("Actual payload") -} diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 91dacb1c..cc36614a 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -13,7 +13,7 @@ import ( "v2ray.com/core/common/errors" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" - proto "v2ray.com/core/common/protocol" + "v2ray.com/core/common/protocol" "v2ray.com/core/common/serial" "v2ray.com/core/common/signal" "v2ray.com/core/proxy" @@ -120,7 +120,7 @@ func (v *Server) handleConnection(connection internet.Connection) { return } - if request.Command == proto.RequestCommandTCP { + if request.Command == protocol.RequestCommandTCP { dest := request.Destination() session := &proxy.SessionInfo{ Source: clientAddr, @@ -134,7 +134,7 @@ func (v *Server) handleConnection(connection internet.Connection) { return } - if request.Command == proto.RequestCommandUDP { + if request.Command == protocol.RequestCommandUDP { v.handleUDP() return } diff --git a/proxy/socks/server_udp.go b/proxy/socks/server_udp.go index b05e2b72..3677cdc3 100644 --- a/proxy/socks/server_udp.go +++ b/proxy/socks/server_udp.go @@ -5,7 +5,6 @@ import ( "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" "v2ray.com/core/proxy" - "v2ray.com/core/proxy/socks/protocol" "v2ray.com/core/transport/internet/udp" ) @@ -24,39 +23,33 @@ func (v *Server) listenUDP() error { } func (v *Server) handleUDPPayload(payload *buf.Buffer, session *proxy.SessionInfo) { + defer payload.Release() + source := session.Source log.Info("Socks: Client UDP connection from ", source) - request, err := protocol.ReadUDPRequest(payload.Bytes()) - payload.Release() + request, data, err := DecodeUDPPacket(payload.Bytes()) if err != nil { - log.Error("Socks: Failed to parse UDP request: ", err) - return - } - if request.Data.Len() == 0 { - request.Data.Release() - return - } - if request.Fragment != 0 { - log.Warning("Socks: Dropping fragmented UDP packets.") - // TODO handle fragments - request.Data.Release() + log.Error("Socks|Server: Failed to parse UDP request: ", err) return } - log.Info("Socks: Send packet to ", request.Destination(), " with ", request.Data.Len(), " bytes") + if len(data) == 0 { + return + } + + log.Info("Socks: Send packet to ", request.Destination(), " with ", len(data), " bytes") log.Access(source, request.Destination, log.AccessAccepted, "") - v.udpServer.Dispatch(&proxy.SessionInfo{Source: source, Destination: request.Destination(), Inbound: v.meta}, request.Data, func(destination v2net.Destination, payload *buf.Buffer) { - response := &protocol.Socks5UDPRequest{ - Fragment: 0, - Address: request.Destination().Address, - Port: request.Destination().Port, - Data: payload, - } + + dataBuf := buf.NewSmall() + dataBuf.Append(data) + v.udpServer.Dispatch(&proxy.SessionInfo{Source: source, Destination: request.Destination(), Inbound: v.meta}, dataBuf, func(destination v2net.Destination, payload *buf.Buffer) { + defer payload.Release() + log.Info("Socks: Writing back UDP response with ", payload.Len(), " bytes to ", destination) - udpMessage := buf.NewLocal(2048) - response.Write(udpMessage) + udpMessage := EncodeUDPPacket(request, payload.Bytes()) + defer udpMessage.Release() v.udpMutex.RLock() if !v.accepting { @@ -65,10 +58,9 @@ func (v *Server) handleUDPPayload(payload *buf.Buffer, session *proxy.SessionInf } nBytes, err := v.udpHub.WriteTo(udpMessage.Bytes(), destination) v.udpMutex.RUnlock() - udpMessage.Release() - response.Data.Release() + if err != nil { - log.Error("Socks: failed to write UDP message (", nBytes, " bytes) to ", destination, ": ", err) + log.Warning("Socks: failed to write UDP message (", nBytes, " bytes) to ", destination, ": ", err) } }) } diff --git a/testing/scenarios/socks_end_test.go b/testing/scenarios/socks_end_test.go index 71718339..23979064 100644 --- a/testing/scenarios/socks_end_test.go +++ b/testing/scenarios/socks_end_test.go @@ -200,7 +200,7 @@ func TestUDPAssociate(t *testing.T) { return buffer }, } - _, err := udpServer.Start() + udpServerAddr, err := udpServer.Start() assert.Error(err).IsNil() defer udpServer.Close() @@ -223,7 +223,7 @@ func TestUDPAssociate(t *testing.T) { assert.Error(err).IsNil() assert.Bytes(authResponse[:nBytes]).Equals([]byte{socks5Version, 0}) - connectRequest := socks5Request(byte(3), v2net.TCPDestination(v2net.IPAddress([]byte{127, 0, 0, 1}), udpServer.Port)) + connectRequest := socks5Request(byte(3), v2net.TCPDestination(v2net.LocalHostIP, udpServer.Port)) nBytes, err = conn.Write(connectRequest) assert.Int(nBytes).Equals(len(connectRequest)) assert.Error(err).IsNil() @@ -241,7 +241,7 @@ func TestUDPAssociate(t *testing.T) { for i := 0; i < 100; i++ { udpPayload := "UDP request to udp server." - udpRequest := socks5UDPRequest(v2net.UDPDestination(v2net.LocalHostIP, udpServer.Port), []byte(udpPayload)) + udpRequest := socks5UDPRequest(udpServerAddr, []byte(udpPayload)) nBytes, err = udpConn.Write(udpRequest) assert.Int(nBytes).Equals(len(udpRequest))