From 9f5dcb15910aadc7ef450514747576827a389853 Mon Sep 17 00:00:00 2001 From: patterniha <71074308+patterniha@users.noreply.github.com> Date: Wed, 10 Sep 2025 02:33:19 +0200 Subject: [PATCH] MUX: Prevent goroutine leak (#5110) --- app/reverse/bridge.go | 24 +++++++++++++-- app/reverse/portal.go | 10 +++++- common/mux/client.go | 22 ++++++++------ common/mux/server.go | 62 ++++++++++++++++++++++++++++++-------- common/mux/session.go | 14 +++++++-- common/mux/session_test.go | 4 +-- proxy/proxy.go | 6 ++-- 7 files changed, 109 insertions(+), 33 deletions(-) diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index b86d153b..fe3b2c3a 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -9,6 +9,7 @@ import ( "github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport" @@ -53,6 +54,9 @@ func (b *Bridge) cleanup() { if w.IsActive() { activeWorkers = append(activeWorkers, w) } + if w.Closed() { + w.Timer.SetTimeout(0) + } } if len(activeWorkers) != len(b.workers) { @@ -98,6 +102,7 @@ type BridgeWorker struct { Worker *mux.ServerWorker Dispatcher routing.Dispatcher State Control_State + Timer *signal.ActivityTimer } func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) { @@ -125,6 +130,10 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo } w.Worker = worker + terminate := func() { + worker.Close() + } + w.Timer = signal.CancelAfterInactivity(ctx, terminate, 60*time.Second) return w, nil } @@ -144,6 +153,10 @@ func (w *BridgeWorker) IsActive() bool { return w.State == Control_ACTIVE && !w.Worker.Closed() } +func (w *BridgeWorker) Closed() bool { + return w.Worker.Closed() +} + func (w *BridgeWorker) Connections() uint32 { return w.Worker.ActiveConnections() } @@ -153,13 +166,20 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) { for { mb, err := reader.ReadMultiBuffer() if err != nil { - break + if w.Closed() { + w.Timer.SetTimeout(0) + } else { + w.Timer.SetTimeout(24 * time.Hour) + } + return } + w.Timer.Update() for _, b := range mb { var ctl Control if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil { errors.LogInfoInner(context.Background(), err, "failed to parse proto message") - break + w.Timer.SetTimeout(0) + return } if ctl.State != w.State { w.State = ctl.State diff --git a/app/reverse/portal.go b/app/reverse/portal.go index 11bfc514..7e3f2caf 100644 --- a/app/reverse/portal.go +++ b/app/reverse/portal.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/transport" @@ -159,6 +160,8 @@ func (p *StaticMuxPicker) cleanup() error { for _, w := range p.workers { if !w.Closed() { activeWorkers = append(activeWorkers, w) + } else { + w.timer.SetTimeout(0) } } @@ -225,6 +228,7 @@ type PortalWorker struct { reader buf.Reader draining bool counter uint32 + timer *signal.ActivityTimer } func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { @@ -244,10 +248,14 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { if !f { return nil, errors.New("unable to dispatch control connection") } + terminate := func() { + client.Close() + } w := &PortalWorker{ client: client, reader: downlinkReader, writer: uplinkWriter, + timer: signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak } w.control = &task.Periodic{ Execute: w.heartbeat, @@ -274,7 +282,6 @@ func (w *PortalWorker) heartbeat() error { msg.State = Control_DRAIN defer func() { - w.client.GetTimer().Reset(time.Second * 16) common.Close(w.writer) common.Interrupt(w.reader) w.writer = nil @@ -286,6 +293,7 @@ func (w *PortalWorker) heartbeat() error { b, err := proto.Marshal(msg) common.Must(err) mb := buf.MergeBytes(nil, b) + w.timer.Update() return w.writer.WriteMultiBuffer(mb) } return nil diff --git a/common/mux/client.go b/common/mux/client.go index dddb6371..93357574 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -219,14 +219,16 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} { return m.done.Wait() } -func (m *ClientWorker) GetTimer() *time.Ticker { - return m.timer +func (m *ClientWorker) Close() error { + return m.done.Close() } func (m *ClientWorker) monitor() { defer m.timer.Stop() for { + checkSize := m.sessionManager.Size() + checkCount := m.sessionManager.Count() select { case <-m.done.Wait(): m.sessionManager.Close() @@ -234,8 +236,7 @@ func (m *ClientWorker) monitor() { common.Interrupt(m.link.Reader) return case <-m.timer.C: - size := m.sessionManager.Size() - if size == 0 && m.sessionManager.CloseIfNoSession() { + if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) { common.Must(m.done.Close()) } } @@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error { return nil } -func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) { +func fetchInput(ctx context.Context, s *Session, output buf.Writer) { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] transferType := protocol.TransferTypeStream @@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time. writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) defer s.Close(false) defer writer.Close() - defer timer.Reset(time.Second * 16) errors.LogInfo(ctx, "dispatching request to ", ob.Target) if err := writeFirstPayload(s.input, writer); err != nil { @@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool } s.input = link.Reader s.output = link.Writer - if _, ok := link.Reader.(*pipe.Reader); ok { - go fetchInput(ctx, s, m.link.Writer, m.timer) - } else { - fetchInput(ctx, s, m.link.Writer, m.timer) + go fetchInput(ctx, s, m.link.Writer) + if _, ok := link.Reader.(*pipe.Reader); !ok { + select { + case <-ctx.Done(): + case <-s.done.Wait(): + } } return true } diff --git a/common/mux/server.go b/common/mux/server.go index ac121a9f..70c5ed24 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -3,6 +3,7 @@ package mux import ( "context" "io" + "time" "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/common" @@ -12,6 +13,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport" @@ -63,8 +65,15 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t return s.dispatcher.DispatchLink(ctx, dest, link) } link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) - _, err := NewServerWorker(ctx, s.dispatcher, link) - return err + worker, err := NewServerWorker(ctx, s.dispatcher, link) + if err != nil { + return err + } + select { + case <-ctx.Done(): + case <-worker.done.Wait(): + } + return nil } // Start implements common.Runnable. @@ -81,6 +90,8 @@ type ServerWorker struct { dispatcher routing.Dispatcher link *transport.Link sessionManager *SessionManager + done *done.Instance + timer *time.Ticker } func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) { @@ -88,15 +99,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport. dispatcher: d, link: link, sessionManager: NewSessionManager(), + done: done.New(), + timer: time.NewTicker(60 * time.Second), } if inbound := session.InboundFromContext(ctx); inbound != nil { inbound.CanSpliceCopy = 3 } - if _, ok := link.Reader.(*pipe.Reader); ok { - go worker.run(ctx) - } else { - worker.run(ctx) - } + go worker.run(ctx) + go worker.monitor() return worker, nil } @@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) { s.Close(false) } +func (w *ServerWorker) monitor() { + defer w.timer.Stop() + + for { + checkSize := w.sessionManager.Size() + checkCount := w.sessionManager.Count() + select { + case <-w.done.Wait(): + w.sessionManager.Close() + common.Interrupt(w.link.Writer) + common.Interrupt(w.link.Reader) + return + case <-w.timer.C: + if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) { + common.Must(w.done.Close()) + } + } + } +} + func (w *ServerWorker) ActiveConnections() uint32 { return uint32(w.sessionManager.Size()) } func (w *ServerWorker) Closed() bool { - return w.sessionManager.Closed() + return w.done.Done() +} + +func (w *ServerWorker) WaitClosed() <-chan struct{} { + return w.done.Wait() +} + +func (w *ServerWorker) Close() error { + return w.done.Close() } func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { @@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead } func (w *ServerWorker) run(ctx context.Context) { - reader := &buf.BufferedReader{Reader: w.link.Reader} + defer func() { + common.Must(w.done.Close()) + }() - defer w.sessionManager.Close() - defer common.Interrupt(w.link.Reader) - defer common.Interrupt(w.link.Writer) + reader := &buf.BufferedReader{Reader: w.link.Reader} for { select { diff --git a/common/mux/session.go b/common/mux/session.go index 8bcb01bb..66b9674c 100644 --- a/common/mux/session.go +++ b/common/mux/session.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/transport/pipe" ) @@ -53,7 +54,7 @@ func (m *SessionManager) Count() int { func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session { m.Lock() defer m.Unlock() - + MaxConcurrency := int(Strategy.MaxConcurrency) MaxConnection := uint16(Strategy.MaxConnection) @@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session { s := &Session{ ID: m.count, parent: m, + done: done.New(), } m.sessions[s.ID] = s return s @@ -115,7 +117,7 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) { return s, found } -func (m *SessionManager) CloseIfNoSession() bool { +func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool { m.Lock() defer m.Unlock() @@ -123,11 +125,13 @@ func (m *SessionManager) CloseIfNoSession() bool { return true } - if len(m.sessions) != 0 { + if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) { return false } m.closed = true + + m.sessions = nil return true } @@ -157,6 +161,7 @@ type Session struct { ID uint16 transferType protocol.TransferType closed bool + done *done.Instance XUDP *XUDP } @@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error { return nil } s.closed = true + if s.done != nil { + s.done.Close() + } if s.XUDP == nil { common.Interrupt(s.input) common.Close(s.output) diff --git a/common/mux/session_test.go b/common/mux/session_test.go index a8491a9c..8ef27877 100644 --- a/common/mux/session_test.go +++ b/common/mux/session_test.go @@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) { m := NewSessionManager() s := m.Allocate(&ClientStrategy{}) - if m.CloseIfNoSession() { + if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) { t.Error("able to close") } m.Remove(false, s.ID) - if !m.CloseIfNoSession() { + if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) { t.Error("not able to close") } } diff --git a/proxy/proxy.go b/proxy/proxy.go index 3a1af8f7..f759fdc0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -678,10 +678,10 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net errors.LogInfo(ctx, "CopyRawConn splice") statWriter, _ := writer.(*dispatcher.SizeStatWriter) //runtime.Gosched() // necessary - time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice - timer.SetTimeout(8 * time.Hour) // prevent leak, just in case + time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice + timer.SetTimeout(24 * time.Hour) // prevent leak, just in case if inTimer != nil { - inTimer.SetTimeout(8 * time.Hour) + inTimer.SetTimeout(24 * time.Hour) } w, err := tc.ReadFrom(readerConn) if readCounter != nil {