From 778876df9efee9518a6a34069449629996ab7f04 Mon Sep 17 00:00:00 2001
From: Shelikhoo <xiaokangwang@outlook.com>
Date: Thu, 29 Apr 2021 06:29:42 +0800
Subject: [PATCH] VMess AEAD based packet length

(cherry picked from commit 08221600082a79376bdc262f2ffec1a3129ae98d)
---
 common/protocol/headers.go       |   2 +
 infra/conf/vmess.go              |   8 ++-
 proxy/vmess/account.go           |  15 ++++-
 proxy/vmess/encoding/auth.go     |  10 +++
 proxy/vmess/encoding/client.go   |  46 +++++++++++++
 proxy/vmess/encoding/server.go   |  46 +++++++++++++
 proxy/vmess/outbound/outbound.go |   4 ++
 testing/scenarios/vmess_test.go  | 107 +++++++++++++++++++++++++++++++
 8 files changed, 232 insertions(+), 6 deletions(-)

diff --git a/common/protocol/headers.go b/common/protocol/headers.go
index 921b4d32..1dcc467e 100644
--- a/common/protocol/headers.go
+++ b/common/protocol/headers.go
@@ -38,6 +38,8 @@ const (
 	RequestOptionChunkMasking bitmask.Byte = 0x04
 
 	RequestOptionGlobalPadding bitmask.Byte = 0x08
+
+	RequestOptionAuthenticatedLength bitmask.Byte = 0x10
 )
 
 type RequestHeader struct {
diff --git a/infra/conf/vmess.go b/infra/conf/vmess.go
index 37557c17..9759890d 100644
--- a/infra/conf/vmess.go
+++ b/infra/conf/vmess.go
@@ -15,9 +15,10 @@ import (
 )
 
 type VMessAccount struct {
-	ID       string `json:"id"`
-	AlterIds uint16 `json:"alterId"`
-	Security string `json:"security"`
+	ID          string `json:"id"`
+	AlterIds    uint16 `json:"alterId"`
+	Security    string `json:"security"`
+	Experiments string `json:"experiments"`
 }
 
 // Build implements Buildable
@@ -43,6 +44,7 @@ func (a *VMessAccount) Build() *vmess.Account {
 		SecuritySettings: &protocol.SecurityConfig{
 			Type: st,
 		},
+		TestsEnabled: a.Experiments,
 	}
 }
 
diff --git a/proxy/vmess/account.go b/proxy/vmess/account.go
index a95576b9..8f5e6004 100644
--- a/proxy/vmess/account.go
+++ b/proxy/vmess/account.go
@@ -1,6 +1,8 @@
 package vmess
 
 import (
+	"strings"
+
 	"github.com/xtls/xray-core/common/dice"
 	"github.com/xtls/xray-core/common/protocol"
 	"github.com/xtls/xray-core/common/uuid"
@@ -14,6 +16,8 @@ type MemoryAccount struct {
 	AlterIDs []*protocol.ID
 	// Security type of the account. Used for client connections.
 	Security protocol.SecurityType
+
+	AuthenticatedLengthExperiment bool
 }
 
 // AnyValidID returns an ID that is either the main ID or one of the alternative IDs if any.
@@ -41,9 +45,14 @@ func (a *Account) AsAccount() (protocol.Account, error) {
 		return nil, newError("failed to parse ID").Base(err).AtError()
 	}
 	protoID := protocol.NewID(id)
+	var AuthenticatedLength bool
+	if strings.Contains(a.TestsEnabled, "AuthenticatedLength") {
+		AuthenticatedLength = true
+	}
 	return &MemoryAccount{
-		ID:       protoID,
-		AlterIDs: protocol.NewAlterIDs(protoID, uint16(a.AlterId)),
-		Security: a.SecuritySettings.GetSecurityType(),
+		ID:                            protoID,
+		AlterIDs:                      protocol.NewAlterIDs(protoID, uint16(a.AlterId)),
+		Security:                      a.SecuritySettings.GetSecurityType(),
+		AuthenticatedLengthExperiment: AuthenticatedLength,
 	}, nil
 }
diff --git a/proxy/vmess/encoding/auth.go b/proxy/vmess/encoding/auth.go
index 6233643f..23536de6 100644
--- a/proxy/vmess/encoding/auth.go
+++ b/proxy/vmess/encoding/auth.go
@@ -7,6 +7,8 @@ import (
 
 	"github.com/xtls/xray-core/common"
 
+	"github.com/xtls/xray-core/common/crypto"
+
 	"golang.org/x/crypto/sha3"
 )
 
@@ -116,3 +118,11 @@ func (s *ShakeSizeParser) NextPaddingLen() uint16 {
 func (s *ShakeSizeParser) MaxPaddingLen() uint16 {
 	return 64
 }
+
+type AEADSizeParser struct {
+	crypto.AEADChunkSizeParser
+}
+
+func NewAEADSizeParser(auth *crypto.AEADAuthenticator) *AEADSizeParser {
+	return &AEADSizeParser{crypto.AEADChunkSizeParser{Auth: auth}}
+}
diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go
index d46e8fb6..82ab5dee 100644
--- a/proxy/vmess/encoding/client.go
+++ b/proxy/vmess/encoding/client.go
@@ -171,6 +171,17 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 			NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding)
 	case protocol.SecurityType_CHACHA20_POLY1305:
 		aead, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.requestBodyKey[:]))
@@ -181,6 +192,18 @@ func (c *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 			NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
+			common.Must(err)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding)
 	default:
 		panic("Unknown security type.")
