DNS outbound: Fix some issues (#5081)

pull/5086/head
patterniha 2025-09-04 23:51:21 +02:00 committed by RPRX
parent 8b579bf3ec
commit 197b319f9a
1 changed files with 52 additions and 22 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
go_errors "errors" go_errors "errors"
"io" "io"
"strings"
"sync" "sync"
"time" "time"
@ -168,11 +169,15 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
} }
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) terminate := func() {
cancel()
conn.Close()
}
timer := signal.CancelAfterInactivity(ctx, terminate, h.timeout)
defer timer.SetTimeout(0)
request := func() error { request := func() error {
defer conn.Close() defer timer.SetTimeout(0)
for { for {
b, err := reader.ReadMessage() b, err := reader.ReadMessage()
if err == io.EOF { if err == io.EOF {
@ -190,24 +195,33 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
if len(h.blockTypes) > 0 { if len(h.blockTypes) > 0 {
for _, blocktype := range h.blockTypes { for _, blocktype := range h.blockTypes {
if blocktype == int32(qType) { if blocktype == int32(qType) {
if h.nonIPQuery == "reject" { b.Release()
go h.rejectNonIPQuery(id, qType, domain, writer)
}
errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain) errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
if h.nonIPQuery == "reject" {
err := h.rejectNonIPQuery(id, qType, domain, writer)
if err != nil {
return err
}
}
return nil return nil
} }
} }
} }
if isIPQuery { if isIPQuery {
go h.handleIPQuery(id, qType, domain, writer) b.Release()
go h.handleIPQuery(id, qType, domain, writer, timer)
continue
} }
if isIPQuery || h.nonIPQuery == "drop" { if h.nonIPQuery == "drop" {
b.Release() b.Release()
continue continue
} }
if h.nonIPQuery == "reject" { if h.nonIPQuery == "reject" {
go h.rejectNonIPQuery(id, qType, domain, writer)
b.Release() b.Release()
err := h.rejectNonIPQuery(id, qType, domain, writer)
if err != nil {
return err
}
continue continue
} }
} }
@ -219,6 +233,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
} }
response := func() error { response := func() error {
defer timer.SetTimeout(0)
for { for {
b, err := connReader.ReadMessage() b, err := connReader.ReadMessage()
if err == io.EOF { if err == io.EOF {
@ -244,7 +259,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
return nil return nil
} }
func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, timer *signal.ActivityTimer) {
var ips []net.IP var ips []net.IP
var err error var err error
@ -319,16 +334,21 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
if err != nil { if err != nil {
errors.LogInfoInner(context.Background(), err, "pack message") errors.LogInfoInner(context.Background(), err, "pack message")
b.Release() b.Release()
return timer.SetTimeout(0)
} }
b.Resize(0, int32(len(msgBytes))) b.Resize(0, int32(len(msgBytes)))
if err := writer.WriteMessage(b); err != nil { if err := writer.WriteMessage(b); err != nil {
errors.LogInfoInner(context.Background(), err, "write IP answer") errors.LogInfoInner(context.Background(), err, "write IP answer")
timer.SetTimeout(0)
} }
} }
func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) error {
domainT := strings.TrimSuffix(domain, ".")
if domainT == "" {
return errors.New("empty domain name")
}
b := buf.New() b := buf.New()
rawBytes := b.Extend(buf.Size) rawBytes := b.Extend(buf.Size)
builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
@ -349,20 +369,22 @@ func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain stri
if err != nil { if err != nil {
errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err) errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err)
b.Release() b.Release()
return return err
} }
msgBytes, err := builder.Finish() msgBytes, err := builder.Finish()
if err != nil { if err != nil {
errors.LogInfoInner(context.Background(), err, "pack reject message") errors.LogInfoInner(context.Background(), err, "pack reject message")
b.Release() b.Release()
return return err
} }
b.Resize(0, int32(len(msgBytes))) b.Resize(0, int32(len(msgBytes)))
if err := writer.WriteMessage(b); err != nil { if err := writer.WriteMessage(b); err != nil {
errors.LogInfoInner(context.Background(), err, "write reject answer") errors.LogInfoInner(context.Background(), err, "write reject answer")
return err
} }
return nil
} }
type outboundConn struct { type outboundConn struct {
@ -371,6 +393,7 @@ type outboundConn struct {
conn net.Conn conn net.Conn
connReady chan struct{} connReady chan struct{}
closed bool
} }
func (c *outboundConn) dial() error { func (c *outboundConn) dial() error {
@ -385,12 +408,16 @@ func (c *outboundConn) dial() error {
func (c *outboundConn) Write(b []byte) (int, error) { func (c *outboundConn) Write(b []byte) (int, error) {
c.access.Lock() c.access.Lock()
if c.closed {
c.access.Unlock()
return 0, errors.New("outbound connection closed")
}
if c.conn == nil { if c.conn == nil {
if err := c.dial(); err != nil { if err := c.dial(); err != nil {
c.access.Unlock() c.access.Unlock()
errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection") errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection")
return len(b), nil return 0, err
} }
} }
@ -400,24 +427,27 @@ func (c *outboundConn) Write(b []byte) (int, error) {
} }
func (c *outboundConn) Read(b []byte) (int, error) { func (c *outboundConn) Read(b []byte) (int, error) {
var conn net.Conn
c.access.Lock() c.access.Lock()
conn = c.conn if c.closed {
c.access.Unlock() c.access.Unlock()
return 0, io.EOF
}
if conn == nil { if c.conn == nil {
c.access.Unlock()
_, open := <-c.connReady _, open := <-c.connReady
if !open { if !open {
return 0, io.EOF return 0, io.EOF
} }
conn = c.conn return c.conn.Read(b)
} }
c.access.Unlock()
return conn.Read(b) return c.conn.Read(b)
} }
func (c *outboundConn) Close() error { func (c *outboundConn) Close() error {
c.access.Lock() c.access.Lock()
c.closed = true
close(c.connReady) close(c.connReady)
if c.conn != nil { if c.conn != nil {
c.conn.Close() c.conn.Close()