refactor tcp worker

pull/432/head
Darien Raymond 2017-02-26 14:38:41 +01:00
parent 122461647a
commit d93ff628bc
No known key found for this signature in database
GPG Key ID: 7251FFA14BB18169
10 changed files with 127 additions and 238 deletions

View File

@ -37,7 +37,7 @@ type tcpWorker struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
hub *internet.TCPHub hub internet.Listener
} }
func (w *tcpWorker) callback(conn internet.Connection) { func (w *tcpWorker) callback(conn internet.Connection) {
@ -73,17 +73,41 @@ func (w *tcpWorker) Start() error {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.ctx = ctx w.ctx = ctx
w.cancel = cancel w.cancel = cancel
hub, err := internet.ListenTCP(w.address, w.port, w.callback, w.stream) ctx = internet.ContextWithStreamSettings(ctx, w.stream)
conns := make(chan internet.Connection, 16)
hub, err := internet.ListenTCP(ctx, w.address, w.port, conns)
if err != nil { if err != nil {
return err return err
} }
go w.handleConnections(conns)
w.hub = hub w.hub = hub
return nil return nil
} }
func (w *tcpWorker) handleConnections(conns <-chan internet.Connection) {
for {
select {
case <-w.ctx.Done():
w.hub.Close()
nconns := len(conns)
L:
for i := 0; i < nconns; i++ {
select {
case conn := <-conns:
conn.Close()
default:
break L
}
}
return
case conn := <-conns:
go w.callback(conn)
}
}
}
func (w *tcpWorker) Close() { func (w *tcpWorker) Close() {
if w.hub != nil { if w.hub != nil {
w.hub.Close()
w.cancel() w.cancel()
} }
} }

View File

@ -19,9 +19,12 @@ func ContextWithStreamSettings(ctx context.Context, streamSettings *StreamConfig
return context.WithValue(ctx, streamSettingsKey, streamSettings) return context.WithValue(ctx, streamSettingsKey, streamSettings)
} }
func StreamSettingsFromContext(ctx context.Context) (*StreamConfig, bool) { func StreamSettingsFromContext(ctx context.Context) *StreamConfig {
ss, ok := ctx.Value(streamSettingsKey).(*StreamConfig) ss := ctx.Value(streamSettingsKey)
return ss, ok if ss == nil {
return nil
}
return ss.(*StreamConfig)
} }
func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context { func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context {

View File

@ -24,7 +24,7 @@ func RegisterTransportDialer(protocol TransportProtocol, dialer Dialer) error {
func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) { func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) {
if dest.Network == v2net.Network_TCP { if dest.Network == v2net.Network_TCP {
streamSettings, _ := StreamSettingsFromContext(ctx) streamSettings := StreamSettingsFromContext(ctx)
protocol := streamSettings.GetEffectiveProtocol() protocol := streamSettings.GetEffectiveProtocol()
transportSettings, err := streamSettings.GetEffectiveTransportSettings() transportSettings, err := streamSettings.GetEffectiveTransportSettings()
if err != nil { if err != nil {

View File

@ -18,30 +18,27 @@ import (
func TestDialAndListen(t *testing.T) { func TestDialAndListen(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0)) conns := make(chan internet.Connection, 16)
listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0), conns)
assert.Error(err).IsNil() assert.Error(err).IsNil()
port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port) port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)
go func() { go func() {
for { for conn := range conns {
conn, err := listerner.Accept() go func(c internet.Connection) {
if err != nil {
break
}
go func() {
payload := make([]byte, 4096) payload := make([]byte, 4096)
for { for {
nBytes, err := conn.Read(payload) nBytes, err := c.Read(payload)
if err != nil { if err != nil {
break break
} }
for idx, b := range payload[:nBytes] { for idx, b := range payload[:nBytes] {
payload[idx] = b ^ 'c' payload[idx] = b ^ 'c'
} }
conn.Write(payload[:nBytes]) c.Write(payload[:nBytes])
} }
conn.Close() c.Close()
}() }(conn)
} }
}() }()
@ -79,4 +76,5 @@ func TestDialAndListen(t *testing.T) {
assert.Int(listerner.ActiveConnections()).Equals(0) assert.Int(listerner.ActiveConnections()).Equals(0)
listerner.Close() listerner.Close()
close(conns)
} }

View File

@ -80,18 +80,18 @@ func (o *ServerConnection) Id() internal.ConnectionID {
// Listener defines a server listening for connections // Listener defines a server listening for connections
type Listener struct { type Listener struct {
sync.Mutex sync.Mutex
closed chan bool closed chan bool
sessions map[ConnectionID]*Connection sessions map[ConnectionID]*Connection
awaitingConns chan *Connection hub *udp.Hub
hub *udp.Hub tlsConfig *tls.Config
tlsConfig *tls.Config config *Config
config *Config reader PacketReader
reader PacketReader header internet.PacketHeader
header internet.PacketHeader security cipher.AEAD
security cipher.AEAD conns chan<- internet.Connection
} }
func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) { func NewListener(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (*Listener, error) {
networkSettings := internet.TransportSettingsFromContext(ctx) networkSettings := internet.TransportSettingsFromContext(ctx)
kcpSettings := networkSettings.(*Config) kcpSettings := networkSettings.(*Config)
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false} kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
@ -111,10 +111,10 @@ func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*
Header: header, Header: header,
Security: security, Security: security,
}, },
sessions: make(map[ConnectionID]*Connection), sessions: make(map[ConnectionID]*Connection),
awaitingConns: make(chan *Connection, 64), closed: make(chan bool),
closed: make(chan bool), config: kcpSettings,
config: kcpSettings, conns: conns,
} }
securitySettings := internet.SecuritySettingsFromContext(ctx) securitySettings := internet.SecuritySettingsFromContext(ctx)
if securitySettings != nil { if securitySettings != nil {
@ -194,8 +194,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
closer: writer, closer: writer,
} }
conn = NewConnection(conv, sConn, v, v.config) conn = NewConnection(conv, sConn, v, v.config)
var netConn internet.Connection = conn
if v.tlsConfig != nil {
tlsConn := tls.Server(conn, v.tlsConfig)
netConn = UnreusableConnection{Conn: tlsConn}
}
select { select {
case v.awaitingConns <- conn: case v.conns <- netConn:
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
conn.Close() conn.Close()
return return
@ -216,27 +222,6 @@ func (v *Listener) Remove(id ConnectionID) {
} }
} }
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
func (v *Listener) Accept() (internet.Connection, error) {
for {
select {
case <-v.closed:
return nil, ErrClosedListener
case conn, open := <-v.awaitingConns:
if !open {
break
}
if v.tlsConfig != nil {
tlsConn := tls.Server(conn, v.tlsConfig)
return UnreusableConnection{Conn: tlsConn}, nil
}
return conn, nil
case <-time.After(time.Second):
}
}
}
// Close stops listening on the UDP address. Already Accepted connections are not closed. // Close stops listening on the UDP address. Already Accepted connections are not closed.
func (v *Listener) Close() error { func (v *Listener) Close() error {
@ -249,7 +234,6 @@ func (v *Listener) Close() error {
} }
close(v.closed) close(v.closed)
close(v.awaitingConns)
for _, conn := range v.sessions { for _, conn := range v.sessions {
go conn.Terminate() go conn.Terminate()
} }
@ -288,8 +272,8 @@ func (v *Writer) Close() error {
return nil return nil
} }
func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
return NewListener(ctx, address, port) return NewListener(ctx, address, port, conns)
} }
func init() { func init() {

View File

@ -1,11 +1,10 @@
package internet package internet
import ( import (
"context"
"net" "net"
"time" "time"
"context"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
) )

View File

@ -20,22 +20,17 @@ var (
ErrClosedListener = errors.New("Listener is closed.") ErrClosedListener = errors.New("Listener is closed.")
) )
type ConnectionWithError struct {
conn net.Conn
err error
}
type TCPListener struct { type TCPListener struct {
sync.Mutex sync.Mutex
acccepting bool acccepting bool
listener *net.TCPListener listener *net.TCPListener
awaitingConns chan *ConnectionWithError tlsConfig *tls.Config
tlsConfig *tls.Config authConfig internet.ConnectionAuthenticator
authConfig internet.ConnectionAuthenticator config *Config
config *Config conns chan<- internet.Connection
} }
func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{ listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: address.IP(), IP: address.IP(),
Port: int(port), Port: int(port),
@ -48,10 +43,10 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int
tcpSettings := networkSettings.(*Config) tcpSettings := networkSettings.(*Config)
l := &TCPListener{ l := &TCPListener{
acccepting: true, acccepting: true,
listener: listener, listener: listener,
awaitingConns: make(chan *ConnectionWithError, 32), config: tcpSettings,
config: tcpSettings, conns: conns,
} }
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*v2tls.Config) tlsConfig, ok := securitySettings.(*v2tls.Config)
@ -74,24 +69,6 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int
return l, nil return l, nil
} }
func (v *TCPListener) Accept() (internet.Connection, error) {
for v.acccepting {
select {
case connErr, open := <-v.awaitingConns:
if !open {
return nil, ErrClosedListener
}
if connErr.err != nil {
return nil, connErr.err
}
conn := connErr.conn
return internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())), nil
case <-time.After(time.Second * 2):
}
}
return nil, ErrClosedListener
}
func (v *TCPListener) KeepAccepting() { func (v *TCPListener) KeepAccepting() {
for v.acccepting { for v.acccepting {
conn, err := v.listener.Accept() conn, err := v.listener.Accept()
@ -100,22 +77,22 @@ func (v *TCPListener) KeepAccepting() {
v.Unlock() v.Unlock()
break break
} }
if conn != nil && v.tlsConfig != nil { if err != nil {
log.Warning("TCP|Listener: Failed to accepted raw connections: ", err)
v.Unlock()
continue
}
if v.tlsConfig != nil {
conn = tls.Server(conn, v.tlsConfig) conn = tls.Server(conn, v.tlsConfig)
} }
if conn != nil && v.authConfig != nil { if v.authConfig != nil {
conn = v.authConfig.Server(conn) conn = v.authConfig.Server(conn)
} }
select { select {
case v.awaitingConns <- &ConnectionWithError{ case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())):
conn: conn, case <-time.After(time.Second * 5):
err: err, conn.Close()
}:
default:
if conn != nil {
conn.Close()
}
} }
v.Unlock() v.Unlock()
@ -129,8 +106,8 @@ func (v *TCPListener) Put(id internal.ConnectionID, conn net.Conn) {
return return
} }
select { select {
case v.awaitingConns <- &ConnectionWithError{conn: conn}: case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())):
default: case <-time.After(time.Second * 5):
conn.Close() conn.Close()
} }
} }
@ -144,12 +121,6 @@ func (v *TCPListener) Close() error {
defer v.Unlock() defer v.Unlock()
v.acccepting = false v.acccepting = false
v.listener.Close() v.listener.Close()
close(v.awaitingConns)
for connErr := range v.awaitingConns {
if connErr.conn != nil {
connErr.conn.Close()
}
}
return nil return nil
} }

View File

@ -1,14 +1,11 @@
package internet package internet
import ( import (
"context"
"net" "net"
"context"
"v2ray.com/core/app/log"
"v2ray.com/core/common/errors" "v2ray.com/core/common/errors"
v2net "v2ray.com/core/common/net" v2net "v2ray.com/core/common/net"
"v2ray.com/core/common/retry"
) )
var ( var (
@ -23,10 +20,9 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc)
return nil return nil
} }
type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error) type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error)
type Listener interface { type Listener interface {
Accept() (Connection, error)
Close() error Close() error
Addr() net.Addr Addr() net.Addr
} }
@ -37,8 +33,8 @@ type TCPHub struct {
closed chan bool closed chan bool
} }
func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) { func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error) {
ctx := context.Background() settings := StreamSettingsFromContext(ctx)
protocol := settings.GetEffectiveProtocol() protocol := settings.GetEffectiveProtocol()
transportSettings, err := settings.GetEffectiveTransportSettings() transportSettings, err := settings.GetEffectiveTransportSettings()
if err != nil { if err != nil {
@ -56,65 +52,9 @@ func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandle
if listenFunc == nil { if listenFunc == nil {
return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.") return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
} }
listener, err := listenFunc(ctx, address, port) listener, err := listenFunc(ctx, address, port, conns)
if err != nil { if err != nil {
return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port) return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
} }
return listener, nil
hub := &TCPHub{
listener: listener,
connCallback: callback,
}
go hub.start()
return hub, nil
}
func (v *TCPHub) Close() {
defer func() {
recover()
}()
select {
case <-v.closed:
return
default:
v.listener.Close()
close(v.closed)
}
}
func (v *TCPHub) start() {
for {
select {
case <-v.closed:
return
default:
}
var newConn Connection
err := retry.ExponentialBackoff(10, 500).On(func() error {
select {
case <-v.closed:
return nil
default:
conn, err := v.listener.Accept()
if err != nil {
return errors.Base(err).RequireUserAction().Message("Internet|Listener: Failed to accept new TCP connection.")
}
newConn = conn
return nil
}
})
if err != nil {
if errors.IsActionRequired(err) {
log.Warning(err)
} else {
log.Info(err)
}
continue
}
if newConn != nil {
go v.connCallback(newConn)
}
}
} }

View File

@ -23,14 +23,9 @@ var (
ErrClosedListener = errors.New("Listener is closed.") ErrClosedListener = errors.New("Listener is closed.")
) )
type ConnectionWithError struct {
conn net.Conn
err error
}
type requestHandler struct { type requestHandler struct {
path string path string
conns chan *ConnectionWithError ln *Listener
} }
func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
@ -45,29 +40,29 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
} }
select { select {
case h.conns <- &ConnectionWithError{conn: conn}: case h.ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, h.ln, internal.ReuseConnection(h.ln.config.IsConnectionReuse())):
default: case <-time.After(time.Second * 5):
conn.Close() conn.Close()
} }
} }
type Listener struct { type Listener struct {
sync.Mutex sync.Mutex
closed chan bool closed chan bool
awaitingConns chan *ConnectionWithError listener net.Listener
listener net.Listener tlsConfig *tls.Config
tlsConfig *tls.Config config *Config
config *Config conns chan<- internet.Connection
} }
func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) { func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
networkSettings := internet.TransportSettingsFromContext(ctx) networkSettings := internet.TransportSettingsFromContext(ctx)
wsSettings := networkSettings.(*Config) wsSettings := networkSettings.(*Config)
l := &Listener{ l := &Listener{
closed: make(chan bool), closed: make(chan bool),
awaitingConns: make(chan *ConnectionWithError, 32), config: wsSettings,
config: wsSettings, conns: conns,
} }
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
tlsConfig, ok := securitySettings.(*v2tls.Config) tlsConfig, ok := securitySettings.(*v2tls.Config)
@ -101,8 +96,8 @@ func (ln *Listener) listenws(address v2net.Address, port v2net.Port) error {
go func() { go func() {
http.Serve(listener, &requestHandler{ http.Serve(listener, &requestHandler{
path: ln.config.GetNormailzedPath(), path: ln.config.GetNormailzedPath(),
conns: ln.awaitingConns, ln: ln,
}) })
}() }()
@ -123,24 +118,6 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) {
return &connection{wsc: conn}, nil return &connection{wsc: conn}, nil
} }
func (ln *Listener) Accept() (internet.Connection, error) {
for {
select {
case <-ln.closed:
return nil, ErrClosedListener
case connErr, open := <-ln.awaitingConns:
if !open {
return nil, ErrClosedListener
}
if connErr.err != nil {
return nil, connErr.err
}
return internal.NewConnection(internal.ConnectionID{}, connErr.conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())), nil
case <-time.After(time.Second * 2):
}
}
}
func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) { func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
ln.Lock() ln.Lock()
defer ln.Unlock() defer ln.Unlock()
@ -150,8 +127,8 @@ func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
default: default:
} }
select { select {
case ln.awaitingConns <- &ConnectionWithError{conn: conn}: case ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())):
default: case <-time.After(time.Second * 5):
conn.Close() conn.Close()
} }
} }
@ -170,12 +147,6 @@ func (ln *Listener) Close() error {
} }
close(ln.closed) close(ln.closed)
ln.listener.Close() ln.listener.Close()
close(ln.awaitingConns)
for connErr := range ln.awaitingConns {
if connErr.conn != nil {
connErr.conn.Close()
}
}
return nil return nil
} }

