mirror of https://github.com/XTLS/Xray-core
				
				
				
			Fix buffer.UDP destination override (#2356)
							parent
							
								
									e013dce1df
								
							
						
					
					
						commit
						b8bd243df5
					
				| 
						 | 
				
			
			@ -4,7 +4,6 @@ package dispatcher
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
| 
						 | 
				
			
			@ -135,77 +134,10 @@ func (*DefaultDispatcher) Start() error {
 | 
			
		|||
// Close implements common.Closable.
 | 
			
		||||
func (*DefaultDispatcher) Close() error { return nil }
 | 
			
		||||
 | 
			
		||||
func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sniffing session.SniffingRequest) (*transport.Link, *transport.Link) {
 | 
			
		||||
	downOpt := pipe.OptionsFromContext(ctx)
 | 
			
		||||
	upOpt := downOpt
 | 
			
		||||
 | 
			
		||||
	if network == net.Network_UDP {
 | 
			
		||||
		var ip2domain *sync.Map // net.IP.String() => domain, this map is used by server side when client turn on fakedns
 | 
			
		||||
		// Client will send domain address in the buffer.UDP.Address, server record all possible target IP addrs.
 | 
			
		||||
		// When target replies, server will restore the domain and send back to client.
 | 
			
		||||
		// Note: this map is not global but per connection context
 | 
			
		||||
		upOpt = append(upOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
 | 
			
		||||
			for i, buffer := range mb {
 | 
			
		||||
				if buffer.UDP == nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				addr := buffer.UDP.Address
 | 
			
		||||
				if addr.Family().IsIP() {
 | 
			
		||||
					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(addr) && sniffing.Enabled {
 | 
			
		||||
						domain := fkr0.GetDomainFromFakeDNS(addr)
 | 
			
		||||
						if len(domain) > 0 {
 | 
			
		||||
							buffer.UDP.Address = net.DomainAddress(domain)
 | 
			
		||||
							newError("[fakedns client] override with domain: ", domain, " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
						} else {
 | 
			
		||||
							newError("[fakedns client] failed to find domain! :", addr.String(), " for xUDP buffer at ", i).AtWarning().WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					if ip2domain == nil {
 | 
			
		||||
						ip2domain = new(sync.Map)
 | 
			
		||||
						newError("[fakedns client] create a new map").WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
					}
 | 
			
		||||
					domain := addr.Domain()
 | 
			
		||||
					ips, err := d.dns.LookupIP(domain, dns.IPOption{true, true, false})
 | 
			
		||||
					if err == nil {
 | 
			
		||||
						for _, ip := range ips {
 | 
			
		||||
							ip2domain.Store(ip.String(), domain)
 | 
			
		||||
						}
 | 
			
		||||
						newError("[fakedns client] candidate ip: "+fmt.Sprintf("%v", ips), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
					} else {
 | 
			
		||||
						newError("[fakedns client] failed to look up IP for ", domain, " for xUDP buffer at ", i).Base(err).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return mb
 | 
			
		||||
		}))
 | 
			
		||||
		downOpt = append(downOpt, pipe.OnTransmission(func(mb buf.MultiBuffer) buf.MultiBuffer {
 | 
			
		||||
			for i, buffer := range mb {
 | 
			
		||||
				if buffer.UDP == nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				addr := buffer.UDP.Address
 | 
			
		||||
				if addr.Family().IsIP() {
 | 
			
		||||
					if ip2domain == nil {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					if domain, found := ip2domain.Load(addr.IP().String()); found {
 | 
			
		||||
						buffer.UDP.Address = net.DomainAddress(domain.(string))
 | 
			
		||||
						newError("[fakedns client] restore domain: ", domain.(string), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok {
 | 
			
		||||
						fakeIp := fkr0.GetFakeIPForDomain(addr.Domain())
 | 
			
		||||
						buffer.UDP.Address = fakeIp[0]
 | 
			
		||||
						newError("[fakedns client] restore FakeIP: ", buffer.UDP, fmt.Sprintf("%v", fakeIp), " for xUDP buffer at ", i).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return mb
 | 
			
		||||
		}))
 | 
			
		||||
	}
 | 
			
		||||
	uplinkReader, uplinkWriter := pipe.New(upOpt...)
 | 
			
		||||
	downlinkReader, downlinkWriter := pipe.New(downOpt...)
 | 
			
		||||
func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
 | 
			
		||||
	opt := pipe.OptionsFromContext(ctx)
 | 
			
		||||
	uplinkReader, uplinkWriter := pipe.New(opt...)
 | 
			
		||||
	downlinkReader, downlinkWriter := pipe.New(opt...)
 | 
			
		||||
 | 
			
		||||
	inboundLink := &transport.Link{
 | 
			
		||||
		Reader: downlinkReader,
 | 
			
		||||
| 
						 | 
				
			
			@ -263,7 +195,7 @@ func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResu
 | 
			
		|||
		protocolString = resComp.ProtocolForDomainResult()
 | 
			
		||||
	}
 | 
			
		||||
	for _, p := range request.OverrideDestinationForProtocol {
 | 
			
		||||
		if strings.HasPrefix(protocolString, p) {
 | 
			
		||||
		if strings.HasPrefix(protocolString, p) || strings.HasPrefix(protocolString, p) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
		if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
 | 
			
		||||
| 
						 | 
				
			
			@ -287,7 +219,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 | 
			
		|||
		panic("Dispatcher: Invalid destination.")
 | 
			
		||||
	}
 | 
			
		||||
	ob := &session.Outbound{
 | 
			
		||||
		Target: destination,
 | 
			
		||||
		OriginalTarget: destination,
 | 
			
		||||
		Target:         destination,
 | 
			
		||||
	}
 | 
			
		||||
	ctx = session.ContextWithOutbound(ctx, ob)
 | 
			
		||||
	content := session.ContentFromContext(ctx)
 | 
			
		||||
| 
						 | 
				
			
			@ -295,9 +228,8 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 | 
			
		|||
		content = new(session.Content)
 | 
			
		||||
		ctx = session.ContextWithContent(ctx, content)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sniffingRequest := content.SniffingRequest
 | 
			
		||||
	inbound, outbound := d.getLink(ctx, destination.Network, sniffingRequest)
 | 
			
		||||
	inbound, outbound := d.getLink(ctx)
 | 
			
		||||
	if !sniffingRequest.Enabled {
 | 
			
		||||
		go d.routedDispatch(ctx, outbound, destination)
 | 
			
		||||
	} else {
 | 
			
		||||
| 
						 | 
				
			
			@ -314,7 +246,15 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 | 
			
		|||
				domain := result.Domain()
 | 
			
		||||
				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
				destination.Address = net.ParseAddress(domain)
 | 
			
		||||
				if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
 | 
			
		||||
				protocol := result.Protocol()
 | 
			
		||||
				if resComp, ok := result.(SnifferResultComposite); ok {
 | 
			
		||||
					protocol = resComp.ProtocolForDomainResult()
 | 
			
		||||
				}
 | 
			
		||||
				isFakeIP := false
 | 
			
		||||
				if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
 | 
			
		||||
					isFakeIP = true
 | 
			
		||||
				}
 | 
			
		||||
				if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
 | 
			
		||||
					ob.RouteTarget = destination
 | 
			
		||||
				} else {
 | 
			
		||||
					ob.Target = destination
 | 
			
		||||
| 
						 | 
				
			
			@ -332,7 +272,8 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 | 
			
		|||
		return newError("Dispatcher: Invalid destination.")
 | 
			
		||||
	}
 | 
			
		||||
	ob := &session.Outbound{
 | 
			
		||||
		Target: destination,
 | 
			
		||||
		OriginalTarget: destination,
 | 
			
		||||
		Target:         destination,
 | 
			
		||||
	}
 | 
			
		||||
	ctx = session.ContextWithOutbound(ctx, ob)
 | 
			
		||||
	content := session.ContentFromContext(ctx)
 | 
			
		||||
| 
						 | 
				
			
			@ -356,7 +297,15 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
 | 
			
		|||
			domain := result.Domain()
 | 
			
		||||
			newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 | 
			
		||||
			destination.Address = net.ParseAddress(domain)
 | 
			
		||||
			if sniffingRequest.RouteOnly && result.Protocol() != "fakedns" {
 | 
			
		||||
			protocol := result.Protocol()
 | 
			
		||||
			if resComp, ok := result.(SnifferResultComposite); ok {
 | 
			
		||||
				protocol = resComp.ProtocolForDomainResult()
 | 
			
		||||
			}
 | 
			
		||||
			isFakeIP := false
 | 
			
		||||
			if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && ob.Target.Address.Family().IsIP() && fkr0.IsIPInIPPool(ob.Target.Address) {
 | 
			
		||||
				isFakeIP = true
 | 
			
		||||
			}
 | 
			
		||||
			if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
 | 
			
		||||
				ob.RouteTarget = destination
 | 
			
		||||
			} else {
 | 
			
		||||
				ob.Target = destination
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -8,6 +8,7 @@ import (
 | 
			
		|||
 | 
			
		||||
	"github.com/xtls/xray-core/app/proxyman"
 | 
			
		||||
	"github.com/xtls/xray-core/common"
 | 
			
		||||
	"github.com/xtls/xray-core/common/buf"
 | 
			
		||||
	"github.com/xtls/xray-core/common/mux"
 | 
			
		||||
	"github.com/xtls/xray-core/common/net"
 | 
			
		||||
	"github.com/xtls/xray-core/common/net/cnc"
 | 
			
		||||
| 
						 | 
				
			
			@ -166,6 +167,11 @@ func (h *Handler) Tag() string {
 | 
			
		|||
 | 
			
		||||
// Dispatch implements proxy.Outbound.Dispatch.
 | 
			
		||||
func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
 | 
			
		||||
	outbound := session.OutboundFromContext(ctx)
 | 
			
		||||
	if outbound.Target.Network == net.Network_UDP && outbound.OriginalTarget.Address != nil && outbound.OriginalTarget.Address != outbound.Target.Address {
 | 
			
		||||
		link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
 | 
			
		||||
		link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: outbound.Target.Address, OriginalDest: outbound.OriginalTarget.Address}
 | 
			
		||||
	}
 | 
			
		||||
	if h.mux != nil {
 | 
			
		||||
		test := func(err error) {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
| 
						 | 
				
			
			@ -175,7 +181,6 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
 | 
			
		|||
				common.Interrupt(link.Writer)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		outbound := session.OutboundFromContext(ctx)
 | 
			
		||||
		if outbound.Target.Network == net.Network_UDP && outbound.Target.Port == 443 {
 | 
			
		||||
			switch h.udp443 {
 | 
			
		||||
			case "reject":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,38 @@
 | 
			
		|||
package buf
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/xtls/xray-core/common/net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type EndpointOverrideReader struct {
 | 
			
		||||
	Reader
 | 
			
		||||
	Dest         net.Address
 | 
			
		||||
	OriginalDest net.Address
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *EndpointOverrideReader) ReadMultiBuffer() (MultiBuffer, error) {
 | 
			
		||||
	mb, err := r.Reader.ReadMultiBuffer()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		for _, b := range mb {
 | 
			
		||||
			if b.UDP != nil && b.UDP.Address == r.OriginalDest {
 | 
			
		||||
				b.UDP.Address = r.Dest
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return mb, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type EndpointOverrideWriter struct {
 | 
			
		||||
	Writer
 | 
			
		||||
	Dest         net.Address
 | 
			
		||||
	OriginalDest net.Address
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *EndpointOverrideWriter) WriteMultiBuffer(mb MultiBuffer) error {
 | 
			
		||||
	for _, b := range mb {
 | 
			
		||||
		if b.UDP != nil && b.UDP.Address == w.Dest {
 | 
			
		||||
			b.UDP.Address = w.OriginalDest
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return w.Writer.WriteMultiBuffer(mb)
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -55,8 +55,9 @@ type Inbound struct {
 | 
			
		|||
// Outbound is the metadata of an outbound connection.
 | 
			
		||||
type Outbound struct {
 | 
			
		||||
	// Target address of the outbound connection.
 | 
			
		||||
	Target      net.Destination
 | 
			
		||||
	RouteTarget net.Destination
 | 
			
		||||
	OriginalTarget net.Destination
 | 
			
		||||
	Target         net.Destination
 | 
			
		||||
	RouteTarget    net.Destination
 | 
			
		||||
	// Gateway address
 | 
			
		||||
	Gateway net.Address
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,7 +24,6 @@ const (
 | 
			
		|||
type pipeOption struct {
 | 
			
		||||
	limit           int32 // maximum buffer size in bytes
 | 
			
		||||
	discardOverflow bool
 | 
			
		||||
	onTransmission  func(buffer buf.MultiBuffer) buf.MultiBuffer
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (o *pipeOption) isFull(curSize int32) bool {
 | 
			
		||||
| 
						 | 
				
			
			@ -141,10 +140,6 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
 | 
			
		|||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if p.option.onTransmission != nil {
 | 
			
		||||
		mb = p.option.onTransmission(mb)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		err := p.writeMultiBufferInternal(mb)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,7 +3,6 @@ package pipe
 | 
			
		|||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/xtls/xray-core/common/buf"
 | 
			
		||||
	"github.com/xtls/xray-core/common/signal"
 | 
			
		||||
	"github.com/xtls/xray-core/common/signal/done"
 | 
			
		||||
	"github.com/xtls/xray-core/features/policy"
 | 
			
		||||
| 
						 | 
				
			
			@ -26,12 +25,6 @@ func WithSizeLimit(limit int32) Option {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OnTransmission(hook func(mb buf.MultiBuffer) buf.MultiBuffer) Option {
 | 
			
		||||
	return func(option *pipeOption) {
 | 
			
		||||
		option.onTransmission = hook
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DiscardOverflow returns an Option for Pipe to discard writes if full.
 | 
			
		||||
func DiscardOverflow() Option {
 | 
			
		||||
	return func(opt *pipeOption) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue