From 26818a2602b85a7960c97a467e78be1147a11ea4 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 22 Oct 2017 20:17:06 +0200 Subject: [PATCH] fix domain length issue for all proxies --- app/proxyman/mux/frame.go | 7 +++---- common/protocol/headers.go | 4 ++++ proxy/shadowsocks/protocol.go | 13 +++++++++---- proxy/socks/protocol.go | 10 +++++----- proxy/vmess/encoding/client.go | 14 ++++++++++---- proxy/vmess/encoding/encoding_test.go | 3 ++- proxy/vmess/outbound/outbound.go | 4 +++- 7 files changed, 36 insertions(+), 19 deletions(-) diff --git a/app/proxyman/mux/frame.go b/app/proxyman/mux/frame.go index 0824cff8..462dd0da 100644 --- a/app/proxyman/mux/frame.go +++ b/app/proxyman/mux/frame.go @@ -81,11 +81,10 @@ func (f FrameMetadata) AsSupplier() buf.Supplier { length += 17 case net.AddressFamilyDomain: domain := addr.Domain() - nDomain := len(domain) - if nDomain > 256 { - nDomain = 256 - domain = domain[:256] + if protocol.IsDomainTooLong(domain) { + return 0, newError("domain name too long: ", domain) } + nDomain := len(domain) b = append(b, byte(protocol.AddressTypeDomain), byte(nDomain)) b = append(b, domain...) length += nDomain + 2 diff --git a/common/protocol/headers.go b/common/protocol/headers.go index fb64ce71..ecf798f0 100644 --- a/common/protocol/headers.go +++ b/common/protocol/headers.go @@ -97,3 +97,7 @@ func (sc *SecurityConfig) AsSecurity() Security { } return NormSecurity(Security(sc.Type)) } + +func IsDomainTooLong(domain string) bool { + return len(domain) > 256 +} diff --git a/proxy/shadowsocks/protocol.go b/proxy/shadowsocks/protocol.go index 25a360fa..529b2cb2 100644 --- a/proxy/shadowsocks/protocol.go +++ b/proxy/shadowsocks/protocol.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "io" + "v2ray.com/core/common" "v2ray.com/core/common/bitmask" "v2ray.com/core/common/buf" "v2ray.com/core/common/crypto" @@ -160,19 +161,23 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri header.AppendBytes(AddrTypeIPv6) header.Append([]byte(request.Address.IP())) case net.AddressFamilyDomain: - header.AppendBytes(AddrTypeDomain, byte(len(request.Address.Domain()))) - header.Append([]byte(request.Address.Domain())) + domain := request.Address.Domain() + if protocol.IsDomainTooLong(domain) { + return nil, newError("domain name too long: ", domain) + } + header.AppendBytes(AddrTypeDomain, byte(len(domain))) + common.Must(header.AppendSupplier(serial.WriteString(domain))) default: return nil, newError("unsupported address type: ", request.Address.Family()) } - header.AppendSupplier(serial.WriteUint16(uint16(request.Port))) + common.Must(header.AppendSupplier(serial.WriteUint16(uint16(request.Port)))) if request.Option.Has(RequestOptionOneTimeAuth) { header.SetByte(0, header.Byte(0)|0x10) authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv)) - header.AppendSupplier(authenticator.Authenticate(header.Bytes())) + common.Must(header.AppendSupplier(authenticator.Authenticate(header.Bytes()))) } _, err = writer.Write(header.Bytes()) diff --git a/proxy/socks/protocol.go b/proxy/socks/protocol.go index d1432bb8..f956b6f0 100644 --- a/proxy/socks/protocol.go +++ b/proxy/socks/protocol.go @@ -3,6 +3,7 @@ package socks import ( "io" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -253,14 +254,13 @@ func appendAddress(buffer *buf.Buffer, address net.Address, port net.Port) error buffer.AppendBytes(0x04) buffer.Append(address.IP()) case net.AddressFamilyDomain: - n := byte(len(address.Domain())) - if int(n) != len(address.Domain()) { - return newError("Super long domain is not supported in Socks protocol. ", address.Domain()) + if protocol.IsDomainTooLong(address.Domain()) { + return newError("Super long domain is not supported in Socks protocol: ", address.Domain()) } buffer.AppendBytes(0x03, byte(len(address.Domain()))) - buffer.AppendSupplier(serial.WriteString(address.Domain())) + common.Must(buffer.AppendSupplier(serial.WriteString(address.Domain()))) } - buffer.AppendSupplier(serial.WriteUint16(port.Value())) + common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value()))) return nil } diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index fff48651..e35a3126 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -60,12 +60,12 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { return session } -func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) { +func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() account, err := header.User.GetTypedAccount() if err != nil { log.Trace(newError("failed to get user account: ", err).AtError()) - return + return nil } idHash := c.idHash(account.(*vmess.InternalAccount).AnyValidID().Bytes()) common.Must2(idHash.Write(timestamp.Bytes(nil))) @@ -95,8 +95,13 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ buffer = append(buffer, byte(protocol.AddressTypeIPv6)) buffer = append(buffer, header.Address.IP()...) case net.AddressFamilyDomain: - buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(len(header.Address.Domain()))) - buffer = append(buffer, header.Address.Domain()...) + domain := header.Address.Domain() + if protocol.IsDomainTooLong(domain) { + return newError("long domain not supported: ", domain) + } + nDomain := len(domain) + buffer = append(buffer, byte(protocol.AddressTypeDomain), byte(nDomain)) + buffer = append(buffer, domain...) } } @@ -117,6 +122,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ aesStream := crypto.NewAesEncryptionStream(account.(*vmess.InternalAccount).ID.CmdKey(), iv) aesStream.XORKeyStream(buffer, buffer) common.Must2(writer.Write(buffer)) + return nil } func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, writer io.Writer) buf.Writer { diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index a31db21d..043e3384 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" "v2ray.com/core/common/protocol" @@ -38,7 +39,7 @@ func TestRequestSerialization(t *testing.T) { buffer := buf.New() client := NewClientSession(protocol.DefaultIDHash) - client.EncodeRequestHeader(expectedRequest, buffer) + common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() buffer2.Append(buffer.Bytes()) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index 27c41776..62e33149 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -108,7 +108,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial requestDone := signal.ExecuteAsync(func() error { writer := buf.NewBufferedWriter(conn) - session.EncodeRequestHeader(request, writer) + if err := session.EncodeRequestHeader(request, writer); err != nil { + return newError("failed to encode request").Base(err).AtWarning() + } bodyWriter := session.EncodeRequestBody(request, writer) firstPayload, err := input.ReadTimeout(time.Millisecond * 500)