@@ -312,6 +335,17 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 			NonceGenerator:          GenerateChunkNonce(c.responseBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding)
 	case protocol.SecurityType_CHACHA20_POLY1305:
 		aead, _ := chacha20poly1305.New(GenerateChacha20Poly1305Key(c.responseBodyKey[:]))
@@ -321,6 +355,18 @@ func (c *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 			NonceGenerator:          GenerateChunkNonce(c.responseBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(c.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
+			common.Must(err)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(c.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding)
 	default:
 		panic("Unknown security type.")
diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go
index f4773894..428d0e69 100644
--- a/proxy/vmess/encoding/server.go
+++ b/proxy/vmess/encoding/server.go
@@ -362,6 +362,17 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 			NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding)
 
 	case protocol.SecurityType_CHACHA20_POLY1305:
@@ -372,6 +383,18 @@ func (s *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 			NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
+			common.Must(err)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType(), padding)
 
 	default:
@@ -480,6 +503,17 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 			NonceGenerator:          GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD := crypto.NewAesGcm(AuthenticatedLengthKey)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding)
 
 	case protocol.SecurityType_CHACHA20_POLY1305:
@@ -490,6 +524,18 @@ func (s *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 			NonceGenerator:          GenerateChunkNonce(s.responseBodyIV[:], uint32(aead.NonceSize())),
 			AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
 		}
+		if request.Option.Has(protocol.RequestOptionAuthenticatedLength) {
+			AuthenticatedLengthKey := vmessaead.KDF16(s.requestBodyKey[:], "auth_len")
+			AuthenticatedLengthKeyAEAD, err := chacha20poly1305.New(GenerateChacha20Poly1305Key(AuthenticatedLengthKey))
+			common.Must(err)
+
+			lengthAuth := &crypto.AEADAuthenticator{
+				AEAD:                    AuthenticatedLengthKeyAEAD,
+				NonceGenerator:          GenerateChunkNonce(s.requestBodyIV[:], uint32(aead.NonceSize())),
+				AdditionalDataGenerator: crypto.GenerateEmptyBytes(),
+			}
+			sizeParser = NewAEADSizeParser(lengthAuth)
+		}
 		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType(), padding)
 
 	default:
diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go
index 7265e9f3..a618daea 100644
--- a/proxy/vmess/outbound/outbound.go
+++ b/proxy/vmess/outbound/outbound.go
@@ -119,6 +119,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		request.Option.Clear(protocol.RequestOptionChunkMasking)
 	}
 
+	if account.AuthenticatedLengthExperiment {
+		request.Option.Set(protocol.RequestOptionAuthenticatedLength)
+	}
+
 	input := link.Reader
 	output := link.Writer
 
diff --git a/testing/scenarios/vmess_test.go b/testing/scenarios/vmess_test.go
index 45641deb..2f174cd0 100644
--- a/testing/scenarios/vmess_test.go
+++ b/testing/scenarios/vmess_test.go
@@ -1330,3 +1330,110 @@ func TestVMessZero(t *testing.T) {
 		t.Error(err)
 	}
 }
+
+func TestVMessGCMLengthAuth(t *testing.T) {
+	tcpServer := tcp.Server{
+		MsgProcessor: xor,
+	}
+	dest, err := tcpServer.Start()
+	common.Must(err)
+	defer tcpServer.Close()
+
+	userID := protocol.NewID(uuid.New())
+	serverPort := tcp.PickPort()
+	serverConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&log.Config{
+				ErrorLogLevel: clog.Severity_Debug,
+				ErrorLogType:  log.LogType_Console,
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(serverPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&inbound.Config{
+					User: []*protocol.User{
+						{
+							Account: serial.ToTypedMessage(&vmess.Account{
+								Id:      userID.String(),
+								AlterId: 64,
+							}),
+						},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			},
+		},
+	}
+
+	clientPort := tcp.PickPort()
+	clientConfig := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&log.Config{
+				ErrorLogLevel: clog.Severity_Debug,
+				ErrorLogType:  log.LogType_Console,
+			}),
+		},
+		Inbound: []*core.InboundHandlerConfig{
+			{
+				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
+					PortRange: net.SinglePortRange(clientPort),
+					Listen:    net.NewIPOrDomain(net.LocalHostIP),
+				}),
+				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
+					Address: net.NewIPOrDomain(dest.Address),
+					Port:    uint32(dest.Port),
+					NetworkList: &net.NetworkList{
+						Network: []net.Network{net.Network_TCP},
+					},
+				}),
+			},
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&outbound.Config{
+					Receiver: []*protocol.ServerEndpoint{
+						{
+							Address: net.NewIPOrDomain(net.LocalHostIP),
+							Port:    uint32(serverPort),
+							User: []*protocol.User{
+								{
+									Account: serial.ToTypedMessage(&vmess.Account{
+										Id:      userID.String(),
+										AlterId: 64,
+										SecuritySettings: &protocol.SecurityConfig{
+											Type: protocol.SecurityType_AES128_GCM,
+										},
+										TestsEnabled: "AuthenticatedLength",
+									}),
+								},
+							},
+						},
+					},
+				}),
+			},
+		},
+	}
+
+	servers, err := InitializeServerConfigs(serverConfig, clientConfig)
+	if err != nil {
+		t.Fatal("Failed to initialize all servers: ", err.Error())
+	}
+	defer CloseAllServers(servers)
+
+	var errg errgroup.Group
+	for i := 0; i < 10; i++ {
+		errg.Go(testTCPConn(clientPort, 10240*1024, time.Second*40))
+	}
+
+	if err := errg.Wait(); err != nil {
+		t.Error(err)
+	}
+}