package splithttp import ( "context" "crypto/rand" "math" "math/big" "sync/atomic" "time" "github.com/xtls/xray-core/common/errors" ) type XmuxConn interface { IsClosed() bool } type XmuxClient struct { XmuxConn XmuxConn OpenUsage atomic.Int32 leftUsage int32 expirationTime time.Time LeftRequests atomic.Int32 } type XmuxManager struct { xmuxConfig XmuxConfig concurrency int32 connections int32 newConnFunc func() XmuxConn xmuxClients []*XmuxClient } func NewXmuxManager(xmuxConfig XmuxConfig, newConnFunc func() XmuxConn) *XmuxManager { return &XmuxManager{ xmuxConfig: xmuxConfig, concurrency: xmuxConfig.GetNormalizedMaxConcurrency().rand(), connections: xmuxConfig.GetNormalizedMaxConnections().rand(), newConnFunc: newConnFunc, xmuxClients: make([]*XmuxClient, 0), } } func (m *XmuxManager) newXmuxClient() *XmuxClient { xmuxClient := &XmuxClient{ XmuxConn: m.newConnFunc(), leftUsage: -1, expirationTime: time.UnixMilli(0), } if x := m.xmuxConfig.GetNormalizedCMaxReuseTimes().rand(); x > 0 { xmuxClient.leftUsage = x - 1 } if x := m.xmuxConfig.GetNormalizedCMaxLifetimeMs().rand(); x > 0 { xmuxClient.expirationTime = time.Now().Add(time.Duration(x) * time.Millisecond) } xmuxClient.LeftRequests.Store(math.MaxInt32) if x := m.xmuxConfig.GetNormalizedCMaxRequestTimes().rand(); x > 0 { xmuxClient.LeftRequests.Store(x) } m.xmuxClients = append(m.xmuxClients, xmuxClient) return xmuxClient } func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { // when locking for i := 0; i < len(m.xmuxClients); { xmuxClient := m.xmuxClients[i] if xmuxClient.XmuxConn.IsClosed() || xmuxClient.leftUsage == 0 || (xmuxClient.expirationTime != time.UnixMilli(0) && time.Now().After(xmuxClient.expirationTime)) || xmuxClient.LeftRequests.Load() <= 0 { errors.LogDebug(ctx, "XMUX: removing xmuxClient, IsClosed() = ", xmuxClient.XmuxConn.IsClosed(), ", OpenUsage = ", xmuxClient.OpenUsage.Load(), ", leftUsage = ", xmuxClient.leftUsage, ", expirationTime = ", xmuxClient.expirationTime, ", LeftRequests = ", xmuxClient.LeftRequests.Load()) m.xmuxClients = append(m.xmuxClients[:i], m.xmuxClients[i+1:]...) } else { i++ } } if len(m.xmuxClients) == 0 { errors.LogDebug(ctx, "XMUX: creating xmuxClient because xmuxClients is empty") return m.newXmuxClient() } if m.connections > 0 && len(m.xmuxClients) < int(m.connections) { errors.LogDebug(ctx, "XMUX: creating xmuxClient because maxConnections was not hit, xmuxClients = ", len(m.xmuxClients)) return m.newXmuxClient() } xmuxClients := make([]*XmuxClient, 0) if m.concurrency > 0 { for _, xmuxClient := range m.xmuxClients { if xmuxClient.OpenUsage.Load() < m.concurrency { xmuxClients = append(xmuxClients, xmuxClient) } } } else { xmuxClients = m.xmuxClients } if len(xmuxClients) == 0 { errors.LogDebug(ctx, "XMUX: creating xmuxClient because maxConcurrency was hit, xmuxClients = ", len(m.xmuxClients)) return m.newXmuxClient() } i, _ := rand.Int(rand.Reader, big.NewInt(int64(len(xmuxClients)))) xmuxClient := xmuxClients[i.Int64()] if xmuxClient.leftUsage > 0 { xmuxClient.leftUsage -= 1 } return xmuxClient }