diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 39ee26a9..fa0a4da0 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -20,6 +20,7 @@ import ( "github.com/xtls/reality/hpke" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/utils" "github.com/xtls/xray-core/transport/internet" "golang.org/x/crypto/cryptobyte" ) @@ -77,24 +78,31 @@ func ApplyECH(c *Config, config *tls.Config) error { } type ECHConfigCache struct { - echConfig atomic.Pointer[[]byte] - expire atomic.Pointer[time.Time] + configRecord atomic.Pointer[echConfigRecord] // updateLock is not for preventing concurrent read/write, but for preventing concurrent update UpdateLock sync.Mutex } +type echConfigRecord struct { + config []byte + expire time.Time +} + +var GlobalECHConfigCache = utils.NewTypedSyncMap[string, *ECHConfigCache]() + // Update updates the ECH config for given domain and server. // this method is concurrent safe, only one update request will be sent, others get the cache. -// if isLockedUpdate is true, it means pass the lock to this function, it will release the lock after update. +// if isLockedUpdate is true, it will not try to acquire the lock. func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) { if !isLockedUpdate { - c.UpdateLock.Lock() + c.UpdateLock.Lock() + defer c.UpdateLock.Unlock() } - defer c.UpdateLock.Unlock() // Double check cache after acquiring lock - if c.expire.Load().After(time.Now()) { + configRecord := c.configRecord.Load() + if configRecord.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain) - return *c.echConfig.Load(), nil + return configRecord.config, nil } // Query ECH config from DNS server errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) @@ -102,50 +110,43 @@ func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate boo if err != nil { return nil, err } - c.echConfig.Store(&echConfig) - expire := time.Now().Add(time.Duration(ttl) * time.Second) - c.expire.Store(&expire) - return *c.echConfig.Load(), nil + configRecord = &echConfigRecord{ + config: echConfig, + expire: time.Now().Add(time.Duration(ttl) * time.Second), + } + c.configRecord.Store(configRecord) + return configRecord.config, nil } -var ( - GlobalECHConfigCache map[string]*ECHConfigCache - GlobalECHConfigCacheAccess sync.Mutex -) - // QueryRecord returns the ECH config for given domain. // If the record is not in cache or expired, it will query the DNS server and update the cache. func QueryRecord(domain string, server string) ([]byte, error) { - // Global cache init - GlobalECHConfigCacheAccess.Lock() - if GlobalECHConfigCache == nil { - GlobalECHConfigCache = make(map[string]*ECHConfigCache) - } - - echConfigCache := GlobalECHConfigCache[domain] - if echConfigCache == nil { + echConfigCache, ok := GlobalECHConfigCache.Load(domain) + if !ok { echConfigCache = &ECHConfigCache{} - echConfigCache.expire.Store(&time.Time{}) // zero value means initial state - GlobalECHConfigCache[domain] = echConfigCache + echConfigCache.configRecord.Store(&echConfigRecord{}) + echConfigCache, _ = GlobalECHConfigCache.LoadOrStore(domain, echConfigCache) } - if echConfigCache != nil && echConfigCache.expire.Load().After(time.Now()) { + configRecord := echConfigCache.configRecord.Load() + if configRecord.expire.After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) - GlobalECHConfigCacheAccess.Unlock() - return *echConfigCache.echConfig.Load(), nil + return configRecord.config, nil } - GlobalECHConfigCacheAccess.Unlock() // If expire is zero value, it means we are in initial state, wait for the query to finish // otherwise return old value immediately and update in a goroutine // but if the cache is too old, wait for update - if *echConfigCache.expire.Load() == (time.Time{}) || echConfigCache.expire.Load().Add(time.Hour*6).Before(time.Now()) { + if configRecord.expire == (time.Time{}) || configRecord.expire.Add(time.Hour*6).Before(time.Now()) { return echConfigCache.Update(domain, server, false) } else { // If someone already acquired the lock, it means it is updating, do not start another update goroutine if echConfigCache.UpdateLock.TryLock() { - go echConfigCache.Update(domain, server, true) + go func() { + defer echConfigCache.UpdateLock.Unlock() + echConfigCache.Update(domain, server, true) + }() } - return *echConfigCache.echConfig.Load(), nil + return configRecord.config, nil } }