MUX: Prevent goroutine leak (#5110)

pull/5118/head
patterniha 2025-09-10 02:33:19 +02:00 committed by GitHub
parent ce5c51d3ba
commit 9f5dcb1591
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 109 additions and 33 deletions

View File

@ -9,6 +9,7 @@ import (
"github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/mux"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
@ -53,6 +54,9 @@ func (b *Bridge) cleanup() {
if w.IsActive() { if w.IsActive() {
activeWorkers = append(activeWorkers, w) activeWorkers = append(activeWorkers, w)
} }
if w.Closed() {
w.Timer.SetTimeout(0)
}
} }
if len(activeWorkers) != len(b.workers) { if len(activeWorkers) != len(b.workers) {
@ -98,6 +102,7 @@ type BridgeWorker struct {
Worker *mux.ServerWorker Worker *mux.ServerWorker
Dispatcher routing.Dispatcher Dispatcher routing.Dispatcher
State Control_State State Control_State
Timer *signal.ActivityTimer
} }
func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) { func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
@ -125,6 +130,10 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo
} }
w.Worker = worker w.Worker = worker
terminate := func() {
worker.Close()
}
w.Timer = signal.CancelAfterInactivity(ctx, terminate, 60*time.Second)
return w, nil return w, nil
} }
@ -144,6 +153,10 @@ func (w *BridgeWorker) IsActive() bool {
return w.State == Control_ACTIVE && !w.Worker.Closed() return w.State == Control_ACTIVE && !w.Worker.Closed()
} }
func (w *BridgeWorker) Closed() bool {
return w.Worker.Closed()
}
func (w *BridgeWorker) Connections() uint32 { func (w *BridgeWorker) Connections() uint32 {
return w.Worker.ActiveConnections() return w.Worker.ActiveConnections()
} }
@ -153,13 +166,20 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
for { for {
mb, err := reader.ReadMultiBuffer() mb, err := reader.ReadMultiBuffer()
if err != nil { if err != nil {
break if w.Closed() {
w.Timer.SetTimeout(0)
} else {
w.Timer.SetTimeout(24 * time.Hour)
} }
return
}
w.Timer.Update()
for _, b := range mb { for _, b := range mb {
var ctl Control var ctl Control
if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil { if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil {
errors.LogInfoInner(context.Background(), err, "failed to parse proto message") errors.LogInfoInner(context.Background(), err, "failed to parse proto message")
break w.Timer.SetTimeout(0)
return
} }
if ctl.State != w.State { if ctl.State != w.State {
w.State = ctl.State w.State = ctl.State

View File

@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/serial" "github.com/xtls/xray-core/common/serial"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/features/outbound" "github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
@ -159,6 +160,8 @@ func (p *StaticMuxPicker) cleanup() error {
for _, w := range p.workers { for _, w := range p.workers {
if !w.Closed() { if !w.Closed() {
activeWorkers = append(activeWorkers, w) activeWorkers = append(activeWorkers, w)
} else {
w.timer.SetTimeout(0)
} }
} }
@ -225,6 +228,7 @@ type PortalWorker struct {
reader buf.Reader reader buf.Reader
draining bool draining bool
counter uint32 counter uint32
timer *signal.ActivityTimer
} }
func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) { func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
@ -244,10 +248,14 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
if !f { if !f {
return nil, errors.New("unable to dispatch control connection") return nil, errors.New("unable to dispatch control connection")
} }
terminate := func() {
client.Close()
}
w := &PortalWorker{ w := &PortalWorker{
client: client, client: client,
reader: downlinkReader, reader: downlinkReader,
writer: uplinkWriter, writer: uplinkWriter,
timer: signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak
} }
w.control = &task.Periodic{ w.control = &task.Periodic{
Execute: w.heartbeat, Execute: w.heartbeat,
@ -274,7 +282,6 @@ func (w *PortalWorker) heartbeat() error {
msg.State = Control_DRAIN msg.State = Control_DRAIN
defer func() { defer func() {
w.client.GetTimer().Reset(time.Second * 16)
common.Close(w.writer) common.Close(w.writer)
common.Interrupt(w.reader) common.Interrupt(w.reader)
w.writer = nil w.writer = nil
@ -286,6 +293,7 @@ func (w *PortalWorker) heartbeat() error {
b, err := proto.Marshal(msg) b, err := proto.Marshal(msg)
common.Must(err) common.Must(err)
mb := buf.MergeBytes(nil, b) mb := buf.MergeBytes(nil, b)
w.timer.Update()
return w.writer.WriteMultiBuffer(mb) return w.writer.WriteMultiBuffer(mb)
} }
return nil return nil

View File

@ -219,14 +219,16 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} {
return m.done.Wait() return m.done.Wait()
} }
func (m *ClientWorker) GetTimer() *time.Ticker { func (m *ClientWorker) Close() error {
return m.timer return m.done.Close()
} }
func (m *ClientWorker) monitor() { func (m *ClientWorker) monitor() {
defer m.timer.Stop() defer m.timer.Stop()
for { for {
checkSize := m.sessionManager.Size()
checkCount := m.sessionManager.Count()
select { select {
case <-m.done.Wait(): case <-m.done.Wait():
m.sessionManager.Close() m.sessionManager.Close()
@ -234,8 +236,7 @@ func (m *ClientWorker) monitor() {
common.Interrupt(m.link.Reader) common.Interrupt(m.link.Reader)
return return
case <-m.timer.C: case <-m.timer.C:
size := m.sessionManager.Size() if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
if size == 0 && m.sessionManager.CloseIfNoSession() {
common.Must(m.done.Close()) common.Must(m.done.Close())
} }
} }
@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
return nil return nil
} }
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) { func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
outbounds := session.OutboundsFromContext(ctx) outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1] ob := outbounds[len(outbounds)-1]
transferType := protocol.TransferTypeStream transferType := protocol.TransferTypeStream
@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false) defer s.Close(false)
defer writer.Close() defer writer.Close()
defer timer.Reset(time.Second * 16)
errors.LogInfo(ctx, "dispatching request to ", ob.Target) errors.LogInfo(ctx, "dispatching request to ", ob.Target)
if err := writeFirstPayload(s.input, writer); err != nil { if err := writeFirstPayload(s.input, writer); err != nil {
@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
} }
s.input = link.Reader s.input = link.Reader
s.output = link.Writer s.output = link.Writer
if _, ok := link.Reader.(*pipe.Reader); ok { go fetchInput(ctx, s, m.link.Writer)
go fetchInput(ctx, s, m.link.Writer, m.timer) if _, ok := link.Reader.(*pipe.Reader); !ok {
} else { select {
fetchInput(ctx, s, m.link.Writer, m.timer) case <-ctx.Done():
case <-s.done.Wait():
}
} }
return true return true
} }

View File

@ -3,6 +3,7 @@ package mux
import ( import (
"context" "context"
"io" "io"
"time"
"github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/dispatcher"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
@ -12,6 +13,7 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport"
@ -63,9 +65,16 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
return s.dispatcher.DispatchLink(ctx, dest, link) return s.dispatcher.DispatchLink(ctx, dest, link)
} }
link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link) link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
_, err := NewServerWorker(ctx, s.dispatcher, link) worker, err := NewServerWorker(ctx, s.dispatcher, link)
if err != nil {
return err return err
} }
select {
case <-ctx.Done():
case <-worker.done.Wait():
}
return nil
}
// Start implements common.Runnable. // Start implements common.Runnable.
func (s *Server) Start() error { func (s *Server) Start() error {
@ -81,6 +90,8 @@ type ServerWorker struct {
dispatcher routing.Dispatcher dispatcher routing.Dispatcher
link *transport.Link link *transport.Link
sessionManager *SessionManager sessionManager *SessionManager
done *done.Instance
timer *time.Ticker
} }
func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) { func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
@ -88,15 +99,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
dispatcher: d, dispatcher: d,
link: link, link: link,
sessionManager: NewSessionManager(), sessionManager: NewSessionManager(),
done: done.New(),
timer: time.NewTicker(60 * time.Second),
} }
if inbound := session.InboundFromContext(ctx); inbound != nil { if inbound := session.InboundFromContext(ctx); inbound != nil {
inbound.CanSpliceCopy = 3 inbound.CanSpliceCopy = 3
} }
if _, ok := link.Reader.(*pipe.Reader); ok {
go worker.run(ctx) go worker.run(ctx)
} else { go worker.monitor()
worker.run(ctx)
}
return worker, nil return worker, nil
} }
@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
s.Close(false) s.Close(false)
} }
func (w *ServerWorker) monitor() {
defer w.timer.Stop()
for {
checkSize := w.sessionManager.Size()
checkCount := w.sessionManager.Count()
select {
case <-w.done.Wait():
w.sessionManager.Close()
common.Interrupt(w.link.Writer)
common.Interrupt(w.link.Reader)
return
case <-w.timer.C:
if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
common.Must(w.done.Close())
}
}
}
}
func (w *ServerWorker) ActiveConnections() uint32 { func (w *ServerWorker) ActiveConnections() uint32 {
return uint32(w.sessionManager.Size()) return uint32(w.sessionManager.Size())
} }
func (w *ServerWorker) Closed() bool { func (w *ServerWorker) Closed() bool {
return w.sessionManager.Closed() return w.done.Done()
}
func (w *ServerWorker) WaitClosed() <-chan struct{} {
return w.done.Wait()
}
func (w *ServerWorker) Close() error {
return w.done.Close()
} }
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error { func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
} }
func (w *ServerWorker) run(ctx context.Context) { func (w *ServerWorker) run(ctx context.Context) {
reader := &buf.BufferedReader{Reader: w.link.Reader} defer func() {
common.Must(w.done.Close())
}()
defer w.sessionManager.Close() reader := &buf.BufferedReader{Reader: w.link.Reader}
defer common.Interrupt(w.link.Reader)
defer common.Interrupt(w.link.Writer)
for { for {
select { select {

View File

@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/transport/pipe" "github.com/xtls/xray-core/transport/pipe"
) )
@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
s := &Session{ s := &Session{
ID: m.count, ID: m.count,
parent: m, parent: m,
done: done.New(),
} }
m.sessions[s.ID] = s m.sessions[s.ID] = s
return s return s
@ -115,7 +117,7 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
return s, found return s, found
} }
func (m *SessionManager) CloseIfNoSession() bool { func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
@ -123,11 +125,13 @@ func (m *SessionManager) CloseIfNoSession() bool {
return true return true
} }
if len(m.sessions) != 0 { if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) {
return false return false
} }
m.closed = true m.closed = true
m.sessions = nil
return true return true
} }
@ -157,6 +161,7 @@ type Session struct {
ID uint16 ID uint16
transferType protocol.TransferType transferType protocol.TransferType
closed bool closed bool
done *done.Instance
XUDP *XUDP XUDP *XUDP
} }
@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error {
return nil return nil
} }
s.closed = true s.closed = true
if s.done != nil {
s.done.Close()
}
if s.XUDP == nil { if s.XUDP == nil {
common.Interrupt(s.input) common.Interrupt(s.input)
common.Close(s.output) common.Close(s.output)

View File

@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) {
m := NewSessionManager() m := NewSessionManager()
s := m.Allocate(&ClientStrategy{}) s := m.Allocate(&ClientStrategy{})
if m.CloseIfNoSession() { if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
t.Error("able to close") t.Error("able to close")
} }
m.Remove(false, s.ID) m.Remove(false, s.ID)
if !m.CloseIfNoSession() { if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
t.Error("not able to close") t.Error("not able to close")
} }
} }

View File

@ -679,9 +679,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
statWriter, _ := writer.(*dispatcher.SizeStatWriter) statWriter, _ := writer.(*dispatcher.SizeStatWriter)
//runtime.Gosched() // necessary //runtime.Gosched() // necessary
time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice
timer.SetTimeout(8 * time.Hour) // prevent leak, just in case timer.SetTimeout(24 * time.Hour) // prevent leak, just in case
if inTimer != nil { if inTimer != nil {
inTimer.SetTimeout(8 * time.Hour) inTimer.SetTimeout(24 * time.Hour)
} }
w, err := tc.ReadFrom(readerConn) w, err := tc.ReadFrom(readerConn)
if readCounter != nil { if readCounter != nil {