diff --git a/common/log/internal/log_writer.go b/common/log/internal/log_writer.go index 65c4d560..d89920da 100644 --- a/common/log/internal/log_writer.go +++ b/common/log/internal/log_writer.go @@ -62,6 +62,9 @@ func (this *FileLogWriter) Log(log LogEntry) { } func (this *FileLogWriter) run() { + this.cancel.WaitThread() + defer this.cancel.FinishThread() + for { entry, open := <-this.queue if !open { @@ -69,12 +72,11 @@ func (this *FileLogWriter) run() { } this.logger.Print(entry + platform.LineSeparator()) } - this.cancel.Done() } func (this *FileLogWriter) Close() { close(this.queue) - <-this.cancel.WaitForDone() + this.cancel.WaitForDone() this.file.Close() } diff --git a/common/signal/close.go b/common/signal/cancel.go similarity index 60% rename from common/signal/close.go rename to common/signal/cancel.go index c654ff17..53ad7fa8 100644 --- a/common/signal/close.go +++ b/common/signal/cancel.go @@ -1,35 +1,51 @@ package signal +import ( + "sync" +) + // CancelSignal is a signal passed to goroutine, in order to cancel the goroutine on demand. type CancelSignal struct { cancel chan struct{} - done chan struct{} + done sync.WaitGroup } // NewCloseSignal creates a new CancelSignal. func NewCloseSignal() *CancelSignal { return &CancelSignal{ cancel: make(chan struct{}), - done: make(chan struct{}), } } +func (this *CancelSignal) WaitThread() { + this.done.Add(1) +} + // Cancel signals the goroutine to stop. func (this *CancelSignal) Cancel() { close(this.cancel) } +func (this *CancelSignal) Cancelled() bool { + select { + case <-this.cancel: + return true + default: + return false + } +} + // WaitForCancel should be monitored by the goroutine for when to stop. func (this *CancelSignal) WaitForCancel() <-chan struct{} { return this.cancel } -// Done signals the caller that the goroutine has completely finished. -func (this *CancelSignal) Done() { - close(this.done) +// FinishThread signals that current goroutine has finished. +func (this *CancelSignal) FinishThread() { + this.done.Done() } // WaitForDone is used by caller to wait for the goroutine finishes. -func (this *CancelSignal) WaitForDone() <-chan struct{} { - return this.done +func (this *CancelSignal) WaitForDone() { + this.done.Wait() } diff --git a/proxy/dokodemo/dokodemo.go b/proxy/dokodemo/dokodemo.go index 00550312..15c3c4c0 100644 --- a/proxy/dokodemo/dokodemo.go +++ b/proxy/dokodemo/dokodemo.go @@ -95,6 +95,7 @@ func (this *DokodemoDoor) ListenUDP() error { this.meta.Address, this.meta.Port, udp.ListenOption{ Callback: this.handleUDPPackets, ReceiveOriginalDest: this.config.FollowRedirect, + Concurrency: 2, }) if err != nil { log.Error("Dokodemo failed to listen on ", this.meta.Address, ":", this.meta.Port, ": ", err) diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index fa871a0d..4d95c888 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -59,7 +59,7 @@ func (this *TimedUserValidator) Release() { } this.cancel.Cancel() - <-this.cancel.WaitForDone() + this.cancel.WaitForDone() this.Lock() defer this.Unlock() @@ -100,7 +100,9 @@ func (this *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx } func (this *TimedUserValidator) updateUserHash(interval time.Duration) { -L: + this.cancel.WaitThread() + defer this.cancel.FinishThread() + for { select { case now := <-time.After(interval): @@ -109,10 +111,9 @@ L: this.generateNewHashes(nowSec, entry.userIdx, entry) } case <-this.cancel.WaitForCancel(): - break L + return } } - this.cancel.Done() } func (this *TimedUserValidator) Add(user *protocol.User) error { diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 6250b750..7b9174bc 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -58,7 +58,7 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen l.tlsConfig = securitySettings.GetTLSConfig() } } - hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive}) + hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive, Concurrency: 2}) if err != nil { return nil, err } diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index 3c8b9fd5..a0fdf295 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -5,28 +5,95 @@ import ( "sync" "v2ray.com/core/common/alloc" + "v2ray.com/core/common/dice" "v2ray.com/core/common/log" v2net "v2ray.com/core/common/net" + "v2ray.com/core/common/signal" "v2ray.com/core/proxy" "v2ray.com/core/transport/internet/internal" ) +type UDPPayload struct { + payload *alloc.Buffer + session *proxy.SessionInfo +} + type UDPPayloadHandler func(*alloc.Buffer, *proxy.SessionInfo) -type UDPHub struct { - sync.RWMutex - conn *net.UDPConn - option ListenOption - accepting bool - pool *alloc.BufferPool +type UDPPayloadQueue struct { + queue []chan UDPPayload + callback UDPPayloadHandler + cancel *signal.CancelSignal +} + +func NewUDPPayloadQueue(option ListenOption) *UDPPayloadQueue { + queue := &UDPPayloadQueue{ + callback: option.Callback, + cancel: signal.NewCloseSignal(), + queue: make([]chan UDPPayload, option.Concurrency), + } + for i := range queue.queue { + queue.queue[i] = make(chan UDPPayload, 64) + go queue.Dequeue(queue.queue[i]) + } + return queue +} + +func (this *UDPPayloadQueue) Enqueue(payload UDPPayload) { + size := len(this.queue) + for i := 0; i < size; i++ { + idx := 0 + if size > 1 { + idx = dice.Roll(size) + } + select { + case this.queue[idx] <- payload: + return + default: + } + } +} + +func (this *UDPPayloadQueue) Dequeue(queue <-chan UDPPayload) { + this.cancel.WaitThread() + defer this.cancel.FinishThread() + + for !this.cancel.Cancelled() { + payload, open := <-queue + if !open { + return + } + this.callback(payload.payload, payload.session) + } +} + +func (this *UDPPayloadQueue) Close() { + this.cancel.Cancel() + for _, queue := range this.queue { + close(queue) + } + this.cancel.WaitForDone() } type ListenOption struct { Callback UDPPayloadHandler ReceiveOriginalDest bool + Concurrency int +} + +type UDPHub struct { + sync.RWMutex + conn *net.UDPConn + pool *alloc.BufferPool + cancel *signal.CancelSignal + queue *UDPPayloadQueue + option ListenOption } func ListenUDP(address v2net.Address, port v2net.Port, option ListenOption) (*UDPHub, error) { + if option.Concurrency < 1 { + option.Concurrency = 1 + } udpConn, err := net.ListenUDP("udp", &net.UDPAddr{ IP: address.IP(), Port: int(port), @@ -48,8 +115,10 @@ func ListenUDP(address v2net.Address, port v2net.Port, option ListenOption) (*UD } hub := &UDPHub{ conn: udpConn, - option: option, pool: alloc.NewBufferPool(2048, 64), + queue: NewUDPPayloadQueue(option), + option: option, + cancel: signal.NewCloseSignal(), } go hub.start() return hub, nil @@ -59,8 +128,10 @@ func (this *UDPHub) Close() { this.Lock() defer this.Unlock() - this.accepting = false + this.cancel.Cancel() this.conn.Close() + this.cancel.WaitForDone() + this.queue.Close() } func (this *UDPHub) WriteTo(payload []byte, dest v2net.Destination) (int, error) { @@ -71,9 +142,8 @@ func (this *UDPHub) WriteTo(payload []byte, dest v2net.Destination) (int, error) } func (this *UDPHub) start() { - this.Lock() - this.accepting = true - this.Unlock() + this.cancel.WaitThread() + defer this.cancel.FinishThread() oobBytes := make([]byte, 256) for this.Running() { @@ -91,15 +161,15 @@ func (this *UDPHub) start() { if this.option.ReceiveOriginalDest && noob > 0 { session.Destination = RetrieveOriginalDest(oobBytes[:noob]) } - go this.option.Callback(buffer, session) + this.queue.Enqueue(UDPPayload{ + payload: buffer, + session: session, + }) } } func (this *UDPHub) Running() bool { - this.RLock() - defer this.RUnlock() - - return this.accepting + return !this.cancel.Cancelled() } // Connection return the net.Conn underneath this hub.