From 99e736ac61f5702f67e61bad2c54e75f06498cce Mon Sep 17 00:00:00 2001 From: patterniha <71074308+patterniha@users.noreply.github.com> Date: Mon, 20 Oct 2025 21:04:10 +0330 Subject: [PATCH] Patch: if A-ttl is not expired but AAAA-ttl is expired, we should only send AAAA-query and vice versa 1. if A-ttl is not expired but AAAA-ttl is expired, we should only send AAAA-query and vice versa 2. `sendQuery` send each query in new goroutine so there is no need to run it in new goroutine. --- app/dns/cache_controller.go | 50 ++++++++++++++++++------------------- app/dns/nameserver_doh.go | 20 ++++++++++++--- app/dns/nameserver_quic.go | 20 ++++++++++++--- app/dns/nameserver_tcp.go | 20 ++++++++++++--- app/dns/nameserver_udp.go | 20 ++++++++++++--- 5 files changed, 89 insertions(+), 41 deletions(-) diff --git a/app/dns/cache_controller.go b/app/dns/cache_controller.go index d0c42e09..9d2081b7 100644 --- a/app/dns/cache_controller.go +++ b/app/dns/cache_controller.go @@ -102,20 +102,8 @@ func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) { switch req.reqType { case dnsmessage.TypeA: c.pub.Publish(req.domain+"4", nil) - if !c.disableCache { - _, _, err := rec.AAAA.getIPs() - if !go_errors.Is(err, errRecordNotFound) { - c.pub.Publish(req.domain+"6", nil) - } - } case dnsmessage.TypeAAAA: c.pub.Publish(req.domain+"6", nil) - if !c.disableCache { - _, _, err := rec.A.getIPs() - if !go_errors.Is(err, errRecordNotFound) { - c.pub.Publish(req.domain+"4", nil) - } - } } c.Unlock() @@ -124,13 +112,13 @@ func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) { } } -func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, int32, error) { +func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, int32, bool, bool, error) { c.RLock() record, found := c.ips[domain] c.RUnlock() if !found { - return nil, 0, errRecordNotFound + return nil, 0, true, true, errRecordNotFound } var errs []error @@ -139,43 +127,55 @@ func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPO mergeReq := option.IPv4Enable && option.IPv6Enable + isARecordExpired := true if option.IPv4Enable { ips, ttl, err := record.A.getIPs() - if !mergeReq || go_errors.Is(err, errRecordNotFound) { - return ips, ttl, err + if ttl > 0 { + isARecordExpired = false + } + if !mergeReq { + return ips, ttl, isARecordExpired, true, err } if ttl < rTTL { rTTL = ttl } if len(ips) > 0 { allIPs = append(allIPs, ips...) - } else { - errs = append(errs, err) } + errs = append(errs, err) + } + isAAAARecordExpired := true if option.IPv6Enable { ips, ttl, err := record.AAAA.getIPs() - if !mergeReq || go_errors.Is(err, errRecordNotFound) { - return ips, ttl, err + if ttl > 0 { + isAAAARecordExpired = false + } + if !mergeReq { + return ips, ttl, true, isAAAARecordExpired, err } if ttl < rTTL { rTTL = ttl } if len(ips) > 0 { allIPs = append(allIPs, ips...) - } else { - errs = append(errs, err) } + errs = append(errs, err) + + } + + if go_errors.Is(errs[0], errRecordNotFound) || go_errors.Is(errs[1], errRecordNotFound) { + return nil, 0, isARecordExpired, isAAAARecordExpired, errRecordNotFound } if len(allIPs) > 0 { - return allIPs, rTTL, nil + return allIPs, rTTL, isARecordExpired, isAAAARecordExpired, nil } if go_errors.Is(errs[0], errs[1]) { - return nil, rTTL, errs[0] + return nil, rTTL, isARecordExpired, isAAAARecordExpired, errs[0] } - return nil, rTTL, errors.Combine(errs...) + return nil, rTTL, isARecordExpired, isAAAARecordExpired, errors.Combine(errs...) } func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) { diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index c5ef6bbe..640c8530 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -229,10 +229,22 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) defer closeSubscribers(sub4, sub6) + queryOption := option + if s.cacheController.disableCache { errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option) + if sub4 != nil && !isARecordExpired { + sub4.Close() + sub4 = nil + queryOption.IPv4Enable = false + } + if sub6 != nil && !isAAAARecordExpired { + sub6.Close() + sub6 = nil + queryOption.IPv6Enable = false + } if !go_errors.Is(err, errRecordNotFound) { if ttl > 0 { errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) @@ -241,14 +253,14 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f } if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) { errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips) - go s.sendQuery(ctx, nil, fqdn, option) + s.sendQuery(ctx, nil, fqdn, queryOption) return ips, 1, err } } } noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) + s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption) start := time.Now() if sub4 != nil { @@ -272,7 +284,7 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f } } - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option) log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) var rTTL uint32 if ttl <= 0 { diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index c37fba42..42f55ea8 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -196,10 +196,22 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_ sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) defer closeSubscribers(sub4, sub6) + queryOption := option + if s.cacheController.disableCache { errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option) + if sub4 != nil && !isARecordExpired { + sub4.Close() + sub4 = nil + queryOption.IPv4Enable = false + } + if sub6 != nil && !isAAAARecordExpired { + sub6.Close() + sub6 = nil + queryOption.IPv6Enable = false + } if !go_errors.Is(err, errRecordNotFound) { if ttl > 0 { errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) @@ -208,14 +220,14 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_ } if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) { errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips) - go s.sendQuery(ctx, nil, fqdn, option) + s.sendQuery(ctx, nil, fqdn, queryOption) return ips, 1, err } } } noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) + s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption) start := time.Now() if sub4 != nil { @@ -239,7 +251,7 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_ } } - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option) log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) var rTTL uint32 if ttl <= 0 { diff --git a/app/dns/nameserver_tcp.go b/app/dns/nameserver_tcp.go index 129d1b26..be76e96f 100644 --- a/app/dns/nameserver_tcp.go +++ b/app/dns/nameserver_tcp.go @@ -224,10 +224,22 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) defer closeSubscribers(sub4, sub6) + queryOption := option + if s.cacheController.disableCache { errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option) + if sub4 != nil && !isARecordExpired { + sub4.Close() + sub4 = nil + queryOption.IPv4Enable = false + } + if sub6 != nil && !isAAAARecordExpired { + sub6.Close() + sub6 = nil + queryOption.IPv6Enable = false + } if !go_errors.Is(err, errRecordNotFound) { if ttl > 0 { errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) @@ -236,14 +248,14 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f } if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) { errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips) - go s.sendQuery(ctx, nil, fqdn, option) + s.sendQuery(ctx, nil, fqdn, queryOption) return ips, 1, err } } } noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) + s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption) start := time.Now() if sub4 != nil { @@ -267,7 +279,7 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f } } - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option) log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) var rTTL uint32 if ttl <= 0 { diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index 161e9b2d..734dbe50 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -174,10 +174,22 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) defer closeSubscribers(sub4, sub6) + queryOption := option + if s.cacheController.disableCache { errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option) + if sub4 != nil && !isARecordExpired { + sub4.Close() + sub4 = nil + queryOption.IPv4Enable = false + } + if sub6 != nil && !isAAAARecordExpired { + sub6.Close() + sub6 = nil + queryOption.IPv6Enable = false + } if !go_errors.Is(err, errRecordNotFound) { if ttl > 0 { errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) @@ -186,14 +198,14 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d } if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) { errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips) - go s.sendQuery(ctx, nil, fqdn, option) + s.sendQuery(ctx, nil, fqdn, queryOption) return ips, 1, err } } } noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) + s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption) start := time.Now() if sub4 != nil { @@ -217,7 +229,7 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d } } - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) + ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option) log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) var rTTL uint32 if ttl <= 0 {