package dns import ( "context" go_errors "errors" "time" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/log" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/signal/pubsub" "github.com/xtls/xray-core/features/dns" ) type CachedNameserver interface { getCacheController() *CacheController sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns.IPOption) } // queryIP is called from dns.Server->queryIPTimeout func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.IPOption) ([]net.IP, uint32, error) { fqdn := Fqdn(domain) cache := s.getCacheController() if !cache.disableCache { if rec := cache.findRecords(fqdn); rec != nil { ips, ttl, err := merge(option, rec.A, rec.AAAA) if !go_errors.Is(err, errRecordNotFound) { if ttl > 0 { errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips) log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) return ips, uint32(ttl), err } if cache.serveStale && (cache.serveExpiredTTL == 0 || cache.serveExpiredTTL < ttl) { errors.LogDebugInner(ctx, err, cache.name, " cache OPTIMISTE ", fqdn, " -> ", ips) log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheOptimiste, Elapsed: 0, Error: err}) go pull(ctx, s, fqdn, option) return ips, 1, err } } } } else { errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", fqdn, " at ", cache.name) } return fetch(ctx, s, fqdn, option) } func pull(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) { nctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 8*time.Second) defer cancel() fetch(nctx, s, fqdn, option) } func fetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) ([]net.IP, uint32, error) { key := fqdn switch { case option.IPv4Enable && option.IPv6Enable: key = key + "46" case option.IPv4Enable: key = key + "4" case option.IPv6Enable: key = key + "6" } v, _, _ := s.getCacheController().requestGroup.Do(key, func() (any, error) { return doFetch(ctx, s, fqdn, option), nil }) ret := v.(result) return ret.ips, ret.ttl, ret.error } type result struct { ips []net.IP ttl uint32 error } func doFetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) result { sub4, sub6 := s.getCacheController().registerSubscribers(fqdn, option) defer closeSubscribers(sub4, sub6) noResponseErrCh := make(chan error, 2) onEvent := func(sub *pubsub.Subscriber) (*IPRecord, error) { if sub == nil { return nil, nil } select { case <-ctx.Done(): return nil, ctx.Err() case err := <-noResponseErrCh: return nil, err case msg := <-sub.Wait(): sub.Close() return msg.(*IPRecord), nil // should panic } } start := time.Now() s.sendQuery(ctx, noResponseErrCh, fqdn, option) rec4, err4 := onEvent(sub4) rec6, err6 := onEvent(sub6) var errs []error if err4 != nil { errs = append(errs, err4) } if err6 != nil { errs = append(errs, err6) } ips, ttl, err := merge(option, rec4, rec6, errs...) var rTTL uint32 if ttl > 0 { rTTL = uint32(ttl) } else if ttl == 0 && go_errors.Is(err, errRecordNotFound) { rTTL = 0 } else { // edge case: where a fast rep's ttl expires during the rtt of a slower, parallel query rTTL = 1 } log.Record(&log.DNSLog{Server: s.getCacheController().name, Domain: fqdn, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) return result{ips, rTTL, err} } func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, int32, error) { var allIPs []net.IP var rTTL int32 = dns.DefaultTTL mergeReq := option.IPv4Enable && option.IPv6Enable if option.IPv4Enable { ips, ttl, err := rec4.getIPs() // it's safe if !mergeReq || go_errors.Is(err, errRecordNotFound) { return ips, ttl, err } if ttl < rTTL { rTTL = ttl } if len(ips) > 0 { allIPs = append(allIPs, ips...) } else { errs = append(errs, err) } } if option.IPv6Enable { ips, ttl, err := rec6.getIPs() // it's safe if !mergeReq || go_errors.Is(err, errRecordNotFound) { return ips, ttl, err } if ttl < rTTL { rTTL = ttl } if len(ips) > 0 { allIPs = append(allIPs, ips...) } else { errs = append(errs, err) } } if len(allIPs) > 0 { return allIPs, rTTL, nil } if len(errs) == 2 && go_errors.Is(errs[0], errs[1]) { return nil, rTTL, errs[0] } return nil, rTTL, errors.Combine(errs...) }