mirror of https://github.com/XTLS/Xray-core
				
				
				
			DNS outbound: Fix some issues (#5081)
							parent
							
								
									8b579bf3ec
								
							
						
					
					
						commit
						197b319f9a
					
				| 
						 | 
				
			
			@ -4,6 +4,7 @@ import (
 | 
			
		|||
	"context"
 | 
			
		||||
	go_errors "errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -168,11 +169,15 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithCancel(ctx)
 | 
			
		||||
	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
 | 
			
		||||
	terminate := func() {
 | 
			
		||||
		cancel()
 | 
			
		||||
		conn.Close()
 | 
			
		||||
	}
 | 
			
		||||
	timer := signal.CancelAfterInactivity(ctx, terminate, h.timeout)
 | 
			
		||||
	defer timer.SetTimeout(0)
 | 
			
		||||
 | 
			
		||||
	request := func() error {
 | 
			
		||||
		defer conn.Close()
 | 
			
		||||
 | 
			
		||||
		defer timer.SetTimeout(0)
 | 
			
		||||
		for {
 | 
			
		||||
			b, err := reader.ReadMessage()
 | 
			
		||||
			if err == io.EOF {
 | 
			
		||||
| 
						 | 
				
			
			@ -190,24 +195,33 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 | 
			
		|||
				if len(h.blockTypes) > 0 {
 | 
			
		||||
					for _, blocktype := range h.blockTypes {
 | 
			
		||||
						if blocktype == int32(qType) {
 | 
			
		||||
							if h.nonIPQuery == "reject" {
 | 
			
		||||
								go h.rejectNonIPQuery(id, qType, domain, writer)
 | 
			
		||||
							}
 | 
			
		||||
							b.Release()
 | 
			
		||||
							errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
 | 
			
		||||
							if h.nonIPQuery == "reject" {
 | 
			
		||||
								err := h.rejectNonIPQuery(id, qType, domain, writer)
 | 
			
		||||
								if err != nil {
 | 
			
		||||
									return err
 | 
			
		||||
								}
 | 
			
		||||
							}
 | 
			
		||||
							return nil
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				if isIPQuery {
 | 
			
		||||
					go h.handleIPQuery(id, qType, domain, writer)
 | 
			
		||||
					b.Release()
 | 
			
		||||
					go h.handleIPQuery(id, qType, domain, writer, timer)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				if isIPQuery || h.nonIPQuery == "drop" {
 | 
			
		||||
				if h.nonIPQuery == "drop" {
 | 
			
		||||
					b.Release()
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				if h.nonIPQuery == "reject" {
 | 
			
		||||
					go h.rejectNonIPQuery(id, qType, domain, writer)
 | 
			
		||||
					b.Release()
 | 
			
		||||
					err := h.rejectNonIPQuery(id, qType, domain, writer)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return err
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
| 
						 | 
				
			
			@ -219,6 +233,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 | 
			
		|||
	}
 | 
			
		||||
 | 
			
		||||
	response := func() error {
 | 
			
		||||
		defer timer.SetTimeout(0)
 | 
			
		||||
		for {
 | 
			
		||||
			b, err := connReader.ReadMessage()
 | 
			
		||||
			if err == io.EOF {
 | 
			
		||||
| 
						 | 
				
			
			@ -244,7 +259,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 | 
			
		|||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
 | 
			
		||||
func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, timer *signal.ActivityTimer) {
 | 
			
		||||
	var ips []net.IP
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -319,16 +334,21 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		errors.LogInfoInner(context.Background(), err, "pack message")
 | 
			
		||||
		b.Release()
 | 
			
		||||
		return
 | 
			
		||||
		timer.SetTimeout(0)
 | 
			
		||||
	}
 | 
			
		||||
	b.Resize(0, int32(len(msgBytes)))
 | 
			
		||||
 | 
			
		||||
	if err := writer.WriteMessage(b); err != nil {
 | 
			
		||||
		errors.LogInfoInner(context.Background(), err, "write IP answer")
 | 
			
		||||
		timer.SetTimeout(0)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
 | 
			
		||||
func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) error {
 | 
			
		||||
	domainT := strings.TrimSuffix(domain, ".")
 | 
			
		||||
	if domainT == "" {
 | 
			
		||||
		return errors.New("empty domain name")
 | 
			
		||||
	}
 | 
			
		||||
	b := buf.New()
 | 
			
		||||
	rawBytes := b.Extend(buf.Size)
 | 
			
		||||
	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
 | 
			
		||||
| 
						 | 
				
			
			@ -349,20 +369,22 @@ func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain stri
 | 
			
		|||
	if err != nil {
 | 
			
		||||
		errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err)
 | 
			
		||||
		b.Release()
 | 
			
		||||
		return
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msgBytes, err := builder.Finish()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		errors.LogInfoInner(context.Background(), err, "pack reject message")
 | 
			
		||||
		b.Release()
 | 
			
		||||
		return
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	b.Resize(0, int32(len(msgBytes)))
 | 
			
		||||
 | 
			
		||||
	if err := writer.WriteMessage(b); err != nil {
 | 
			
		||||
		errors.LogInfoInner(context.Background(), err, "write reject answer")
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type outboundConn struct {
 | 
			
		||||
| 
						 | 
				
			
			@ -371,6 +393,7 @@ type outboundConn struct {
 | 
			
		|||
 | 
			
		||||
	conn      net.Conn
 | 
			
		||||
	connReady chan struct{}
 | 
			
		||||
	closed    bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *outboundConn) dial() error {
 | 
			
		||||
| 
						 | 
				
			
			@ -385,12 +408,16 @@ func (c *outboundConn) dial() error {
 | 
			
		|||
 | 
			
		||||
func (c *outboundConn) Write(b []byte) (int, error) {
 | 
			
		||||
	c.access.Lock()
 | 
			
		||||
	if c.closed {
 | 
			
		||||
		c.access.Unlock()
 | 
			
		||||
		return 0, errors.New("outbound connection closed")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if c.conn == nil {
 | 
			
		||||
		if err := c.dial(); err != nil {
 | 
			
		||||
			c.access.Unlock()
 | 
			
		||||
			errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection")
 | 
			
		||||
			return len(b), nil
 | 
			
		||||
			return 0, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -400,24 +427,27 @@ func (c *outboundConn) Write(b []byte) (int, error) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
func (c *outboundConn) Read(b []byte) (int, error) {
 | 
			
		||||
	var conn net.Conn
 | 
			
		||||
	c.access.Lock()
 | 
			
		||||
	conn = c.conn
 | 
			
		||||
	c.access.Unlock()
 | 
			
		||||
	if c.closed {
 | 
			
		||||
		c.access.Unlock()
 | 
			
		||||
		return 0, io.EOF
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
	if c.conn == nil {
 | 
			
		||||
		c.access.Unlock()
 | 
			
		||||
		_, open := <-c.connReady
 | 
			
		||||
		if !open {
 | 
			
		||||
			return 0, io.EOF
 | 
			
		||||
		}
 | 
			
		||||
		conn = c.conn
 | 
			
		||||
		return c.conn.Read(b)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return conn.Read(b)
 | 
			
		||||
	c.access.Unlock()
 | 
			
		||||
	return c.conn.Read(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *outboundConn) Close() error {
 | 
			
		||||
	c.access.Lock()
 | 
			
		||||
	c.closed = true
 | 
			
		||||
	close(c.connReady)
 | 
			
		||||
	if c.conn != nil {
 | 
			
		||||
		c.conn.Close()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue