fix domain length issue for all proxies

pull/642/merge v2.42
Darien Raymond 2017-10-22 20:17:06 +02:00
parent 9f392eb506
commit 26818a2602
7 changed files with 36 additions and 19 deletions

View File

@ -81,11 +81,10 @@ func (f FrameMetadata) AsSupplier() buf.Supplier {
length += 17 length += 17
case net.AddressFamilyDomain: case net.AddressFamilyDomain:
domain := addr.Domain() domain := addr.Domain()
nDomain := len(domain) if protocol.IsDomainTooLong(domain) {
if nDomain > 256 { return 0, newError("domain name too long: ", domain)
nDomain = 256
domain = domain[:256]
} }
nDomain := len(domain)
b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain)) b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain))
b = append(b, domain...) b = append(b, domain...)
length += nDomain + 2 length += nDomain + 2

View File

@ -97,3 +97,7 @@ func (sc *SecurityConfig) AsSecurity() Security {
} }
return NormSecurity(Security(sc.Type)) return NormSecurity(Security(sc.Type))
} }
func IsDomainTooLong(domain string) bool {
return len(domain) > 256
}

View File

@ -5,6 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"io" "io"
"v2ray.com/core/common"
"v2ray.com/core/common/bitmask" "v2ray.com/core/common/bitmask"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/crypto" "v2ray.com/core/common/crypto"
@ -160,19 +161,23 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
header.AppendBytes(AddrTypeIPv6) header.AppendBytes(AddrTypeIPv6)
header.Append([]byte(request.Address.IP())) header.Append([]byte(request.Address.IP()))
case net.AddressFamilyDomain: case net.AddressFamilyDomain:
header.AppendBytes(AddrTypeDomain, byte(len(request.Address.Domain()))) domain := request.Address.Domain()
header.Append([]byte(request.Address.Domain())) if protocol.IsDomainTooLong(domain) {
return nil, newError("domain name too long: ", domain)
}
header.AppendBytes(AddrTypeDomain, byte(len(domain)))
common.Must(header.AppendSupplier(serial.WriteString(domain)))
default: default:
return nil, newError("unsupported address type: ", request.Address.Family()) return nil, newError("unsupported address type: ", request.Address.Family())
} }
header.AppendSupplier(serial.WriteUint16(uint16(request.Port))) common.Must(header.AppendSupplier(serial.WriteUint16(uint16(request.Port))))
if request.Option.Has(RequestOptionOneTimeAuth) { if request.Option.Has(RequestOptionOneTimeAuth) {
header.SetByte(0, header.Byte(0)|0x10) header.SetByte(0, header.Byte(0)|0x10)
authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
header.AppendSupplier(authenticator.Authenticate(header.Bytes())) common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes())))
} }
_, err = writer.Write(header.Bytes()) _, err = writer.Write(header.Bytes())

View File

@ -3,6 +3,7 @@ package socks
import ( import (
"io" "io"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
@ -253,14 +254,13 @@ func appendAddress(buffer *buf.Buffer, address net.Address, port net.Port) error
buffer.AppendBytes(0x04) buffer.AppendBytes(0x04)
buffer.Append(address.IP()) buffer.Append(address.IP())
case net.AddressFamilyDomain: case net.AddressFamilyDomain:
n := byte(len(address.Domain())) if protocol.IsDomainTooLong(address.Domain()) {
if int(n) != len(address.Domain()) { return newError("Super long domain is not supported in Socks protocol: ", address.Domain())
return newError("Super long domain is not supported in Socks protocol. ", address.Domain())
} }
buffer.AppendBytes(0x03, byte(len(address.Domain()))) buffer.AppendBytes(0x03, byte(len(address.Domain())))
buffer.AppendSupplier(serial.WriteString(address.Domain())) common.Must(buffer.AppendSupplier(serial.WriteString(address.Domain())))
} }
buffer.AppendSupplier(serial.WriteUint16(port.Value())) common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
return nil return nil
} }

View File

@ -60,12 +60,12 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession {
return session return session
} }
func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) { func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error {
timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)()
account, err := header.User.GetTypedAccount() account, err := header.User.GetTypedAccount()
if err != nil { if err != nil {
log.Trace(newError("failed to get user account: ", err).AtError()) log.Trace(newError("failed to get user account: ", err).AtError())
return return nil
} }
idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes()) idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes())
common.Must2(idHash.Write(timestamp.Bytes(nil))) common.Must2(idHash.Write(timestamp.Bytes(nil)))
@ -95,8 +95,13 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
buffer = append(buffer, byte(protocol.AddressTypeIPv6)) buffer = append(buffer, byte(protocol.AddressTypeIPv6))
buffer = append(buffer, header.Address.IP()...) buffer = append(buffer, header.Address.IP()...)
case net.AddressFamilyDomain: case net.AddressFamilyDomain:
buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(len(header.Address.Domain()))) domain := header.Address.Domain()
buffer = append(buffer, header.Address.Domain()...) if protocol.IsDomainTooLong(domain) {
return newError("long domain not supported: ", domain)
}
nDomain := len(domain)
buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(nDomain))
buffer = append(buffer, domain...)
} }
} }
@ -117,6 +122,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv) aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv)
aesStream.XORKeyStream(buffer, buffer) aesStream.XORKeyStream(buffer, buffer)
common.Must2(writer.Write(buffer)) common.Must2(writer.Write(buffer))
return nil
} }
func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"testing" "testing"
"v2ray.com/core/common"
"v2ray.com/core/common/buf" "v2ray.com/core/common/buf"
"v2ray.com/core/common/net" "v2ray.com/core/common/net"
"v2ray.com/core/common/protocol" "v2ray.com/core/common/protocol"
@ -38,7 +39,7 @@ func TestRequestSerialization(t *testing.T) {
buffer := buf.New() buffer := buf.New()
client := NewClientSession(protocol.DefaultIDHash) client := NewClientSession(protocol.DefaultIDHash)
client.EncodeRequestHeader(expectedRequest, buffer) common.Must(client.EncodeRequestHeader(expectedRequest, buffer))
buffer2 := buf.New() buffer2 := buf.New()
buffer2.Append(buffer.Bytes()) buffer2.Append(buffer.Bytes())

View File

@ -108,7 +108,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
requestDone := signal.ExecuteAsync(func() error { requestDone := signal.ExecuteAsync(func() error {
writer := buf.NewBufferedWriter(conn) writer := buf.NewBufferedWriter(conn)
session.EncodeRequestHeader(request, writer) if err := session.EncodeRequestHeader(request, writer); err != nil {
return newError("failed to encode request").Base(err).AtWarning()
}
bodyWriter := session.EncodeRequestBody(request, writer) bodyWriter := session.EncodeRequestBody(request, writer)
firstPayload, err := input.ReadTimeout(time.Millisecond * 500) firstPayload, err := input.ReadTimeout(time.Millisecond * 500)