View File

@ -16,31 +16,28 @@ import (
func Test_listenWSAndDial(t *testing.T) { func Test_listenWSAndDial(t *testing.T) {
assert := assert.On(t) assert := assert.On(t)
conns := make(chan internet.Connection, 16)
listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{ listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
Path: "ws", Path: "ws",
}), v2net.DomainAddress("localhost"), 13146) }), v2net.DomainAddress("localhost"), 13146, conns)
assert.Error(err).IsNil() assert.Error(err).IsNil()
go func() { go func() {
for { for conn := range conns {
conn, err := listen.Accept() go func(c internet.Connection) {
if err != nil { defer c.Close()
break
}
go func() {
defer conn.Close()
var b [1024]byte var b [1024]byte
n, err := conn.Read(b[:]) n, err := c.Read(b[:])
//assert.Error(err).IsNil() //assert.Error(err).IsNil()
if err != nil { if err != nil {
conn.SetReusable(false) c.SetReusable(false)
return return
} }
assert.Bool(bytes.HasPrefix(b[:n], []byte("Test connection"))).IsTrue() assert.Bool(bytes.HasPrefix(b[:n], []byte("Test connection"))).IsTrue()
_, err = conn.Write([]byte("Response")) _, err = c.Write([]byte("Response"))
assert.Error(err).IsNil() assert.Error(err).IsNil()
}() }(conn)
} }
}() }()
@ -77,6 +74,8 @@ func Test_listenWSAndDial(t *testing.T) {
assert.Error(conn.Close()).IsNil() assert.Error(conn.Close()).IsNil()
assert.Error(listen.Close()).IsNil() assert.Error(listen.Close()).IsNil()
close(conns)
} }
func Test_listenWSAndDial_TLS(t *testing.T) { func Test_listenWSAndDial_TLS(t *testing.T) {
@ -96,11 +95,11 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
AllowInsecure: true, AllowInsecure: true,
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()}, Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
}) })
listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143) conns := make(chan internet.Connection, 16)
listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143, conns)
assert.Error(err).IsNil() assert.Error(err).IsNil()
go func() { go func() {
conn, err := listen.Accept() conn := <-conns
assert.Error(err).IsNil()
conn.Close() conn.Close()
listen.Close() listen.Close()
}() }()