From 593ededd3e3f8d94b5da0ac5b54203bcf37a81e7 Mon Sep 17 00:00:00 2001 From: patterniha <71074308+patterniha@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:31:46 +0200 Subject: [PATCH] Trojan-UoT & UDP-nameserver: Fix forgotten release buffer; UDP dispatcher: Simplified and optimized (#5050) --- app/dns/nameserver_udp.go | 8 +++- app/proxyman/outbound/handler.go | 10 ++++- common/mux/client.go | 17 +++++--- common/mux/server.go | 2 +- proxy/freedom/freedom.go | 4 +- proxy/proxy.go | 3 ++ proxy/shadowsocks/server.go | 6 +-- proxy/socks/server.go | 3 +- proxy/trojan/protocol.go | 2 + transport/internet/udp/dispatcher.go | 60 +++++++++++++++++----------- transport/pipe/impl.go | 13 +++--- 11 files changed, 83 insertions(+), 45 deletions(-) diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 3c25e612..e29f6e24 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error { // HandleResponse handles udp response packet from remote DNS server. func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { - ipRec, err := parseResponse(packet.Payload.Bytes()) + payload := packet.Payload + ipRec, err := parseResponse(payload.Bytes()) + payload.Release() if err != nil { errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp") return @@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot newReq.msg = &newMsg s.addPendingRequest(&newReq) b, _ := dns.PackMessage(newReq.msg) + copyDest := net.UDPDestination(s.address.Address, s.address.Port) + b.UDP = ©Dest s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b) return } @@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai } s.addPendingRequest(udpReq) b, _ := dns.PackMessage(req.msg) + copyDest := net.UDPDestination(s.address.Address, s.address.Port) + b.UDP = ©Dest s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b) } } diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index d0e670c5..ef5eed0e 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) { } out: err := h.proxy.Process(ctx, link, h) + var errC error if err != nil { - if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) { + errC = errors.Cause(err) + if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) { err = nil } } @@ -251,7 +253,11 @@ out: errors.LogInfo(ctx, err.Error()) common.Interrupt(link.Writer) } else { - common.Close(link.Writer) + if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) { + common.Interrupt(link.Writer) + } else { + common.Close(link.Writer) + } } common.Interrupt(link.Reader) } diff --git a/common/mux/client.go b/common/mux/client.go index 6987f762..e94fd3ad 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -2,6 +2,7 @@ package mux import ( "context" + goerrors "errors" "io" "sync" "time" @@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) { ctx := session.ContextWithOutbounds(context.Background(), outbounds) ctx, cancel := context.WithCancel(ctx) - if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { - errors.LogInfoInner(ctx, err, "failed to handler mux client connection") + if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil { + errC := errors.Cause(errP) + if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) { + errors.LogInfoInner(ctx, errP, "failed to handler mux client connection") + } } common.Must(c.Close()) cancel() @@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() { select { case <-m.done.Wait(): m.sessionManager.Close() - common.Close(m.link.Writer) + common.Interrupt(m.link.Writer) common.Interrupt(m.link.Reader) return case <-m.timer.C: @@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error { return nil } -func fetchInput(ctx context.Context, s *Session, output buf.Writer) { +func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) { outbounds := session.OutboundsFromContext(ctx) ob := outbounds[len(outbounds)-1] transferType := protocol.TransferTypeStream @@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) { writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) defer s.Close(false) defer writer.Close() + defer timer.Reset(time.Second * 16) errors.LogInfo(ctx, "dispatching request to ", ob.Target) if err := writeFirstPayload(s.input, writer); err != nil { @@ -308,9 +313,9 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool s.input = link.Reader s.output = link.Writer if _, ok := link.Reader.(*pipe.Reader); ok { - go fetchInput(ctx, s, m.link.Writer) + go fetchInput(ctx, s, m.link.Writer, m.timer) } else { - fetchInput(ctx, s, m.link.Writer) + fetchInput(ctx, s, m.link.Writer, m.timer) } return true } diff --git a/common/mux/server.go b/common/mux/server.go index 0a632e81..12e4a68f 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -318,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) { reader := &buf.BufferedReader{Reader: w.link.Reader} defer w.sessionManager.Close() - defer common.Close(w.link.Writer) defer common.Interrupt(w.link.Reader) + defer common.Interrupt(w.link.Writer) for { select { diff --git a/proxy/freedom/freedom.go b/proxy/freedom/freedom.go index 0e9937e3..f8d64812 100644 --- a/proxy/freedom/freedom.go +++ b/proxy/freedom/freedom.go @@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool { } a := addr.AsAddress() - return a != net.AnyIP + return a != net.AnyIP && a != net.AnyIPv6 } // Process implements proxy.Outbound. @@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { } } } - destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr()) + destAddr := b.UDP.RawNetAddr() if destAddr == nil { b.Release() continue diff --git a/proxy/proxy.go b/proxy/proxy.go index 049d9fbd..edfa63d0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net } } if err != nil { + if errors.Cause(err) == io.EOF { + return nil + } return err } } diff --git a/proxy/shadowsocks/server.go b/proxy/shadowsocks/server.go index ec022084..360ea38c 100644 --- a/proxy/shadowsocks/server.go +++ b/proxy/shadowsocks/server.go @@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { request := protocol.RequestHeaderFromContext(ctx) + payload := packet.Payload if request == nil { + payload.Release() return } - payload := packet.Payload - if payload.UDP != nil { request = &protocol.RequestHeader{ User: request.User, @@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis errors.LogWarningInner(ctx, err, "failed to encode UDP packet") return } - defer data.Release() conn.Write(data.Bytes()) + data.Release() }) defer udpServer.RemoveRay() diff --git a/proxy/socks/server.go b/proxy/socks/server.go index 13d98851..166deaa3 100644 --- a/proxy/socks/server.go +++ b/proxy/socks/server.go @@ -196,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis request := protocol.RequestHeaderFromContext(ctx) if request == nil { + payload.Release() return } @@ -214,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis errors.LogWarningInner(ctx, err, "failed to write UDP response") return } - defer udpMessage.Release() conn.Write(udpMessage.Bytes()) + udpMessage.Release() }) defer udpServer.RemoveRay() diff --git a/proxy/trojan/protocol.go b/proxy/trojan/protocol.go index 96a16638..889ccc5c 100644 --- a/proxy/trojan/protocol.go +++ b/proxy/trojan/protocol.go @@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { target = b.UDP } if _, err := w.writePacket(b.Bytes(), *target); err != nil { + b.Release() buf.ReleaseMulti(mb) return err } + b.Release() } return nil } diff --git a/transport/internet/udp/dispatcher.go b/transport/internet/udp/dispatcher.go index 22db4244..963ce662 100644 --- a/transport/internet/udp/dispatcher.go +++ b/transport/internet/udp/dispatcher.go @@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet) type connEntry struct { link *transport.Link - timer signal.ActivityUpdater + timer *signal.ActivityTimer cancel context.CancelFunc + closed bool +} + +func (c *connEntry) Close() error { + c.timer.SetTimeout(0) + return nil +} + +func (c *connEntry) terminate() { + if c.closed { + panic("terminate called more than once") + } + c.closed = true + c.cancel() + common.Interrupt(c.link.Reader) + common.Interrupt(c.link.Writer) } type Dispatcher struct { @@ -32,6 +48,7 @@ type Dispatcher struct { dispatcher routing.Dispatcher callback ResponseCallback callClose func() error + closed bool } func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { @@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di func (v *Dispatcher) RemoveRay() { v.Lock() defer v.Unlock() - v.removeRay() -} - -func (v *Dispatcher) removeRay() { + v.closed = true if v.conn != nil { - common.Interrupt(v.conn.link.Reader) - common.Close(v.conn.link.Writer) + v.conn.Close() v.conn = nil } } @@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (* v.Lock() defer v.Unlock() + if v.closed { + return nil, errors.New("dispatcher is closed") + } + if v.conn != nil { - return v.conn, nil + if v.conn.closed { + v.conn = nil + } else { + return v.conn, nil + } } errors.LogInfo(ctx, "establishing new connection for ", dest) ctx, cancel := context.WithCancel(ctx) - entry := &connEntry{} - removeRay := func() { - v.Lock() - defer v.Unlock() - // sometimes the entry is already removed by others, don't close again - if entry == v.conn { - cancel() - v.removeRay() - } - } - timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) link, err := v.dispatcher.Dispatch(ctx, dest) if err != nil { + cancel() return nil, errors.New("failed to dispatch request to ", dest).Base(err) } - *entry = connEntry{ + entry := &connEntry{ link: link, - timer: timer, - cancel: removeRay, + cancel: cancel, } + + entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute) v.conn = entry go handleInput(ctx, entry, dest, v.callback, v.callClose) return entry, nil @@ -106,7 +118,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, if outputStream != nil { if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { errors.LogInfoInner(ctx, err, "failed to write first UDP payload") - conn.cancel() + conn.Close() return } } @@ -114,7 +126,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { defer func() { - conn.cancel() + conn.Close() if callClose != nil { callClose() } diff --git a/transport/pipe/impl.go b/transport/pipe/impl.go index 4a30dbbb..e5d67827 100644 --- a/transport/pipe/impl.go +++ b/transport/pipe/impl.go @@ -200,16 +200,19 @@ func (p *pipe) Interrupt() { p.Lock() defer p.Unlock() + if !p.data.IsEmpty() { + buf.ReleaseMulti(p.data) + p.data = nil + if p.state == closed { + p.state = errord + } + } + if p.state == closed || p.state == errord { return } p.state = errord - if !p.data.IsEmpty() { - buf.ReleaseMulti(p.data) - p.data = nil - } - common.Must(p.done.Close()) }