Trojan-UoT & UDP-nameserver: Fix forgotten release buffer; UDP dispatcher: Simplified and optimized (#5050)

pull/5070/head
patterniha 2025-08-29 16:31:46 +02:00 committed by GitHub
parent 82ea7a3cc5
commit 593ededd3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 83 additions and 45 deletions

View File

@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error {
// HandleResponse handles udp response packet from remote DNS server. // HandleResponse handles udp response packet from remote DNS server.
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
ipRec, err := parseResponse(packet.Payload.Bytes()) payload := packet.Payload
ipRec, err := parseResponse(payload.Bytes())
payload.Release()
if err != nil { if err != nil {
errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp") errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
return return
@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
newReq.msg = &newMsg newReq.msg = &newMsg
s.addPendingRequest(&newReq) s.addPendingRequest(&newReq)
b, _ := dns.PackMessage(newReq.msg) b, _ := dns.PackMessage(newReq.msg)
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
b.UDP = &copyDest
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b) s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
return return
} }
@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
} }
s.addPendingRequest(udpReq) s.addPendingRequest(udpReq)
b, _ := dns.PackMessage(req.msg) b, _ := dns.PackMessage(req.msg)
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
b.UDP = &copyDest
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b) s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
} }
} }

View File

@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
} }
out: out:
err := h.proxy.Process(ctx, link, h) err := h.proxy.Process(ctx, link, h)
var errC error
if err != nil { if err != nil {
if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) { errC = errors.Cause(err)
if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) {
err = nil err = nil
} }
} }
@ -250,9 +252,13 @@ out:
session.SubmitOutboundErrorToOriginator(ctx, err) session.SubmitOutboundErrorToOriginator(ctx, err)
errors.LogInfo(ctx, err.Error()) errors.LogInfo(ctx, err.Error())
common.Interrupt(link.Writer) common.Interrupt(link.Writer)
} else {
if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) {
common.Interrupt(link.Writer)
} else { } else {
common.Close(link.Writer) common.Close(link.Writer)
} }
}
common.Interrupt(link.Reader) common.Interrupt(link.Reader)
} }

View File

@ -2,6 +2,7 @@ package mux
import ( import (
"context" "context"
goerrors "errors"
"io" "io"
"sync" "sync"
"time" "time"
@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
ctx := session.ContextWithOutbounds(context.Background(), outbounds) ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil { if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil {
errors.LogInfoInner(ctx, err, "failed to handler mux client connection") errC := errors.Cause(errP)
if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) {
errors.LogInfoInner(ctx, errP, "failed to handler mux client connection")
}
} }
common.Must(c.Close()) common.Must(c.Close())
cancel() cancel()
@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() {
select { select {
case <-m.done.Wait(): case <-m.done.Wait():
m.sessionManager.Close() m.sessionManager.Close()
common.Close(m.link.Writer) common.Interrupt(m.link.Writer)
common.Interrupt(m.link.Reader) common.Interrupt(m.link.Reader)
return return
case <-m.timer.C: case <-m.timer.C:
@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
return nil return nil
} }
func fetchInput(ctx context.Context, s *Session, output buf.Writer) { func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
outbounds := session.OutboundsFromContext(ctx) outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1] ob := outbounds[len(outbounds)-1]
transferType := protocol.TransferTypeStream transferType := protocol.TransferTypeStream
@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false) defer s.Close(false)
defer writer.Close() defer writer.Close()
defer timer.Reset(time.Second * 16)
errors.LogInfo(ctx, "dispatching request to ", ob.Target) errors.LogInfo(ctx, "dispatching request to ", ob.Target)
if err := writeFirstPayload(s.input, writer); err != nil { if err := writeFirstPayload(s.input, writer); err != nil {
@ -308,9 +313,9 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
s.input = link.Reader s.input = link.Reader
s.output = link.Writer s.output = link.Writer
if _, ok := link.Reader.(*pipe.Reader); ok { if _, ok := link.Reader.(*pipe.Reader); ok {
go fetchInput(ctx, s, m.link.Writer) go fetchInput(ctx, s, m.link.Writer, m.timer)
} else { } else {
fetchInput(ctx, s, m.link.Writer) fetchInput(ctx, s, m.link.Writer, m.timer)
} }
return true return true
} }

View File

@ -318,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) {
reader := &buf.BufferedReader{Reader: w.link.Reader} reader := &buf.BufferedReader{Reader: w.link.Reader}
defer w.sessionManager.Close() defer w.sessionManager.Close()
defer common.Close(w.link.Writer)
defer common.Interrupt(w.link.Reader) defer common.Interrupt(w.link.Reader)
defer common.Interrupt(w.link.Writer)
for { for {
select { select {

View File

@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool {
} }
a := addr.AsAddress() a := addr.AsAddress()
return a != net.AnyIP return a != net.AnyIP && a != net.AnyIPv6
} }
// Process implements proxy.Outbound. // Process implements proxy.Outbound.
@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
} }
} }
} }
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr()) destAddr := b.UDP.RawNetAddr()
if destAddr == nil { if destAddr == nil {
b.Release() b.Release()
continue continue

View File

@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
} }
} }
if err != nil { if err != nil {
if errors.Cause(err) == io.EOF {
return nil
}
return err return err
} }
} }

