diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index 40db53a8..9ae19cbe 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -4,6 +4,7 @@ import ( "context" go_errors "errors" "io" + "strings" "sync" "time" @@ -168,11 +169,15 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. } ctx, cancel := context.WithCancel(ctx) - timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) + terminate := func() { + cancel() + conn.Close() + } + timer := signal.CancelAfterInactivity(ctx, terminate, h.timeout) + defer timer.SetTimeout(0) request := func() error { - defer conn.Close() - + defer timer.SetTimeout(0) for { b, err := reader.ReadMessage() if err == io.EOF { @@ -190,24 +195,33 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. if len(h.blockTypes) > 0 { for _, blocktype := range h.blockTypes { if blocktype == int32(qType) { - if h.nonIPQuery == "reject" { - go h.rejectNonIPQuery(id, qType, domain, writer) - } + b.Release() errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain) + if h.nonIPQuery == "reject" { + err := h.rejectNonIPQuery(id, qType, domain, writer) + if err != nil { + return err + } + } return nil } } } if isIPQuery { - go h.handleIPQuery(id, qType, domain, writer) + b.Release() + go h.handleIPQuery(id, qType, domain, writer, timer) + continue } - if isIPQuery || h.nonIPQuery == "drop" { + if h.nonIPQuery == "drop" { b.Release() continue } if h.nonIPQuery == "reject" { - go h.rejectNonIPQuery(id, qType, domain, writer) b.Release() + err := h.rejectNonIPQuery(id, qType, domain, writer) + if err != nil { + return err + } continue } } @@ -219,6 +233,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. } response := func() error { + defer timer.SetTimeout(0) for { b, err := connReader.ReadMessage() if err == io.EOF { @@ -244,7 +259,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet. return nil } -func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { +func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, timer *signal.ActivityTimer) { var ips []net.IP var err error @@ -319,16 +334,21 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, if err != nil { errors.LogInfoInner(context.Background(), err, "pack message") b.Release() - return + timer.SetTimeout(0) } b.Resize(0, int32(len(msgBytes))) if err := writer.WriteMessage(b); err != nil { errors.LogInfoInner(context.Background(), err, "write IP answer") + timer.SetTimeout(0) } } -func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { +func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) error { + domainT := strings.TrimSuffix(domain, ".") + if domainT == "" { + return errors.New("empty domain name") + } b := buf.New() rawBytes := b.Extend(buf.Size) builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ @@ -349,20 +369,22 @@ func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain stri if err != nil { errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err) b.Release() - return + return err } msgBytes, err := builder.Finish() if err != nil { errors.LogInfoInner(context.Background(), err, "pack reject message") b.Release() - return + return err } b.Resize(0, int32(len(msgBytes))) if err := writer.WriteMessage(b); err != nil { errors.LogInfoInner(context.Background(), err, "write reject answer") + return err } + return nil } type outboundConn struct { @@ -371,6 +393,7 @@ type outboundConn struct { conn net.Conn connReady chan struct{} + closed bool } func (c *outboundConn) dial() error { @@ -385,12 +408,16 @@ func (c *outboundConn) dial() error { func (c *outboundConn) Write(b []byte) (int, error) { c.access.Lock() + if c.closed { + c.access.Unlock() + return 0, errors.New("outbound connection closed") + } if c.conn == nil { if err := c.dial(); err != nil { c.access.Unlock() errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection") - return len(b), nil + return 0, err } } @@ -400,24 +427,27 @@ func (c *outboundConn) Write(b []byte) (int, error) { } func (c *outboundConn) Read(b []byte) (int, error) { - var conn net.Conn c.access.Lock() - conn = c.conn - c.access.Unlock() + if c.closed { + c.access.Unlock() + return 0, io.EOF + } - if conn == nil { + if c.conn == nil { + c.access.Unlock() _, open := <-c.connReady if !open { return 0, io.EOF } - conn = c.conn + return c.conn.Read(b) } - - return conn.Read(b) + c.access.Unlock() + return c.conn.Read(b) } func (c *outboundConn) Close() error { c.access.Lock() + c.closed = true close(c.connReady) if c.conn != nil { c.conn.Close()