diff --git a/common/collect/timed_queue.go b/common/collect/timed_queue.go index c8e16b9c..61791d86 100644 --- a/common/collect/timed_queue.go +++ b/common/collect/timed_queue.go @@ -41,55 +41,63 @@ func (queue *timedQueueImpl) Pop() interface{} { // TimedQueue is a priority queue that entries with oldest timestamp get removed first. type TimedQueue struct { - queue timedQueueImpl - access sync.RWMutex - removed chan interface{} + queue timedQueueImpl + access sync.Mutex + removedCallback func(interface{}) } -func NewTimedQueue(updateInterval int) *TimedQueue { +func NewTimedQueue(updateInterval int, removedCallback func(interface{})) *TimedQueue { queue := &TimedQueue{ - queue: make([]*timedQueueEntry, 0, 256), - removed: make(chan interface{}, 16), - access: sync.RWMutex{}, + queue: make([]*timedQueueEntry, 0, 256), + removedCallback: removedCallback, + access: sync.Mutex{}, } go queue.cleanup(time.Tick(time.Duration(updateInterval) * time.Second)) return queue } func (queue *TimedQueue) Add(value interface{}, time2Remove int64) { - queue.access.Lock() - heap.Push(&queue.queue, &timedQueueEntry{ + newEntry := &timedQueueEntry{ timeSec: time2Remove, value: value, - }) + } + var removedEntry *timedQueueEntry + queue.access.Lock() + nowSec := time.Now().Unix() + if queue.queue.Len() > 0 && queue.queue[0].timeSec < nowSec { + removedEntry = queue.queue[0] + queue.queue[0] = newEntry + heap.Fix(&queue.queue, 0) + } else { + heap.Push(&queue.queue, newEntry) + } queue.access.Unlock() -} - -func (queue *TimedQueue) RemovedEntries() <-chan interface{} { - return queue.removed + if removedEntry != nil { + queue.removedCallback(removedEntry) + } } func (queue *TimedQueue) cleanup(tick <-chan time.Time) { for now := range tick { nowSec := now.Unix() - for { - queue.access.RLock() - queueLen := queue.queue.Len() - queue.access.RUnlock() - if queueLen == 0 { - break - } - queue.access.RLock() - entry := queue.queue[0] - queue.access.RUnlock() - if entry.timeSec > nowSec { - break + removedEntries := make([]*timedQueueEntry, 0, 128) + queue.access.Lock() + changed := false + for i := 0; i < queue.queue.Len(); i++ { + entry := queue.queue[i] + if entry.timeSec < nowSec { + removedEntries = append(removedEntries, entry) + queue.queue.Swap(i, queue.queue.Len()-1) + queue.queue.Pop() + changed = true } - queue.access.Lock() - heap.Pop(&queue.queue) - queue.access.Unlock() - - queue.removed <- entry.value + } + if changed { + heap.Init(&queue.queue) + } + queue.access.Unlock() + for _, entry := range removedEntries { + queue.removedCallback(entry.value) } } } diff --git a/common/collect/timed_queue_test.go b/common/collect/timed_queue_test.go index 8ab9e764..940f64e1 100644 --- a/common/collect/timed_queue_test.go +++ b/common/collect/timed_queue_test.go @@ -14,14 +14,9 @@ func TestTimedQueue(t *testing.T) { removed := make(map[string]bool) nowSec := time.Now().Unix() - q := NewTimedQueue(2) - - go func() { - for { - entry := <-q.RemovedEntries() - removed[entry.(string)] = true - } - }() + q := NewTimedQueue(2, func(v interface{}) { + removed[v.(string)] = true + }) q.Add("Value1", nowSec) q.Add("Value2", nowSec+5) diff --git a/proxy/vmess/protocol/user/userset.go b/proxy/vmess/protocol/user/userset.go index 50207458..333dfa5d 100644 --- a/proxy/vmess/protocol/user/userset.go +++ b/proxy/vmess/protocol/user/userset.go @@ -32,22 +32,19 @@ type indexTimePair struct { func NewTimedUserSet() UserSet { tus := &TimedUserSet{ - validUsers: make([]vmess.User, 0, 16), - userHash: make(map[string]indexTimePair, 512), - userHashDeleteQueue: collect.NewTimedQueue(updateIntervalSec), - access: sync.RWMutex{}, + validUsers: make([]vmess.User, 0, 16), + userHash: make(map[string]indexTimePair, 512), + access: sync.RWMutex{}, } + tus.userHashDeleteQueue = collect.NewTimedQueue(updateIntervalSec, tus.removeEntry) go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second)) - go tus.removeEntries(tus.userHashDeleteQueue.RemovedEntries()) return tus } -func (us *TimedUserSet) removeEntries(entries <-chan interface{}) { - for entry := range entries { - us.access.Lock() - delete(us.userHash, entry.(string)) - us.access.Unlock() - } +func (us *TimedUserSet) removeEntry(entry interface{}) { + us.access.Lock() + delete(us.userHash, entry.(string)) + us.access.Unlock() } func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id *vmess.ID) {