View File

@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
request := protocol.RequestHeaderFromContext(ctx) request := protocol.RequestHeaderFromContext(ctx)
payload := packet.Payload
if request == nil { if request == nil {
payload.Release()
return return
} }
payload := packet.Payload
if payload.UDP != nil { if payload.UDP != nil {
request = &protocol.RequestHeader{ request = &protocol.RequestHeader{
User: request.User, User: request.User,
@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
errors.LogWarningInner(ctx, err, "failed to encode UDP packet") errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
return return
} }
defer data.Release()
conn.Write(data.Bytes()) conn.Write(data.Bytes())
data.Release()
}) })
defer udpServer.RemoveRay() defer udpServer.RemoveRay()

View File

@ -196,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
request := protocol.RequestHeaderFromContext(ctx) request := protocol.RequestHeaderFromContext(ctx)
if request == nil { if request == nil {
payload.Release()
return return
} }
@ -214,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
errors.LogWarningInner(ctx, err, "failed to write UDP response") errors.LogWarningInner(ctx, err, "failed to write UDP response")
return return
} }
defer udpMessage.Release()
conn.Write(udpMessage.Bytes()) conn.Write(udpMessage.Bytes())
udpMessage.Release()
}) })
defer udpServer.RemoveRay() defer udpServer.RemoveRay()

View File

@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
target = b.UDP target = b.UDP
} }
if _, err := w.writePacket(b.Bytes(), *target); err != nil { if _, err := w.writePacket(b.Bytes(), *target); err != nil {
b.Release()
buf.ReleaseMulti(mb) buf.ReleaseMulti(mb)
return err return err
} }
b.Release()
} }
return nil return nil
} }

View File

@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet)
type connEntry struct { type connEntry struct {
link *transport.Link link *transport.Link
timer signal.ActivityUpdater timer *signal.ActivityTimer
cancel context.CancelFunc cancel context.CancelFunc
closed bool
}
func (c *connEntry) Close() error {
c.timer.SetTimeout(0)
return nil
}
func (c *connEntry) terminate() {
if c.closed {
panic("terminate called more than once")
}
c.closed = true
c.cancel()
common.Interrupt(c.link.Reader)
common.Interrupt(c.link.Writer)
} }
type Dispatcher struct { type Dispatcher struct {
@ -32,6 +48,7 @@ type Dispatcher struct {
dispatcher routing.Dispatcher dispatcher routing.Dispatcher
callback ResponseCallback callback ResponseCallback
callClose func() error callClose func() error
closed bool
} }
func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di
func (v *Dispatcher) RemoveRay() { func (v *Dispatcher) RemoveRay() {
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
v.removeRay() v.closed = true
}
func (v *Dispatcher) removeRay() {
if v.conn != nil { if v.conn != nil {
common.Interrupt(v.conn.link.Reader) v.conn.Close()
common.Close(v.conn.link.Writer)
v.conn = nil v.conn = nil
} }
} }
@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
v.Lock() v.Lock()
defer v.Unlock() defer v.Unlock()
if v.closed {
return nil, errors.New("dispatcher is closed")
}
if v.conn != nil { if v.conn != nil {
if v.conn.closed {
v.conn = nil
} else {
return v.conn, nil return v.conn, nil
} }
}
errors.LogInfo(ctx, "establishing new connection for ", dest) errors.LogInfo(ctx, "establishing new connection for ", dest)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
entry := &connEntry{}
removeRay := func() {
v.Lock()
defer v.Unlock()
// sometimes the entry is already removed by others, don't close again
if entry == v.conn {
cancel()
v.removeRay()
}
}
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
link, err := v.dispatcher.Dispatch(ctx, dest) link, err := v.dispatcher.Dispatch(ctx, dest)
if err != nil { if err != nil {
cancel()
return nil, errors.New("failed to dispatch request to ", dest).Base(err) return nil, errors.New("failed to dispatch request to ", dest).Base(err)
} }
*entry = connEntry{ entry := &connEntry{
link: link, link: link,
timer: timer, cancel: cancel,
cancel: removeRay,
} }
entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute)
v.conn = entry v.conn = entry
go handleInput(ctx, entry, dest, v.callback, v.callClose) go handleInput(ctx, entry, dest, v.callback, v.callClose)
return entry, nil return entry, nil
@ -106,7 +118,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
if outputStream != nil { if outputStream != nil {
if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
errors.LogInfoInner(ctx, err, "failed to write first UDP payload") errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
conn.cancel() conn.Close()
return return
} }
} }
@ -114,7 +126,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
defer func() { defer func() {
conn.cancel() conn.Close()
if callClose != nil { if callClose != nil {
callClose() callClose()
} }

View File

@ -200,16 +200,19 @@ func (p *pipe) Interrupt() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if !p.data.IsEmpty() {
buf.ReleaseMulti(p.data)
p.data = nil
if p.state == closed {
p.state = errord
}
}
if p.state == closed || p.state == errord { if p.state == closed || p.state == errord {
return return
} }
p.state = errord p.state = errord
if !p.data.IsEmpty() {
buf.ReleaseMulti(p.data)
p.data = nil
}
common.Must(p.done.Close()) common.Must(p.done.Close())
} }