mirror of
https://github.com/XTLS/Xray-core.git
synced 2025-12-15 09:34:00 +08:00
refactor(dns): enhance cache safety, optimize performance, and refactor query logic (#5248)
This commit is contained in:
@@ -3,24 +3,37 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
dns_feature "github.com/xtls/xray-core/features/dns"
|
||||
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
"sync"
|
||||
"time"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
minSizeForEmptyRebuild = 512
|
||||
shrinkAbsoluteThreshold = 10240
|
||||
shrinkRatioThreshold = 0.65
|
||||
migrationBatchSize = 4096
|
||||
)
|
||||
|
||||
type CacheController struct {
|
||||
sync.RWMutex
|
||||
ips map[string]*record
|
||||
pub *pubsub.Service
|
||||
cacheCleanup *task.Periodic
|
||||
name string
|
||||
disableCache bool
|
||||
ips map[string]*record
|
||||
dirtyips map[string]*record
|
||||
pub *pubsub.Service
|
||||
cacheCleanup *task.Periodic
|
||||
name string
|
||||
disableCache bool
|
||||
highWatermark int
|
||||
requestGroup singleflight.Group
|
||||
}
|
||||
|
||||
func NewCacheController(name string, disableCache bool) *CacheController {
|
||||
@@ -32,7 +45,7 @@ func NewCacheController(name string, disableCache bool) *CacheController {
|
||||
}
|
||||
|
||||
c.cacheCleanup = &task.Periodic{
|
||||
Interval: time.Minute,
|
||||
Interval: 300 * time.Second,
|
||||
Execute: c.CacheCleanup,
|
||||
}
|
||||
return c
|
||||
@@ -40,131 +53,253 @@ func NewCacheController(name string, disableCache bool) *CacheController {
|
||||
|
||||
// CacheCleanup clears expired items from cache
|
||||
func (c *CacheController) CacheCleanup() error {
|
||||
now := time.Now()
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if len(c.ips) == 0 {
|
||||
return errors.New("nothing to do. stopping...")
|
||||
expiredKeys, err := c.collectExpiredKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for domain, record := range c.ips {
|
||||
if record.A != nil && record.A.Expire.Before(now) {
|
||||
record.A = nil
|
||||
}
|
||||
if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
||||
record.AAAA = nil
|
||||
}
|
||||
|
||||
if record.A == nil && record.AAAA == nil {
|
||||
errors.LogDebug(context.Background(), c.name, "cache cleanup ", domain)
|
||||
delete(c.ips, domain)
|
||||
} else {
|
||||
c.ips[domain] = record
|
||||
}
|
||||
if len(expiredKeys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(c.ips) == 0 {
|
||||
c.ips = make(map[string]*record)
|
||||
}
|
||||
|
||||
c.writeAndShrink(expiredKeys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
|
||||
elapsed := time.Since(req.start)
|
||||
func (c *CacheController) collectExpiredKeys() ([]string, error) {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
if len(c.ips) == 0 {
|
||||
return nil, errors.New("nothing to do. stopping...")
|
||||
}
|
||||
|
||||
// skip collection if a migration is in progress
|
||||
if c.dirtyips != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate
|
||||
|
||||
for domain, rec := range c.ips {
|
||||
if (rec.A != nil && rec.A.Expire.Before(now)) ||
|
||||
(rec.AAAA != nil && rec.AAAA.Expire.Before(now)) {
|
||||
expiredKeys = append(expiredKeys, domain)
|
||||
}
|
||||
}
|
||||
|
||||
return expiredKeys, nil
|
||||
}
|
||||
|
||||
func (c *CacheController) writeAndShrink(expiredKeys []string) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// double check to prevent upper call multiple cleanup tasks
|
||||
if c.dirtyips != nil {
|
||||
return
|
||||
}
|
||||
|
||||
lenBefore := len(c.ips)
|
||||
if lenBefore > c.highWatermark {
|
||||
c.highWatermark = lenBefore
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, domain := range expiredKeys {
|
||||
rec := c.ips[domain]
|
||||
if rec == nil {
|
||||
continue
|
||||
}
|
||||
if rec.A != nil && rec.A.Expire.Before(now) {
|
||||
rec.A = nil
|
||||
}
|
||||
if rec.AAAA != nil && rec.AAAA.Expire.Before(now) {
|
||||
rec.AAAA = nil
|
||||
}
|
||||
if rec.A == nil && rec.AAAA == nil {
|
||||
delete(c.ips, domain)
|
||||
}
|
||||
}
|
||||
|
||||
lenAfter := len(c.ips)
|
||||
|
||||
if lenAfter == 0 {
|
||||
if c.highWatermark >= minSizeForEmptyRebuild {
|
||||
errors.LogDebug(context.Background(), c.name,
|
||||
" rebuilding empty cache map to reclaim memory.",
|
||||
" size_before_cleanup=", lenBefore,
|
||||
" peak_size_before_rebuild=", c.highWatermark,
|
||||
)
|
||||
|
||||
c.ips = make(map[string]*record)
|
||||
c.highWatermark = 0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold &&
|
||||
float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold {
|
||||
errors.LogDebug(context.Background(), c.name,
|
||||
" shrinking cache map to reclaim memory.",
|
||||
" new_size=", lenAfter,
|
||||
" peak_size_before_shrink=", c.highWatermark,
|
||||
" reduction_since_peak=", reductionFromPeak,
|
||||
)
|
||||
|
||||
c.dirtyips = c.ips
|
||||
c.ips = make(map[string]*record, int(float64(lenAfter)*1.1))
|
||||
c.highWatermark = lenAfter
|
||||
go c.migrate()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type migrationEntry struct {
|
||||
key string
|
||||
value *record
|
||||
}
|
||||
|
||||
func (c *CacheController) migrate() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errors.LogError(context.Background(), c.name, " panic during cache migration: ", r)
|
||||
c.Lock()
|
||||
c.dirtyips = nil
|
||||
// c.ips = make(map[string]*record)
|
||||
// c.highWatermark = 0
|
||||
c.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
c.RLock()
|
||||
dirtyips := c.dirtyips
|
||||
c.RUnlock()
|
||||
|
||||
// double check to prevent upper call multiple cleanup tasks
|
||||
if dirtyips == nil {
|
||||
return
|
||||
}
|
||||
|
||||
errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items.")
|
||||
|
||||
batch := make([]migrationEntry, 0, migrationBatchSize)
|
||||
for domain, recD := range dirtyips {
|
||||
batch = append(batch, migrationEntry{domain, recD})
|
||||
|
||||
if len(batch) >= migrationBatchSize {
|
||||
c.flush(batch)
|
||||
batch = batch[:0]
|
||||
runtime.Gosched()
|
||||
}
|
||||
}
|
||||
if len(batch) > 0 {
|
||||
c.flush(batch)
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
rec, found := c.ips[req.domain]
|
||||
if !found {
|
||||
rec = &record{}
|
||||
}
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
rec.A = ipRec
|
||||
case dnsmessage.TypeAAAA:
|
||||
rec.AAAA = ipRec
|
||||
}
|
||||
|
||||
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)
|
||||
c.ips[req.domain] = rec
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
c.pub.Publish(req.domain+"4", nil)
|
||||
if !c.disableCache {
|
||||
_, _, err := rec.AAAA.getIPs()
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
c.pub.Publish(req.domain+"6", nil)
|
||||
}
|
||||
}
|
||||
case dnsmessage.TypeAAAA:
|
||||
c.pub.Publish(req.domain+"6", nil)
|
||||
if !c.disableCache {
|
||||
_, _, err := rec.A.getIPs()
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
c.pub.Publish(req.domain+"4", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.dirtyips = nil
|
||||
c.Unlock()
|
||||
|
||||
errors.LogDebug(context.Background(), c.name, " cache migration completed.")
|
||||
}
|
||||
|
||||
func (c *CacheController) flush(batch []migrationEntry) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
for _, dirty := range batch {
|
||||
if cur := c.ips[dirty.key]; cur != nil {
|
||||
merge := &record{}
|
||||
if cur.A == nil {
|
||||
merge.A = dirty.value.A
|
||||
} else {
|
||||
merge.A = cur.A
|
||||
}
|
||||
if cur.AAAA == nil {
|
||||
merge.AAAA = dirty.value.AAAA
|
||||
} else {
|
||||
merge.AAAA = cur.AAAA
|
||||
}
|
||||
c.ips[dirty.key] = merge
|
||||
} else {
|
||||
c.ips[dirty.key] = dirty.value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
|
||||
rtt := time.Since(req.start)
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
c.pub.Publish(req.domain+"4", rep)
|
||||
case dnsmessage.TypeAAAA:
|
||||
c.pub.Publish(req.domain+"6", rep)
|
||||
}
|
||||
|
||||
if c.disableCache {
|
||||
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt)
|
||||
return
|
||||
}
|
||||
|
||||
c.Lock()
|
||||
lockWait := time.Since(req.start) - rtt
|
||||
|
||||
newRec := &record{}
|
||||
oldRec := c.ips[req.domain]
|
||||
var dirtyRec *record
|
||||
if c.dirtyips != nil {
|
||||
dirtyRec = c.dirtyips[req.domain]
|
||||
}
|
||||
|
||||
var pubRecord *IPRecord
|
||||
var pubSuffix string
|
||||
|
||||
switch req.reqType {
|
||||
case dnsmessage.TypeA:
|
||||
newRec.A = rep
|
||||
if oldRec != nil && oldRec.AAAA != nil {
|
||||
newRec.AAAA = oldRec.AAAA
|
||||
pubRecord = oldRec.AAAA
|
||||
} else if dirtyRec != nil && dirtyRec.AAAA != nil {
|
||||
pubRecord = dirtyRec.AAAA
|
||||
}
|
||||
pubSuffix = "6"
|
||||
case dnsmessage.TypeAAAA:
|
||||
newRec.AAAA = rep
|
||||
if oldRec != nil && oldRec.A != nil {
|
||||
newRec.A = oldRec.A
|
||||
pubRecord = oldRec.A
|
||||
} else if dirtyRec != nil && dirtyRec.A != nil {
|
||||
pubRecord = dirtyRec.A
|
||||
}
|
||||
pubSuffix = "4"
|
||||
}
|
||||
|
||||
c.ips[req.domain] = newRec
|
||||
c.Unlock()
|
||||
|
||||
if pubRecord != nil {
|
||||
_, _ /*ttl*/, err := pubRecord.getIPs()
|
||||
if /*ttl >= 0 &&*/ !go_errors.Is(err, errRecordNotFound) {
|
||||
c.pub.Publish(req.domain+pubSuffix, pubRecord)
|
||||
}
|
||||
}
|
||||
|
||||
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)
|
||||
|
||||
common.Must(c.cacheCleanup.Start())
|
||||
}
|
||||
|
||||
func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
func (c *CacheController) findRecords(domain string) *record {
|
||||
c.RLock()
|
||||
record, found := c.ips[domain]
|
||||
c.RUnlock()
|
||||
defer c.RUnlock()
|
||||
|
||||
if !found {
|
||||
return nil, 0, errRecordNotFound
|
||||
rec := c.ips[domain]
|
||||
if rec == nil && c.dirtyips != nil {
|
||||
rec = c.dirtyips[domain]
|
||||
}
|
||||
|
||||
var errs []error
|
||||
var allIPs []net.IP
|
||||
var rTTL uint32 = dns_feature.DefaultTTL
|
||||
|
||||
mergeReq := option.IPv4Enable && option.IPv6Enable
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err := record.A.getIPs()
|
||||
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
||||
return ips, ttl, err
|
||||
}
|
||||
if ttl < rTTL {
|
||||
rTTL = ttl
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
allIPs = append(allIPs, ips...)
|
||||
} else {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ips, ttl, err := record.AAAA.getIPs()
|
||||
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
||||
return ips, ttl, err
|
||||
}
|
||||
if ttl < rTTL {
|
||||
rTTL = ttl
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
allIPs = append(allIPs, ips...)
|
||||
} else {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allIPs) > 0 {
|
||||
return allIPs, rTTL, nil
|
||||
}
|
||||
if go_errors.Is(errs[0], errs[1]) {
|
||||
return nil, rTTL, errs[0]
|
||||
}
|
||||
return nil, rTTL, errors.Combine(errs...)
|
||||
return rec
|
||||
}
|
||||
|
||||
func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// Fqdn normalizes domain make sure it ends with '.'
|
||||
// case-sensitive
|
||||
func Fqdn(domain string) string {
|
||||
if len(domain) > 0 && strings.HasSuffix(domain, ".") {
|
||||
return domain
|
||||
|
||||
149
app/dns/nameserver_cached.go
Normal file
149
app/dns/nameserver_cached.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/signal/pubsub"
|
||||
"github.com/xtls/xray-core/features/dns"
|
||||
)
|
||||
|
||||
type CachedNameserver interface {
|
||||
getCacheController() *CacheController
|
||||
|
||||
sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns.IPOption)
|
||||
}
|
||||
|
||||
// queryIP is called from dns.Server->queryIPTimeout
|
||||
func queryIP(ctx context.Context, s CachedNameserver, domain string, option dns.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
|
||||
cache := s.getCacheController()
|
||||
if !cache.disableCache {
|
||||
if rec := cache.findRecords(fqdn); rec != nil {
|
||||
ips, ttl, err := merge(option, rec.A, rec.AAAA)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
// errors.LogDebugInner(ctx, err, cache.name, " cache HIT ", fqdn, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: cache.name, Domain: fqdn, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", fqdn, " at ", cache.name)
|
||||
}
|
||||
|
||||
return fetch(ctx, s, fqdn, option)
|
||||
}
|
||||
|
||||
func fetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) ([]net.IP, uint32, error) {
|
||||
key := fqdn + "f"
|
||||
switch {
|
||||
case option.IPv4Enable && option.IPv6Enable:
|
||||
key = key + "46"
|
||||
case option.IPv4Enable:
|
||||
key = key + "4"
|
||||
case option.IPv6Enable:
|
||||
key = key + "6"
|
||||
}
|
||||
|
||||
v, _, _ := s.getCacheController().requestGroup.Do(key, func() (any, error) {
|
||||
return doFetch(ctx, s, fqdn, option), nil
|
||||
})
|
||||
ret := v.(result)
|
||||
|
||||
return ret.ips, ret.ttl, ret.error
|
||||
}
|
||||
|
||||
type result struct {
|
||||
ips []net.IP
|
||||
ttl uint32
|
||||
error
|
||||
}
|
||||
|
||||
func doFetch(ctx context.Context, s CachedNameserver, fqdn string, option dns.IPOption) result {
|
||||
sub4, sub6 := s.getCacheController().registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
onEvent := func(sub *pubsub.Subscriber) (*IPRecord, error) {
|
||||
if sub == nil {
|
||||
return nil, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, err
|
||||
case msg := <-sub.Wait():
|
||||
sub.Close()
|
||||
return msg.(*IPRecord), nil // should panic
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
|
||||
rec4, err4 := onEvent(sub4)
|
||||
rec6, err6 := onEvent(sub6)
|
||||
|
||||
var errs []error
|
||||
if err4 != nil {
|
||||
errs = append(errs, err4)
|
||||
}
|
||||
if err6 != nil {
|
||||
errs = append(errs, err6)
|
||||
}
|
||||
|
||||
ips, ttl, err := merge(option, rec4, rec6, errs...)
|
||||
log.Record(&log.DNSLog{Server: s.getCacheController().name, Domain: fqdn, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return result{ips, ttl, err}
|
||||
}
|
||||
|
||||
func merge(option dns.IPOption, rec4 *IPRecord, rec6 *IPRecord, errs ...error) ([]net.IP, uint32, error) {
|
||||
var allIPs []net.IP
|
||||
var rTTL uint32 = dns.DefaultTTL
|
||||
|
||||
mergeReq := option.IPv4Enable && option.IPv6Enable
|
||||
|
||||
if option.IPv4Enable {
|
||||
ips, ttl, err := rec4.getIPs() // it's safe
|
||||
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
||||
return ips, ttl, err
|
||||
}
|
||||
if ttl < rTTL {
|
||||
rTTL = ttl
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
allIPs = append(allIPs, ips...)
|
||||
} else {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if option.IPv6Enable {
|
||||
ips, ttl, err := rec6.getIPs() // it's safe
|
||||
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
|
||||
return ips, ttl, err
|
||||
}
|
||||
if ttl < rTTL {
|
||||
rTTL = ttl
|
||||
}
|
||||
if len(ips) > 0 {
|
||||
allIPs = append(allIPs, ips...)
|
||||
} else {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allIPs) > 0 {
|
||||
return allIPs, rTTL, nil
|
||||
}
|
||||
if len(errs) == 2 && go_errors.Is(errs[0], errs[1]) {
|
||||
return nil, rTTL, errs[0]
|
||||
}
|
||||
return nil, rTTL, errors.Combine(errs...)
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
go_errors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -121,10 +120,16 @@ func (s *DoHNameServer) newReqID() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying: ", domain)
|
||||
// getCacheController implements CachedNameserver.
|
||||
func (s *DoHNameServer) getCacheController() *CacheController {
|
||||
return s.cacheController
|
||||
}
|
||||
|
||||
if s.Name()+"." == "DOH//"+domain {
|
||||
// sendQuery implements CachedNameserver.
|
||||
func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying: ", fqdn)
|
||||
|
||||
if s.Name()+"." == "DOH//"+fqdn {
|
||||
errors.LogError(ctx, s.Name(), " tries to resolve itself! Use IP or set \"hosts\" instead.")
|
||||
noResponseErrCh <- errors.New("tries to resolve itself!", s.Name())
|
||||
return
|
||||
@@ -132,7 +137,7 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
|
||||
|
||||
// As we don't want our traffic pattern looks like DoH, we use Random-Length Padding instead of Block-Length Padding recommended in RFC 8467
|
||||
// Although DoH server like 1.1.1.1 will pad the response to Block-Length 468, at least it is better than no padding for response at all
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
|
||||
reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, int(crypto.RandBetween(100, 300))))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
@@ -166,23 +171,23 @@ func (s *DoHNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
|
||||
|
||||
b, err := dns.PackMessage(r.msg)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to pack dns query for ", domain)
|
||||
errors.LogErrorInner(ctx, err, "failed to pack dns query for ", fqdn)
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to retrieve response for ", domain)
|
||||
errors.LogErrorInner(ctx, err, "failed to retrieve response for ", fqdn)
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
rec, err := parseResponse(resp)
|
||||
if err != nil {
|
||||
errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", domain)
|
||||
errors.LogErrorInner(ctx, err, "failed to handle DOH response for ", fqdn)
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
s.cacheController.updateIP(r, rec)
|
||||
s.cacheController.updateRecord(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
@@ -216,49 +221,6 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
|
||||
}
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) { // nolint: dupl
|
||||
fqdn := Fqdn(domain)
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
return queryIP(ctx, s, domain, option)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
go_errors "errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -59,7 +58,7 @@ func NewQUICNameServer(url *url.URL, disableCache bool, clientIP net.IP) (*QUICN
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Name returns client name
|
||||
// Name implements Server.
|
||||
func (s *QUICNameServer) Name() string {
|
||||
return s.cacheController.name
|
||||
}
|
||||
@@ -68,10 +67,14 @@ func (s *QUICNameServer) newReqID() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying: ", domain)
|
||||
// getCacheController implements CachedNameServer.
|
||||
func (s *QUICNameServer) getCacheController() *CacheController { return s.cacheController }
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
// sendQuery implements CachedNameServer.
|
||||
func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
|
||||
errors.LogInfo(ctx, s.Name(), " querying: ", fqdn)
|
||||
|
||||
reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
@@ -167,57 +170,14 @@ func (s *QUICNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- e
|
||||
noResponseErrCh <- err
|
||||
return
|
||||
}
|
||||
s.cacheController.updateIP(r, rec)
|
||||
s.cacheController.updateRecord(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
// QueryIP is called from dns.Server->queryIPTimeout
|
||||
// QueryIP implements Server.
|
||||
func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
return queryIP(ctx, s, domain, option)
|
||||
}
|
||||
|
||||
func isActive(s *quic.Conn) bool {
|
||||
|
||||
@@ -4,14 +4,12 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
go_errors "errors"
|
||||
"net/url"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/net/cnc"
|
||||
"github.com/xtls/xray-core/common/protocol/dns"
|
||||
@@ -99,10 +97,16 @@ func (s *TCPNameServer) newReqID() uint16 {
|
||||
return uint16(atomic.AddUint32(&s.reqID, 1))
|
||||
}
|
||||
|
||||
func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
|
||||
// getCacheController implements CachedNameserver.
|
||||
func (s *TCPNameServer) getCacheController() *CacheController {
|
||||
return s.cacheController
|
||||
}
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
// sendQuery implements CachedNameserver.
|
||||
func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- error, fqdn string, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.Name(), " querying DNS for: ", fqdn)
|
||||
|
||||
reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
|
||||
var deadline time.Time
|
||||
if d, ok := ctx.Deadline(); ok {
|
||||
@@ -195,55 +199,12 @@ func (s *TCPNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<- er
|
||||
return
|
||||
}
|
||||
|
||||
s.cacheController.updateIP(r, rec)
|
||||
s.cacheController.updateRecord(r, rec)
|
||||
}(req)
|
||||
}
|
||||
}
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
return queryIP(ctx, s, domain, option)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
go_errors "errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -10,7 +9,6 @@ import (
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/log"
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol/dns"
|
||||
udp_proto "github.com/xtls/xray-core/common/protocol/udp"
|
||||
@@ -134,7 +132,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
||||
}
|
||||
}
|
||||
|
||||
s.cacheController.updateIP(&req.dnsRequest, ipRec)
|
||||
s.cacheController.updateRecord(&req.dnsRequest, ipRec)
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) newReqID() uint16 {
|
||||
@@ -150,10 +148,16 @@ func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
|
||||
common.Must(s.requestsCleanup.Start())
|
||||
}
|
||||
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domain string, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.Name(), " querying DNS for: ", domain)
|
||||
// getCacheController implements CachedNameserver.
|
||||
func (s *ClassicNameServer) getCacheController() *CacheController {
|
||||
return s.cacheController
|
||||
}
|
||||
|
||||
reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
// sendQuery implements CachedNameserver.
|
||||
func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, fqdn string, option dns_feature.IPOption) {
|
||||
errors.LogDebug(ctx, s.Name(), " querying DNS for: ", fqdn)
|
||||
|
||||
reqs := buildReqMsgs(fqdn, option, s.newReqID, genEDNS0Options(s.clientIP, 0))
|
||||
|
||||
for _, req := range reqs {
|
||||
udpReq := &udpDnsRequest{
|
||||
@@ -170,48 +174,5 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
|
||||
|
||||
// QueryIP implements Server.
|
||||
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, uint32, error) {
|
||||
fqdn := Fqdn(domain)
|
||||
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
|
||||
defer closeSubscribers(sub4, sub6)
|
||||
|
||||
if s.cacheController.disableCache {
|
||||
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
|
||||
} else {
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
if !go_errors.Is(err, errRecordNotFound) {
|
||||
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})
|
||||
return ips, ttl, err
|
||||
}
|
||||
}
|
||||
|
||||
noResponseErrCh := make(chan error, 2)
|
||||
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
|
||||
start := time.Now()
|
||||
|
||||
if sub4 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub4.Wait():
|
||||
sub4.Close()
|
||||
}
|
||||
}
|
||||
if sub6 != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, 0, ctx.Err()
|
||||
case err := <-noResponseErrCh:
|
||||
return nil, 0, err
|
||||
case <-sub6.Wait():
|
||||
sub6.Close()
|
||||
}
|
||||
}
|
||||
|
||||
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
|
||||
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
|
||||
return ips, ttl, err
|
||||
|
||||
return queryIP(ctx, s, domain, option)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user