diff --git a/app/proxyman/config.pb.go b/app/proxyman/config.pb.go index 2dfe2931..03613e88 100644 --- a/app/proxyman/config.pb.go +++ b/app/proxyman/config.pb.go @@ -595,6 +595,8 @@ type MultiplexingConfig struct { Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` // Max number of concurrent connections that one Mux connection can handle. Concurrency uint32 `protobuf:"varint,2,opt,name=concurrency,proto3" json:"concurrency,omitempty"` + // Both(0), TCP(1), UDP(2). + Only uint32 `protobuf:"varint,3,opt,name=only,proto3" json:"only,omitempty"` } func (x *MultiplexingConfig) Reset() { @@ -643,6 +645,13 @@ func (x *MultiplexingConfig) GetConcurrency() uint32 { return 0 } +func (x *MultiplexingConfig) GetOnly() uint32 { + if x != nil { + return x.Only + } + return 0 +} + type AllocationStrategy_AllocationStrategyConcurrency struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -856,21 +865,22 @@ var file_app_proxyman_config_proto_rawDesc = []byte{ 0x28, 0x0b, 0x32, 0x25, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x2e, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x11, 0x6d, 0x75, 0x6c, 0x74, 0x69, - 0x70, 0x6c, 0x65, 0x78, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x50, 0x0a, 0x12, + 0x70, 0x6c, 0x65, 0x78, 0x53, 0x65, 0x74, 0x74, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x64, 0x0a, 0x12, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x78, 0x69, 0x6e, 0x67, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x79, 0x2a, 0x23, - 0x0a, 0x0e, 0x4b, 0x6e, 0x6f, 0x77, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x73, - 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x4c, - 0x53, 0x10, 0x01, 0x42, 0x55, 0x0a, 0x15, 0x63, 0x6f, 0x6d, 0x2e, 0x78, 0x72, 0x61, 0x79, 0x2e, - 0x61, 0x70, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x50, 0x01, 0x5a, 0x26, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, 0x74, 0x6c, 0x73, 0x2f, - 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, 0x70, 0x2f, 0x70, 0x72, - 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0xaa, 0x02, 0x11, 0x58, 0x72, 0x61, 0x79, 0x2e, 0x41, 0x70, - 0x70, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x0d, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x63, 0x75, 0x72, 0x72, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x12, + 0x0a, 0x04, 0x6f, 0x6e, 0x6c, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x04, 0x6f, 0x6e, + 0x6c, 0x79, 0x2a, 0x23, 0x0a, 0x0e, 0x4b, 0x6e, 0x6f, 0x77, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x73, 0x12, 0x08, 0x0a, 0x04, 0x48, 0x54, 0x54, 0x50, 0x10, 0x00, 0x12, 0x07, + 0x0a, 0x03, 0x54, 0x4c, 0x53, 0x10, 0x01, 0x42, 0x55, 0x0a, 0x15, 0x63, 0x6f, 0x6d, 0x2e, 0x78, + 0x72, 0x61, 0x79, 0x2e, 0x61, 0x70, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, + 0x50, 0x01, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x78, + 0x74, 0x6c, 0x73, 0x2f, 0x78, 0x72, 0x61, 0x79, 0x2d, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x61, 0x70, + 0x70, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0xaa, 0x02, 0x11, 0x58, 0x72, 0x61, + 0x79, 0x2e, 0x41, 0x70, 0x70, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x6d, 0x61, 0x6e, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/app/proxyman/config.proto b/app/proxyman/config.proto index 24216d2c..54f63436 100644 --- a/app/proxyman/config.proto +++ b/app/proxyman/config.proto @@ -98,4 +98,6 @@ message MultiplexingConfig { bool enabled = 1; // Max number of concurrent connections that one Mux connection can handle. uint32 concurrency = 2; + // Both(0), TCP(1), UDP(2). + uint32 only = 3; } diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 42554b72..89e2862d 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -111,7 +111,7 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning() } h.mux = &mux.ClientManager{ - Enabled: h.senderSettings.MultiplexSettings.Enabled, + Enabled: config.Enabled, Picker: &mux.IncrementalWorkerPicker{ Factory: &mux.DialingWorkerFactory{ Proxy: proxyHandler, @@ -122,6 +122,7 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou }, }, }, + Only: config.Only, } } @@ -136,7 +137,9 @@ func (h *Handler) Tag() string { // Dispatch implements proxy.Outbound.Dispatch. func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { - if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) { + outbound := session.OutboundFromContext(ctx) + if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) && + (h.mux.Only == 0 || (outbound != nil && h.mux.Only == uint32(outbound.Target.Network))) { if err := h.mux.Dispatch(ctx, link); err != nil { err := newError("failed to process mux outbound traffic").Base(err) session.SubmitOutboundErrorToOriginator(ctx, err) diff --git a/common/mux/client.go b/common/mux/client.go index 2019738f..f933ef4c 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -14,6 +14,7 @@ import ( "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/signal/done" "github.com/xtls/xray-core/common/task" + "github.com/xtls/xray-core/common/xudp" "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet" @@ -23,6 +24,7 @@ import ( type ClientManager struct { Enabled bool // wheather mux is enabled from user config Picker WorkerPicker + Only uint32 } func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error { @@ -247,22 +249,20 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) { transferType = protocol.TransferTypePacket } s.transferType = transferType - writer := NewWriter(s.ID, dest, output, transferType) - defer s.Close() + writer := NewWriter(s.ID, dest, output, transferType, xudp.GetGlobalID(ctx)) + defer s.Close(false) defer writer.Close() newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx)) if err := writeFirstPayload(s.input, writer); err != nil { newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) writer.hasError = true - common.Interrupt(s.input) return } if err := buf.Copy(s.input, writer); err != nil { newError("failed to fetch all input").Base(err).WriteToLog(session.ExportIDToError(ctx)) writer.hasError = true - common.Interrupt(s.input) return } } @@ -335,15 +335,8 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere err := buf.Copy(rr, s.output) if err != nil && buf.IsWriteError(err) { newError("failed to write to downstream. closing session ", s.ID).Base(err).WriteToLog() - - // Notify remote peer to close this session. - closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream) - closingWriter.Close() - - drainErr := buf.Copy(rr, buf.Discard) - common.Interrupt(s.input) - s.Close() - return drainErr + s.Close(false) + return buf.Copy(rr, buf.Discard) } return err @@ -351,12 +344,7 @@ func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error { if s, found := m.sessionManager.Get(meta.SessionID); found { - if meta.Option.Has(OptionError) { - common.Interrupt(s.input) - common.Interrupt(s.output) - } - common.Interrupt(s.input) - s.Close() + s.Close(false) } if meta.Option.Has(OptionData) { return buf.Copy(NewStreamReader(reader), buf.Discard) diff --git a/common/mux/frame.go b/common/mux/frame.go index 30f3c1db..ab57d771 100644 --- a/common/mux/frame.go +++ b/common/mux/frame.go @@ -58,6 +58,7 @@ type FrameMetadata struct { SessionID uint16 Option bitmask.Byte SessionStatus SessionStatus + GlobalID [8]byte } func (f FrameMetadata) WriteTo(b *buf.Buffer) error { @@ -81,6 +82,9 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error { if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil { return err } + if b.UDP != nil { + b.Write(f.GlobalID[:]) + } } else if b.UDP != nil { b.WriteByte(byte(TargetNetworkUDP)) addrParser.WriteAddressPort(b, b.UDP.Address, b.UDP.Port) @@ -144,5 +148,10 @@ func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error { } } + if f.SessionStatus == SessionStatusNew && f.Option.Has(OptionData) && + f.Target.Network == net.Network_UDP && b.Len() >= 8 { + copy(f.GlobalID[:], b.Bytes()) + } + return nil } diff --git a/common/mux/mux_test.go b/common/mux/mux_test.go index 39def2ab..f326ffd7 100644 --- a/common/mux/mux_test.go +++ b/common/mux/mux_test.go @@ -32,13 +32,13 @@ func TestReaderWriter(t *testing.T) { pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024)) dest := net.TCPDestination(net.DomainAddress("example.com"), 80) - writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream) + writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream, [8]byte{}) dest2 := net.TCPDestination(net.LocalHostIP, 443) - writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream) + writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream, [8]byte{}) dest3 := net.TCPDestination(net.LocalHostIPv6, 18374) - writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream) + writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream, [8]byte{}) writePayload := func(writer *Writer, payload ...byte) error { b := buf.New() diff --git a/common/mux/server.go b/common/mux/server.go index df461be7..e64e038f 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -2,6 +2,7 @@ package mux import ( "context" + "fmt" "io" "github.com/xtls/xray-core/common" @@ -11,6 +12,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" + "github.com/xtls/xray-core/common/xudp" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport" @@ -99,7 +101,7 @@ func handle(ctx context.Context, s *Session, output buf.Writer) { } writer.Close() - s.Close() + s.Close(false) } func (w *ServerWorker) ActiveConnections() uint32 { @@ -131,6 +133,81 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, } ctx = log.ContextWithAccessMessage(ctx, msg) } + + if meta.GlobalID != [8]byte{} { + mb, err := NewPacketReader(reader, &meta.Target).ReadMultiBuffer() + if err != nil { + return err + } + XUDPManager.Lock() + x := XUDPManager.Map[meta.GlobalID] + if x == nil { + x = &XUDP{GlobalID: meta.GlobalID} + XUDPManager.Map[meta.GlobalID] = x + XUDPManager.Unlock() + } else { + if x.Status == Initializing { // nearly impossible + XUDPManager.Unlock() + if xudp.Show { + fmt.Printf("XUDP hit: %v err: conflict\n", meta.GlobalID) + } + // It's not a good idea to return an err here, so just let client wait. + // Client will receive an End frame after sending a Keep frame. + return nil + } + x.Status = Initializing + XUDPManager.Unlock() + x.Mux.Close(false) // detach from previous Mux + b := buf.New() + b.Write(mb[0].Bytes()) + b.UDP = mb[0].UDP + if err = x.Mux.output.WriteMultiBuffer(mb); err != nil { + x.Interrupt() + mb = buf.MultiBuffer{b} + } else { + b.Release() + mb = nil + } + if xudp.Show { + fmt.Printf("XUDP hit: %v err: %v\n", meta.GlobalID, err) + } + } + if mb != nil { + ctx = session.ContextWithTimeoutOnly(ctx, true) + // Actually, it won't return an error in Xray-core's implementations. + link, err := w.dispatcher.Dispatch(ctx, meta.Target) + if err != nil { + err = newError("failed to dispatch request to ", meta.Target).Base(err) + if xudp.Show { + fmt.Printf("XUDP new: %v err: %v\n", meta.GlobalID, err) + } + return err // it will break the whole Mux connection + } + link.Writer.WriteMultiBuffer(mb) // it's meaningless to test a new pipe + x.Mux = &Session{ + input: link.Reader, + output: link.Writer, + } + if xudp.Show { + fmt.Printf("XUDP new: %v err: %v\n", meta.GlobalID, err) + } + } + x.Mux = &Session{ + input: x.Mux.input, + output: x.Mux.output, + parent: w.sessionManager, + ID: meta.SessionID, + transferType: protocol.TransferTypePacket, + XUDP: x, + } + go handle(ctx, x.Mux, w.link.Writer) + x.Status = Active + if !w.sessionManager.Add(x.Mux) { + x.Mux.Close(false) + } + return nil + } + link, err := w.dispatcher.Dispatch(ctx, meta.Target) if err != nil { if meta.Option.Has(OptionData) { @@ -157,8 +234,7 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, rr := s.NewReader(reader, &meta.Target) if err := buf.Copy(rr, s.output); err != nil { buf.Copy(rr, buf.Discard) - common.Interrupt(s.input) - return s.Close() + return s.Close(false) } return nil } @@ -182,15 +258,8 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere if err != nil && buf.IsWriteError(err) { newError("failed to write to downstream writer. closing session ", s.ID).Base(err).WriteToLog() - - // Notify remote peer to close this session. - closingWriter := NewResponseWriter(meta.SessionID, w.link.Writer, protocol.TransferTypeStream) - closingWriter.Close() - - drainErr := buf.Copy(rr, buf.Discard) - common.Interrupt(s.input) - s.Close() - return drainErr + s.Close(false) + return buf.Copy(rr, buf.Discard) } return err @@ -198,12 +267,7 @@ func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.Buffere func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error { if s, found := w.sessionManager.Get(meta.SessionID); found { - if meta.Option.Has(OptionError) { - common.Interrupt(s.input) - common.Interrupt(s.output) - } - common.Interrupt(s.input) - s.Close() + s.Close(false) } if meta.Option.Has(OptionData) { return buf.Copy(NewStreamReader(reader), buf.Discard) diff --git a/common/mux/session.go b/common/mux/session.go index 2f21b97a..650e3545 100644 --- a/common/mux/session.go +++ b/common/mux/session.go @@ -1,12 +1,18 @@ package mux import ( + "fmt" + "io" + "runtime" "sync" + "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/xudp" + "github.com/xtls/xray-core/transport/pipe" ) type SessionManager struct { @@ -61,21 +67,25 @@ func (m *SessionManager) Allocate() *Session { return s } -func (m *SessionManager) Add(s *Session) { +func (m *SessionManager) Add(s *Session) bool { m.Lock() defer m.Unlock() if m.closed { - return + return false } m.count++ m.sessions[s.ID] = s + return true } -func (m *SessionManager) Remove(id uint16) { - m.Lock() - defer m.Unlock() +func (m *SessionManager) Remove(locked bool, id uint16) { + if !locked { + m.Lock() + defer m.Unlock() + } + locked = true if m.closed { return @@ -83,9 +93,11 @@ func (m *SessionManager) Remove(id uint16) { delete(m.sessions, id) - if len(m.sessions) == 0 { - m.sessions = make(map[uint16]*Session, 16) - } + /* + if len(m.sessions) == 0 { + m.sessions = make(map[uint16]*Session, 16) + } + */ } func (m *SessionManager) Get(id uint16) (*Session, bool) { @@ -127,8 +139,7 @@ func (m *SessionManager) Close() error { m.closed = true for _, s := range m.sessions { - common.Close(s.input) - common.Close(s.output) + s.Close(true) } m.sessions = nil @@ -142,13 +153,42 @@ type Session struct { parent *SessionManager ID uint16 transferType protocol.TransferType + closed bool + XUDP *XUDP } // Close closes all resources associated with this session. -func (s *Session) Close() error { - common.Close(s.output) - common.Close(s.input) - s.parent.Remove(s.ID) +func (s *Session) Close(locked bool) error { + if !locked { + s.parent.Lock() + defer s.parent.Unlock() + } + locked = true + if s.closed { + return nil + } + s.closed = true + if s.XUDP == nil { + common.Interrupt(s.input) + common.Close(s.output) + } else { + // Stop existing handle(), then trigger writer.Close(). + // Note that s.output may be dispatcher.SizeStatWriter. + s.input.(*pipe.Reader).ReturnAnError(io.EOF) + runtime.Gosched() + // If the error set by ReturnAnError still exists, clear it. + s.input.(*pipe.Reader).Recover() + XUDPManager.Lock() + if s.XUDP.Status == Active { + s.XUDP.Expire = time.Now().Add(time.Minute) + s.XUDP.Status = Expiring + if xudp.Show { + fmt.Printf("XUDP put: %v\n", s.XUDP.GlobalID) + } + } + XUDPManager.Unlock() + } + s.parent.Remove(locked, s.ID) return nil } @@ -159,3 +199,47 @@ func (s *Session) NewReader(reader *buf.BufferedReader, dest *net.Destination) b } return NewPacketReader(reader, dest) } + +const ( + Initializing = 0 + Active = 1 + Expiring = 2 +) + +type XUDP struct { + GlobalID [8]byte + Status uint64 + Expire time.Time + Mux *Session +} + +func (x *XUDP) Interrupt() { + common.Interrupt(x.Mux.input) + common.Close(x.Mux.output) +} + +var XUDPManager struct { + sync.Mutex + Map map[[8]byte]*XUDP +} + +func init() { + XUDPManager.Map = make(map[[8]byte]*XUDP) + go func() { + for { + time.Sleep(time.Minute) + now := time.Now() + XUDPManager.Lock() + for id, x := range XUDPManager.Map { + if x.Status == Expiring && now.After(x.Expire) { + x.Interrupt() + delete(XUDPManager.Map, id) + if xudp.Show { + fmt.Printf("XUDP del: %v\n", id) + } + } + } + XUDPManager.Unlock() + } + }() +} diff --git a/common/mux/session_test.go b/common/mux/session_test.go index 7139df10..d81ad8c4 100644 --- a/common/mux/session_test.go +++ b/common/mux/session_test.go @@ -44,7 +44,7 @@ func TestSessionManagerClose(t *testing.T) { if m.CloseIfNoSession() { t.Error("able to close") } - m.Remove(s.ID) + m.Remove(false, s.ID) if !m.CloseIfNoSession() { t.Error("not able to close") } diff --git a/common/mux/writer.go b/common/mux/writer.go index f7a22b2d..a6dc551d 100644 --- a/common/mux/writer.go +++ b/common/mux/writer.go @@ -15,15 +15,17 @@ type Writer struct { followup bool hasError bool transferType protocol.TransferType + globalID [8]byte } -func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType) *Writer { +func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType, globalID [8]byte) *Writer { return &Writer{ id: id, dest: dest, writer: writer, followup: false, transferType: transferType, + globalID: globalID, } } @@ -40,6 +42,7 @@ func (w *Writer) getNextFrameMeta() FrameMetadata { meta := FrameMetadata{ SessionID: w.id, Target: w.dest, + GlobalID: w.globalID, } if w.followup { diff --git a/common/session/context.go b/common/session/context.go index 2959807e..71e4b154 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -2,10 +2,14 @@ package session import ( "context" + _ "unsafe" "github.com/xtls/xray-core/features/routing" ) +//go:linkname IndependentCancelCtx context.newCancelCtx +func IndependentCancelCtx(parent context.Context) context.Context + type sessionKey int const ( @@ -17,6 +21,7 @@ const ( sockoptSessionKey trackedConnectionErrorKey dispatcherKey + timeoutOnlyKey ) // ContextWithID returns a new context with the given ID. @@ -131,3 +136,14 @@ func DispatcherFromContext(ctx context.Context) routing.Dispatcher { } return nil } + +func ContextWithTimeoutOnly(ctx context.Context, only bool) context.Context { + return context.WithValue(ctx, timeoutOnlyKey, only) +} + +func TimeoutOnlyFromContext(ctx context.Context) bool { + if val, ok := ctx.Value(timeoutOnlyKey).(bool); ok { + return val + } + return false +} diff --git a/common/session/session.go b/common/session/session.go index 656a2404..83c48fde 100644 --- a/common/session/session.go +++ b/common/session/session.go @@ -42,6 +42,8 @@ type Inbound struct { Gateway net.Destination // Tag of the inbound proxy that handles the connection. Tag string + // Name of the inbound proxy that handles the connection. + Name string // User is the user that authencates for the inbound. May be nil if the protocol allows anounymous traffic. User *protocol.MemoryUser // Conn is actually internet.Connection. May be nil. diff --git a/common/task/task.go b/common/task/task.go index 52b0d44b..eeba1dcd 100644 --- a/common/task/task.go +++ b/common/task/task.go @@ -38,6 +38,12 @@ func Run(ctx context.Context, tasks ...func() error) error { }(task) } + /* + if altctx := ctx.Value("altctx"); altctx != nil { + ctx = altctx.(context.Context) + } + */ + for i := 0; i < n; i++ { select { case err := <-done: @@ -48,5 +54,11 @@ func Run(ctx context.Context, tasks ...func() error) error { } } + /* + if cancel := ctx.Value("cancel"); cancel != nil { + cancel.(context.CancelFunc)() + } + */ + return nil } diff --git a/common/xudp/xudp.go b/common/xudp/xudp.go index 80a35e41..65096d16 100644 --- a/common/xudp/xudp.go +++ b/common/xudp/xudp.go @@ -1,30 +1,76 @@ package xudp import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" "io" + "os" + "strings" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" + "lukechampine.com/blake3" ) -var addrParser = protocol.NewAddressParser( +var AddrParser = protocol.NewAddressParser( protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv4), net.AddressFamilyIPv4), protocol.AddressFamilyByte(byte(protocol.AddressTypeDomain), net.AddressFamilyDomain), protocol.AddressFamilyByte(byte(protocol.AddressTypeIPv6), net.AddressFamilyIPv6), protocol.PortThenAddress(), ) -func NewPacketWriter(writer buf.Writer, dest net.Destination) *PacketWriter { +var ( + Show bool + BaseKey [32]byte +) + +const ( + EnvShow = "XRAY_XUDP_SHOW" + EnvBaseKey = "XRAY_XUDP_BASEKEY" +) + +func init() { + if strings.ToLower(os.Getenv(EnvShow)) == "true" { + Show = true + } + if raw := os.Getenv(EnvBaseKey); raw != "" { + if key, _ := base64.RawURLEncoding.DecodeString(raw); len(key) == len(BaseKey) { + copy(BaseKey[:], key) + return + } else { + panic(EnvBaseKey + ": invalid value: " + raw) + } + } + rand.Read(BaseKey[:]) +} + +func GetGlobalID(ctx context.Context) (globalID [8]byte) { + if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.Network == net.Network_UDP && + (inbound.Name == "dokodemo-door" || inbound.Name == "socks" || inbound.Name == "shadowsocks") { + h := blake3.New(8, BaseKey[:]) + h.Write([]byte(inbound.Source.String())) + copy(globalID[:], h.Sum(nil)) + fmt.Printf("XUDP inbound.Source.String(): %v\tglobalID: %v\n", inbound.Source.String(), globalID) + } + return +} + +func NewPacketWriter(writer buf.Writer, dest net.Destination, globalID [8]byte) *PacketWriter { return &PacketWriter{ - Writer: writer, - Dest: dest, + Writer: writer, + Dest: dest, + GlobalID: globalID, } } type PacketWriter struct { - Writer buf.Writer - Dest net.Destination + Writer buf.Writer + Dest net.Destination + GlobalID [8]byte } func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { @@ -42,14 +88,17 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { eb.WriteByte(1) // New eb.WriteByte(1) // Opt eb.WriteByte(2) // UDP - addrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port) + AddrParser.WriteAddressPort(eb, w.Dest.Address, w.Dest.Port) + if b.UDP != nil { // make sure it's user's proxy request + eb.Write(w.GlobalID[:]) + } w.Dest.Network = net.Network_Unknown } else { eb.WriteByte(2) // Keep eb.WriteByte(1) if b.UDP != nil { eb.WriteByte(2) - addrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port) + AddrParser.WriteAddressPort(eb, b.UDP.Address, b.UDP.Port) } } l := eb.Len() - 2 @@ -98,7 +147,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { case 2: if l != 4 { b.Advance(5) - addr, port, err := addrParser.ReadAddressPort(nil, b) + addr, port, err := AddrParser.ReadAddressPort(nil, b) if err != nil { b.Release() return nil, err diff --git a/go.mod b/go.mod index 7d534bac..76d7c0d7 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( google.golang.org/protobuf v1.30.0 gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h12.io/socks v1.0.3 + lukechampine.com/blake3 v1.1.7 ) require ( @@ -55,5 +56,4 @@ require ( google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - lukechampine.com/blake3 v1.1.7 // indirect ) diff --git a/infra/conf/xray.go b/infra/conf/xray.go index 949e5534..2306e380 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -10,6 +10,7 @@ import ( "github.com/xtls/xray-core/app/dispatcher" "github.com/xtls/xray-core/app/proxyman" "github.com/xtls/xray-core/app/stats" + "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/serial" core "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/transport/internet" @@ -107,8 +108,9 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) { } type MuxConfig struct { - Enabled bool `json:"enabled"` - Concurrency int16 `json:"concurrency"` + Enabled bool `json:"enabled"` + Concurrency int16 `json:"concurrency"` + Only string `json:"only"` } // Build creates MultiplexingConfig, Concurrency < 0 completely disables mux. @@ -116,16 +118,23 @@ func (m *MuxConfig) Build() *proxyman.MultiplexingConfig { if m.Concurrency < 0 { return nil } - - var con uint32 = 8 - if m.Concurrency > 0 { - con = uint32(m.Concurrency) + if m.Concurrency == 0 { + m.Concurrency = 8 } - return &proxyman.MultiplexingConfig{ + config := &proxyman.MultiplexingConfig{ Enabled: m.Enabled, - Concurrency: con, + Concurrency: uint32(m.Concurrency), + } + + switch strings.ToLower(m.Only) { + case "tcp": + config.Only = uint32(net.Network_TCP) + case "udp": + config.Only = uint32(net.Network_UDP) } + + return config } type InboundDetourAllocationConfig struct { diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index ae123c28..be05e4f7 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -148,6 +148,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. } } + if session.TimeoutOnlyFromContext(ctx) { + ctx, _ = context.WithCancel(context.Background()) + } + ctx, cancel := context.WithCancel(ctx) timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index d0fb69f9..42d8256f 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -103,6 +103,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st inbound := session.InboundFromContext(ctx) if inbound != nil { + inbound.Name = "dokodemo-door" inbound.User = &protocol.MemoryUser{ Level: d.config.UserLevel, } diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 15ebc22b..8630ab9c 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -149,9 +149,20 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } defer conn.Close() + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + plcy := h.policy() ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, plcy.Timeouts.ConnectionIdle) requestDone := func() error { defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) @@ -186,6 +197,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return nil } + if newCtx != nil { + ctx = newCtx + } + if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil { return newError("connection ends").Base(err) } diff --git a/proxy/http/client.go b/proxy/http/client.go index 71a10e69..b1661011 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -128,8 +128,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter p = c.policyManager.ForLevel(user.Level) } + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, p.Timeouts.ConnectionIdle) requestFunc := func() error { defer timer.SetTimeout(p.Timeouts.DownlinkOnly) @@ -140,6 +151,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) } + if newCtx != nil { + ctx = newCtx + } + responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { return newError("connection ends").Base(err) diff --git a/proxy/http/server.go b/proxy/http/server.go index cdcf2e3a..6b00fe2b 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -85,6 +85,7 @@ type readerOnly struct { func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) if inbound != nil { + inbound.Name = "http" inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/shadowsocks/client.go b/proxy/shadowsocks/client.go index 2d8a4e81..e22b11c7 100644 --- a/proxy/shadowsocks/client.go +++ b/proxy/shadowsocks/client.go @@ -96,9 +96,24 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } request.User = user + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + sessionPolicy := c.policyManager.ForLevel(user.Level) ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, sessionPolicy.Timeouts.ConnectionIdle) + + if newCtx != nil { + ctx = newCtx + } if request.Command == protocol.RequestCommandTCP { requestDone := func() error { diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index 140c6704..1d89db5e 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -113,6 +113,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis if inbound == nil { panic("no inbound metadata") } + inbound.Name = "shadowsocks" var dest *net.Destination diff --git a/proxy/shadowsocks_2022/inbound.go b/proxy/shadowsocks_2022/inbound.go index 550aadd1..1c2ae1d2 100644 --- a/proxy/shadowsocks_2022/inbound.go +++ b/proxy/shadowsocks_2022/inbound.go @@ -3,7 +3,7 @@ package shadowsocks_2022 import ( "context" - "github.com/sagernet/sing-shadowsocks" + shadowsocks "github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks/shadowaead_2022" C "github.com/sagernet/sing/common" B "github.com/sagernet/sing/common/buf" @@ -64,6 +64,7 @@ func (i *Inbound) Network() []net.Network { func (i *Inbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) + inbound.Name = "shadowsocks-2022" var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_multi.go b/proxy/shadowsocks_2022/inbound_multi.go index 695de8e2..77a34427 100644 --- a/proxy/shadowsocks_2022/inbound_multi.go +++ b/proxy/shadowsocks_2022/inbound_multi.go @@ -153,6 +153,7 @@ func (i *MultiUserInbound) Network() []net.Network { func (i *MultiUserInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) + inbound.Name = "shadowsocks-2022-multi" var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/inbound_relay.go b/proxy/shadowsocks_2022/inbound_relay.go index 3e0043ee..d07babb8 100644 --- a/proxy/shadowsocks_2022/inbound_relay.go +++ b/proxy/shadowsocks_2022/inbound_relay.go @@ -85,6 +85,7 @@ func (i *RelayInbound) Network() []net.Network { func (i *RelayInbound) Process(ctx context.Context, network net.Network, connection stat.Connection, dispatcher routing.Dispatcher) error { inbound := session.InboundFromContext(ctx) + inbound.Name = "shadowsocks-2022-relay" var metadata M.Metadata if inbound.Source.IsValid() { diff --git a/proxy/shadowsocks_2022/outbound.go b/proxy/shadowsocks_2022/outbound.go index eb38c017..41e239dc 100644 --- a/proxy/shadowsocks_2022/outbound.go +++ b/proxy/shadowsocks_2022/outbound.go @@ -6,7 +6,7 @@ import ( "runtime" "time" - "github.com/sagernet/sing-shadowsocks" + shadowsocks "github.com/sagernet/sing-shadowsocks" "github.com/sagernet/sing-shadowsocks/shadowaead_2022" C "github.com/sagernet/sing/common" B "github.com/sagernet/sing/common/buf" @@ -88,6 +88,10 @@ func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer int return newError("failed to connect to server").Base(err) } + if session.TimeoutOnlyFromContext(ctx) { + ctx, _ = context.WithCancel(context.Background()) + } + if network == net.Network_TCP { serverConn := o.method.DialEarlyConn(connection, toSocksaddr(destination)) var handshake bool diff --git a/proxy/socks/client.go b/proxy/socks/client.go index f1690bec..1993aa0b 100644 --- a/proxy/socks/client.go +++ b/proxy/socks/client.go @@ -151,8 +151,19 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter newError("failed to clear deadline after handshake").Base(err).WriteToLog(session.ExportIDToError(ctx)) } + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, p.Timeouts.ConnectionIdle) var requestFunc func() error var responseFunc func() error @@ -183,6 +194,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter } } + if newCtx != nil { + ctx = newCtx + } + responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { return newError("connection ends").Base(err) diff --git a/proxy/socks/server.go b/proxy/socks/server.go index ce15163c..184ecd08 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -64,6 +64,7 @@ func (s *Server) Network() []net.Network { // Process implements proxy.Inbound. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { if inbound := session.InboundFromContext(ctx); inbound != nil { + inbound.Name = "socks" inbound.User = &protocol.MemoryUser{ Level: s.config.UserLevel, } diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index ffd10359..2605239d 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -93,9 +93,20 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter Flow: account.Flow, } + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + sessionPolicy := c.policyManager.ForLevel(user.Level) ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, sessionPolicy.Timeouts.ConnectionIdle) postRequest := func() error { defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) @@ -149,6 +160,10 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)) } + if newCtx != nil { + ctx = newCtx + } + responseDoneAndCloseWriter := task.OnSuccess(getResponse, task.Close(link.Writer)) if err := task.Run(ctx, postRequest, responseDoneAndCloseWriter); err != nil { return newError("connection ends").Base(err) diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go index 029d4eff..368374ff 100644 --- a/proxy/trojan/server.go +++ b/proxy/trojan/server.go @@ -217,6 +217,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con if inbound == nil { panic("no inbound metadata") } + inbound.Name = "trojan" inbound.User = user sessionPolicy = s.policyManager.ForLevel(user.Level) diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index b3def4bb..c8a69444 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -438,6 +438,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if inbound == nil { panic("no inbound metadata") } + inbound.Name = "vless" inbound.User = request.User account := request.User.Account.(*vless.MemoryAccount) diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index b7bc6964..cb2a1b76 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -170,9 +170,20 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } } + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + sessionPolicy := h.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, sessionPolicy.Timeouts.ConnectionIdle) clientReader := link.Reader // .(*pipe.Reader) clientWriter := link.Writer // .(*pipe.Writer) @@ -200,7 +211,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte // default: serverWriter := bufferWriter serverWriter := encoding.EncodeBodyAddons(bufferWriter, request, requestAddons) if request.Command == protocol.RequestCommandMux && request.Port == 666 { - serverWriter = xudp.NewPacketWriter(serverWriter, target) + serverWriter = xudp.NewPacketWriter(serverWriter, target, xudp.GetGlobalID(ctx)) } userUUID := account.ID.Bytes() timeoutReader, ok := clientReader.(buf.TimeoutReader) @@ -300,6 +311,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return nil } + if newCtx != nil { + ctx = newCtx + } + if err := task.Run(ctx, postRequest, task.OnSuccess(getResponse, task.Close(clientWriter))); err != nil { return newError("connection ends").Base(err).AtInfo() } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 00b07f14..eb24a6c6 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -287,6 +287,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection s if inbound == nil { panic("no inbound metadata") } + inbound.Name = "vmess" inbound.User = request.User sessionPolicy = h.policyManager.ForLevel(request.User.Level) diff --git a/proxy/vmess/outbound/outbound.go b/proxy/vmess/outbound/outbound.go index e7c6466e..64c29225 100644 --- a/proxy/vmess/outbound/outbound.go +++ b/proxy/vmess/outbound/outbound.go @@ -138,11 +138,22 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte behaviorSeed := crc64.Checksum(hashkdf.Sum(nil), crc64.MakeTable(crc64.ISO)) + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + session := encoding.NewClientSession(ctx, isAEAD, protocol.DefaultIDHash, int64(behaviorSeed)) sessionPolicy := h.policyManager.ForLevel(request.User.Level) ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, sessionPolicy.Timeouts.ConnectionIdle) if request.Command == protocol.RequestCommandUDP && h.cone && request.Port != 53 && request.Port != 443 { request.Command = protocol.RequestCommandMux @@ -164,7 +175,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } bodyWriter2 := bodyWriter if request.Command == protocol.RequestCommandMux && request.Port == 666 { - bodyWriter = xudp.NewPacketWriter(bodyWriter, target) + bodyWriter = xudp.NewPacketWriter(bodyWriter, target, xudp.GetGlobalID(ctx)) } if err := buf.CopyOnceTimeout(input, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { return newError("failed to write first payload").Base(err) @@ -208,6 +219,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte return buf.Copy(bodyReader, output, buf.UpdateActivity(timer)) } + if newCtx != nil { + ctx = newCtx + } + responseDonePost := task.OnSuccess(responseDone, task.Close(output)) if err := task.Run(ctx, requestDone, responseDonePost); err != nil { return newError("connection ends").Base(err) diff --git a/proxy/wireguard/wireguard.go b/proxy/wireguard/wireguard.go index 2b7e1c87..0d4994f5 100644 --- a/proxy/wireguard/wireguard.go +++ b/proxy/wireguard/wireguard.go @@ -127,10 +127,21 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte addr = net.IPAddress(ips[0]) } + var newCtx context.Context + var newCancel context.CancelFunc + if session.TimeoutOnlyFromContext(ctx) { + newCtx, newCancel = context.WithCancel(context.Background()) + } + p := h.policyManager.ForLevel(0) ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, p.Timeouts.ConnectionIdle) + timer := signal.CancelAfterInactivity(ctx, func() { + cancel() + if newCancel != nil { + newCancel() + } + }, p.Timeouts.ConnectionIdle) addrPort := netip.AddrPortFrom(toNetIpAddr(addr), destination.Port.Value()) var requestFunc func() error @@ -166,6 +177,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte } } + if newCtx != nil { + ctx = newCtx + } + responseDonePost := task.OnSuccess(responseFunc, task.Close(link.Writer)) if err := task.Run(ctx, requestFunc, responseDonePost); err != nil { return newError("connection ends").Base(err) diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index 14a18e63..a60bc485 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -37,6 +37,7 @@ type pipe struct { readSignal *signal.Notifier writeSignal *signal.Notifier done *done.Instance + errChan chan error option pipeOption state state } @@ -92,6 +93,8 @@ func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) { select { case <-p.readSignal.Wait(): case <-p.done.Wait(): + case err = <-p.errChan: + return nil, err } } } diff --git a/transport/pipe/pipe.go b/transport/pipe/pipe.go index 0b22c2db..735cc091 100644 --- a/transport/pipe/pipe.go +++ b/transport/pipe/pipe.go @@ -59,6 +59,7 @@ func New(opts ...Option) (*Reader, *Writer) { readSignal: signal.NewNotifier(), writeSignal: signal.NewNotifier(), done: done.New(), + errChan: make(chan error, 1), option: pipeOption{ limit: -1, }, diff --git a/transport/pipe/reader.go b/transport/pipe/reader.go index 66733436..79f0ac03 100644 --- a/transport/pipe/reader.go +++ b/transport/pipe/reader.go @@ -25,3 +25,17 @@ func (r *Reader) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error func (r *Reader) Interrupt() { r.pipe.Interrupt() } + +// ReturnAnError makes ReadMultiBuffer return an error, only once. +func (r *Reader) ReturnAnError(err error) { + r.pipe.errChan <- err +} + +// Recover catches an error set by ReturnAnError, if exists. +func (r *Reader) Recover() (err error) { + select { + case err = <-r.pipe.errChan: + default: + } + return +}