mirror of https://github.com/XTLS/Xray-core
MUX: Prevent goroutine leak (#5110)
parent
ce5c51d3ba
commit
9f5dcb1591
|
@ -9,6 +9,7 @@ import (
|
|||
"github.com/xtls/xray-core/common/mux"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
|
@ -53,6 +54,9 @@ func (b *Bridge) cleanup() {
|
|||
if w.IsActive() {
|
||||
activeWorkers = append(activeWorkers, w)
|
||||
}
|
||||
if w.Closed() {
|
||||
w.Timer.SetTimeout(0)
|
||||
}
|
||||
}
|
||||
|
||||
if len(activeWorkers) != len(b.workers) {
|
||||
|
@ -98,6 +102,7 @@ type BridgeWorker struct {
|
|||
Worker *mux.ServerWorker
|
||||
Dispatcher routing.Dispatcher
|
||||
State Control_State
|
||||
Timer *signal.ActivityTimer
|
||||
}
|
||||
|
||||
func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
|
||||
|
@ -125,6 +130,10 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo
|
|||
}
|
||||
w.Worker = worker
|
||||
|
||||
terminate := func() {
|
||||
worker.Close()
|
||||
}
|
||||
w.Timer = signal.CancelAfterInactivity(ctx, terminate, 60*time.Second)
|
||||
return w, nil
|
||||
}
|
||||
|
||||
|
@ -144,6 +153,10 @@ func (w *BridgeWorker) IsActive() bool {
|
|||
return w.State == Control_ACTIVE && !w.Worker.Closed()
|
||||
}
|
||||
|
||||
func (w *BridgeWorker) Closed() bool {
|
||||
return w.Worker.Closed()
|
||||
}
|
||||
|
||||
func (w *BridgeWorker) Connections() uint32 {
|
||||
return w.Worker.ActiveConnections()
|
||||
}
|
||||
|
@ -153,13 +166,20 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) {
|
|||
for {
|
||||
mb, err := reader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
break
|
||||
if w.Closed() {
|
||||
w.Timer.SetTimeout(0)
|
||||
} else {
|
||||
w.Timer.SetTimeout(24 * time.Hour)
|
||||
}
|
||||
return
|
||||
}
|
||||
w.Timer.Update()
|
||||
for _, b := range mb {
|
||||
var ctl Control
|
||||
if err := proto.Unmarshal(b.Bytes(), &ctl); err != nil {
|
||||
errors.LogInfoInner(context.Background(), err, "failed to parse proto message")
|
||||
break
|
||||
w.Timer.SetTimeout(0)
|
||||
return
|
||||
}
|
||||
if ctl.State != w.State {
|
||||
w.State = ctl.State
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/serial"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/features/outbound"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
|
@ -159,6 +160,8 @@ func (p *StaticMuxPicker) cleanup() error {
|
|||
for _, w := range p.workers {
|
||||
if !w.Closed() {
|
||||
activeWorkers = append(activeWorkers, w)
|
||||
} else {
|
||||
w.timer.SetTimeout(0)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -225,6 +228,7 @@ type PortalWorker struct {
|
|||
reader buf.Reader
|
||||
draining bool
|
||||
counter uint32
|
||||
timer *signal.ActivityTimer
|
||||
}
|
||||
|
||||
func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
|
||||
|
@ -244,10 +248,14 @@ func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
|
|||
if !f {
|
||||
return nil, errors.New("unable to dispatch control connection")
|
||||
}
|
||||
terminate := func() {
|
||||
client.Close()
|
||||
}
|
||||
w := &PortalWorker{
|
||||
client: client,
|
||||
reader: downlinkReader,
|
||||
writer: uplinkWriter,
|
||||
timer: signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak
|
||||
}
|
||||
w.control = &task.Periodic{
|
||||
Execute: w.heartbeat,
|
||||
|
@ -274,7 +282,6 @@ func (w *PortalWorker) heartbeat() error {
|
|||
msg.State = Control_DRAIN
|
||||
|
||||
defer func() {
|
||||
w.client.GetTimer().Reset(time.Second * 16)
|
||||
common.Close(w.writer)
|
||||
common.Interrupt(w.reader)
|
||||
w.writer = nil
|
||||
|
@ -286,6 +293,7 @@ func (w *PortalWorker) heartbeat() error {
|
|||
b, err := proto.Marshal(msg)
|
||||
common.Must(err)
|
||||
mb := buf.MergeBytes(nil, b)
|
||||
w.timer.Update()
|
||||
return w.writer.WriteMultiBuffer(mb)
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -219,14 +219,16 @@ func (m *ClientWorker) WaitClosed() <-chan struct{} {
|
|||
return m.done.Wait()
|
||||
}
|
||||
|
||||
func (m *ClientWorker) GetTimer() *time.Ticker {
|
||||
return m.timer
|
||||
func (m *ClientWorker) Close() error {
|
||||
return m.done.Close()
|
||||
}
|
||||
|
||||
func (m *ClientWorker) monitor() {
|
||||
defer m.timer.Stop()
|
||||
|
||||
for {
|
||||
checkSize := m.sessionManager.Size()
|
||||
checkCount := m.sessionManager.Count()
|
||||
select {
|
||||
case <-m.done.Wait():
|
||||
m.sessionManager.Close()
|
||||
|
@ -234,8 +236,7 @@ func (m *ClientWorker) monitor() {
|
|||
common.Interrupt(m.link.Reader)
|
||||
return
|
||||
case <-m.timer.C:
|
||||
size := m.sessionManager.Size()
|
||||
if size == 0 && m.sessionManager.CloseIfNoSession() {
|
||||
if m.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
|
||||
common.Must(m.done.Close())
|
||||
}
|
||||
}
|
||||
|
@ -255,7 +256,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
|
||||
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds)-1]
|
||||
transferType := protocol.TransferTypeStream
|
||||
|
@ -266,7 +267,6 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.
|
|||
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
|
||||
defer s.Close(false)
|
||||
defer writer.Close()
|
||||
defer timer.Reset(time.Second * 16)
|
||||
|
||||
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
|
||||
if err := writeFirstPayload(s.input, writer); err != nil {
|
||||
|
@ -316,10 +316,12 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
|
|||
}
|
||||
s.input = link.Reader
|
||||
s.output = link.Writer
|
||||
if _, ok := link.Reader.(*pipe.Reader); ok {
|
||||
go fetchInput(ctx, s, m.link.Writer, m.timer)
|
||||
} else {
|
||||
fetchInput(ctx, s, m.link.Writer, m.timer)
|
||||
go fetchInput(ctx, s, m.link.Writer)
|
||||
if _, ok := link.Reader.(*pipe.Reader); !ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-s.done.Wait():
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package mux
|
|||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/app/dispatcher"
|
||||
"github.com/xtls/xray-core/common"
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal/done"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
|
@ -63,8 +65,15 @@ func (s *Server) DispatchLink(ctx context.Context, dest net.Destination, link *t
|
|||
return s.dispatcher.DispatchLink(ctx, dest, link)
|
||||
}
|
||||
link = s.dispatcher.(*dispatcher.DefaultDispatcher).WrapLink(ctx, link)
|
||||
_, err := NewServerWorker(ctx, s.dispatcher, link)
|
||||
return err
|
||||
worker, err := NewServerWorker(ctx, s.dispatcher, link)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-worker.done.Wait():
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start implements common.Runnable.
|
||||
|
@ -81,6 +90,8 @@ type ServerWorker struct {
|
|||
dispatcher routing.Dispatcher
|
||||
link *transport.Link
|
||||
sessionManager *SessionManager
|
||||
done *done.Instance
|
||||
timer *time.Ticker
|
||||
}
|
||||
|
||||
func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.Link) (*ServerWorker, error) {
|
||||
|
@ -88,15 +99,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
|
|||
dispatcher: d,
|
||||
link: link,
|
||||
sessionManager: NewSessionManager(),
|
||||
done: done.New(),
|
||||
timer: time.NewTicker(60 * time.Second),
|
||||
}
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||
inbound.CanSpliceCopy = 3
|
||||
}
|
||||
if _, ok := link.Reader.(*pipe.Reader); ok {
|
||||
go worker.run(ctx)
|
||||
} else {
|
||||
worker.run(ctx)
|
||||
}
|
||||
go worker.run(ctx)
|
||||
go worker.monitor()
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
|
@ -111,12 +121,40 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
|
|||
s.Close(false)
|
||||
}
|
||||
|
||||
func (w *ServerWorker) monitor() {
|
||||
defer w.timer.Stop()
|
||||
|
||||
for {
|
||||
checkSize := w.sessionManager.Size()
|
||||
checkCount := w.sessionManager.Count()
|
||||
select {
|
||||
case <-w.done.Wait():
|
||||
w.sessionManager.Close()
|
||||
common.Interrupt(w.link.Writer)
|
||||
common.Interrupt(w.link.Reader)
|
||||
return
|
||||
case <-w.timer.C:
|
||||
if w.sessionManager.CloseIfNoSessionAndIdle(checkSize, checkCount) {
|
||||
common.Must(w.done.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ServerWorker) ActiveConnections() uint32 {
|
||||
return uint32(w.sessionManager.Size())
|
||||
}
|
||||
|
||||
func (w *ServerWorker) Closed() bool {
|
||||
return w.sessionManager.Closed()
|
||||
return w.done.Done()
|
||||
}
|
||||
|
||||
func (w *ServerWorker) WaitClosed() <-chan struct{} {
|
||||
return w.done.Wait()
|
||||
}
|
||||
|
||||
func (w *ServerWorker) Close() error {
|
||||
return w.done.Close()
|
||||
}
|
||||
|
||||
func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
||||
|
@ -317,11 +355,11 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
|
|||
}
|
||||
|
||||
func (w *ServerWorker) run(ctx context.Context) {
|
||||
reader := &buf.BufferedReader{Reader: w.link.Reader}
|
||||
defer func() {
|
||||
common.Must(w.done.Close())
|
||||
}()
|
||||
|
||||
defer w.sessionManager.Close()
|
||||
defer common.Interrupt(w.link.Reader)
|
||||
defer common.Interrupt(w.link.Writer)
|
||||
reader := &buf.BufferedReader{Reader: w.link.Reader}
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/signal/done"
|
||||
"github.com/xtls/xray-core/transport/pipe"
|
||||
)
|
||||
|
||||
|
@ -53,7 +54,7 @@ func (m *SessionManager) Count() int {
|
|||
func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
|
||||
MaxConcurrency := int(Strategy.MaxConcurrency)
|
||||
MaxConnection := uint16(Strategy.MaxConnection)
|
||||
|
||||
|
@ -65,6 +66,7 @@ func (m *SessionManager) Allocate(Strategy *ClientStrategy) *Session {
|
|||
s := &Session{
|
||||
ID: m.count,
|
||||
parent: m,
|
||||
done: done.New(),
|
||||
}
|
||||
m.sessions[s.ID] = s
|
||||
return s
|
||||
|
@ -115,7 +117,7 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
|
|||
return s, found
|
||||
}
|
||||
|
||||
func (m *SessionManager) CloseIfNoSession() bool {
|
||||
func (m *SessionManager) CloseIfNoSessionAndIdle(checkSize int, checkCount int) bool {
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
|
@ -123,11 +125,13 @@ func (m *SessionManager) CloseIfNoSession() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
if len(m.sessions) != 0 {
|
||||
if len(m.sessions) != 0 || checkSize != 0 || checkCount != int(m.count) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
|
||||
m.sessions = nil
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -157,6 +161,7 @@ type Session struct {
|
|||
ID uint16
|
||||
transferType protocol.TransferType
|
||||
closed bool
|
||||
done *done.Instance
|
||||
XUDP *XUDP
|
||||
}
|
||||
|
||||
|
@ -171,6 +176,9 @@ func (s *Session) Close(locked bool) error {
|
|||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
if s.done != nil {
|
||||
s.done.Close()
|
||||
}
|
||||
if s.XUDP == nil {
|
||||
common.Interrupt(s.input)
|
||||
common.Close(s.output)
|
||||
|
|
|
@ -41,11 +41,11 @@ func TestSessionManagerClose(t *testing.T) {
|
|||
m := NewSessionManager()
|
||||
s := m.Allocate(&ClientStrategy{})
|
||||
|
||||
if m.CloseIfNoSession() {
|
||||
if m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
|
||||
t.Error("able to close")
|
||||
}
|
||||
m.Remove(false, s.ID)
|
||||
if !m.CloseIfNoSession() {
|
||||
if !m.CloseIfNoSessionAndIdle(m.Size(), m.Count()) {
|
||||
t.Error("not able to close")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -678,10 +678,10 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
|
|||
errors.LogInfo(ctx, "CopyRawConn splice")
|
||||
statWriter, _ := writer.(*dispatcher.SizeStatWriter)
|
||||
//runtime.Gosched() // necessary
|
||||
time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice
|
||||
timer.SetTimeout(8 * time.Hour) // prevent leak, just in case
|
||||
time.Sleep(time.Millisecond) // without this, there will be a rare ssl error for freedom splice
|
||||
timer.SetTimeout(24 * time.Hour) // prevent leak, just in case
|
||||
if inTimer != nil {
|
||||
inTimer.SetTimeout(8 * time.Hour)
|
||||
inTimer.SetTimeout(24 * time.Hour)
|
||||
}
|
||||
w, err := tc.ReadFrom(readerConn)
|
||||
if readCounter != nil {
|
||||
|
|
Loading…
Reference in New Issue