From d93ff628bcda0fa4218508d7eedbbb21fedd5d44 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sun, 26 Feb 2017 14:38:41 +0100 Subject: [PATCH] refactor tcp worker --- app/proxyman/inbound/worker.go | 30 +++++++++- transport/internet/context.go | 9 ++- transport/internet/dialer.go | 2 +- transport/internet/kcp/kcp_test.go | 20 +++---- transport/internet/kcp/listener.go | 62 ++++++++------------ transport/internet/system_dialer.go | 3 +- transport/internet/tcp/hub.go | 75 ++++++++----------------- transport/internet/tcp_hub.go | 72 ++---------------------- transport/internet/websocket/hub.go | 63 ++++++--------------- transport/internet/websocket/ws_test.go | 29 +++++----- 10 files changed, 127 insertions(+), 238 deletions(-) diff --git a/app/proxyman/inbound/worker.go b/app/proxyman/inbound/worker.go index 33149e76..81730b4f 100644 --- a/app/proxyman/inbound/worker.go +++ b/app/proxyman/inbound/worker.go @@ -37,7 +37,7 @@ type tcpWorker struct { ctx context.Context cancel context.CancelFunc - hub *internet.TCPHub + hub internet.Listener } func (w *tcpWorker) callback(conn internet.Connection) { @@ -73,17 +73,41 @@ func (w *tcpWorker) Start() error { ctx, cancel := context.WithCancel(context.Background()) w.ctx = ctx w.cancel = cancel - hub, err := internet.ListenTCP(w.address, w.port, w.callback, w.stream) + ctx = internet.ContextWithStreamSettings(ctx, w.stream) + conns := make(chan internet.Connection, 16) + hub, err := internet.ListenTCP(ctx, w.address, w.port, conns) if err != nil { return err } + go w.handleConnections(conns) w.hub = hub return nil } +func (w *tcpWorker) handleConnections(conns <-chan internet.Connection) { + for { + select { + case <-w.ctx.Done(): + w.hub.Close() + nconns := len(conns) + L: + for i := 0; i < nconns; i++ { + select { + case conn := <-conns: + conn.Close() + default: + break L + } + } + return + case conn := <-conns: + go w.callback(conn) + } + } +} + func (w *tcpWorker) Close() { if w.hub != nil { - w.hub.Close() w.cancel() } } diff --git a/transport/internet/context.go b/transport/internet/context.go index 33f8b862..05de7766 100644 --- a/transport/internet/context.go +++ b/transport/internet/context.go @@ -19,9 +19,12 @@ func ContextWithStreamSettings(ctx context.Context, streamSettings *StreamConfig return context.WithValue(ctx, streamSettingsKey, streamSettings) } -func StreamSettingsFromContext(ctx context.Context) (*StreamConfig, bool) { - ss, ok := ctx.Value(streamSettingsKey).(*StreamConfig) - return ss, ok +func StreamSettingsFromContext(ctx context.Context) *StreamConfig { + ss := ctx.Value(streamSettingsKey) + if ss == nil { + return nil + } + return ss.(*StreamConfig) } func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context { diff --git a/transport/internet/dialer.go b/transport/internet/dialer.go index 547d21d1..0e9c24ca 100644 --- a/transport/internet/dialer.go +++ b/transport/internet/dialer.go @@ -24,7 +24,7 @@ func RegisterTransportDialer(protocol TransportProtocol, dialer Dialer) error { func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) { if dest.Network == v2net.Network_TCP { - streamSettings, _ := StreamSettingsFromContext(ctx) + streamSettings := StreamSettingsFromContext(ctx) protocol := streamSettings.GetEffectiveProtocol() transportSettings, err := streamSettings.GetEffectiveTransportSettings() if err != nil { diff --git a/transport/internet/kcp/kcp_test.go b/transport/internet/kcp/kcp_test.go index 3159f397..66bfcd0e 100644 --- a/transport/internet/kcp/kcp_test.go +++ b/transport/internet/kcp/kcp_test.go @@ -18,30 +18,27 @@ import ( func TestDialAndListen(t *testing.T) { assert := assert.On(t) - listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0)) + conns := make(chan internet.Connection, 16) + listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0), conns) assert.Error(err).IsNil() port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port) go func() { - for { - conn, err := listerner.Accept() - if err != nil { - break - } - go func() { + for conn := range conns { + go func(c internet.Connection) { payload := make([]byte, 4096) for { - nBytes, err := conn.Read(payload) + nBytes, err := c.Read(payload) if err != nil { break } for idx, b := range payload[:nBytes] { payload[idx] = b ^ 'c' } - conn.Write(payload[:nBytes]) + c.Write(payload[:nBytes]) } - conn.Close() - }() + c.Close() + }(conn) } }() @@ -79,4 +76,5 @@ func TestDialAndListen(t *testing.T) { assert.Int(listerner.ActiveConnections()).Equals(0) listerner.Close() + close(conns) } diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 2aa27dae..65e90ce0 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -80,18 +80,18 @@ func (o *ServerConnection) Id() internal.ConnectionID { // Listener defines a server listening for connections type Listener struct { sync.Mutex - closed chan bool - sessions map[ConnectionID]*Connection - awaitingConns chan *Connection - hub *udp.Hub - tlsConfig *tls.Config - config *Config - reader PacketReader - header internet.PacketHeader - security cipher.AEAD + closed chan bool + sessions map[ConnectionID]*Connection + hub *udp.Hub + tlsConfig *tls.Config + config *Config + reader PacketReader + header internet.PacketHeader + security cipher.AEAD + conns chan<- internet.Connection } -func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) { +func NewListener(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (*Listener, error) { networkSettings := internet.TransportSettingsFromContext(ctx) kcpSettings := networkSettings.(*Config) kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false} @@ -111,10 +111,10 @@ func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (* Header: header, Security: security, }, - sessions: make(map[ConnectionID]*Connection), - awaitingConns: make(chan *Connection, 64), - closed: make(chan bool), - config: kcpSettings, + sessions: make(map[ConnectionID]*Connection), + closed: make(chan bool), + config: kcpSettings, + conns: conns, } securitySettings := internet.SecuritySettingsFromContext(ctx) if securitySettings != nil { @@ -194,8 +194,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina closer: writer, } conn = NewConnection(conv, sConn, v, v.config) + var netConn internet.Connection = conn + if v.tlsConfig != nil { + tlsConn := tls.Server(conn, v.tlsConfig) + netConn = UnreusableConnection{Conn: tlsConn} + } + select { - case v.awaitingConns <- conn: + case v.conns <- netConn: case <-time.After(time.Second * 5): conn.Close() return @@ -216,27 +222,6 @@ func (v *Listener) Remove(id ConnectionID) { } } -// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. -func (v *Listener) Accept() (internet.Connection, error) { - for { - select { - case <-v.closed: - return nil, ErrClosedListener - case conn, open := <-v.awaitingConns: - if !open { - break - } - if v.tlsConfig != nil { - tlsConn := tls.Server(conn, v.tlsConfig) - return UnreusableConnection{Conn: tlsConn}, nil - } - return conn, nil - case <-time.After(time.Second): - - } - } -} - // Close stops listening on the UDP address. Already Accepted connections are not closed. func (v *Listener) Close() error { @@ -249,7 +234,6 @@ func (v *Listener) Close() error { } close(v.closed) - close(v.awaitingConns) for _, conn := range v.sessions { go conn.Terminate() } @@ -288,8 +272,8 @@ func (v *Writer) Close() error { return nil } -func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { - return NewListener(ctx, address, port) +func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) { + return NewListener(ctx, address, port, conns) } func init() { diff --git a/transport/internet/system_dialer.go b/transport/internet/system_dialer.go index c843d9a0..68fe3e52 100644 --- a/transport/internet/system_dialer.go +++ b/transport/internet/system_dialer.go @@ -1,11 +1,10 @@ package internet import ( + "context" "net" "time" - "context" - v2net "v2ray.com/core/common/net" ) diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 65a69ee5..2a43891f 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -20,22 +20,17 @@ var ( ErrClosedListener = errors.New("Listener is closed.") ) -type ConnectionWithError struct { - conn net.Conn - err error -} - type TCPListener struct { sync.Mutex - acccepting bool - listener *net.TCPListener - awaitingConns chan *ConnectionWithError - tlsConfig *tls.Config - authConfig internet.ConnectionAuthenticator - config *Config + acccepting bool + listener *net.TCPListener + tlsConfig *tls.Config + authConfig internet.ConnectionAuthenticator + config *Config + conns chan<- internet.Connection } -func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { +func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) { listener, err := net.ListenTCP("tcp", &net.TCPAddr{ IP: address.IP(), Port: int(port), @@ -48,10 +43,10 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int tcpSettings := networkSettings.(*Config) l := &TCPListener{ - acccepting: true, - listener: listener, - awaitingConns: make(chan *ConnectionWithError, 32), - config: tcpSettings, + acccepting: true, + listener: listener, + config: tcpSettings, + conns: conns, } if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { tlsConfig, ok := securitySettings.(*v2tls.Config) @@ -74,24 +69,6 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int return l, nil } -func (v *TCPListener) Accept() (internet.Connection, error) { - for v.acccepting { - select { - case connErr, open := <-v.awaitingConns: - if !open { - return nil, ErrClosedListener - } - if connErr.err != nil { - return nil, connErr.err - } - conn := connErr.conn - return internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())), nil - case <-time.After(time.Second * 2): - } - } - return nil, ErrClosedListener -} - func (v *TCPListener) KeepAccepting() { for v.acccepting { conn, err := v.listener.Accept() @@ -100,22 +77,22 @@ func (v *TCPListener) KeepAccepting() { v.Unlock() break } - if conn != nil && v.tlsConfig != nil { + if err != nil { + log.Warning("TCP|Listener: Failed to accepted raw connections: ", err) + v.Unlock() + continue + } + if v.tlsConfig != nil { conn = tls.Server(conn, v.tlsConfig) } - if conn != nil && v.authConfig != nil { + if v.authConfig != nil { conn = v.authConfig.Server(conn) } select { - case v.awaitingConns <- &ConnectionWithError{ - conn: conn, - err: err, - }: - default: - if conn != nil { - conn.Close() - } + case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())): + case <-time.After(time.Second * 5): + conn.Close() } v.Unlock() @@ -129,8 +106,8 @@ func (v *TCPListener) Put(id internal.ConnectionID, conn net.Conn) { return } select { - case v.awaitingConns <- &ConnectionWithError{conn: conn}: - default: + case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())): + case <-time.After(time.Second * 5): conn.Close() } } @@ -144,12 +121,6 @@ func (v *TCPListener) Close() error { defer v.Unlock() v.acccepting = false v.listener.Close() - close(v.awaitingConns) - for connErr := range v.awaitingConns { - if connErr.conn != nil { - connErr.conn.Close() - } - } return nil } diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 5ab5d3bb..4d996b1a 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -1,14 +1,11 @@ package internet import ( + "context" "net" - "context" - - "v2ray.com/core/app/log" "v2ray.com/core/common/errors" v2net "v2ray.com/core/common/net" - "v2ray.com/core/common/retry" ) var ( @@ -23,10 +20,9 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc) return nil } -type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error) +type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error) type Listener interface { - Accept() (Connection, error) Close() error Addr() net.Addr } @@ -37,8 +33,8 @@ type TCPHub struct { closed chan bool } -func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) { - ctx := context.Background() +func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error) { + settings := StreamSettingsFromContext(ctx) protocol := settings.GetEffectiveProtocol() transportSettings, err := settings.GetEffectiveTransportSettings() if err != nil { @@ -56,65 +52,9 @@ func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandle if listenFunc == nil { return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.") } - listener, err := listenFunc(ctx, address, port) + listener, err := listenFunc(ctx, address, port, conns) if err != nil { return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port) } - - hub := &TCPHub{ - listener: listener, - connCallback: callback, - } - - go hub.start() - return hub, nil -} - -func (v *TCPHub) Close() { - defer func() { - recover() - }() - - select { - case <-v.closed: - return - default: - v.listener.Close() - close(v.closed) - } -} - -func (v *TCPHub) start() { - for { - select { - case <-v.closed: - return - default: - } - var newConn Connection - err := retry.ExponentialBackoff(10, 500).On(func() error { - select { - case <-v.closed: - return nil - default: - conn, err := v.listener.Accept() - if err != nil { - return errors.Base(err).RequireUserAction().Message("Internet|Listener: Failed to accept new TCP connection.") - } - newConn = conn - return nil - } - }) - if err != nil { - if errors.IsActionRequired(err) { - log.Warning(err) - } else { - log.Info(err) - } - continue - } - if newConn != nil { - go v.connCallback(newConn) - } - } + return listener, nil } diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index f7d1ebb4..0031bb99 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -23,14 +23,9 @@ var ( ErrClosedListener = errors.New("Listener is closed.") ) -type ConnectionWithError struct { - conn net.Conn - err error -} - type requestHandler struct { - path string - conns chan *ConnectionWithError + path string + ln *Listener } func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { @@ -45,29 +40,29 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } select { - case h.conns <- &ConnectionWithError{conn: conn}: - default: + case h.ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, h.ln, internal.ReuseConnection(h.ln.config.IsConnectionReuse())): + case <-time.After(time.Second * 5): conn.Close() } } type Listener struct { sync.Mutex - closed chan bool - awaitingConns chan *ConnectionWithError - listener net.Listener - tlsConfig *tls.Config - config *Config + closed chan bool + listener net.Listener + tlsConfig *tls.Config + config *Config + conns chan<- internet.Connection } -func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { +func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) { networkSettings := internet.TransportSettingsFromContext(ctx) wsSettings := networkSettings.(*Config) l := &Listener{ - closed: make(chan bool), - awaitingConns: make(chan *ConnectionWithError, 32), - config: wsSettings, + closed: make(chan bool), + config: wsSettings, + conns: conns, } if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { tlsConfig, ok := securitySettings.(*v2tls.Config) @@ -101,8 +96,8 @@ func (ln *Listener) listenws(address v2net.Address, port v2net.Port) error { go func() { http.Serve(listener, &requestHandler{ - path: ln.config.GetNormailzedPath(), - conns: ln.awaitingConns, + path: ln.config.GetNormailzedPath(), + ln: ln, }) }() @@ -123,24 +118,6 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) { return &connection{wsc: conn}, nil } -func (ln *Listener) Accept() (internet.Connection, error) { - for { - select { - case <-ln.closed: - return nil, ErrClosedListener - case connErr, open := <-ln.awaitingConns: - if !open { - return nil, ErrClosedListener - } - if connErr.err != nil { - return nil, connErr.err - } - return internal.NewConnection(internal.ConnectionID{}, connErr.conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())), nil - case <-time.After(time.Second * 2): - } - } -} - func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) { ln.Lock() defer ln.Unlock() @@ -150,8 +127,8 @@ func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) { default: } select { - case ln.awaitingConns <- &ConnectionWithError{conn: conn}: - default: + case ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())): + case <-time.After(time.Second * 5): conn.Close() } } @@ -170,12 +147,6 @@ func (ln *Listener) Close() error { } close(ln.closed) ln.listener.Close() - close(ln.awaitingConns) - for connErr := range ln.awaitingConns { - if connErr.conn != nil { - connErr.conn.Close() - } - } return nil } diff --git a/transport/internet/websocket/ws_test.go b/transport/internet/websocket/ws_test.go index 3c6124a1..6a8f0689 100644 --- a/transport/internet/websocket/ws_test.go +++ b/transport/internet/websocket/ws_test.go @@ -16,31 +16,28 @@ import ( func Test_listenWSAndDial(t *testing.T) { assert := assert.On(t) + conns := make(chan internet.Connection, 16) listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{ Path: "ws", - }), v2net.DomainAddress("localhost"), 13146) + }), v2net.DomainAddress("localhost"), 13146, conns) assert.Error(err).IsNil() go func() { - for { - conn, err := listen.Accept() - if err != nil { - break - } - go func() { - defer conn.Close() + for conn := range conns { + go func(c internet.Connection) { + defer c.Close() var b [1024]byte - n, err := conn.Read(b[:]) + n, err := c.Read(b[:]) //assert.Error(err).IsNil() if err != nil { - conn.SetReusable(false) + c.SetReusable(false) return } assert.Bool(bytes.HasPrefix(b[:n], []byte("Test connection"))).IsTrue() - _, err = conn.Write([]byte("Response")) + _, err = c.Write([]byte("Response")) assert.Error(err).IsNil() - }() + }(conn) } }() @@ -77,6 +74,8 @@ func Test_listenWSAndDial(t *testing.T) { assert.Error(conn.Close()).IsNil() assert.Error(listen.Close()).IsNil() + + close(conns) } func Test_listenWSAndDial_TLS(t *testing.T) { @@ -96,11 +95,11 @@ func Test_listenWSAndDial_TLS(t *testing.T) { AllowInsecure: true, Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()}, }) - listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143) + conns := make(chan internet.Connection, 16) + listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143, conns) assert.Error(err).IsNil() go func() { - conn, err := listen.Accept() - assert.Error(err).IsNil() + conn := <-conns conn.Close() listen.Close() }()