Xray-core/common/mux/client.go

404 lines
9.3 KiB
Go

package mux
import (
"context"
"io"
"sync"
"time"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/pipe"
)
type ClientManager struct {
Enabled bool // wheather mux is enabled from user config
Picker WorkerPicker
}
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
for i := 0; i < 16; i++ {
worker, err := m.Picker.PickAvailable()
if err != nil {
return err
}
if worker.Dispatch(ctx, link) {
return nil
}
}
return newError("unable to find an available mux client").AtWarning()
}
type WorkerPicker interface {
PickAvailable() (*ClientWorker, error)
}
type IncrementalWorkerPicker struct {
Factory ClientWorkerFactory
access sync.Mutex
workers []*ClientWorker
cleanupTask *task.Periodic
}
func (p *IncrementalWorkerPicker) cleanupFunc() error {
p.access.Lock()
defer p.access.Unlock()
if len(p.workers) == 0 {
return newError("no worker")
}
p.cleanup()
return nil
}
func (p *IncrementalWorkerPicker) cleanup() {
var activeWorkers []*ClientWorker
for _, w := range p.workers {
if !w.Closed() {
activeWorkers = append(activeWorkers, w)
}
}
p.workers = activeWorkers
}
func (p *IncrementalWorkerPicker) findAvailable() int {
for idx, w := range p.workers {
if !w.IsFull() {
return idx
}
}
return -1
}
func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
p.access.Lock()
defer p.access.Unlock()
idx := p.findAvailable()
if idx >= 0 {
n := len(p.workers)
if n > 1 && idx != n-1 {
p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
}
return p.workers[idx], false, nil
}
p.cleanup()
worker, err := p.Factory.Create()
if err != nil {
return nil, false, err
}
p.workers = append(p.workers, worker)
if p.cleanupTask == nil {
p.cleanupTask = &task.Periodic{
Interval: time.Second * 30,
Execute: p.cleanupFunc,
}
}
return worker, true, nil
}
func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
worker, start, err := p.pickInternal()
if start {
common.Must(p.cleanupTask.Start())
}
return worker, err
}
type ClientWorkerFactory interface {
Create() (*ClientWorker, error)
}
type DialingWorkerFactory struct {
Proxy proxy.Outbound
Dialer internet.Dialer
Strategy ClientStrategy
}
func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
uplinkReader, upLinkWriter := pipe.New(opts...)
downlinkReader, downlinkWriter := pipe.New(opts...)
c, err := NewClientWorker(transport.Link{
Reader: downlinkReader,
Writer: upLinkWriter,
}, f.Strategy)
if err != nil {
return nil, err
}
go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
})
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
errors.New("failed to handler mux client connection").Base(err).WriteToLog()
}
common.Must(c.Close())
cancel()
}(f.Proxy, f.Dialer, c.done)
return c, nil
}
type ClientStrategy struct {
MaxConcurrency uint32
MaxConnection uint32
}
type ClientWorker struct {
sessionManager *SessionManager
link transport.Link
done *done.Instance
strategy ClientStrategy
}
var (
muxCoolAddress = net.DomainAddress("v1.mux.cool")
muxCoolPort = net.Port(9527)
)
// NewClientWorker creates a new mux.Client.
func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
c := &ClientWorker{
sessionManager: NewSessionManager(),
link: stream,
done: done.New(),
strategy: s,
}
go c.fetchOutput()
go c.monitor()
return c, nil
}
func (m *ClientWorker) TotalConnections() uint32 {
return uint32(m.sessionManager.Count())
}
func (m *ClientWorker) ActiveConnections() uint32 {
return uint32(m.sessionManager.Size())
}
// Closed returns true if this Client is closed.
func (m *ClientWorker) Closed() bool {
return m.done.Done()
}
func (m *ClientWorker) monitor() {
timer := time.NewTicker(time.Second * 16)
defer timer.Stop()
for {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-timer.C:
size := m.sessionManager.Size()
if size == 0 && m.sessionManager.CloseIfNoSession() {
common.Must(m.done.Close())
}
}
}
}
func writeFirstPayload(reader buf.Reader, writer *Writer) error {
err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
return writer.WriteMultiBuffer(buf.MultiBuffer{})
}
if err != nil {
return err
}
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
dest := session.OutboundFromContext(ctx).Target
transferType := protocol.TransferTypeStream
if dest.Network == net.Network_UDP {
transferType = protocol.TransferTypePacket
}
s.transferType = transferType
writer := NewWriter(s.ID, dest, output, transferType)
defer s.Close()
defer writer.Close()
newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
if err := writeFirstPayload(s.input, writer); err != nil {
newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
common.Interrupt(s.input)
return
}
if err := buf.Copy(s.input, writer); err != nil {
newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx))
writer.hasError = true
common.Interrupt(s.input)
return
}
}
func (m *ClientWorker) IsClosing() bool {
sm := m.sessionManager
if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
return true
}
return false
}
func (m *ClientWorker) IsFull() bool {
if m.IsClosing() || m.Closed() {
return true
}
sm := m.sessionManager
if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
return true
}
return false
}
func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
if m.IsFull() || m.Closed() {
return false
}
sm := m.sessionManager
s := sm.Allocate()
if s == nil {
return false
}
s.input = link.Reader
s.output = link.Writer
go fetchInput(ctx, s, m.link.Writer)
return true
}
func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
if !meta.Option.Has(OptionData) {
return nil
}
s, found := m.sessionManager.Get(meta.SessionID)
if !found {
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
rr := s.NewReader(reader, &meta.Target)
err := buf.Copy(rr, s.output)
if err != nil && buf.IsWriteError(err) {
newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog()
// Notify remote peer to close this session.
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
closingWriter.Close()
drainErr := buf.Copy(rr, buf.Discard)
common.Interrupt(s.input)
s.Close()
return drainErr
}
return err
}
func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
if s, found := m.sessionManager.Get(meta.SessionID); found {
if meta.Option.Has(OptionError) {
common.Interrupt(s.input)
common.Interrupt(s.output)
}
s.Close()
}
if meta.Option.Has(OptionData) {
return buf.Copy(NewStreamReader(reader), buf.Discard)
}
return nil
}
func (m *ClientWorker) fetchOutput() {
defer func() {
common.Must(m.done.Close())
}()
reader := &buf.BufferedReader{Reader: m.link.Reader}
var meta FrameMetadata
for {
err := meta.Unmarshal(reader)
if err != nil {
if errors.Cause(err) != io.EOF {
newError("failed to read metadata").Base(err).WriteToLog()
}
break
}
switch meta.SessionStatus {
case SessionStatusKeepAlive:
err = m.handleStatueKeepAlive(&meta, reader)
case SessionStatusEnd:
err = m.handleStatusEnd(&meta, reader)
case SessionStatusNew:
err = m.handleStatusNew(&meta, reader)
case SessionStatusKeep:
err = m.handleStatusKeep(&meta, reader)
default:
status := meta.SessionStatus
newError("unknown status: ", status).AtError().WriteToLog()
return
}
if err != nil {
newError("failed to process data").Base(err).WriteToLog()
return
}
}
}