dont start periodic task until necessary

pull/1269/head
Darien Raymond 2018-08-29 23:00:01 +02:00
parent 5a0a9aa65e
commit eb05a92592
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
10 changed files with 81 additions and 162 deletions

View File

@ -81,12 +81,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
} }
} }
if domainMatcher.Size() > 64 {
server.domainMatcher = strmatcher.NewCachedMatcherGroup(domainMatcher)
} else {
server.domainMatcher = domainMatcher server.domainMatcher = domainMatcher
}
server.domainIndexMap = domainIndexMap server.domainIndexMap = domainIndexMap
} }

View File

@ -59,13 +59,18 @@ func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher, c
Execute: s.Cleanup, Execute: s.Cleanup,
} }
s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
common.Must(s.cleanup.Start())
return s return s
} }
func (s *ClassicNameServer) Cleanup() error { func (s *ClassicNameServer) Cleanup() error {
now := time.Now() now := time.Now()
s.Lock() s.Lock()
defer s.Unlock()
if len(s.ips) == 0 && len(s.requests) == 0 {
return newError("nothing to do. stopping...")
}
for domain, ips := range s.ips { for domain, ips := range s.ips {
newIPs := make([]IPRecord, 0, len(ips)) newIPs := make([]IPRecord, 0, len(ips))
for _, ip := range ips { for _, ip := range ips {
@ -94,7 +99,6 @@ func (s *ClassicNameServer) Cleanup() error {
s.requests = make(map[uint16]pendingRequest) s.requests = make(map[uint16]pendingRequest)
} }
s.Unlock()
return nil return nil
} }
@ -151,7 +155,6 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, payload *buf.Buf
func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) { func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
s.Lock() s.Lock()
defer s.Unlock()
newError("updating IP records for domain:", domain).AtDebug().WriteToLog() newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
now := time.Now() now := time.Now()
@ -163,6 +166,9 @@ func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
} }
s.ips[domain] = ips s.ips[domain] = ips
s.pub.Publish(domain, nil) s.pub.Publish(domain, nil)
s.Unlock()
common.Must(s.cleanup.Start())
} }
func (s *ClassicNameServer) getMsgOptions() *dns.OPT { func (s *ClassicNameServer) getMsgOptions() *dns.OPT {

View File

@ -289,6 +289,8 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest
} }
if !existing { if !existing {
common.Must(w.checker.Start())
go func() { go func() {
ctx := context.Background() ctx := context.Background()
sid := session.NewID() sid := session.NewID()
@ -324,21 +326,13 @@ func (w *udpWorker) handlePackets() {
} }
} }
func (w *udpWorker) Start() error { func (w *udpWorker) clean() error {
w.activeConn = make(map[connID]*udpConn, 16)
h, err := udp.ListenUDP(w.address, w.port, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256))
if err != nil {
return err
}
w.checker = &task.Periodic{
Interval: time.Second * 16,
Execute: func() error {
nowSec := time.Now().Unix() nowSec := time.Now().Unix()
w.Lock() w.Lock()
defer w.Unlock() defer w.Unlock()
if len(w.activeConn) == 0 { if len(w.activeConn) == 0 {
return nil return newError("no more connections. stopping...")
} }
for addr, conn := range w.activeConn { for addr, conn := range w.activeConn {
@ -353,11 +347,20 @@ func (w *udpWorker) Start() error {
} }
return nil return nil
}, }
}
if err := w.checker.Start(); err != nil { func (w *udpWorker) Start() error {
w.activeConn = make(map[connID]*udpConn, 16)
h, err := udp.ListenUDP(w.address, w.port, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256))
if err != nil {
return err return err
} }
w.checker = &task.Periodic{
Interval: time.Second * 16,
Execute: w.clean,
}
w.hub = h w.hub = h
go w.handlePackets() go w.handlePackets()
return nil return nil

View File

@ -100,15 +100,9 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
g.Add(m) g.Add(m)
} }
if len(domains) < 64 {
return &DomainMatcher{ return &DomainMatcher{
matchers: g, matchers: g,
}, nil }, nil
}
return &DomainMatcher{
matchers: strmatcher.NewCachedMatcherGroup(g),
}, nil
} }
func (m *DomainMatcher) ApplyDomain(domain string) bool { func (m *DomainMatcher) ApplyDomain(domain string) bool {

View File

@ -1,6 +1,7 @@
package pubsub package pubsub
import ( import (
"errors"
"sync" "sync"
"time" "time"
@ -47,7 +48,6 @@ func NewService() *Service {
Execute: s.Cleanup, Execute: s.Cleanup,
Interval: time.Second * 30, Interval: time.Second * 30,
} }
common.Must(s.ctask.Start())
return s return s
} }
@ -57,6 +57,10 @@ func (s *Service) Cleanup() error {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
if len(s.subs) == 0 {
return errors.New("nothing to do")
}
for name, subs := range s.subs { for name, subs := range s.subs {
newSub := make([]*Subscriber, 0, len(s.subs)) newSub := make([]*Subscriber, 0, len(s.subs))
for _, sub := range subs { for _, sub := range subs {
@ -86,6 +90,7 @@ func (s *Service) Subscribe(name string) *Subscriber {
subs := append(s.subs[name], sub) subs := append(s.subs[name], sub)
s.subs[name] = subs s.subs[name] = subs
s.Unlock() s.Unlock()
common.Must(s.ctask.Start())
return sub return sub
} }

View File

@ -47,20 +47,3 @@ func BenchmarkMarchGroup(b *testing.B) {
_ = g.Match("0.v2ray.com") _ = g.Match("0.v2ray.com")
} }
} }
func BenchmarkCachedMarchGroup(b *testing.B) {
g := new(MatcherGroup)
for i := 1; i <= 1024; i++ {
m, err := Domain.New(strconv.Itoa(i) + ".v2ray.com")
common.Must(err)
g.Add(m)
}
cg := NewCachedMatcherGroup(g)
_ = cg.Match("0.v2ray.com")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = cg.Match("0.v2ray.com")
}
}

View File

@ -2,11 +2,6 @@ package strmatcher
import ( import (
"regexp" "regexp"
"sync"
"time"
"v2ray.com/core/common"
"v2ray.com/core/common/task"
) )
// Matcher is the interface to determine a string matches a pattern. // Matcher is the interface to determine a string matches a pattern.
@ -114,71 +109,3 @@ func (g *MatcherGroup) Match(pattern string) uint32 {
func (g *MatcherGroup) Size() uint32 { func (g *MatcherGroup) Size() uint32 {
return g.count return g.count
} }
type cacheEntry struct {
timestamp time.Time
result uint32
}
// CachedMatcherGroup is a IndexMatcher with cachable results.
type CachedMatcherGroup struct {
sync.RWMutex
group *MatcherGroup
cache map[string]cacheEntry
cleanup *task.Periodic
}
// NewCachedMatcherGroup creats a new CachedMatcherGroup.
func NewCachedMatcherGroup(g *MatcherGroup) *CachedMatcherGroup {
r := &CachedMatcherGroup{
group: g,
cache: make(map[string]cacheEntry),
}
r.cleanup = &task.Periodic{
Interval: time.Second * 30,
Execute: func() error {
r.Lock()
defer r.Unlock()
if len(r.cache) == 0 {
return nil
}
expire := time.Now().Add(-1 * time.Second * 120)
for p, e := range r.cache {
if e.timestamp.Before(expire) {
delete(r.cache, p)
}
}
if len(r.cache) == 0 {
r.cache = make(map[string]cacheEntry)
}
return nil
},
}
common.Must(r.cleanup.Start())
return r
}
// Match implements IndexMatcher.Match.
func (g *CachedMatcherGroup) Match(pattern string) uint32 {
g.RLock()
r, f := g.cache[pattern]
g.RUnlock()
if f {
return r.result
}
mr := g.group.Match(pattern)
g.Lock()
g.cache[pattern] = cacheEntry{
result: mr,
timestamp: time.Now(),
}
g.Unlock()
return mr
}

View File

@ -11,25 +11,17 @@ type Periodic struct {
Interval time.Duration Interval time.Duration
// Execute is the task function // Execute is the task function
Execute func() error Execute func() error
// OnFailure will be called when Execute returns non-nil error
OnError func(error)
access sync.RWMutex access sync.Mutex
timer *time.Timer timer *time.Timer
closed bool running bool
}
func (t *Periodic) setClosed(f bool) {
t.access.Lock()
t.closed = f
t.access.Unlock()
} }
func (t *Periodic) hasClosed() bool { func (t *Periodic) hasClosed() bool {
t.access.RLock() t.access.Lock()
defer t.access.RUnlock() defer t.access.Unlock()
return t.closed return !t.running
} }
func (t *Periodic) checkedExecute() error { func (t *Periodic) checkedExecute() error {
@ -38,31 +30,39 @@ func (t *Periodic) checkedExecute() error {
} }
if err := t.Execute(); err != nil { if err := t.Execute(); err != nil {
t.access.Lock()
t.running = false
t.access.Unlock()
return err return err
} }
t.access.Lock() t.access.Lock()
defer t.access.Unlock() defer t.access.Unlock()
if t.closed { if !t.running {
return nil return nil
} }
t.timer = time.AfterFunc(t.Interval, func() { t.timer = time.AfterFunc(t.Interval, func() {
if err := t.checkedExecute(); err != nil && t.OnError != nil { t.checkedExecute() // nolint: errcheck
t.OnError(err)
}
}) })
return nil return nil
} }
// Start implements common.Runnable. Start must not be called multiple times without Close being called. // Start implements common.Runnable.
func (t *Periodic) Start() error { func (t *Periodic) Start() error {
t.setClosed(false) t.access.Lock()
if t.running {
return nil
}
t.running = true
t.access.Unlock()
if err := t.checkedExecute(); err != nil { if err := t.checkedExecute(); err != nil {
t.setClosed(true) t.access.Lock()
t.running = false
t.access.Unlock()
return err return err
} }
@ -74,7 +74,7 @@ func (t *Periodic) Close() error {
t.access.Lock() t.access.Lock()
defer t.access.Unlock() defer t.access.Unlock()
t.closed = true t.running = false
if t.timer != nil { if t.timer != nil {
t.timer.Stop() t.timer.Stop()
t.timer = nil t.timer = nil

View File

@ -27,4 +27,10 @@ func TestPeriodicTaskStop(t *testing.T) {
assert(value, Equals, 3) assert(value, Equals, 3)
time.Sleep(time.Second * 4) time.Sleep(time.Second * 4)
assert(value, Equals, 3) assert(value, Equals, 3)
common.Must(task.Start())
time.Sleep(time.Second * 3)
if value != 5 {
t.Fatal("Expected 5, but ", value)
}
common.Must(task.Close())
} }

View File

@ -42,12 +42,8 @@ func NewSessionHistory() *SessionHistory {
} }
h.task = &task.Periodic{ h.task = &task.Periodic{
Interval: time.Second * 30, Interval: time.Second * 30,
Execute: func() error { Execute: h.removeExpiredEntries,
h.removeExpiredEntries()
return nil
},
} }
common.Must(h.task.Start())
return h return h
} }
@ -58,24 +54,26 @@ func (h *SessionHistory) Close() error {
func (h *SessionHistory) addIfNotExits(session sessionId) bool { func (h *SessionHistory) addIfNotExits(session sessionId) bool {
h.Lock() h.Lock()
defer h.Unlock()
if expire, found := h.cache[session]; found && expire.After(time.Now()) { if expire, found := h.cache[session]; found && expire.After(time.Now()) {
h.Unlock()
return false return false
} }
h.cache[session] = time.Now().Add(time.Minute * 3) h.cache[session] = time.Now().Add(time.Minute * 3)
h.Unlock()
common.Must(h.task.Start())
return true return true
} }
func (h *SessionHistory) removeExpiredEntries() { func (h *SessionHistory) removeExpiredEntries() error {
now := time.Now() now := time.Now()
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
if len(h.cache) == 0 { if len(h.cache) == 0 {
return return newError("nothing to do")
} }
for session, expire := range h.cache { for session, expire := range h.cache {
@ -87,6 +85,8 @@ func (h *SessionHistory) removeExpiredEntries() {
if len(h.cache) == 0 { if len(h.cache) == 0 {
h.cache = make(map[sessionId]time.Time, 128) h.cache = make(map[sessionId]time.Time, 128)
} }
return nil
} }
// ServerSession keeps information for a session in VMess server. // ServerSession keeps information for a session in VMess server.