From 56a45ad57893919bd43d7e764ab32bc5b1fb0c1e Mon Sep 17 00:00:00 2001 From: RPRX <63339210+RPRX@users.noreply.github.com> Date: Fri, 29 Aug 2025 12:35:56 +0000 Subject: [PATCH] First step of upcoming refactor for Xray-core: Add TimeoutWrapperReader; Use DispatchLink() in Tunnel/Socks/HTTP inbounds https://github.com/XTLS/Xray-core/pull/5067#issuecomment-3236833240 Fixes https://github.com/XTLS/Xray-core/pull/4952#issuecomment-3229878125 for client's Xray-core --- app/dispatcher/default.go | 8 ++- common/buf/io.go | 37 ++++++++++++ common/mux/client.go | 6 +- common/mux/server.go | 9 ++- proxy/dokodemo/dokodemo.go | 112 +++++-------------------------------- proxy/http/server.go | 61 +++++--------------- proxy/socks/server.go | 66 ++++------------------ 7 files changed, 93 insertions(+), 206 deletions(-) diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 544a0956..74e55ab4 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -29,7 +29,7 @@ var errSniffingTimeout = errors.New("timeout on sniffing") type cachedReader struct { sync.Mutex - reader *pipe.Reader + reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader cache buf.MultiBuffer } @@ -87,7 +87,9 @@ func (r *cachedReader) Interrupt() { r.cache = buf.ReleaseMulti(r.cache) } r.Unlock() - r.reader.Interrupt() + if p, ok := r.reader.(*pipe.Reader); ok { + p.Interrupt() + } } // DefaultDispatcher is a default implementation of Dispatcher. @@ -319,7 +321,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De d.routedDispatch(ctx, outbound, destination) } else { cReader := &cachedReader{ - reader: outbound.Reader.(*pipe.Reader), + reader: outbound.Reader.(buf.TimeoutReader), } outbound.Reader = cReader result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) diff --git a/common/buf/io.go b/common/buf/io.go index 0974b4f3..0b3cc6b2 100644 --- a/common/buf/io.go +++ b/common/buf/io.go @@ -24,9 +24,46 @@ var ErrReadTimeout = errors.New("IO timeout") // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout. type TimeoutReader interface { + Reader ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error) } +type TimeoutWrapperReader struct { + Reader + mb MultiBuffer + err error + done chan struct{} +} + +func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) { + if r.done != nil { + <-r.done + r.done = nil + return r.mb, r.err + } + r.mb = nil + r.err = nil + return r.Reader.ReadMultiBuffer() +} + +func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) { + if r.done == nil { + r.done = make(chan struct{}) + go func() { + r.mb, r.err = r.Reader.ReadMultiBuffer() + close(r.done) + }() + } + time.Sleep(duration) + select { + case <-r.done: + r.done = nil + return r.mb, r.err + default: + return nil, nil + } +} + // Writer extends io.Writer with MultiBuffer. type Writer interface { // WriteMultiBuffer writes a MultiBuffer into underlying writer. diff --git a/common/mux/client.go b/common/mux/client.go index 764cc4d9..6987f762 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -307,7 +307,11 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool } s.input = link.Reader s.output = link.Writer - go fetchInput(ctx, s, m.link.Writer) + if _, ok := link.Reader.(*pipe.Reader); ok { + go fetchInput(ctx, s, m.link.Writer) + } else { + fetchInput(ctx, s, m.link.Writer) + } return true } diff --git a/common/mux/server.go b/common/mux/server.go index 99a144a5..0a632e81 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -87,7 +87,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport. link: link, sessionManager: NewSessionManager(), } - go worker.run(ctx) + if inbound := session.InboundFromContext(ctx); inbound != nil { + inbound.CanSpliceCopy = 3 + } + if _, ok := link.Reader.(*pipe.Reader); ok { + go worker.run(ctx) + } else { + worker.run(ctx) + } return worker, nil } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 2d553300..c14fe34d 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -2,10 +2,8 @@ package dokodemo import ( "context" - "runtime" "strconv" "strings" - "sync/atomic" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/buf" @@ -14,11 +12,10 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal" - "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" + "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/tls" ) @@ -144,39 +141,11 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st }) errors.LogInfo(ctx, "received request for ", conn.RemoteAddr()) - plcy := d.policy() - ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) - - if inbound != nil { - inbound.Timer = timer - } - - ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer) - link, err := dispatcher.Dispatch(ctx, dest) - if err != nil { - return errors.New("failed to dispatch request").Base(err) - } - - requestCount := int32(1) - requestDone := func() error { - defer func() { - if atomic.AddInt32(&requestCount, -1) == 0 { - timer.SetTimeout(plcy.Timeouts.DownlinkOnly) - } - }() - - var reader buf.Reader - if dest.Network == net.Network_UDP { - reader = buf.NewPacketReader(conn) - } else { - reader = buf.NewReader(conn) - } - if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil { - return errors.New("failed to transport request").Base(err) - } - - return nil + var reader buf.Reader + if dest.Network == net.Network_TCP { + reader = buf.NewReader(conn) + } else { + reader = buf.NewPacketReader(conn) } var writer buf.Writer @@ -208,72 +177,17 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st return err } writer = NewPacketWriter(pConn, &dest, mark, back) - defer func() { - runtime.Gosched() - common.Interrupt(link.Reader) // maybe duplicated - runtime.Gosched() - writer.(*PacketWriter).Close() // close fake UDP conns - }() - /* - sockopt := &internet.SocketConfig{ - Tproxy: internet.SocketConfig_TProxy, - } - if dest.Address.Family().IsIP() { - sockopt.BindAddress = dest.Address.IP() - sockopt.BindPort = uint32(dest.Port) - } - if d.sockopt != nil { - sockopt.Mark = d.sockopt.Mark - } - tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt) - if err != nil { - return err - } - defer tConn.Close() - - writer = &buf.SequentialWriter{Writer: tConn} - tReader := buf.NewPacketReader(tConn) - requestCount++ - tproxyRequest = func() error { - defer func() { - if atomic.AddInt32(&requestCount, -1) == 0 { - timer.SetTimeout(plcy.Timeouts.DownlinkOnly) - } - }() - if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil { - return errors.New("failed to transport request (TPROXY conn)").Base(err) - } - return nil - } - */ + defer writer.(*PacketWriter).Close() // close fake UDP conns } } - responseDone := func() error { - defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) - - if network == net.Network_UDP && destinationOverridden { - buf.Copy(link.Reader, writer) // respect upload's timeout - return nil - } - - if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil { - return errors.New("failed to transport response").Base(err) - } - return nil + if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{ + Reader: &buf.TimeoutWrapperReader{Reader: reader}, + Writer: writer}, + ); err != nil { + return errors.New("failed to dispatch request").Base(err) } - - if err := task.Run(ctx, - task.OnSuccess(func() error { return task.Run(ctx, requestDone) }, task.Close(link.Writer)), - responseDone); err != nil { - runtime.Gosched() - common.Interrupt(link.Writer) - runtime.Gosched() - common.Interrupt(link.Reader) - return errors.New("connection ends").Base(err) - } - - return nil + return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process() } func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer { diff --git a/proxy/http/server.go b/proxy/http/server.go index 8d6290a3..12a5292c 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -18,12 +18,12 @@ import ( "github.com/xtls/xray-core/common/protocol" http_proto "github.com/xtls/xray-core/common/protocol/http" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal" "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/proxy" + "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet/stat" ) @@ -173,64 +173,31 @@ Start: return err } -func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { +func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { _, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) if err != nil { return errors.New("failed to write back OK response").Base(err) } - plcy := s.policy() - ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) - - if inbound != nil { - inbound.Timer = timer - } - - ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer) - link, err := dispatcher.Dispatch(ctx, dest) - if err != nil { - return err - } - - if reader.Buffered() > 0 { - payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered()))) + reader := buf.NewReader(conn) + if buffer.Buffered() > 0 { + payload, err := buf.ReadFrom(io.LimitReader(buffer, int64(buffer.Buffered()))) if err != nil { return err } - if err := link.Writer.WriteMultiBuffer(payload); err != nil { - return err - } - reader = nil + reader = &buf.BufferedReader{Reader: reader, Buffer: payload} + buffer = nil } - requestDone := func() error { - defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) - - return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)) + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 } - - responseDone := func() error { - if inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 - } - defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) - - v2writer := buf.NewWriter(conn) - if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil { - return err - } - - return nil + if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{ + Reader: &buf.TimeoutWrapperReader{Reader: reader}, + Writer: buf.NewWriter(conn)}, + ); err != nil { + return errors.New("failed to dispatch request").Base(err) } - - closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer)) - if err := task.Run(ctx, closeWriter, responseDone); err != nil { - common.Interrupt(link.Reader) - common.Interrupt(link.Writer) - return errors.New("connection ends").Base(err) - } - return nil } diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 08e2e657..13d98851 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -14,13 +14,12 @@ import ( "github.com/xtls/xray-core/common/protocol" udp_proto "github.com/xtls/xray-core/common/protocol/udp" "github.com/xtls/xray-core/common/session" - "github.com/xtls/xray-core/common/signal" - "github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy/http" + "github.com/xtls/xray-core/transport" "github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/udp" ) @@ -158,8 +157,16 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche Reason: "", }) } - - return s.transport(ctx, reader, conn, dest, dispatcher, inbound) + if inbound.CanSpliceCopy == 2 { + inbound.CanSpliceCopy = 1 + } + if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{ + Reader: &buf.TimeoutWrapperReader{Reader: reader}, + Writer: buf.NewWriter(conn)}, + ); err != nil { + return errors.New("failed to dispatch request").Base(err) + } + return nil } if request.Command == protocol.RequestCommandUDP { @@ -178,54 +185,6 @@ func (*Server) handleUDP(c io.Reader) error { return common.Error2(io.Copy(buf.DiscardBytes, c)) } -func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { - ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle) - - if inbound != nil { - inbound.Timer = timer - } - - plcy := s.policy() - ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer) - link, err := dispatcher.Dispatch(ctx, dest) - if err != nil { - return err - } - - requestDone := func() error { - defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) - if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil { - return errors.New("failed to transport all TCP request").Base(err) - } - - return nil - } - - responseDone := func() error { - if inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 - } - defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) - - v2writer := buf.NewWriter(writer) - if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil { - return errors.New("failed to transport all TCP response").Base(err) - } - - return nil - } - - requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) - if err := task.Run(ctx, requestDonePost, responseDone); err != nil { - common.Interrupt(link.Reader) - common.Interrupt(link.Writer) - return errors.New("connection ends").Base(err) - } - - return nil -} - func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) { errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String()) @@ -265,9 +224,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis if inbound != nil && inbound.Source.IsValid() { errors.LogInfo(ctx, "client UDP connection from ", inbound.Source) } - if inbound.CanSpliceCopy == 2 { - inbound.CanSpliceCopy = 1 - } var dest *net.Destination