From bc8ba765b03606aee19b9ff2ffb30f3081611ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=8E=E6=89=87=E6=BB=91=E7=BF=94=E7=BF=BC?= Date: Sat, 26 Jul 2025 06:22:42 +0000 Subject: [PATCH] Bug fix and refine --- transport/internet/tls/ech.go | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 93cce476..39ee26a9 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -80,12 +80,17 @@ type ECHConfigCache struct { echConfig atomic.Pointer[[]byte] expire atomic.Pointer[time.Time] // updateLock is not for preventing concurrent read/write, but for preventing concurrent update - updateLock sync.Mutex + UpdateLock sync.Mutex } -func (c *ECHConfigCache) update(domain string, server string) ([]byte, error) { - c.updateLock.Lock() - defer c.updateLock.Unlock() +// Update updates the ECH config for given domain and server. +// this method is concurrent safe, only one update request will be sent, others get the cache. +// if isLockedUpdate is true, it means pass the lock to this function, it will release the lock after update. +func (c *ECHConfigCache) Update(domain string, server string, isLockedUpdate bool) ([]byte, error) { + if !isLockedUpdate { + c.UpdateLock.Lock() + } + defer c.UpdateLock.Unlock() // Double check cache after acquiring lock if c.expire.Load().After(time.Now()) { errors.LogDebug(context.Background(), "Cache hit for domain after double check: ", domain) @@ -132,12 +137,13 @@ func QueryRecord(domain string, server string) ([]byte, error) { // If expire is zero value, it means we are in initial state, wait for the query to finish // otherwise return old value immediately and update in a goroutine - if *echConfigCache.expire.Load() == (time.Time{}) { - return echConfigCache.update(domain, server) + // but if the cache is too old, wait for update + if *echConfigCache.expire.Load() == (time.Time{}) || echConfigCache.expire.Load().Add(time.Hour*6).Before(time.Now()) { + return echConfigCache.Update(domain, server, false) } else { // If someone already acquired the lock, it means it is updating, do not start another update goroutine - if echConfigCache.updateLock.TryLock() { - go echConfigCache.update(domain, server) + if echConfigCache.UpdateLock.TryLock() { + go echConfigCache.Update(domain, server, true) } return *echConfigCache.echConfig.Load(), nil } @@ -210,7 +216,12 @@ func dnsQuery(server string, domain string) ([]byte, uint32, error) { defer cancel() // use xray's internet.DialSystem as mentioned above conn, err := internet.DialSystem(dnsTimeoutCtx, dest, nil) - defer conn.Close() + defer func() { + err := conn.Close() + if err != nil { + errors.LogDebug(context.Background(), "Failed to close connection: ", err) + } + }() if err != nil { return []byte{}, 0, err }