diff --git a/app/dns/cache_controller.go b/app/dns/cache_controller.go index f23c414d..85818039 100644 --- a/app/dns/cache_controller.go +++ b/app/dns/cache_controller.go @@ -3,24 +3,37 @@ package dns import ( "context" go_errors "errors" + "runtime" + "sync" + "time" + "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/signal/pubsub" "github.com/xtls/xray-core/common/task" dns_feature "github.com/xtls/xray-core/features/dns" + "golang.org/x/net/dns/dnsmessage" - "sync" - "time" + "golang.org/x/sync/singleflight" +) + +const ( + minSizeForEmptyRebuild = 512 + shrinkAbsoluteThreshold = 10240 + shrinkRatioThreshold = 0.65 + migrationBatchSize = 4096 ) type CacheController struct { sync.RWMutex - ips map[string]*record - pub *pubsub.Service - cacheCleanup *task.Periodic - name string - disableCache bool + ips map[string]*record + dirtyips map[string]*record + pub *pubsub.Service + cacheCleanup *task.Periodic + name string + disableCache bool + highWatermark int + requestGroup singleflight.Group } func NewCacheController(name string, disableCache bool) *CacheController { @@ -32,7 +45,7 @@ func NewCacheController(name string, disableCache bool) *CacheController { } c.cacheCleanup = &task.Periodic{ - Interval: time.Minute, + Interval: 300 * time.Second, Execute: c.CacheCleanup, } return c @@ -40,131 +53,253 @@ func NewCacheController(name string, disableCache bool) *CacheController { // CacheCleanup clears expired items from cache func (c *CacheController) CacheCleanup() error { - now := time.Now() - c.Lock() - defer c.Unlock() - - if len(c.ips) == 0 { - return errors.New("nothing to do. stopping...") + expiredKeys, err := c.collectExpiredKeys() + if err != nil { + return err } - - for domain, record := range c.ips { - if record.A != nil && record.A.Expire.Before(now) { - record.A = nil - } - if record.AAAA != nil && record.AAAA.Expire.Before(now) { - record.AAAA = nil - } - - if record.A == nil && record.AAAA == nil { - errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain) - delete(c.ips, domain) - } else { - c.ips[domain] = record - } + if len(expiredKeys) == 0 { + return nil } - - if len(c.ips) == 0 { - c.ips = make(map[string]*record) - } - + c.writeAndShrink(expiredKeys) return nil } -func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) { - elapsed := time.Since(req.start) +func (c *CacheController) collectExpiredKeys() ([]string, error) { + c.RLock() + defer c.RUnlock() + + if len(c.ips) == 0 { + return nil, errors.New("nothing to do. stopping...") + } + + // skip collection if a migration is in progress + if c.dirtyips != nil { + return nil, nil + } + + now := time.Now() + expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate + + for domain, rec := range c.ips { + if (rec.A != nil && rec.A.Expire.Before(now)) || + (rec.AAAA != nil && rec.AAAA.Expire.Before(now)) { + expiredKeys = append(expiredKeys, domain) + } + } + + return expiredKeys, nil +} + +func (c *CacheController) writeAndShrink(expiredKeys []string) { + c.Lock() + defer c.Unlock() + + // double check to prevent upper call multiple cleanup tasks + if c.dirtyips != nil { + return + } + + lenBefore := len(c.ips) + if lenBefore > c.highWatermark { + c.highWatermark = lenBefore + } + + now := time.Now() + for _, domain := range expiredKeys { + rec := c.ips[domain] + if rec == nil { + continue + } + if rec.A != nil && rec.A.Expire.Before(now) { + rec.A = nil + } + if rec.AAAA != nil && rec.AAAA.Expire.Before(now) { + rec.AAAA = nil + } + if rec.A == nil && rec.AAAA == nil { + delete(c.ips, domain) + } + } + + lenAfter := len(c.ips) + + if lenAfter == 0 { + if c.highWatermark >= minSizeForEmptyRebuild { + errors.LogDebug(context.Background(), c.name, + " rebuilding empty cache map to reclaim memory.", + " size_before_cleanup=", lenBefore, + " peak_size_before_rebuild=", c.highWatermark, + ) + + c.ips = make(map[string]*record) + c.highWatermark = 0 + } + return + } + + if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold && + float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold { + errors.LogDebug(context.Background(), c.name, + " shrinking cache map to reclaim memory.", + " new_size=", lenAfter, + " peak_size_before_shrink=", c.highWatermark, + " reduction_since_peak=", reductionFromPeak, + ) + + c.dirtyips = c.ips + c.ips = make(map[string]*record, int(float64(lenAfter)*1.1)) + c.highWatermark = lenAfter + go c.migrate() + } + +} + +type migrationEntry struct { + key string + value *record +} + +func (c *CacheController) migrate() { + defer func() { + if r := recover(); r != nil { + errors.LogError(context.Background(), c.name, " panic during cache migration: ", r) + c.Lock() + c.dirtyips = nil + // c.ips = make(map[string]*record) + // c.highWatermark = 0 + c.Unlock() + } + }() + + c.RLock() + dirtyips := c.dirtyips + c.RUnlock() + + // double check to prevent upper call multiple cleanup tasks + if dirtyips == nil { + return + } + + errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.") + + batch := make([]migrationEntry, 0, migrationBatchSize) + for domain, recD := range dirtyips { + batch = append(batch, migrationEntry{domain, recD}) + + if len(batch) >= migrationBatchSize { + c.flush(batch) + batch = batch[:0] + runtime.Gosched() + } + } + if len(batch) > 0 { + c.flush(batch) + } c.Lock() - rec, found := c.ips[req.domain] - if !found { - rec = &record{} - } - - switch req.reqType { - case dnsmessage.TypeA: - rec.A = ipRec - case dnsmessage.TypeAAAA: - rec.AAAA = ipRec - } - - errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed) - c.ips[req.domain] = rec - - switch req.reqType { - case dnsmessage.TypeA: - c.pub.Publish(req.domain+"4", nil) - if !c.disableCache { - _, _, err := rec.AAAA.getIPs() - if !go_errors.Is(err, errRecordNotFound) { - c.pub.Publish(req.domain+"6", nil) - } - } - case dnsmessage.TypeAAAA: - c.pub.Publish(req.domain+"6", nil) - if !c.disableCache { - _, _, err := rec.A.getIPs() - if !go_errors.Is(err, errRecordNotFound) { - c.pub.Publish(req.domain+"4", nil) - } - } - } - + c.dirtyips = nil c.Unlock() + + errors.LogDebug(context.Background(), c.name, " cache migration completed.") +} + +func (c *CacheController) flush(batch []migrationEntry) { + c.Lock() + defer c.Unlock() + + for _, dirty := range batch { + if cur := c.ips[dirty.key]; cur != nil { + merge := &record{} + if cur.A == nil { + merge.A = dirty.value.A + } else { + merge.A = cur.A + } + if cur.AAAA == nil { + merge.AAAA = dirty.value.AAAA + } else { + merge.AAAA = cur.AAAA + } + c.ips[dirty.key] = merge + } else { + c.ips[dirty.key] = dirty.value + } + } +} + +func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) { + rtt := time.Since(req.start) + + switch req.reqType { + case dnsmessage.TypeA: + c.pub.Publish(req.domain+"4", rep) + case dnsmessage.TypeAAAA: + c.pub.Publish(req.domain+"6", rep) + } + + if c.disableCache { + errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt) + return + } + + c.Lock() + lockWait := time.Since(req.start) - rtt + + newRec := &record{} + oldRec := c.ips[req.domain] + var dirtyRec *record + if c.dirtyips != nil { + dirtyRec = c.dirtyips[req.domain] + } + + var pubRecord *IPRecord + var pubSuffix string + + switch req.reqType { + case dnsmessage.TypeA: + newRec.A = rep + if oldRec != nil && oldRec.AAAA != nil { + newRec.AAAA = oldRec.AAAA + pubRecord = oldRec.AAAA + } else if dirtyRec != nil && dirtyRec.AAAA != nil { + pubRecord = dirtyRec.AAAA + } + pubSuffix = "6" + case dnsmessage.TypeAAAA: + newRec.AAAA = rep + if oldRec != nil && oldRec.A != nil { + newRec.A = oldRec.A + pubRecord = oldRec.A + } else if dirtyRec != nil && dirtyRec.A != nil { + pubRecord = dirtyRec.A + } + pubSuffix = "4" + } + + c.ips[req.domain] = newRec + c.Unlock() + + if pubRecord != nil { + _, _ /*ttl*/, err := pubRecord.getIPs() + if /*ttl >= 0 &&*/ !go_errors.Is(err, errRecordNotFound) { + c.pub.Publish(req.domain+pubSuffix, pubRecord) + } + } + + errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait) + common.Must(c.cacheCleanup.Start()) } -func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { +func (c *CacheController) findRecords(domain string) *record { c.RLock() - record, found := c.ips[domain] - c.RUnlock() + defer c.RUnlock() - if !found { - return nil, 0, errRecordNotFound + rec := c.ips[domain] + if rec == nil && c.dirtyips != nil { + rec = c.dirtyips[domain] } - - var errs []error - var allIPs []net.IP - var rTTL uint32 = dns_feature.DefaultTTL - - mergeReq := option.IPv4Enable && option.IPv6Enable - - if option.IPv4Enable { - ips, ttl, err := record.A.getIPs() - if !mergeReq || go_errors.Is(err, errRecordNotFound) { - return ips, ttl, err - } - if ttl < rTTL { - rTTL = ttl - } - if len(ips) > 0 { - allIPs = append(allIPs, ips...) - } else { - errs = append(errs, err) - } - } - - if option.IPv6Enable { - ips, ttl, err := record.AAAA.getIPs() - if !mergeReq || go_errors.Is(err, errRecordNotFound) { - return ips, ttl, err - } - if ttl < rTTL { - rTTL = ttl - } - if len(ips) > 0 { - allIPs = append(allIPs, ips...) - } else { - errs = append(errs, err) - } - } - - if len(allIPs) > 0 { - return allIPs, rTTL, nil - } - if go_errors.Is(errs[0], errs[1]) { - return nil, rTTL, errs[0] - } - return nil, rTTL, errors.Combine(errs...) + return rec } func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) { diff --git a/app/dns/dnscommon.go b/app/dns/dnscommon.go index 15ac2efe..e4ec8ae8 100644 --- a/app/dns/dnscommon.go +++ b/app/dns/dnscommon.go @@ -17,6 +17,7 @@ import ( ) // Fqdn normalizes domain make sure it ends with '.' +// case-sensitive func Fqdn(domain string) string { if len(domain) > 0 && strings.HasSuffix(domain, ".") { return domain diff --git a/app/dns/nameserver_cached.go b/app/dns/nameserver_cached.go new file mode 100644 index 00000000..d1872b0c --- /dev/null +++ b/app/dns/nameserver_cached.go @@ -0,0 +1,149 @@ +package dns + +import ( + "context" + go_errors "errors" + "time" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/common/log" + "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/signal/pubsub" + "github.com/xtls/xray-core/features/dns" +) + +type CachedNameserver interface { + getCacheController() *CacheController + + sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns.IPOption) +} + +// queryIP is called from dns.Server->queryIPTimeout +func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.IPOption) ([]net.IP, uint32, error) { + fqdn := Fqdn(domain) + + cache := s.getCacheController() + if !cache.disableCache { + if rec := cache.findRecords(fqdn); rec != nil { + ips, ttl, err := merge(option, rec.A, rec.AAAA) + if !go_errors.Is(err, errRecordNotFound) { + // errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips) + log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) + return ips, ttl, err + } + } + } else { + errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", fqdn, " at ", cache.name) + } + + return fetch(ctx, s, fqdn, option) +} + +func fetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) ([]net.IP, uint32, error) { + key := fqdn + "f" + switch { + case option.IPv4Enable && option.IPv6Enable: + key = key + "46" + case option.IPv4Enable: + key = key + "4" + case option.IPv6Enable: + key = key + "6" + } + + v, _, _ := s.getCacheController().requestGroup.Do(key, func() (any, error) { + return doFetch(ctx, s, fqdn, option), nil + }) + ret := v.(result) + + return ret.ips, ret.ttl, ret.error +} + +type result struct { + ips []net.IP + ttl uint32 + error +} + +func doFetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) result { + sub4, sub6 := s.getCacheController().registerSubscribers(fqdn, option) + defer closeSubscribers(sub4, sub6) + + noResponseErrCh := make(chan error, 2) + onEvent := func(sub *pubsub.Subscriber) (*IPRecord, error) { + if sub == nil { + return nil, nil + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-noResponseErrCh: + return nil, err + case msg := <-sub.Wait(): + sub.Close() + return msg.(*IPRecord), nil // should panic + } + } + + start := time.Now() + s.sendQuery(ctx, noResponseErrCh, fqdn, option) + + rec4, err4 := onEvent(sub4) + rec6, err6 := onEvent(sub6) + + var errs []error + if err4 != nil { + errs = append(errs, err4) + } + if err6 != nil { + errs = append(errs, err6) + } + + ips, ttl, err := merge(option, rec4, rec6, errs...) + log.Record(&log.DNSLog{Server: s.getCacheController().name, Domain: fqdn, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) + return result{ips, ttl, err} +} + +func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, uint32, error) { + var allIPs []net.IP + var rTTL uint32 = dns.DefaultTTL + + mergeReq := option.IPv4Enable && option.IPv6Enable + + if option.IPv4Enable { + ips, ttl, err := rec4.getIPs() // it's safe + if !mergeReq || go_errors.Is(err, errRecordNotFound) { + return ips, ttl, err + } + if ttl < rTTL { + rTTL = ttl + } + if len(ips) > 0 { + allIPs = append(allIPs, ips...) + } else { + errs = append(errs, err) + } + } + + if option.IPv6Enable { + ips, ttl, err := rec6.getIPs() // it's safe + if !mergeReq || go_errors.Is(err, errRecordNotFound) { + return ips, ttl, err + } + if ttl < rTTL { + rTTL = ttl + } + if len(ips) > 0 { + allIPs = append(allIPs, ips...) + } else { + errs = append(errs, err) + } + } + + if len(allIPs) > 0 { + return allIPs, rTTL, nil + } + if len(errs) == 2 && go_errors.Is(errs[0], errs[1]) { + return nil, rTTL, errs[0] + } + return nil, rTTL, errors.Combine(errs...) +} diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index cba59423..ebdea6a2 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "crypto/tls" - go_errors "errors" "fmt" "io" "net/http" @@ -121,10 +120,16 @@ func (s *DoHNameServer) newReqID() uint16 { return 0 } -func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) { - errors.LogInfo(ctx, s.Name(), " querying: ", domain) +// getCacheController implements CachedNameserver. +func (s *DoHNameServer) getCacheController() *CacheController { + return s.cacheController +} - if s.Name()+"." == "DOH//"+domain { +// sendQuery implements CachedNameserver. +func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) { + errors.LogInfo(ctx, s.Name(), " querying: ", fqdn) + + if s.Name()+"." == "DOH//"+fqdn { errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.") noResponseErrCh <- errors.New("tries to resolve itself!", s.Name()) return @@ -132,7 +137,7 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er // As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467 // Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300)))) + reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300)))) var deadline time.Time if d, ok := ctx.Deadline(); ok { @@ -166,23 +171,23 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er b, err := dns.PackMessage(r.msg) if err != nil { - errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain) + errors.LogErrorInner(ctx, err, "failed to pack dns query for ", fqdn) noResponseErrCh <- err return } resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes()) if err != nil { - errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain) + errors.LogErrorInner(ctx, err, "failed to retrieve response for ", fqdn) noResponseErrCh <- err return } rec, err := parseResponse(resp) if err != nil { - errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain) + errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", fqdn) noResponseErrCh <- err return } - s.cacheController.updateIP(r, rec) + s.cacheController.updateRecord(r, rec) }(req) } } @@ -216,49 +221,6 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, } // QueryIP implements Server. -func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl - fqdn := Fqdn(domain) - sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) - defer closeSubscribers(sub4, sub6) - - if s.cacheController.disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) - } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - if !go_errors.Is(err, errRecordNotFound) { - errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) - return ips, ttl, err - } - } - - noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) - start := time.Now() - - if sub4 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub4.Wait(): - sub4.Close() - } - } - if sub6 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub6.Wait(): - sub6.Close() - } - } - - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - +func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { + return queryIP(ctx, s, domain, option) } diff --git a/app/dns/nameserver_quic.go b/app/dns/nameserver_quic.go index 75b9854e..c294ecda 100644 --- a/app/dns/nameserver_quic.go +++ b/app/dns/nameserver_quic.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/binary" - go_errors "errors" "net/url" "sync" "time" @@ -59,7 +58,7 @@ func NewQUICNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*QUICN return s, nil } -// Name returns client name +// Name implements Server. func (s *QUICNameServer) Name() string { return s.cacheController.name } @@ -68,10 +67,14 @@ func (s *QUICNameServer) newReqID() uint16 { return 0 } -func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) { - errors.LogInfo(ctx, s.Name(), " querying: ", domain) +// getCacheController implements CachedNameServer. +func (s *QUICNameServer) getCacheController() *CacheController { return s.cacheController } - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) +// sendQuery implements CachedNameServer. +func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) { + errors.LogInfo(ctx, s.Name(), " querying: ", fqdn) + + reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) var deadline time.Time if d, ok := ctx.Deadline(); ok { @@ -167,57 +170,14 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- e noResponseErrCh <- err return } - s.cacheController.updateIP(r, rec) + s.cacheController.updateRecord(r, rec) }(req) } } -// QueryIP is called from dns.Server->queryIPTimeout +// QueryIP implements Server. func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - fqdn := Fqdn(domain) - sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) - defer closeSubscribers(sub4, sub6) - - if s.cacheController.disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) - } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - if !go_errors.Is(err, errRecordNotFound) { - errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) - return ips, ttl, err - } - } - - noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) - start := time.Now() - - if sub4 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub4.Wait(): - sub4.Close() - } - } - if sub6 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub6.Wait(): - sub6.Close() - } - } - - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - + return queryIP(ctx, s, domain, option) } func isActive(s *quic.Conn) bool { diff --git a/app/dns/nameserver_tcp.go b/app/dns/nameserver_tcp.go index a5e81ae0..b9a2c2de 100644 --- a/app/dns/nameserver_tcp.go +++ b/app/dns/nameserver_tcp.go @@ -4,14 +4,12 @@ import ( "bytes" "context" "encoding/binary" - go_errors "errors" "net/url" "sync/atomic" "time" "github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net/cnc" "github.com/xtls/xray-core/common/protocol/dns" @@ -99,10 +97,16 @@ func (s *TCPNameServer) newReqID() uint16 { return uint16(atomic.AddUint32(&s.reqID, 1)) } -func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) { - errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain) +// getCacheController implements CachedNameserver. +func (s *TCPNameServer) getCacheController() *CacheController { + return s.cacheController +} - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) +// sendQuery implements CachedNameserver. +func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) { + errors.LogDebug(ctx, s.Name(), " querying DNS for: ", fqdn) + + reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) var deadline time.Time if d, ok := ctx.Deadline(); ok { @@ -195,55 +199,12 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er return } - s.cacheController.updateIP(r, rec) + s.cacheController.updateRecord(r, rec) }(req) } } // QueryIP implements Server. func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - fqdn := Fqdn(domain) - sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) - defer closeSubscribers(sub4, sub6) - - if s.cacheController.disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) - } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - if !go_errors.Is(err, errRecordNotFound) { - errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) - return ips, ttl, err - } - } - - noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) - start := time.Now() - - if sub4 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub4.Wait(): - sub4.Close() - } - } - if sub6 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub6.Wait(): - sub6.Close() - } - } - - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - + return queryIP(ctx, s, domain, option) } diff --git a/app/dns/nameserver_udp.go b/app/dns/nameserver_udp.go index e29f6e24..67a57b7f 100644 --- a/app/dns/nameserver_udp.go +++ b/app/dns/nameserver_udp.go @@ -2,7 +2,6 @@ package dns import ( "context" - go_errors "errors" "strings" "sync" "sync/atomic" @@ -10,7 +9,6 @@ import ( "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol/dns" udp_proto "github.com/xtls/xray-core/common/protocol/udp" @@ -134,7 +132,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot } } - s.cacheController.updateIP(&req.dnsRequest, ipRec) + s.cacheController.updateRecord(&req.dnsRequest, ipRec) } func (s *ClassicNameServer) newReqID() uint16 { @@ -150,10 +148,16 @@ func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) { common.Must(s.requestsCleanup.Start()) } -func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) { - errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain) +// getCacheController implements CachedNameserver. +func (s *ClassicNameServer) getCacheController() *CacheController { + return s.cacheController +} - reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) +// sendQuery implements CachedNameserver. +func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, fqdn string, option dns_feature.IPOption) { + errors.LogDebug(ctx, s.Name(), " querying DNS for: ", fqdn) + + reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0)) for _, req := range reqs { udpReq := &udpDnsRequest{ @@ -170,48 +174,5 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai // QueryIP implements Server. func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { - fqdn := Fqdn(domain) - sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option) - defer closeSubscribers(sub4, sub6) - - if s.cacheController.disableCache { - errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name()) - } else { - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - if !go_errors.Is(err, errRecordNotFound) { - errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) - return ips, ttl, err - } - } - - noResponseErrCh := make(chan error, 2) - s.sendQuery(ctx, noResponseErrCh, fqdn, option) - start := time.Now() - - if sub4 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub4.Wait(): - sub4.Close() - } - } - if sub6 != nil { - select { - case <-ctx.Done(): - return nil, 0, ctx.Err() - case err := <-noResponseErrCh: - return nil, 0, err - case <-sub6.Wait(): - sub6.Close() - } - } - - ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option) - log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) - return ips, ttl, err - + return queryIP(ctx, s, domain, option) }