diff --git a/common/common.go b/common/common.go index c3bfa944..f0134243 100644 --- a/common/common.go +++ b/common/common.go @@ -7,6 +7,7 @@ import ( "go/build" "os" "path/filepath" + "reflect" "strings" "github.com/xtls/xray-core/common/errors" @@ -153,3 +154,14 @@ func GetModuleName(pathToProjectRoot string) (string, error) { } return moduleName, fmt.Errorf("no `go.mod` file in every parent directory of `%s`", pathToProjectRoot) } + +// CloseIfExists call obj.Close() if obj is not nil. +func CloseIfExists(obj any) error { + if obj != nil { + v := reflect.ValueOf(obj) + if !v.IsNil() { + return Close(obj) + } + } + return nil +} diff --git a/common/signal/timer.go b/common/signal/timer.go index ece9f496..d5b35605 100644 --- a/common/signal/timer.go +++ b/common/signal/timer.go @@ -3,6 +3,7 @@ package signal import ( "context" "sync" + "sync/atomic" "time" "github.com/xtls/xray-core/common" @@ -14,10 +15,12 @@ type ActivityUpdater interface { } type ActivityTimer struct { - sync.RWMutex + mu sync.RWMutex updated chan struct{} checkTask *task.Periodic onTimeout func() + consumed atomic.Bool + once sync.Once } func (t *ActivityTimer) Update() { @@ -37,39 +40,39 @@ func (t *ActivityTimer) check() error { } func (t *ActivityTimer) finish() { - t.Lock() - defer t.Unlock() + t.once.Do(func() { + t.consumed.Store(true) + t.mu.Lock() + defer t.mu.Unlock() - if t.onTimeout != nil { + common.CloseIfExists(t.checkTask) t.onTimeout() - t.onTimeout = nil - } - if t.checkTask != nil { - t.checkTask.Close() - t.checkTask = nil - } + }) } func (t *ActivityTimer) SetTimeout(timeout time.Duration) { + if t.consumed.Load() { + return + } if timeout == 0 { t.finish() return } - checkTask := &task.Periodic{ + t.mu.Lock() + defer t.mu.Unlock() + // double check, just in case + if t.consumed.Load() { + return + } + newCheckTask := &task.Periodic{ Interval: timeout, Execute: t.check, } - - t.Lock() - - if t.checkTask != nil { - t.checkTask.Close() - } - t.checkTask = checkTask + common.CloseIfExists(t.checkTask) + t.checkTask = newCheckTask t.Update() - common.Must(checkTask.Start()) - t.Unlock() + common.Must(newCheckTask.Start()) } func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {