mirror of https://github.com/XTLS/Xray-core
340 lines
7.5 KiB
Go
340 lines
7.5 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
go_errors "errors"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/xtls/xray-core/common"
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"github.com/xtls/xray-core/common/signal/pubsub"
|
|
"github.com/xtls/xray-core/common/task"
|
|
dns_feature "github.com/xtls/xray-core/features/dns"
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
"golang.org/x/sync/singleflight"
|
|
)
|
|
|
|
const (
|
|
minSizeForEmptyRebuild = 512
|
|
shrinkAbsoluteThreshold = 10240
|
|
shrinkRatioThreshold = 0.65
|
|
migrationBatchSize = 4096
|
|
)
|
|
|
|
type CacheController struct {
|
|
name string
|
|
disableCache bool
|
|
serveStale bool
|
|
serveExpiredTTL int32
|
|
|
|
ips map[string]*record
|
|
dirtyips map[string]*record
|
|
|
|
sync.RWMutex
|
|
pub *pubsub.Service
|
|
cacheCleanup *task.Periodic
|
|
highWatermark int
|
|
requestGroup singleflight.Group
|
|
}
|
|
|
|
func NewCacheController(name string, disableCache bool, serveStale bool, serveExpiredTTL uint32) *CacheController {
|
|
c := &CacheController{
|
|
name: name,
|
|
disableCache: disableCache,
|
|
serveStale: serveStale,
|
|
serveExpiredTTL: -int32(serveExpiredTTL),
|
|
ips: make(map[string]*record),
|
|
pub: pubsub.NewService(),
|
|
}
|
|
|
|
c.cacheCleanup = &task.Periodic{
|
|
Interval: 300 * time.Second,
|
|
Execute: c.CacheCleanup,
|
|
}
|
|
return c
|
|
}
|
|
|
|
// CacheCleanup clears expired items from cache
|
|
func (c *CacheController) CacheCleanup() error {
|
|
expiredKeys, err := c.collectExpiredKeys()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(expiredKeys) == 0 {
|
|
return nil
|
|
}
|
|
c.writeAndShrink(expiredKeys)
|
|
return nil
|
|
}
|
|
|
|
func (c *CacheController) collectExpiredKeys() ([]string, error) {
|
|
c.RLock()
|
|
defer c.RUnlock()
|
|
|
|
if len(c.ips) == 0 {
|
|
return nil, errors.New("nothing to do. stopping...")
|
|
}
|
|
|
|
// skip collection if a migration is in progress
|
|
if c.dirtyips != nil {
|
|
return nil, nil
|
|
}
|
|
|
|
now := time.Now()
|
|
if c.serveStale && c.serveExpiredTTL != 0 {
|
|
now = now.Add(time.Duration(c.serveExpiredTTL) * time.Second)
|
|
}
|
|
|
|
expiredKeys := make([]string, 0, len(c.ips)/4) // pre-allocate
|
|
|
|
for domain, rec := range c.ips {
|
|
if (rec.A != nil && rec.A.Expire.Before(now)) ||
|
|
(rec.AAAA != nil && rec.AAAA.Expire.Before(now)) {
|
|
expiredKeys = append(expiredKeys, domain)
|
|
}
|
|
}
|
|
|
|
return expiredKeys, nil
|
|
}
|
|
|
|
func (c *CacheController) writeAndShrink(expiredKeys []string) {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
// double check to prevent upper call multiple cleanup tasks
|
|
if c.dirtyips != nil {
|
|
return
|
|
}
|
|
|
|
lenBefore := len(c.ips)
|
|
if lenBefore > c.highWatermark {
|
|
c.highWatermark = lenBefore
|
|
}
|
|
|
|
now := time.Now()
|
|
if c.serveStale && c.serveExpiredTTL != 0 {
|
|
now = now.Add(time.Duration(c.serveExpiredTTL) * time.Second)
|
|
}
|
|
|
|
for _, domain := range expiredKeys {
|
|
rec := c.ips[domain]
|
|
if rec == nil {
|
|
continue
|
|
}
|
|
if rec.A != nil && rec.A.Expire.Before(now) {
|
|
rec.A = nil
|
|
}
|
|
if rec.AAAA != nil && rec.AAAA.Expire.Before(now) {
|
|
rec.AAAA = nil
|
|
}
|
|
if rec.A == nil && rec.AAAA == nil {
|
|
delete(c.ips, domain)
|
|
}
|
|
}
|
|
|
|
lenAfter := len(c.ips)
|
|
|
|
if lenAfter == 0 {
|
|
if c.highWatermark >= minSizeForEmptyRebuild {
|
|
errors.LogDebug(context.Background(), c.name,
|
|
" rebuilding empty cache map to reclaim memory.",
|
|
" size_before_cleanup=", lenBefore,
|
|
" peak_size_before_rebuild=", c.highWatermark,
|
|
)
|
|
|
|
c.ips = make(map[string]*record)
|
|
c.highWatermark = 0
|
|
}
|
|
return
|
|
}
|
|
|
|
if reductionFromPeak := c.highWatermark - lenAfter; reductionFromPeak > shrinkAbsoluteThreshold &&
|
|
float64(reductionFromPeak) > float64(c.highWatermark)*shrinkRatioThreshold {
|
|
errors.LogDebug(context.Background(), c.name,
|
|
" shrinking cache map to reclaim memory.",
|
|
" new_size=", lenAfter,
|
|
" peak_size_before_shrink=", c.highWatermark,
|
|
" reduction_since_peak=", reductionFromPeak,
|
|
)
|
|
|
|
c.dirtyips = c.ips
|
|
c.ips = make(map[string]*record, int(float64(lenAfter)*1.1))
|
|
c.highWatermark = lenAfter
|
|
go c.migrate()
|
|
}
|
|
|
|
}
|
|
|
|
type migrationEntry struct {
|
|
key string
|
|
value *record
|
|
}
|
|
|
|
func (c *CacheController) migrate() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
errors.LogError(context.Background(), c.name, " panic during cache migration: ", r)
|
|
c.Lock()
|
|
c.dirtyips = nil
|
|
// c.ips = make(map[string]*record)
|
|
// c.highWatermark = 0
|
|
c.Unlock()
|
|
}
|
|
}()
|
|
|
|
c.RLock()
|
|
dirtyips := c.dirtyips
|
|
c.RUnlock()
|
|
|
|
// double check to prevent upper call multiple cleanup tasks
|
|
if dirtyips == nil {
|
|
return
|
|
}
|
|
|
|
errors.LogDebug(context.Background(), c.name, " starting background cache migration for ", len(dirtyips), " items")
|
|
|
|
batch := make([]migrationEntry, 0, migrationBatchSize)
|
|
for domain, recD := range dirtyips {
|
|
batch = append(batch, migrationEntry{domain, recD})
|
|
|
|
if len(batch) >= migrationBatchSize {
|
|
c.flush(batch)
|
|
batch = batch[:0]
|
|
runtime.Gosched()
|
|
}
|
|
}
|
|
if len(batch) > 0 {
|
|
c.flush(batch)
|
|
}
|
|
|
|
c.Lock()
|
|
c.dirtyips = nil
|
|
c.Unlock()
|
|
|
|
errors.LogDebug(context.Background(), c.name, " cache migration completed")
|
|
}
|
|
|
|
func (c *CacheController) flush(batch []migrationEntry) {
|
|
c.Lock()
|
|
defer c.Unlock()
|
|
|
|
for _, dirty := range batch {
|
|
if cur := c.ips[dirty.key]; cur != nil {
|
|
merge := &record{}
|
|
if cur.A == nil {
|
|
merge.A = dirty.value.A
|
|
} else {
|
|
merge.A = cur.A
|
|
}
|
|
if cur.AAAA == nil {
|
|
merge.AAAA = dirty.value.AAAA
|
|
} else {
|
|
merge.AAAA = cur.AAAA
|
|
}
|
|
c.ips[dirty.key] = merge
|
|
} else {
|
|
c.ips[dirty.key] = dirty.value
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *CacheController) updateRecord(req *dnsRequest, rep *IPRecord) {
|
|
rtt := time.Since(req.start)
|
|
|
|
switch req.reqType {
|
|
case dnsmessage.TypeA:
|
|
c.pub.Publish(req.domain+"4", rep)
|
|
case dnsmessage.TypeAAAA:
|
|
c.pub.Publish(req.domain+"6", rep)
|
|
}
|
|
|
|
if c.disableCache {
|
|
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt)
|
|
return
|
|
}
|
|
|
|
c.Lock()
|
|
lockWait := time.Since(req.start) - rtt
|
|
|
|
newRec := &record{}
|
|
oldRec := c.ips[req.domain]
|
|
var dirtyRec *record
|
|
if c.dirtyips != nil {
|
|
dirtyRec = c.dirtyips[req.domain]
|
|
}
|
|
|
|
var pubRecord *IPRecord
|
|
var pubSuffix string
|
|
|
|
switch req.reqType {
|
|
case dnsmessage.TypeA:
|
|
newRec.A = rep
|
|
if oldRec != nil && oldRec.AAAA != nil {
|
|
newRec.AAAA = oldRec.AAAA
|
|
pubRecord = oldRec.AAAA
|
|
} else if dirtyRec != nil && dirtyRec.AAAA != nil {
|
|
pubRecord = dirtyRec.AAAA
|
|
}
|
|
pubSuffix = "6"
|
|
case dnsmessage.TypeAAAA:
|
|
newRec.AAAA = rep
|
|
if oldRec != nil && oldRec.A != nil {
|
|
newRec.A = oldRec.A
|
|
pubRecord = oldRec.A
|
|
} else if dirtyRec != nil && dirtyRec.A != nil {
|
|
pubRecord = dirtyRec.A
|
|
}
|
|
pubSuffix = "4"
|
|
}
|
|
|
|
c.ips[req.domain] = newRec
|
|
c.Unlock()
|
|
|
|
if pubRecord != nil {
|
|
_, ttl, err := pubRecord.getIPs()
|
|
if ttl > 0 && !go_errors.Is(err, errRecordNotFound) {
|
|
c.pub.Publish(req.domain+pubSuffix, pubRecord)
|
|
}
|
|
}
|
|
|
|
errors.LogInfo(context.Background(), c.name, " got answer: ", req.domain, " ", req.reqType, " -> ", rep.IP, ", rtt: ", rtt, ", lock: ", lockWait)
|
|
|
|
if !c.serveStale || c.serveExpiredTTL != 0 {
|
|
common.Must(c.cacheCleanup.Start())
|
|
}
|
|
}
|
|
|
|
func (c *CacheController) findRecords(domain string) *record {
|
|
c.RLock()
|
|
defer c.RUnlock()
|
|
|
|
rec := c.ips[domain]
|
|
if rec == nil && c.dirtyips != nil {
|
|
rec = c.dirtyips[domain]
|
|
}
|
|
return rec
|
|
}
|
|
|
|
func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
|
|
// ipv4 and ipv6 belong to different subscription groups
|
|
if option.IPv4Enable {
|
|
sub4 = c.pub.Subscribe(domain + "4")
|
|
}
|
|
if option.IPv6Enable {
|
|
sub6 = c.pub.Subscribe(domain + "6")
|
|
}
|
|
return
|
|
}
|
|
|
|
func closeSubscribers(sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {
|
|
if sub4 != nil {
|
|
sub4.Close()
|
|
}
|
|
if sub6 != nil {
|
|
sub6.Close()
|
|
}
|
|
}
|