Trojan-UoT & UDP-nameserver: Fix forgotten release buffer; UDP dispatcher: Simplified and optimized (#5050)

pull/5070/head
patterniha 2025-08-29 16:31:46 +02:00 committed by GitHub
parent 82ea7a3cc5
commit 593ededd3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 83 additions and 45 deletions

View File

@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error {
// HandleResponse handles udp response packet from remote DNS server.
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
ipRec, err := parseResponse(packet.Payload.Bytes())
payload := packet.Payload
ipRec, err := parseResponse(payload.Bytes())
payload.Release()
if err != nil {
errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
return
@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
newReq.msg = &newMsg
s.addPendingRequest(&newReq)
b, _ := dns.PackMessage(newReq.msg)
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
b.UDP = &copyDest
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
return
}
@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
}
s.addPendingRequest(udpReq)
b, _ := dns.PackMessage(req.msg)
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
b.UDP = &copyDest
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
}
}

View File

@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
}
out:
err := h.proxy.Process(ctx, link, h)
var errC error
if err != nil {
if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) {
errC = errors.Cause(err)
if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) {
err = nil
}
}
@ -251,7 +253,11 @@ out:
errors.LogInfo(ctx, err.Error())
common.Interrupt(link.Writer)
} else {
common.Close(link.Writer)
if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) {
common.Interrupt(link.Writer)
} else {
common.Close(link.Writer)
}
}
common.Interrupt(link.Reader)
}

View File

@ -2,6 +2,7 @@ package mux
import (
"context"
goerrors "errors"
"io"
"sync"
"time"
@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
errors.LogInfoInner(ctx, err, "failed to handler mux client connection")
if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil {
errC := errors.Cause(errP)
if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) {
errors.LogInfoInner(ctx, errP, "failed to handler mux client connection")
}
}
common.Must(c.Close())
cancel()
@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer)
common.Interrupt(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-m.timer.C:
@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1]
transferType := protocol.TransferTypeStream
@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()
defer timer.Reset(time.Second * 16)
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
if err := writeFirstPayload(s.input, writer); err != nil {
@ -308,9 +313,9 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
s.input = link.Reader
s.output = link.Writer
if _, ok := link.Reader.(*pipe.Reader); ok {
go fetchInput(ctx, s, m.link.Writer)
go fetchInput(ctx, s, m.link.Writer, m.timer)
} else {
fetchInput(ctx, s, m.link.Writer)
fetchInput(ctx, s, m.link.Writer, m.timer)
}
return true
}

View File

@ -318,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) {
reader := &buf.BufferedReader{Reader: w.link.Reader}
defer w.sessionManager.Close()
defer common.Close(w.link.Writer)
defer common.Interrupt(w.link.Reader)
defer common.Interrupt(w.link.Writer)
for {
select {

View File

@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool {
}
a := addr.AsAddress()
return a != net.AnyIP
return a != net.AnyIP && a != net.AnyIPv6
}
// Process implements proxy.Outbound.
@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
}
}
}
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
destAddr := b.UDP.RawNetAddr()
if destAddr == nil {
b.Release()
continue

View File

@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
}
}
if err != nil {
if errors.Cause(err) == io.EOF {
return nil
}
return err
}
}

View File

@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
request := protocol.RequestHeaderFromContext(ctx)
payload := packet.Payload
if request == nil {
payload.Release()
return
}
payload := packet.Payload
if payload.UDP != nil {
request = &protocol.RequestHeader{
User: request.User,
@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
return
}
defer data.Release()
conn.Write(data.Bytes())
data.Release()
})
defer udpServer.RemoveRay()

View File

@ -196,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
request := protocol.RequestHeaderFromContext(ctx)
if request == nil {
payload.Release()
return
}
@ -214,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
errors.LogWarningInner(ctx, err, "failed to write UDP response")
return
}
defer udpMessage.Release()
conn.Write(udpMessage.Bytes())
udpMessage.Release()
})
defer udpServer.RemoveRay()

View File

@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
target = b.UDP
}
if _, err := w.writePacket(b.Bytes(), *target); err != nil {
b.Release()
buf.ReleaseMulti(mb)
return err
}
b.Release()
}
return nil
}

View File

@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet)
type connEntry struct {
link *transport.Link
timer signal.ActivityUpdater
timer *signal.ActivityTimer
cancel context.CancelFunc
closed bool
}
func (c *connEntry) Close() error {
c.timer.SetTimeout(0)
return nil
}
func (c *connEntry) terminate() {
if c.closed {
panic("terminate called more than once")
}
c.closed = true
c.cancel()
common.Interrupt(c.link.Reader)
common.Interrupt(c.link.Writer)
}
type Dispatcher struct {
@ -32,6 +48,7 @@ type Dispatcher struct {
dispatcher routing.Dispatcher
callback ResponseCallback
callClose func() error
closed bool
}
func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di
func (v *Dispatcher) RemoveRay() {
v.Lock()
defer v.Unlock()
v.removeRay()
}
func (v *Dispatcher) removeRay() {
v.closed = true
if v.conn != nil {
common.Interrupt(v.conn.link.Reader)
common.Close(v.conn.link.Writer)
v.conn.Close()
v.conn = nil
}
}
@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
v.Lock()
defer v.Unlock()
if v.closed {
return nil, errors.New("dispatcher is closed")
}
if v.conn != nil {
return v.conn, nil
if v.conn.closed {
v.conn = nil
} else {
return v.conn, nil
}
}
errors.LogInfo(ctx, "establishing new connection for ", dest)
ctx, cancel := context.WithCancel(ctx)
entry := &connEntry{}
removeRay := func() {
v.Lock()
defer v.Unlock()
// sometimes the entry is already removed by others, don't close again
if entry == v.conn {
cancel()
v.removeRay()
}
}
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
link, err := v.dispatcher.Dispatch(ctx, dest)
if err != nil {
cancel()
return nil, errors.New("failed to dispatch request to ", dest).Base(err)
}
*entry = connEntry{
entry := &connEntry{
link: link,
timer: timer,
cancel: removeRay,
cancel: cancel,
}
entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute)
v.conn = entry
go handleInput(ctx, entry, dest, v.callback, v.callClose)
return entry, nil
@ -106,7 +118,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
if outputStream != nil {
if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
conn.cancel()
conn.Close()
return
}
}
@ -114,7 +126,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
defer func() {
conn.cancel()
conn.Close()
if callClose != nil {
callClose()
}

View File

@ -200,16 +200,19 @@ func (p *pipe) Interrupt() {
p.Lock()
defer p.Unlock()
if !p.data.IsEmpty() {
buf.ReleaseMulti(p.data)
p.data = nil
if p.state == closed {
p.state = errord
}
}
if p.state == closed || p.state == errord {
return
}
p.state = errord
if !p.data.IsEmpty() {
buf.ReleaseMulti(p.data)
p.data = nil
}
common.Must(p.done.Close())
}