diff --git a/proxy/vmess/encoding/client.go b/proxy/vmess/encoding/client.go index 596bb997..a0f4c261 100644 --- a/proxy/vmess/encoding/client.go +++ b/proxy/vmess/encoding/client.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "hash/fnv" "io" + "sync" "golang.org/x/crypto/chacha20poly1305" @@ -40,12 +41,16 @@ type ClientSession struct { responseHeader byte } +var clientSessionPool = sync.Pool{ + New: func() interface{} { return &ClientSession{} }, +} + // NewClientSession creates a new ClientSession. func NewClientSession(idHash protocol.IDHash) *ClientSession { randomBytes := make([]byte, 33) // 16 + 16 + 1 common.Must2(rand.Read(randomBytes)) - session := &ClientSession{} + session := clientSessionPool.Get().(*ClientSession) copy(session.requestBodyKey[:], randomBytes[:16]) copy(session.requestBodyIV[:], randomBytes[16:32]) session.responseHeader = randomBytes[32] @@ -56,6 +61,12 @@ func NewClientSession(idHash protocol.IDHash) *ClientSession { return session } +func ReleaseClientSession(session *ClientSession) { + session.idHash = nil + session.responseReader = nil + clientSessionPool.Put(session) +} + func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writer io.Writer) error { timestamp := protocol.NewTimestampGenerator(protocol.NowTime(), 30)() account := header.User.Account.(*vmess.InternalAccount) diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index b7cc1049..2aa83ce2 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -44,6 +44,8 @@ func TestRequestSerialization(t *testing.T) { buffer := buf.New() client := NewClientSession(protocol.DefaultIDHash) + defer ReleaseClientSession(client) + common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() @@ -57,6 +59,7 @@ func TestRequestSerialization(t *testing.T) { defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) + defer ReleaseServerSession(server) actualRequest, err := server.DecodeRequestHeader(buffer) assert(err, IsNil) @@ -97,6 +100,8 @@ func TestInvalidRequest(t *testing.T) { buffer := buf.New() client := NewClientSession(protocol.DefaultIDHash) + defer ReleaseClientSession(client) + common.Must(client.EncodeRequestHeader(expectedRequest, buffer)) buffer2 := buf.New() @@ -110,6 +115,7 @@ func TestInvalidRequest(t *testing.T) { defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) + defer ReleaseServerSession(server) _, err := server.DecodeRequestHeader(buffer) assert(err, IsNotNil) } @@ -150,6 +156,7 @@ func TestMuxRequest(t *testing.T) { defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) + defer ReleaseServerSession(server) actualRequest, err := server.DecodeRequestHeader(buffer) assert(err, IsNil) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 42117b58..e5736f5e 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -101,13 +101,24 @@ type ServerSession struct { responseHeader byte } +var serverSessionPool = sync.Pool{ + New: func() interface{} { return &ServerSession{} }, +} + // NewServerSession creates a new ServerSession, using the given UserValidator. // The ServerSession instance doesn't take ownership of the validator. func NewServerSession(validator *vmess.TimedUserValidator, sessionHistory *SessionHistory) *ServerSession { - return &ServerSession{ - userValidator: validator, - sessionHistory: sessionHistory, - } + session := serverSessionPool.Get().(*ServerSession) + session.userValidator = validator + session.sessionHistory = sessionHistory + return session +} + +func ReleaseServerSession(session *ServerSession) { + session.responseWriter = nil + session.userValidator = nil + session.sessionHistory = nil + serverSessionPool.Put(session) } func parseSecurityType(b byte) protocol.SecurityType { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index dd0f40a2..88b6288d 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -221,6 +221,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i reader := &buf.BufferedReader{Reader: buf.NewReader(connection)} svrSession := encoding.NewServerSession(h.clients, h.sessionHistory) + defer encoding.ReleaseServerSession(svrSession) + request, err := svrSession.DecodeRequestHeader(reader) if err != nil { diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index bfce2c28..690c00b9 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -106,6 +106,8 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia output := link.Writer session := encoding.NewClientSession(protocol.DefaultIDHash) + defer encoding.ReleaseClientSession(session) + sessionPolicy := v.v.PolicyManager().ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx)