First step of upcoming refactor for Xray-core: Add TimeoutWrapperReader; Use DispatchLink() in Tunnel/Socks/HTTP inbounds

https://github.com/XTLS/Xray-core/pull/5067#issuecomment-3236833240

Fixes https://github.com/XTLS/Xray-core/pull/4952#issuecomment-3229878125 for client's Xray-core
pull/5064/head
RPRX 2025-08-29 12:35:56 +00:00 committed by GitHub
parent 4976085ddb
commit 56a45ad578
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 93 additions and 206 deletions

View File

@ -29,7 +29,7 @@ var errSniffingTimeout = errors.New("timeout on sniffing")
type cachedReader struct { type cachedReader struct {
sync.Mutex sync.Mutex
reader *pipe.Reader reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
cache buf.MultiBuffer cache buf.MultiBuffer
} }
@ -87,7 +87,9 @@ func (r *cachedReader) Interrupt() {
r.cache = buf.ReleaseMulti(r.cache) r.cache = buf.ReleaseMulti(r.cache)
} }
r.Unlock() r.Unlock()
r.reader.Interrupt() if p, ok := r.reader.(*pipe.Reader); ok {
p.Interrupt()
}
} }
// DefaultDispatcher is a default implementation of Dispatcher. // DefaultDispatcher is a default implementation of Dispatcher.
@ -319,7 +321,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
d.routedDispatch(ctx, outbound, destination) d.routedDispatch(ctx, outbound, destination)
} else { } else {
cReader := &cachedReader{ cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader), reader: outbound.Reader.(buf.TimeoutReader),
} }
outbound.Reader = cReader outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network) result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)

View File

@ -24,9 +24,46 @@ var ErrReadTimeout = errors.New("IO timeout")
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout. // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
type TimeoutReader interface { type TimeoutReader interface {
Reader
ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error) ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
} }
type TimeoutWrapperReader struct {
Reader
mb MultiBuffer
err error
done chan struct{}
}
func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.done != nil {
<-r.done
r.done = nil
return r.mb, r.err
}
r.mb = nil
r.err = nil
return r.Reader.ReadMultiBuffer()
}
func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) {
if r.done == nil {
r.done = make(chan struct{})
go func() {
r.mb, r.err = r.Reader.ReadMultiBuffer()
close(r.done)
}()
}
time.Sleep(duration)
select {
case <-r.done:
r.done = nil
return r.mb, r.err
default:
return nil, nil
}
}
// Writer extends io.Writer with MultiBuffer. // Writer extends io.Writer with MultiBuffer.
type Writer interface { type Writer interface {
// WriteMultiBuffer writes a MultiBuffer into underlying writer. // WriteMultiBuffer writes a MultiBuffer into underlying writer.

View File

@ -307,7 +307,11 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
} }
s.input = link.Reader s.input = link.Reader
s.output = link.Writer s.output = link.Writer
go fetchInput(ctx, s, m.link.Writer) if _, ok := link.Reader.(*pipe.Reader); ok {
go fetchInput(ctx, s, m.link.Writer)
} else {
fetchInput(ctx, s, m.link.Writer)
}
return true return true
} }

View File

@ -87,7 +87,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
link: link, link: link,
sessionManager: NewSessionManager(), sessionManager: NewSessionManager(),
} }
go worker.run(ctx) if inbound := session.InboundFromContext(ctx); inbound != nil {
inbound.CanSpliceCopy = 3
}
if _, ok := link.Reader.(*pipe.Reader); ok {
go worker.run(ctx)
} else {
worker.run(ctx)
}
return worker, nil return worker, nil
} }

View File

@ -2,10 +2,8 @@ package dokodemo
import ( import (
"context" "context"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
@ -14,11 +12,10 @@ import (
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
) )
@ -144,39 +141,11 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
}) })
errors.LogInfo(ctx, "received request for ", conn.RemoteAddr()) errors.LogInfo(ctx, "received request for ", conn.RemoteAddr())
plcy := d.policy() var reader buf.Reader
ctx, cancel := context.WithCancel(ctx) if dest.Network == net.Network_TCP {
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) reader = buf.NewReader(conn)
} else {
if inbound != nil { reader = buf.NewPacketReader(conn)
inbound.Timer = timer
}
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return errors.New("failed to dispatch request").Base(err)
}
requestCount := int32(1)
requestDone := func() error {
defer func() {
if atomic.AddInt32(&requestCount, -1) == 0 {
timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
}
}()
var reader buf.Reader
if dest.Network == net.Network_UDP {
reader = buf.NewPacketReader(conn)
} else {
reader = buf.NewReader(conn)
}
if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport request").Base(err)
}
return nil
} }
var writer buf.Writer var writer buf.Writer
@ -208,72 +177,17 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
return err return err
} }
writer = NewPacketWriter(pConn, &dest, mark, back) writer = NewPacketWriter(pConn, &dest, mark, back)
defer func() { defer writer.(*PacketWriter).Close() // close fake UDP conns
runtime.Gosched()
common.Interrupt(link.Reader) // maybe duplicated
runtime.Gosched()
writer.(*PacketWriter).Close() // close fake UDP conns
}()
/*
sockopt := &internet.SocketConfig{
Tproxy: internet.SocketConfig_TProxy,
}
if dest.Address.Family().IsIP() {
sockopt.BindAddress = dest.Address.IP()
sockopt.BindPort = uint32(dest.Port)
}
if d.sockopt != nil {
sockopt.Mark = d.sockopt.Mark
}
tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
if err != nil {
return err
}
defer tConn.Close()
writer = &buf.SequentialWriter{Writer: tConn}
tReader := buf.NewPacketReader(tConn)
requestCount++
tproxyRequest = func() error {
defer func() {
if atomic.AddInt32(&requestCount, -1) == 0 {
timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
}
}()
if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport request (TPROXY conn)").Base(err)
}
return nil
}
*/
} }
} }
responseDone := func() error { if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) Reader: &buf.TimeoutWrapperReader{Reader: reader},
Writer: writer},
if network == net.Network_UDP && destinationOverridden { ); err != nil {
buf.Copy(link.Reader, writer) // respect upload's timeout return errors.New("failed to dispatch request").Base(err)
return nil
}
if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport response").Base(err)
}
return nil
} }
return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process()
if err := task.Run(ctx,
task.OnSuccess(func() error { return task.Run(ctx, requestDone) }, task.Close(link.Writer)),
responseDone); err != nil {
runtime.Gosched()
common.Interrupt(link.Writer)
runtime.Gosched()
common.Interrupt(link.Reader)
return errors.New("connection ends").Base(err)
}
return nil
} }
func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer { func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {

View File

@ -18,12 +18,12 @@ import (
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
http_proto "github.com/xtls/xray-core/common/protocol/http" http_proto "github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task" "github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
) )
@ -173,64 +173,31 @@ Start:
return err return err
} }
func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) _, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
if err != nil { if err != nil {
return errors.New("failed to write back OK response").Base(err) return errors.New("failed to write back OK response").Base(err)
} }
plcy := s.policy() reader := buf.NewReader(conn)
ctx, cancel := context.WithCancel(ctx) if buffer.Buffered() > 0 {
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle) payload, err := buf.ReadFrom(io.LimitReader(buffer, int64(buffer.Buffered())))
if inbound != nil {
inbound.Timer = timer
}
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return err
}
if reader.Buffered() > 0 {
payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
if err != nil { if err != nil {
return err return err
} }
if err := link.Writer.WriteMultiBuffer(payload); err != nil { reader = &buf.BufferedReader{Reader: reader, Buffer: payload}
return err buffer = nil
}
reader = nil
} }
requestDone := func() error { if inbound.CanSpliceCopy == 2 {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) inbound.CanSpliceCopy = 1
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
} }
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
responseDone := func() error { Reader: &buf.TimeoutWrapperReader{Reader: reader},
if inbound.CanSpliceCopy == 2 { Writer: buf.NewWriter(conn)},
inbound.CanSpliceCopy = 1 ); err != nil {
} return errors.New("failed to dispatch request").Base(err)
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
v2writer := buf.NewWriter(conn)
if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
return err
}
return nil
} }
closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return errors.New("connection ends").Base(err)
}
return nil return nil
} }

View File

@ -14,13 +14,12 @@ import (
"github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/protocol"
udp_proto "github.com/xtls/xray-core/common/protocol/udp" udp_proto "github.com/xtls/xray-core/common/protocol/udp"
"github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core" "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy" "github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy" "github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/http" "github.com/xtls/xray-core/proxy/http"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/udp" "github.com/xtls/xray-core/transport/internet/udp"
) )
@ -158,8 +157,16 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
Reason: "", Reason: "",
}) })
} }
if inbound.CanSpliceCopy == 2 {
return s.transport(ctx, reader, conn, dest, dispatcher, inbound) inbound.CanSpliceCopy = 1
}
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
Reader: &buf.TimeoutWrapperReader{Reader: reader},
Writer: buf.NewWriter(conn)},
); err != nil {
return errors.New("failed to dispatch request").Base(err)
}
return nil
} }
if request.Command == protocol.RequestCommandUDP { if request.Command == protocol.RequestCommandUDP {
@ -178,54 +185,6 @@ func (*Server) handleUDP(c io.Reader) error {
return common.Error2(io.Copy(buf.DiscardBytes, c)) return common.Error2(io.Copy(buf.DiscardBytes, c))
} }
func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
if inbound != nil {
inbound.Timer = timer
}
plcy := s.policy()
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return err
}
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP request").Base(err)
}
return nil
}
responseDone := func() error {
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
v2writer := buf.NewWriter(writer)
if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP response").Base(err)
}
return nil
}
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return errors.New("connection ends").Base(err)
}
return nil
}
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) { if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) {
errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String()) errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String())
@ -265,9 +224,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
if inbound != nil && inbound.Source.IsValid() { if inbound != nil && inbound.Source.IsValid() {
errors.LogInfo(ctx, "client UDP connection from ", inbound.Source) errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
} }
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
var dest *net.Destination var dest *net.Destination