diff --git a/app/commander/commander.go b/app/commander/commander.go index 2afc7a75..e63935ff 100644 --- a/app/commander/commander.go +++ b/app/commander/commander.go @@ -72,7 +72,7 @@ func (c *Commander) Start() error { return nil } -func (c *Commander) Close() { +func (c *Commander) Close() error { c.Lock() defer c.Unlock() @@ -80,6 +80,8 @@ func (c *Commander) Close() { c.server.Stop() c.server = nil } + + return nil } func init() { diff --git a/app/commander/outbound.go b/app/commander/outbound.go index 018cef24..9d9f36a5 100644 --- a/app/commander/outbound.go +++ b/app/commander/outbound.go @@ -3,6 +3,7 @@ package commander import ( "context" "net" + "sync" "v2ray.com/core/common/signal" "v2ray.com/core/transport/ray" @@ -43,12 +44,24 @@ func (l *OutboundListener) Addr() net.Addr { type CommanderOutbound struct { tag string listener *OutboundListener + access sync.RWMutex + closed bool } func (co *CommanderOutbound) Dispatch(ctx context.Context, r ray.OutboundRay) { + co.access.RLock() + + if co.closed { + r.OutboundInput().CloseError() + r.OutboundOutput().CloseError() + co.access.RUnlock() + return + } + closeSignal := signal.NewNotifier() c := ray.NewConnection(r.OutboundInput(), r.OutboundOutput(), ray.ConnCloseSignal(closeSignal)) co.listener.add(c) + co.access.RUnlock() <-closeSignal.Wait() return @@ -59,7 +72,17 @@ func (co *CommanderOutbound) Tag() string { } func (co *CommanderOutbound) Start() error { + co.access.Lock() + co.closed = false + co.access.Unlock() return nil } -func (co *CommanderOutbound) Close() {} +func (co *CommanderOutbound) Close() error { + co.access.Lock() + co.closed = true + co.listener.Close() + co.access.Unlock() + + return nil +} diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 5f103a1f..6be84fcf 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -49,7 +49,7 @@ func (*DefaultDispatcher) Start() error { } // Close implements app.Application. -func (*DefaultDispatcher) Close() {} +func (*DefaultDispatcher) Close() error { return nil } // Dispatch implements core.Dispatcher. func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (ray.InboundRay, error) { diff --git a/app/dns/server.go b/app/dns/server.go index c26c7985..ba273b53 100644 --- a/app/dns/server.go +++ b/app/dns/server.go @@ -83,7 +83,8 @@ func (s *Server) Start() error { return nil } -func (*Server) Close() { +func (*Server) Close() error { + return nil } func (s *Server) GetCached(domain string) []net.IP { diff --git a/app/log/log.go b/app/log/log.go index 5614da2b..c2dc5e57 100644 --- a/app/log/log.go +++ b/app/log/log.go @@ -103,11 +103,13 @@ func (g *Instance) Handle(msg log.Message) { } // Close implement app.Application.Close(). -func (g *Instance) Close() { +func (g *Instance) Close() error { g.Lock() defer g.Unlock() g.active = false + + return nil } func init() { diff --git a/app/policy/manager.go b/app/policy/manager.go index d21eec4f..14f7b7b9 100644 --- a/app/policy/manager.go +++ b/app/policy/manager.go @@ -51,7 +51,8 @@ func (m *Instance) Start() error { } // Close implements app.Application.Close(). -func (m *Instance) Close() { +func (m *Instance) Close() error { + return nil } func init() { diff --git a/app/proxyman/command/command.go b/app/proxyman/command/command.go index 7ee973c4..4d93adca 100644 --- a/app/proxyman/command/command.go +++ b/app/proxyman/command/command.go @@ -119,7 +119,9 @@ func (*feature) Start() error { return nil } -func (*feature) Close() {} +func (*feature) Close() error { + return nil +} func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, cfg interface{}) (interface{}, error) { diff --git a/app/proxyman/inbound/always.go b/app/proxyman/inbound/always.go index ed085574..7253fa5e 100644 --- a/app/proxyman/inbound/always.go +++ b/app/proxyman/inbound/always.go @@ -76,10 +76,11 @@ func (h *AlwaysOnInboundHandler) Start() error { return nil } -func (h *AlwaysOnInboundHandler) Close() { +func (h *AlwaysOnInboundHandler) Close() error { for _, worker := range h.workers { worker.Close() } + return nil } func (h *AlwaysOnInboundHandler) GetRandomInboundProxy() (interface{}, net.Port, int) { diff --git a/app/proxyman/inbound/dynamic.go b/app/proxyman/inbound/dynamic.go index 281f8eb6..e6f77972 100644 --- a/app/proxyman/inbound/dynamic.go +++ b/app/proxyman/inbound/dynamic.go @@ -5,17 +5,18 @@ import ( "sync" "time" + "v2ray.com/core" "v2ray.com/core/app/proxyman" "v2ray.com/core/app/proxyman/mux" "v2ray.com/core/common/dice" "v2ray.com/core/common/net" + "v2ray.com/core/common/signal" "v2ray.com/core/proxy" ) type DynamicInboundHandler struct { tag string - ctx context.Context - cancel context.CancelFunc + v *core.Instance proxyConfig interface{} receiverConfig *proxyman.ReceiverConfig portMutex sync.Mutex @@ -24,18 +25,26 @@ type DynamicInboundHandler struct { worker []worker lastRefresh time.Time mux *mux.Server + task *signal.PeriodicTask } func NewDynamicInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*DynamicInboundHandler, error) { - ctx, cancel := context.WithCancel(ctx) + v := core.FromContext(ctx) + if v == nil { + return nil, newError("V is not in context.") + } h := &DynamicInboundHandler{ - ctx: ctx, tag: tag, - cancel: cancel, proxyConfig: proxyConfig, receiverConfig: receiverConfig, portsInUse: make(map[net.Port]bool), mux: mux.NewServer(ctx), + v: v, + } + + h.task = &signal.PeriodicTask{ + Interval: time.Minute * time.Duration(h.receiverConfig.AllocationStrategy.GetRefreshValue()), + Execute: h.refresh, } return h, nil @@ -59,9 +68,7 @@ func (h *DynamicInboundHandler) allocatePort() net.Port { } } -func (h *DynamicInboundHandler) waitAnyCloseWorkers(ctx context.Context, cancel context.CancelFunc, workers []worker, duration time.Duration) { - time.Sleep(duration) - cancel() +func (h *DynamicInboundHandler) closeWorkers(workers []worker) { ports2Del := make([]net.Port, len(workers)) for idx, worker := range workers { ports2Del[idx] = worker.Port() @@ -80,7 +87,6 @@ func (h *DynamicInboundHandler) refresh() error { timeout := time.Minute * time.Duration(h.receiverConfig.AllocationStrategy.GetRefreshValue()) * 2 concurrency := h.receiverConfig.AllocationStrategy.GetConcurrencyValue() - ctx, cancel := context.WithTimeout(h.ctx, timeout) workers := make([]worker, 0, concurrency) address := h.receiverConfig.Listen.AsAddress() @@ -89,11 +95,12 @@ func (h *DynamicInboundHandler) refresh() error { } for i := uint32(0); i < concurrency; i++ { port := h.allocatePort() - p, err := proxy.CreateInboundHandler(ctx, h.proxyConfig) + rawProxy, err := h.v.CreateObject(h.proxyConfig) if err != nil { newError("failed to create proxy instance").Base(err).AtWarning().WriteToLog() continue } + p := rawProxy.(proxy.Inbound) nl := p.Network() if nl.HasNetwork(net.Network_TCP) { worker := &tcpWorker{ @@ -134,33 +141,19 @@ func (h *DynamicInboundHandler) refresh() error { h.worker = workers h.workerMutex.Unlock() - go h.waitAnyCloseWorkers(ctx, cancel, workers, timeout) + time.AfterFunc(timeout, func() { + h.closeWorkers(workers) + }) return nil } -func (h *DynamicInboundHandler) monitor() { - timer := time.NewTicker(time.Minute * time.Duration(h.receiverConfig.AllocationStrategy.GetRefreshValue())) - defer timer.Stop() - - for { - select { - case <-h.ctx.Done(): - return - case <-timer.C: - h.refresh() - } - } -} - func (h *DynamicInboundHandler) Start() error { - err := h.refresh() - go h.monitor() - return err + return h.task.Start() } -func (h *DynamicInboundHandler) Close() { - h.cancel() +func (h *DynamicInboundHandler) Close() error { + return h.task.Close() } func (h *DynamicInboundHandler) GetRandomInboundProxy() (interface{}, net.Port, int) { diff --git a/app/proxyman/inbound/inbound.go b/app/proxyman/inbound/inbound.go index ee1508bc..cfd4cd5c 100644 --- a/app/proxyman/inbound/inbound.go +++ b/app/proxyman/inbound/inbound.go @@ -99,7 +99,7 @@ func (m *Manager) Start() error { return nil } -func (m *Manager) Close() { +func (m *Manager) Close() error { m.access.Lock() defer m.access.Unlock() @@ -111,6 +111,8 @@ func (m *Manager) Close() { for _, handler := range m.untaggedHandler { handler.Close() } + + return nil } func NewHandler(ctx context.Context, config *core.InboundHandlerConfig) (core.InboundHandler, error) { diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 51bdc6ff..95cd882c 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -9,8 +9,10 @@ import ( "v2ray.com/core" "v2ray.com/core/app/proxyman" + "v2ray.com/core/common" "v2ray.com/core/common/buf" "v2ray.com/core/common/net" + "v2ray.com/core/common/signal" "v2ray.com/core/proxy" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tcp" @@ -19,7 +21,7 @@ import ( type worker interface { Start() error - Close() + Close() error Port() net.Port Proxy() proxy.Inbound } @@ -34,13 +36,11 @@ type tcpWorker struct { dispatcher core.Dispatcher sniffers []proxyman.KnownProtocols - ctx context.Context - cancel context.CancelFunc - hub internet.Listener + hub internet.Listener } func (w *tcpWorker) callback(conn internet.Connection) { - ctx, cancel := context.WithCancel(w.ctx) + ctx, cancel := context.WithCancel(context.Background()) if w.recvOrigDest { dest, err := tcp.GetOriginalDestination(conn) if err != nil { @@ -70,45 +70,24 @@ func (w *tcpWorker) Proxy() proxy.Inbound { } func (w *tcpWorker) Start() error { - ctx, cancel := context.WithCancel(context.Background()) - w.ctx = ctx - w.cancel = cancel - ctx = internet.ContextWithStreamSettings(ctx, w.stream) - conns := make(chan internet.Connection, 16) - hub, err := internet.ListenTCP(ctx, w.address, w.port, conns) + ctx := internet.ContextWithStreamSettings(context.Background(), w.stream) + hub, err := internet.ListenTCP(ctx, w.address, w.port, func(conn internet.Connection) { + go w.callback(conn) + }) if err != nil { return newError("failed to listen TCP on ", w.port).AtWarning().Base(err) } - go w.handleConnections(conns) w.hub = hub return nil } -func (w *tcpWorker) handleConnections(conns <-chan internet.Connection) { - for { - select { - case <-w.ctx.Done(): - w.hub.Close() - L: - for { - select { - case conn := <-conns: - conn.Close() - default: - break L - } - } - return - case conn := <-conns: - go w.callback(conn) - } - } -} - -func (w *tcpWorker) Close() { +func (w *tcpWorker) Close() error { if w.hub != nil { - w.cancel() + common.Close(w.hub) + common.Close(w.proxy) } + + return nil } func (w *tcpWorker) Port() net.Port { @@ -121,8 +100,7 @@ type udpConn struct { output func([]byte) (int, error) remote net.Addr local net.Addr - ctx context.Context - cancel context.CancelFunc + done *signal.Done } func (c *udpConn) updateActivity() { @@ -135,7 +113,7 @@ func (c *udpConn) Read(buf []byte) (int, error) { defer in.Release() c.updateActivity() return copy(buf, in.Bytes()), nil - case <-c.ctx.Done(): + case <-c.done.C(): return 0, io.EOF } } @@ -150,6 +128,7 @@ func (c *udpConn) Write(buf []byte) (int, error) { } func (c *udpConn) Close() error { + common.Close(c.done) return nil } @@ -189,8 +168,7 @@ type udpWorker struct { tag string dispatcher core.Dispatcher - ctx context.Context - cancel context.CancelFunc + done *signal.Done activeConn map[connId]*udpConn } @@ -215,6 +193,7 @@ func (w *udpWorker) getConnection(id connId) (*udpConn, bool) { IP: w.address.IP(), Port: int(w.port), }, + done: signal.NewDone(), } w.activeConn[id] = conn @@ -230,16 +209,15 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest conn, existing := w.getConnection(id) select { case conn.input <- b: + case <-conn.done.C(): + b.Release() default: b.Release() } if !existing { go func() { - ctx := w.ctx - ctx, cancel := context.WithCancel(ctx) - conn.ctx = ctx - conn.cancel = cancel + ctx := context.Background() if originalDest.IsValid() { ctx = proxy.ContextWithOriginalTarget(ctx, originalDest) } @@ -251,8 +229,8 @@ func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil { newError("connection ends").Base(err).WriteToLog() } + conn.Close() w.removeConn(id) - cancel() }() } } @@ -265,9 +243,7 @@ func (w *udpWorker) removeConn(id connId) { func (w *udpWorker) Start() error { w.activeConn = make(map[connId]*udpConn, 16) - ctx, cancel := context.WithCancel(context.Background()) - w.ctx = ctx - w.cancel = cancel + w.done = signal.NewDone() h, err := udp.ListenUDP(w.address, w.port, udp.ListenOption{ Callback: w.callback, ReceiveOriginalDest: w.recvOrigDest, @@ -280,11 +256,13 @@ func (w *udpWorker) Start() error { return nil } -func (w *udpWorker) Close() { +func (w *udpWorker) Close() error { if w.hub != nil { w.hub.Close() - w.cancel() + w.done.Close() + common.Close(w.proxy) } + return nil } func (w *udpWorker) monitor() { @@ -293,7 +271,7 @@ func (w *udpWorker) monitor() { for { select { - case <-w.ctx.Done(): + case <-w.done.C(): return case <-timer.C: nowSec := time.Now().Unix() @@ -301,7 +279,7 @@ func (w *udpWorker) monitor() { for addr, conn := range w.activeConn { if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 8 { delete(w.activeConn, addr) - conn.cancel() + conn.Close() } } w.Unlock() diff --git a/app/proxyman/mux/mux.go b/app/proxyman/mux/mux.go index 52b2b885..e2ae68fb 100644 --- a/app/proxyman/mux/mux.go +++ b/app/proxyman/mux/mux.go @@ -291,7 +291,8 @@ func (s *Server) Start() error { return nil } -func (s *Server) Close() { +func (s *Server) Close() error { + return nil } type ServerWorker struct { diff --git a/app/proxyman/mux/session.go b/app/proxyman/mux/session.go index 66e1b729..c828d21c 100644 --- a/app/proxyman/mux/session.go +++ b/app/proxyman/mux/session.go @@ -103,12 +103,12 @@ func (m *SessionManager) CloseIfNoSession() bool { return true } -func (m *SessionManager) Close() { +func (m *SessionManager) Close() error { m.Lock() defer m.Unlock() if m.closed { - return + return nil } m.closed = true @@ -119,6 +119,7 @@ func (m *SessionManager) Close() { } m.sessions = nil + return nil } // Session represents a client connection in a Mux connection. @@ -131,10 +132,11 @@ type Session struct { } // Close closes all resources associated with this session. -func (s *Session) Close() { +func (s *Session) Close() error { s.output.Close() s.input.Close() s.parent.Remove(s.ID) + return nil } // NewReader creates a buf.Reader based on the transfer type of this Session. diff --git a/app/proxyman/mux/writer.go b/app/proxyman/mux/writer.go index 64c290bd..f8cbc21f 100644 --- a/app/proxyman/mux/writer.go +++ b/app/proxyman/mux/writer.go @@ -100,7 +100,7 @@ func (w *Writer) WriteMultiBuffer(mb buf.MultiBuffer) error { return nil } -func (w *Writer) Close() { +func (w *Writer) Close() error { meta := FrameMetadata{ SessionID: w.id, SessionStatus: SessionStatusEnd, @@ -110,4 +110,5 @@ func (w *Writer) Close() { common.Must(frame.Reset(meta.AsSupplier())) w.writer.WriteMultiBuffer(buf.NewMultiBufferValue(frame)) + return nil } diff --git a/app/proxyman/outbound/outbound.go b/app/proxyman/outbound/outbound.go index ccd98845..13671309 100644 --- a/app/proxyman/outbound/outbound.go +++ b/app/proxyman/outbound/outbound.go @@ -39,7 +39,7 @@ func New(ctx context.Context, config *proxyman.OutboundConfig) (*Manager, error) func (*Manager) Start() error { return nil } // Close implements core.Feature -func (*Manager) Close() {} +func (*Manager) Close() error { return nil } // GetDefaultHandler implements core.OutboundHandlerManager. func (m *Manager) GetDefaultHandler() core.OutboundHandler { diff --git a/app/router/router.go b/app/router/router.go index 7d003dc4..5e0109d5 100644 --- a/app/router/router.go +++ b/app/router/router.go @@ -114,7 +114,9 @@ func (*Router) Start() error { return nil } -func (*Router) Close() {} +func (*Router) Close() error { + return nil +} func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { diff --git a/clock.go b/clock.go index 5489cae6..54bcb905 100644 --- a/clock.go +++ b/clock.go @@ -40,15 +40,15 @@ func (c *syncClock) Start() error { return c.Clock.Start() } -func (c *syncClock) Close() { +func (c *syncClock) Close() error { c.RLock() defer c.RUnlock() if c.Clock == nil { - return + return nil } - c.Clock.Close() + return c.Clock.Close() } func (c *syncClock) Set(clock Clock) { diff --git a/commander.go b/commander.go index f65f0072..63209bfd 100644 --- a/commander.go +++ b/commander.go @@ -44,15 +44,15 @@ func (c *syncCommander) Start() error { return c.Commander.Start() } -func (c *syncCommander) Close() { +func (c *syncCommander) Close() error { c.RLock() defer c.RUnlock() if c.Commander == nil { - return + return nil } - c.Commander.Close() + return c.Commander.Close() } func (c *syncCommander) Set(commander Commander) { diff --git a/common/interfaces.go b/common/interfaces.go index 6ce590c2..d5cc177e 100644 --- a/common/interfaces.go +++ b/common/interfaces.go @@ -1,10 +1,23 @@ package common +// Closable is the interface for objects that can release its resources. +type Closable interface { + // Close release all resources used by this object, including goroutines. + Close() error +} + +// Close closes the obj if it is a Closable. +func Close(obj interface{}) error { + if c, ok := obj.(Closable); ok { + return c.Close() + } + return nil +} + // Runnable is the interface for objects that can start to work and stop on demand. type Runnable interface { // Start starts the runnable object. Upon the method returning nil, the object begins to function properly. Start() error - // Close stops the object being working. - Close() + Closable } diff --git a/common/signal/done.go b/common/signal/done.go new file mode 100644 index 00000000..8cffb62e --- /dev/null +++ b/common/signal/done.go @@ -0,0 +1,48 @@ +package signal + +import ( + "sync" +) + +type Done struct { + access sync.Mutex + c chan struct{} + closed bool +} + +func NewDone() *Done { + return &Done{ + c: make(chan struct{}), + } +} + +func (d *Done) Done() bool { + select { + case <-d.c: + return true + default: + return false + } +} + +func (d *Done) C() chan struct{} { + return d.c +} + +func (d *Done) Wait() { + <-d.c +} + +func (d *Done) Close() error { + d.access.Lock() + defer d.access.Unlock() + + if d.closed { + return nil + } + + d.closed = true + close(d.c) + + return nil +} diff --git a/common/signal/notifier.go b/common/signal/notifier.go index 0d98c220..a4c4d5b5 100644 --- a/common/signal/notifier.go +++ b/common/signal/notifier.go @@ -1,22 +1,22 @@ package signal type Notifier struct { - c chan bool + c chan struct{} } func NewNotifier() *Notifier { return &Notifier{ - c: make(chan bool, 1), + c: make(chan struct{}, 1), } } func (n *Notifier) Signal() { select { - case n.c <- true: + case n.c <- struct{}{}: default: } } -func (n *Notifier) Wait() <-chan bool { +func (n *Notifier) Wait() <-chan struct{} { return n.c } diff --git a/common/signal/semaphore.go b/common/signal/semaphore.go index 034a4ee7..f9a80db5 100644 --- a/common/signal/semaphore.go +++ b/common/signal/semaphore.go @@ -1,23 +1,23 @@ package signal type Semaphore struct { - token chan bool + token chan struct{} } func NewSemaphore(n int) *Semaphore { s := &Semaphore{ - token: make(chan bool, n), + token: make(chan struct{}, n), } for i := 0; i < n; i++ { - s.token <- true + s.token <- struct{}{} } return s } -func (s *Semaphore) Wait() <-chan bool { +func (s *Semaphore) Wait() <-chan struct{} { return s.token } func (s *Semaphore) Signal() { - s.token <- true + s.token <- struct{}{} } diff --git a/common/signal/task.go b/common/signal/task.go new file mode 100644 index 00000000..9d876804 --- /dev/null +++ b/common/signal/task.go @@ -0,0 +1,60 @@ +package signal + +import ( + "sync" + "time" +) + +type PeriodicTask struct { + Interval time.Duration + Execute func() error + + access sync.Mutex + timer *time.Timer + closed bool +} + +func (t *PeriodicTask) checkedExecute() error { + t.access.Lock() + defer t.access.Unlock() + + if t.closed { + return nil + } + + if err := t.Execute(); err != nil { + return err + } + + t.timer = time.AfterFunc(t.Interval, func() { + t.checkedExecute() + }) + + return nil +} + +func (t *PeriodicTask) Start() error { + t.access.Lock() + t.closed = false + t.access.Unlock() + + if err := t.checkedExecute(); err != nil { + t.closed = true + return err + } + + return nil +} + +func (t *PeriodicTask) Close() error { + t.access.Lock() + defer t.access.Unlock() + + t.closed = true + if t.timer != nil { + t.timer.Stop() + t.timer = nil + } + + return nil +} diff --git a/common/signal/timer.go b/common/signal/timer.go index 30989b2f..6a63af05 100644 --- a/common/signal/timer.go +++ b/common/signal/timer.go @@ -10,14 +10,14 @@ type ActivityUpdater interface { } type ActivityTimer struct { - updated chan bool + updated chan struct{} timeout chan time.Duration - closing chan bool + closing chan struct{} } func (t *ActivityTimer) Update() { select { - case t.updated <- true: + case t.updated <- struct{}{}: default: } } @@ -72,8 +72,8 @@ func (t *ActivityTimer) run(ctx context.Context, cancel context.CancelFunc) { func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer { timer := &ActivityTimer{ timeout: make(chan time.Duration, 1), - updated: make(chan bool, 1), - closing: make(chan bool), + updated: make(chan struct{}, 1), + closing: make(chan struct{}), } timer.timeout <- timeout go timer.run(ctx, cancel) diff --git a/dns.go b/dns.go index 0c8046e1..5e94d1c2 100644 --- a/dns.go +++ b/dns.go @@ -1,7 +1,11 @@ package core -import "net" -import "sync" +import ( + "net" + "sync" + + "v2ray.com/core/common" +) // DNSClient is a V2Ray feature for querying DNS information. type DNSClient interface { @@ -36,13 +40,11 @@ func (d *syncDNSClient) Start() error { return d.DNSClient.Start() } -func (d *syncDNSClient) Close() { +func (d *syncDNSClient) Close() error { d.RLock() defer d.RUnlock() - if d.DNSClient != nil { - d.DNSClient.Close() - } + return common.Close(d.DNSClient) } func (d *syncDNSClient) Set(client DNSClient) { diff --git a/network.go b/network.go index 1d2f5091..d526a91f 100644 --- a/network.go +++ b/network.go @@ -21,6 +21,7 @@ type InboundHandler interface { // OutboundHandler is the interface for handlers that process outbound connections. type OutboundHandler interface { + common.Runnable Tag() string Dispatch(ctx context.Context, outboundRay ray.OutboundRay) } @@ -75,13 +76,11 @@ func (m *syncInboundHandlerManager) Start() error { return m.InboundHandlerManager.Start() } -func (m *syncInboundHandlerManager) Close() { +func (m *syncInboundHandlerManager) Close() error { m.RLock() defer m.RUnlock() - if m.InboundHandlerManager != nil { - m.InboundHandlerManager.Close() - } + return common.Close(m.InboundHandlerManager) } func (m *syncInboundHandlerManager) Set(manager InboundHandlerManager) { @@ -154,13 +153,11 @@ func (m *syncOutboundHandlerManager) Start() error { return m.OutboundHandlerManager.Start() } -func (m *syncOutboundHandlerManager) Close() { +func (m *syncOutboundHandlerManager) Close() error { m.RLock() defer m.RUnlock() - if m.OutboundHandlerManager != nil { - m.OutboundHandlerManager.Close() - } + return common.Close(m.OutboundHandlerManager) } func (m *syncOutboundHandlerManager) Set(manager OutboundHandlerManager) { diff --git a/policy.go b/policy.go index 12cff0f1..a0717320 100644 --- a/policy.go +++ b/policy.go @@ -3,6 +3,8 @@ package core import ( "sync" "time" + + "v2ray.com/core/common" ) // TimeoutPolicy contains limits for connection timeout. @@ -96,13 +98,11 @@ func (m *syncPolicyManager) Start() error { return m.PolicyManager.Start() } -func (m *syncPolicyManager) Close() { +func (m *syncPolicyManager) Close() error { m.RLock() defer m.RUnlock() - if m.PolicyManager != nil { - m.PolicyManager.Close() - } + return common.Close(m.PolicyManager) } func (m *syncPolicyManager) Set(manager PolicyManager) { diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index ee6b1a1d..75970590 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -48,8 +48,9 @@ func TestRequestSerialization(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) sessionHistory := NewSessionHistory() - userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash) + userValidator := vmess.NewTimedUserValidator(protocol.DefaultIDHash) userValidator.Add(user) + defer common.Close(userValidator) server := NewServerSession(userValidator, sessionHistory) actualRequest, err := server.DecodeRequestHeader(buffer) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 8a81a33a..8806540f 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -30,28 +30,34 @@ type sessionId struct { type SessionHistory struct { sync.RWMutex cache map[sessionId]time.Time - token *signal.Semaphore - timer *time.Timer + task *signal.PeriodicTask } func NewSessionHistory() *SessionHistory { h := &SessionHistory{ cache: make(map[sessionId]time.Time, 128), - token: signal.NewSemaphore(1), } + h.task = &signal.PeriodicTask{ + Interval: time.Second * 30, + Execute: func() error { + h.removeExpiredEntries() + return nil + }, + } + common.Must(h.task.Start()) return h } +// Close implements common.Closable. +func (h *SessionHistory) Close() error { + return h.task.Close() +} + func (h *SessionHistory) add(session sessionId) { h.Lock() defer h.Unlock() h.cache[session] = time.Now().Add(time.Minute * 3) - select { - case <-h.token.Wait(): - h.timer = time.AfterFunc(time.Minute*3, h.removeExpiredEntries) - default: - } } func (h *SessionHistory) has(session sessionId) bool { @@ -75,11 +81,6 @@ func (h *SessionHistory) removeExpiredEntries() { delete(h.cache, session) } } - - if h.timer != nil { - h.timer.Stop() - h.timer = nil - } } type ServerSession struct { diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index 7490cbc7..26deba01 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -90,7 +90,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) { handler := &Handler{ policyManager: v.PolicyManager(), inboundHandlerManager: v.InboundHandlerManager(), - clients: vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash), + clients: vmess.NewTimedUserValidator(protocol.DefaultIDHash), detours: config.Detour, usersByEmail: newUserByEmail(config.User, config.GetDefaultValue()), sessionHistory: encoding.NewSessionHistory(), @@ -105,6 +105,14 @@ func New(ctx context.Context, config *Config) (*Handler, error) { return handler, nil } +// Close implements common.Closable. +func (h *Handler) Close() error { + common.Close(h.clients) + common.Close(h.sessionHistory) + common.Close(h.usersByEmail) + return nil +} + // Network implements proxy.Inbound.Network(). func (*Handler) Network() net.NetworkList { return net.NetworkList{ diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index 9c805fa2..4c1cc5c3 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -8,17 +8,17 @@ package vmess //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg vmess -path Proxy,VMess import ( - "context" "sync" "time" "v2ray.com/core/common" "v2ray.com/core/common/protocol" + "v2ray.com/core/common/signal" ) const ( - updateIntervalSec = 10 - cacheDurationSec = 120 + updateInterval = 10 * time.Second + cacheDurationSec = 120 ) type idEntry struct { @@ -34,6 +34,7 @@ type TimedUserValidator struct { ids []*idEntry hasher protocol.IDHash baseTime protocol.Timestamp + task *signal.PeriodicTask } type indexTimePair struct { @@ -41,16 +42,23 @@ type indexTimePair struct { timeInc uint32 } -func NewTimedUserValidator(ctx context.Context, hasher protocol.IDHash) protocol.UserValidator { - tus := &TimedUserValidator{ +func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator { + tuv := &TimedUserValidator{ validUsers: make([]*protocol.User, 0, 16), userHash: make(map[[16]byte]indexTimePair, 512), ids: make([]*idEntry, 0, 512), hasher: hasher, baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*3), } - go tus.updateUserHash(ctx, updateIntervalSec*time.Second) - return tus + tuv.task = &signal.PeriodicTask{ + Interval: updateInterval, + Execute: func() error { + tuv.updateUserHash() + return nil + }, + } + tuv.task.Start() + return tuv } func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) { @@ -78,24 +86,19 @@ func (v *TimedUserValidator) removeExpiredHashes(expire uint32) { } } -func (v *TimedUserValidator) updateUserHash(ctx context.Context, interval time.Duration) { - for { - select { - case now := <-time.After(interval): - nowSec := protocol.Timestamp(now.Unix() + cacheDurationSec) - v.Lock() - for _, entry := range v.ids { - v.generateNewHashes(nowSec, entry.userIdx, entry) - } - - expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3) - if expire > v.baseTime { - v.removeExpiredHashes(uint32(expire - v.baseTime)) - } - v.Unlock() - case <-ctx.Done(): - return - } +func (v *TimedUserValidator) updateUserHash() { + now := time.Now() + nowSec := protocol.Timestamp(now.Unix() + cacheDurationSec) + v.Lock() + defer v.Unlock() + + for _, entry := range v.ids { + v.generateNewHashes(nowSec, entry.userIdx, entry) + } + + expire := protocol.Timestamp(now.Unix() - cacheDurationSec*3) + if expire > v.baseTime { + v.removeExpiredHashes(uint32(expire - v.baseTime)) } } @@ -145,3 +148,8 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Time } return nil, 0, false } + +// Close implements common.Closable. +func (v *TimedUserValidator) Close() error { + return v.task.Close() +} diff --git a/router.go b/router.go index 26080620..d3a8031f 100644 --- a/router.go +++ b/router.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "v2ray.com/core/common" "v2ray.com/core/common/errors" "v2ray.com/core/common/net" "v2ray.com/core/transport/ray" @@ -45,13 +46,11 @@ func (d *syncDispatcher) Start() error { return d.Dispatcher.Start() } -func (d *syncDispatcher) Close() { +func (d *syncDispatcher) Close() error { d.RLock() defer d.RUnlock() - if d.Dispatcher != nil { - d.Dispatcher.Close() - } + return common.Close(d.Dispatcher) } func (d *syncDispatcher) Set(disp Dispatcher) { @@ -101,13 +100,11 @@ func (r *syncRouter) Start() error { return r.Router.Start() } -func (r *syncRouter) Close() { +func (r *syncRouter) Close() error { r.RLock() defer r.RUnlock() - if r.Router != nil { - r.Router.Close() - } + return common.Close(r.Router) } func (r *syncRouter) Set(router Router) { diff --git a/testing/servers/http/http.go b/testing/servers/http/http.go index 230e973c..6bada232 100644 --- a/testing/servers/http/http.go +++ b/testing/servers/http/http.go @@ -36,6 +36,6 @@ func (s *Server) Start() (net.Destination, error) { return net.TCPDestination(net.LocalHostIP, net.Port(s.Port)), nil } -func (s *Server) Close() { - s.server.Close() +func (s *Server) Close() error { + return s.server.Close() } diff --git a/testing/servers/tcp/tcp.go b/testing/servers/tcp/tcp.go index 48c819c0..a54e4b39 100644 --- a/testing/servers/tcp/tcp.go +++ b/testing/servers/tcp/tcp.go @@ -69,6 +69,6 @@ func (server *Server) handleConnection(conn net.Conn) { conn.Close() } -func (server *Server) Close() { - server.listener.Close() +func (server *Server) Close() error { + return server.listener.Close() } diff --git a/testing/servers/udp/udp.go b/testing/servers/udp/udp.go index ff4c789d..22e47689 100644 --- a/testing/servers/udp/udp.go +++ b/testing/servers/udp/udp.go @@ -46,7 +46,7 @@ func (server *Server) handleConnection(conn *net.UDPConn) { } } -func (server *Server) Close() { +func (server *Server) Close() error { server.accepting = false - server.conn.Close() + return server.conn.Close() } diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index c46de60c..3fdcbf9f 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -23,7 +23,6 @@ type ConnectionID struct { // Listener defines a server listening for connections type Listener struct { sync.Mutex - ctx context.Context sessions map[ConnectionID]*Connection hub *udp.Hub tlsConfig *tls.Config @@ -31,10 +30,10 @@ type Listener struct { reader PacketReader header internet.PacketHeader security cipher.AEAD - addConn internet.AddConnection + addConn internet.ConnHandler } -func NewListener(ctx context.Context, address net.Address, port net.Port, addConn internet.AddConnection) (*Listener, error) { +func NewListener(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (*Listener, error) { networkSettings := internet.TransportSettingsFromContext(ctx) kcpSettings := networkSettings.(*Config) @@ -54,7 +53,6 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon Security: security, }, sessions: make(map[ConnectionID]*Connection), - ctx: ctx, config: kcpSettings, addConn: addConn, } @@ -86,12 +84,6 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD l.Lock() defer l.Unlock() - select { - case <-l.ctx.Done(): - return - default: - } - if l.hub == nil { return } @@ -136,23 +128,16 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination, originalD netConn = tlsConn } - if !l.addConn(context.Background(), netConn) { - return - } + l.addConn(netConn) l.sessions[id] = conn } conn.Input(segments) } func (l *Listener) Remove(id ConnectionID) { - select { - case <-l.ctx.Done(): - return - default: - l.Lock() - delete(l.sessions, id) - l.Unlock() - } + l.Lock() + delete(l.sessions, id) + l.Unlock() } // Close stops listening on the UDP address. Already Accepted connections are not closed. @@ -197,7 +182,7 @@ func (w *Writer) Close() error { return nil } -func ListenKCP(ctx context.Context, address net.Address, port net.Port, addConn internet.AddConnection) (internet.Listener, error) { +func ListenKCP(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (internet.Listener, error) { return NewListener(ctx, address, port, addConn) } diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 4047c71a..05c7e4eb 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -3,10 +3,10 @@ package tcp import ( "context" gotls "crypto/tls" + "strings" "v2ray.com/core/common" "v2ray.com/core/common/net" - "v2ray.com/core/common/retry" "v2ray.com/core/transport/internet" "v2ray.com/core/transport/internet/tls" ) @@ -17,11 +17,11 @@ type Listener struct { tlsConfig *gotls.Config authConfig internet.ConnectionAuthenticator config *Config - addConn internet.AddConnection + addConn internet.ConnHandler } // ListenTCP creates a new Listener based on configurations. -func ListenTCP(ctx context.Context, address net.Address, port net.Port, addConn internet.AddConnection) (internet.Listener, error) { +func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: address.IP(), Port: int(port), @@ -36,7 +36,7 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, addConn l := &Listener{ listener: listener, config: tcpSettings, - addConn: addConn, + addConn: handler, } if config := tls.ConfigFromContext(ctx, tls.WithNextProto("h2")); config != nil { @@ -54,27 +54,17 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, addConn } l.authConfig = auth } - go l.keepAccepting(ctx) + go l.keepAccepting() return l, nil } -func (v *Listener) keepAccepting(ctx context.Context) { +func (v *Listener) keepAccepting() { for { - select { - case <-ctx.Done(): - return - default: - } - var conn net.Conn - err := retry.ExponentialBackoff(5, 200).On(func() error { - rawConn, err := v.listener.Accept() - if err != nil { - return err - } - conn = rawConn - return nil - }) + conn, err := v.listener.Accept() if err != nil { + if strings.Contains(err.Error(), "closed") { + break + } newError("failed to accepted raw connections").Base(err).AtWarning().WriteToLog() continue } @@ -86,7 +76,7 @@ func (v *Listener) keepAccepting(ctx context.Context) { conn = v.authConfig.Server(conn) } - v.addConn(context.Background(), internet.Connection(conn)) + v.addConn(internet.Connection(conn)) } } diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 5f1a1d29..081b04aa 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -2,7 +2,6 @@ package internet import ( "context" - "time" "v2ray.com/core/common/net" ) @@ -19,16 +18,16 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc) return nil } -type AddConnection func(context.Context, Connection) bool +type ConnHandler func(Connection) -type ListenFunc func(ctx context.Context, address net.Address, port net.Port, addConn AddConnection) (Listener, error) +type ListenFunc func(ctx context.Context, address net.Address, port net.Port, handler ConnHandler) (Listener, error) type Listener interface { Close() error Addr() net.Addr } -func ListenTCP(ctx context.Context, address net.Address, port net.Port, conns chan<- Connection) (Listener, error) { +func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler ConnHandler) (Listener, error) { settings := StreamSettingsFromContext(ctx) protocol := settings.GetEffectiveProtocol() transportSettings, err := settings.GetEffectiveTransportSettings() @@ -47,26 +46,7 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, conns ch if listenFunc == nil { return nil, newError(protocol, " listener not registered.").AtError() } - listener, err := listenFunc(ctx, address, port, func(ctx context.Context, conn Connection) bool { - select { - case <-ctx.Done(): - conn.Close() - return false - case conns <- conn: - return true - default: - select { - case <-ctx.Done(): - conn.Close() - return false - case conns <- conn: - return true - case <-time.After(time.Second * 5): - conn.Close() - return false - } - } - }) + listener, err := listenFunc(ctx, address, port, handler) if err != nil { return nil, newError("failed to listen on address: ", address, ":", port).Base(err) } diff --git a/transport/internet/udp/dispatcher_test.go b/transport/internet/udp/dispatcher_test.go index 989cb651..e6c1958a 100644 --- a/transport/internet/udp/dispatcher_test.go +++ b/transport/internet/udp/dispatcher_test.go @@ -25,7 +25,9 @@ func (d *TestDispatcher) Start() error { return nil } -func (d *TestDispatcher) Close() {} +func (d *TestDispatcher) Close() error { + return nil +} func TestSameDestinationDispatching(t *testing.T) { assert := With(t) diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index d6b15178..cc96ca68 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -60,10 +60,11 @@ func (q *PayloadQueue) Dequeue(queue <-chan Payload) { } } -func (q *PayloadQueue) Close() { +func (q *PayloadQueue) Close() error { for _, queue := range q.queue { close(queue) } + return nil } type ListenOption struct { @@ -116,9 +117,10 @@ func ListenUDP(address net.Address, port net.Port, option ListenOption) (*Hub, e return hub, nil } -func (h *Hub) Close() { +func (h *Hub) Close() error { h.cancel() h.conn.Close() + return nil } func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) { diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index a5ec0eb0..36275fe2 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -44,24 +44,22 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req remoteAddr.(*net.TCPAddr).IP = forwardedAddrs[0].IP() } - h.ln.addConn(h.ln.ctx, newConnection(conn, remoteAddr)) + h.ln.addConn(newConnection(conn, remoteAddr)) } type Listener struct { sync.Mutex - ctx context.Context listener net.Listener tlsConfig *tls.Config config *Config - addConn internet.AddConnection + addConn internet.ConnHandler } -func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn internet.AddConnection) (internet.Listener, error) { +func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn internet.ConnHandler) (internet.Listener, error) { networkSettings := internet.TransportSettingsFromContext(ctx) wsSettings := networkSettings.(*Config) l := &Listener{ - ctx: ctx, config: wsSettings, addConn: addConn, } diff --git a/transport/ray/direct.go b/transport/ray/direct.go index efcb71cd..53af0f4d 100644 --- a/transport/ray/direct.go +++ b/transport/ray/direct.go @@ -209,12 +209,13 @@ func (s *Stream) WriteMultiBuffer(data buf.MultiBuffer) error { } // Close closes the stream for writing. Read() still works until EOF. -func (s *Stream) Close() { +func (s *Stream) Close() error { s.access.Lock() s.close = true s.readSignal.Signal() s.writeSignal.Signal() s.access.Unlock() + return nil } // CloseError closes the Stream with error. Read() will return an error afterwards. diff --git a/transport/ray/ray.go b/transport/ray/ray.go index 46d41f7b..557bc64e 100644 --- a/transport/ray/ray.go +++ b/transport/ray/ray.go @@ -1,6 +1,9 @@ package ray -import "v2ray.com/core/common/buf" +import ( + "v2ray.com/core/common" + "v2ray.com/core/common/buf" +) // OutboundRay is a transport interface for outbound connections. type OutboundRay interface { @@ -34,7 +37,7 @@ type Ray interface { } type RayStream interface { - Close() + common.Closable CloseError() } diff --git a/v2ray.go b/v2ray.go index 7bca7bab..af710fd3 100644 --- a/v2ray.go +++ b/v2ray.go @@ -101,7 +101,7 @@ func (s *Instance) ID() uuid.UUID { } // Close shutdown the V2Ray instance. -func (s *Instance) Close() { +func (s *Instance) Close() error { s.access.Lock() defer s.access.Unlock() @@ -109,6 +109,8 @@ func (s *Instance) Close() { for _, f := range s.features { f.Close() } + + return nil } // Start starts the V2Ray instance, including all registered features. When Start returns error, the state of the instance is unknown.