Patch: if A-ttl is not expired but AAAA-ttl is expired, we should only send AAAA-query and vice versa

1. if A-ttl is not expired but AAAA-ttl is expired, we should only send  AAAA-query and vice versa

2. `sendQuery` send each query in new goroutine so there is no need to run it in new goroutine.
This commit is contained in:
patterniha
2025-10-20 21:04:10 +03:30
parent cd4f1cd4a5
commit 99e736ac61
5 changed files with 89 additions and 41 deletions

View File

@@ -102,20 +102,8 @@ func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
switch req.reqType {
case dnsmessage.TypeA:
c.pub.Publish(req.domain+"4", nil)
if !c.disableCache {
_, _, err := rec.AAAA.getIPs()
if !go_errors.Is(err, errRecordNotFound) {
c.pub.Publish(req.domain+"6", nil)
}
}
case dnsmessage.TypeAAAA:
c.pub.Publish(req.domain+"6", nil)
if !c.disableCache {
_, _, err := rec.A.getIPs()
if !go_errors.Is(err, errRecordNotFound) {
c.pub.Publish(req.domain+"4", nil)
}
}
}
c.Unlock()
@@ -124,13 +112,13 @@ func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
}
}
func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, int32, error) {
func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, int32, bool, bool, error) {
c.RLock()
record, found := c.ips[domain]
c.RUnlock()
if !found {
return nil, 0, errRecordNotFound
return nil, 0, true, true, errRecordNotFound
}
var errs []error
@@ -139,43 +127,55 @@ func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPO
mergeReq := option.IPv4Enable && option.IPv6Enable
isARecordExpired := true
if option.IPv4Enable {
ips, ttl, err := record.A.getIPs()
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
return ips, ttl, err
if ttl > 0 {
isARecordExpired = false
}
if !mergeReq {
return ips, ttl, isARecordExpired, true, err
}
if ttl < rTTL {
rTTL = ttl
}
if len(ips) > 0 {
allIPs = append(allIPs, ips...)
} else {
errs = append(errs, err)
}
errs = append(errs, err)
}
isAAAARecordExpired := true
if option.IPv6Enable {
ips, ttl, err := record.AAAA.getIPs()
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
return ips, ttl, err
if ttl > 0 {
isAAAARecordExpired = false
}
if !mergeReq {
return ips, ttl, true, isAAAARecordExpired, err
}
if ttl < rTTL {
rTTL = ttl
}
if len(ips) > 0 {
allIPs = append(allIPs, ips...)
} else {
errs = append(errs, err)
}
errs = append(errs, err)
}
if go_errors.Is(errs[0], errRecordNotFound) || go_errors.Is(errs[1], errRecordNotFound) {
return nil, 0, isARecordExpired, isAAAARecordExpired, errRecordNotFound
}
if len(allIPs) > 0 {
return allIPs, rTTL, nil
return allIPs, rTTL, isARecordExpired, isAAAARecordExpired, nil
}
if go_errors.Is(errs[0], errs[1]) {
return nil, rTTL, errs[0]
return nil, rTTL, isARecordExpired, isAAAARecordExpired, errs[0]
}
return nil, rTTL, errors.Combine(errs...)
return nil, rTTL, isARecordExpired, isAAAARecordExpired, errors.Combine(errs...)
}
func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {

View File

@@ -229,10 +229,22 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
defer closeSubscribers(sub4, sub6)
queryOption := option
if s.cacheController.disableCache {
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
} else {
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
if sub4 != nil && !isARecordExpired {
sub4.Close()
sub4 = nil
queryOption.IPv4Enable = false
}
if sub6 != nil && !isAAAARecordExpired {
sub6.Close()
sub6 = nil
queryOption.IPv6Enable = false
}
if !go_errors.Is(err, errRecordNotFound) {
if ttl > 0 {
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -241,14 +253,14 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f
}
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
go s.sendQuery(ctx, nil, fqdn, option)
s.sendQuery(ctx, nil, fqdn, queryOption)
return ips, 1, err
}
}
}
noResponseErrCh := make(chan error, 2)
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
start := time.Now()
if sub4 != nil {
@@ -272,7 +284,7 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f
}
}
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
var rTTL uint32
if ttl <= 0 {

View File

@@ -196,10 +196,22 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
defer closeSubscribers(sub4, sub6)
queryOption := option
if s.cacheController.disableCache {
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
} else {
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
if sub4 != nil && !isARecordExpired {
sub4.Close()
sub4 = nil
queryOption.IPv4Enable = false
}
if sub6 != nil && !isAAAARecordExpired {
sub6.Close()
sub6 = nil
queryOption.IPv6Enable = false
}
if !go_errors.Is(err, errRecordNotFound) {
if ttl > 0 {
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -208,14 +220,14 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_
}
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
go s.sendQuery(ctx, nil, fqdn, option)
s.sendQuery(ctx, nil, fqdn, queryOption)
return ips, 1, err
}
}
}
noResponseErrCh := make(chan error, 2)
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
start := time.Now()
if sub4 != nil {
@@ -239,7 +251,7 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_
}
}
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
var rTTL uint32
if ttl <= 0 {

View File

@@ -224,10 +224,22 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
defer closeSubscribers(sub4, sub6)
queryOption := option
if s.cacheController.disableCache {
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
} else {
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
if sub4 != nil && !isARecordExpired {
sub4.Close()
sub4 = nil
queryOption.IPv4Enable = false
}
if sub6 != nil && !isAAAARecordExpired {
sub6.Close()
sub6 = nil
queryOption.IPv6Enable = false
}
if !go_errors.Is(err, errRecordNotFound) {
if ttl > 0 {
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -236,14 +248,14 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f
}
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
go s.sendQuery(ctx, nil, fqdn, option)
s.sendQuery(ctx, nil, fqdn, queryOption)
return ips, 1, err
}
}
}
noResponseErrCh := make(chan error, 2)
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
start := time.Now()
if sub4 != nil {
@@ -267,7 +279,7 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f
}
}
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
var rTTL uint32
if ttl <= 0 {

View File

@@ -174,10 +174,22 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
defer closeSubscribers(sub4, sub6)
queryOption := option
if s.cacheController.disableCache {
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
} else {
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
if sub4 != nil && !isARecordExpired {
sub4.Close()
sub4 = nil
queryOption.IPv4Enable = false
}
if sub6 != nil && !isAAAARecordExpired {
sub6.Close()
sub6 = nil
queryOption.IPv6Enable = false
}
if !go_errors.Is(err, errRecordNotFound) {
if ttl > 0 {
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -186,14 +198,14 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d
}
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
go s.sendQuery(ctx, nil, fqdn, option)
s.sendQuery(ctx, nil, fqdn, queryOption)
return ips, 1, err
}
}
}
noResponseErrCh := make(chan error, 2)
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
start := time.Now()
if sub4 != nil {
@@ -217,7 +229,7 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d
}
}
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
var rTTL uint32
if ttl <= 0 {