mirror of https://github.com/v2ray/v2ray-core
refactor tcp worker
parent
122461647a
commit
d93ff628bc
|
@ -37,7 +37,7 @@ type tcpWorker struct {
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
hub *internet.TCPHub
|
hub internet.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *tcpWorker) callback(conn internet.Connection) {
|
func (w *tcpWorker) callback(conn internet.Connection) {
|
||||||
|
@ -73,17 +73,41 @@ func (w *tcpWorker) Start() error {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
w.ctx = ctx
|
w.ctx = ctx
|
||||||
w.cancel = cancel
|
w.cancel = cancel
|
||||||
hub, err := internet.ListenTCP(w.address, w.port, w.callback, w.stream)
|
ctx = internet.ContextWithStreamSettings(ctx, w.stream)
|
||||||
|
conns := make(chan internet.Connection, 16)
|
||||||
|
hub, err := internet.ListenTCP(ctx, w.address, w.port, conns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
go w.handleConnections(conns)
|
||||||
w.hub = hub
|
w.hub = hub
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *tcpWorker) handleConnections(conns <-chan internet.Connection) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-w.ctx.Done():
|
||||||
|
w.hub.Close()
|
||||||
|
nconns := len(conns)
|
||||||
|
L:
|
||||||
|
for i := 0; i < nconns; i++ {
|
||||||
|
select {
|
||||||
|
case conn := <-conns:
|
||||||
|
conn.Close()
|
||||||
|
default:
|
||||||
|
break L
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case conn := <-conns:
|
||||||
|
go w.callback(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *tcpWorker) Close() {
|
func (w *tcpWorker) Close() {
|
||||||
if w.hub != nil {
|
if w.hub != nil {
|
||||||
w.hub.Close()
|
|
||||||
w.cancel()
|
w.cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,12 @@ func ContextWithStreamSettings(ctx context.Context, streamSettings *StreamConfig
|
||||||
return context.WithValue(ctx, streamSettingsKey, streamSettings)
|
return context.WithValue(ctx, streamSettingsKey, streamSettings)
|
||||||
}
|
}
|
||||||
|
|
||||||
func StreamSettingsFromContext(ctx context.Context) (*StreamConfig, bool) {
|
func StreamSettingsFromContext(ctx context.Context) *StreamConfig {
|
||||||
ss, ok := ctx.Value(streamSettingsKey).(*StreamConfig)
|
ss := ctx.Value(streamSettingsKey)
|
||||||
return ss, ok
|
if ss == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ss.(*StreamConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context {
|
func ContextWithDialerSource(ctx context.Context, addr net.Address) context.Context {
|
||||||
|
|
|
@ -24,7 +24,7 @@ func RegisterTransportDialer(protocol TransportProtocol, dialer Dialer) error {
|
||||||
|
|
||||||
func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) {
|
func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) {
|
||||||
if dest.Network == v2net.Network_TCP {
|
if dest.Network == v2net.Network_TCP {
|
||||||
streamSettings, _ := StreamSettingsFromContext(ctx)
|
streamSettings := StreamSettingsFromContext(ctx)
|
||||||
protocol := streamSettings.GetEffectiveProtocol()
|
protocol := streamSettings.GetEffectiveProtocol()
|
||||||
transportSettings, err := streamSettings.GetEffectiveTransportSettings()
|
transportSettings, err := streamSettings.GetEffectiveTransportSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -18,30 +18,27 @@ import (
|
||||||
func TestDialAndListen(t *testing.T) {
|
func TestDialAndListen(t *testing.T) {
|
||||||
assert := assert.On(t)
|
assert := assert.On(t)
|
||||||
|
|
||||||
listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0))
|
conns := make(chan internet.Connection, 16)
|
||||||
|
listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0), conns)
|
||||||
assert.Error(err).IsNil()
|
assert.Error(err).IsNil()
|
||||||
port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)
|
port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for conn := range conns {
|
||||||
conn, err := listerner.Accept()
|
go func(c internet.Connection) {
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
payload := make([]byte, 4096)
|
payload := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
nBytes, err := conn.Read(payload)
|
nBytes, err := c.Read(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
for idx, b := range payload[:nBytes] {
|
for idx, b := range payload[:nBytes] {
|
||||||
payload[idx] = b ^ 'c'
|
payload[idx] = b ^ 'c'
|
||||||
}
|
}
|
||||||
conn.Write(payload[:nBytes])
|
c.Write(payload[:nBytes])
|
||||||
}
|
}
|
||||||
conn.Close()
|
c.Close()
|
||||||
}()
|
}(conn)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -79,4 +76,5 @@ func TestDialAndListen(t *testing.T) {
|
||||||
assert.Int(listerner.ActiveConnections()).Equals(0)
|
assert.Int(listerner.ActiveConnections()).Equals(0)
|
||||||
|
|
||||||
listerner.Close()
|
listerner.Close()
|
||||||
|
close(conns)
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,18 +80,18 @@ func (o *ServerConnection) Id() internal.ConnectionID {
|
||||||
// Listener defines a server listening for connections
|
// Listener defines a server listening for connections
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
closed chan bool
|
closed chan bool
|
||||||
sessions map[ConnectionID]*Connection
|
sessions map[ConnectionID]*Connection
|
||||||
awaitingConns chan *Connection
|
hub *udp.Hub
|
||||||
hub *udp.Hub
|
tlsConfig *tls.Config
|
||||||
tlsConfig *tls.Config
|
config *Config
|
||||||
config *Config
|
reader PacketReader
|
||||||
reader PacketReader
|
header internet.PacketHeader
|
||||||
header internet.PacketHeader
|
security cipher.AEAD
|
||||||
security cipher.AEAD
|
conns chan<- internet.Connection
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) {
|
func NewListener(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (*Listener, error) {
|
||||||
networkSettings := internet.TransportSettingsFromContext(ctx)
|
networkSettings := internet.TransportSettingsFromContext(ctx)
|
||||||
kcpSettings := networkSettings.(*Config)
|
kcpSettings := networkSettings.(*Config)
|
||||||
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
|
kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
|
||||||
|
@ -111,10 +111,10 @@ func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*
|
||||||
Header: header,
|
Header: header,
|
||||||
Security: security,
|
Security: security,
|
||||||
},
|
},
|
||||||
sessions: make(map[ConnectionID]*Connection),
|
sessions: make(map[ConnectionID]*Connection),
|
||||||
awaitingConns: make(chan *Connection, 64),
|
closed: make(chan bool),
|
||||||
closed: make(chan bool),
|
config: kcpSettings,
|
||||||
config: kcpSettings,
|
conns: conns,
|
||||||
}
|
}
|
||||||
securitySettings := internet.SecuritySettingsFromContext(ctx)
|
securitySettings := internet.SecuritySettingsFromContext(ctx)
|
||||||
if securitySettings != nil {
|
if securitySettings != nil {
|
||||||
|
@ -194,8 +194,14 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
|
||||||
closer: writer,
|
closer: writer,
|
||||||
}
|
}
|
||||||
conn = NewConnection(conv, sConn, v, v.config)
|
conn = NewConnection(conv, sConn, v, v.config)
|
||||||
|
var netConn internet.Connection = conn
|
||||||
|
if v.tlsConfig != nil {
|
||||||
|
tlsConn := tls.Server(conn, v.tlsConfig)
|
||||||
|
netConn = UnreusableConnection{Conn: tlsConn}
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case v.awaitingConns <- conn:
|
case v.conns <- netConn:
|
||||||
case <-time.After(time.Second * 5):
|
case <-time.After(time.Second * 5):
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return
|
return
|
||||||
|
@ -216,27 +222,6 @@ func (v *Listener) Remove(id ConnectionID) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
|
|
||||||
func (v *Listener) Accept() (internet.Connection, error) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-v.closed:
|
|
||||||
return nil, ErrClosedListener
|
|
||||||
case conn, open := <-v.awaitingConns:
|
|
||||||
if !open {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if v.tlsConfig != nil {
|
|
||||||
tlsConn := tls.Server(conn, v.tlsConfig)
|
|
||||||
return UnreusableConnection{Conn: tlsConn}, nil
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops listening on the UDP address. Already Accepted connections are not closed.
|
// Close stops listening on the UDP address. Already Accepted connections are not closed.
|
||||||
func (v *Listener) Close() error {
|
func (v *Listener) Close() error {
|
||||||
|
|
||||||
|
@ -249,7 +234,6 @@ func (v *Listener) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
close(v.closed)
|
close(v.closed)
|
||||||
close(v.awaitingConns)
|
|
||||||
for _, conn := range v.sessions {
|
for _, conn := range v.sessions {
|
||||||
go conn.Terminate()
|
go conn.Terminate()
|
||||||
}
|
}
|
||||||
|
@ -288,8 +272,8 @@ func (v *Writer) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
|
func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
|
||||||
return NewListener(ctx, address, port)
|
return NewListener(ctx, address, port, conns)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
package internet
|
package internet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"context"
|
|
||||||
|
|
||||||
v2net "v2ray.com/core/common/net"
|
v2net "v2ray.com/core/common/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,22 +20,17 @@ var (
|
||||||
ErrClosedListener = errors.New("Listener is closed.")
|
ErrClosedListener = errors.New("Listener is closed.")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionWithError struct {
|
|
||||||
conn net.Conn
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type TCPListener struct {
|
type TCPListener struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
acccepting bool
|
acccepting bool
|
||||||
listener *net.TCPListener
|
listener *net.TCPListener
|
||||||
awaitingConns chan *ConnectionWithError
|
tlsConfig *tls.Config
|
||||||
tlsConfig *tls.Config
|
authConfig internet.ConnectionAuthenticator
|
||||||
authConfig internet.ConnectionAuthenticator
|
config *Config
|
||||||
config *Config
|
conns chan<- internet.Connection
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
|
func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
|
||||||
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||||
IP: address.IP(),
|
IP: address.IP(),
|
||||||
Port: int(port),
|
Port: int(port),
|
||||||
|
@ -48,10 +43,10 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int
|
||||||
tcpSettings := networkSettings.(*Config)
|
tcpSettings := networkSettings.(*Config)
|
||||||
|
|
||||||
l := &TCPListener{
|
l := &TCPListener{
|
||||||
acccepting: true,
|
acccepting: true,
|
||||||
listener: listener,
|
listener: listener,
|
||||||
awaitingConns: make(chan *ConnectionWithError, 32),
|
config: tcpSettings,
|
||||||
config: tcpSettings,
|
conns: conns,
|
||||||
}
|
}
|
||||||
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
|
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
|
||||||
tlsConfig, ok := securitySettings.(*v2tls.Config)
|
tlsConfig, ok := securitySettings.(*v2tls.Config)
|
||||||
|
@ -74,24 +69,6 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (int
|
||||||
return l, nil
|
return l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *TCPListener) Accept() (internet.Connection, error) {
|
|
||||||
for v.acccepting {
|
|
||||||
select {
|
|
||||||
case connErr, open := <-v.awaitingConns:
|
|
||||||
if !open {
|
|
||||||
return nil, ErrClosedListener
|
|
||||||
}
|
|
||||||
if connErr.err != nil {
|
|
||||||
return nil, connErr.err
|
|
||||||
}
|
|
||||||
conn := connErr.conn
|
|
||||||
return internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())), nil
|
|
||||||
case <-time.After(time.Second * 2):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, ErrClosedListener
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *TCPListener) KeepAccepting() {
|
func (v *TCPListener) KeepAccepting() {
|
||||||
for v.acccepting {
|
for v.acccepting {
|
||||||
conn, err := v.listener.Accept()
|
conn, err := v.listener.Accept()
|
||||||
|
@ -100,22 +77,22 @@ func (v *TCPListener) KeepAccepting() {
|
||||||
v.Unlock()
|
v.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if conn != nil && v.tlsConfig != nil {
|
if err != nil {
|
||||||
|
log.Warning("TCP|Listener: Failed to accepted raw connections: ", err)
|
||||||
|
v.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if v.tlsConfig != nil {
|
||||||
conn = tls.Server(conn, v.tlsConfig)
|
conn = tls.Server(conn, v.tlsConfig)
|
||||||
}
|
}
|
||||||
if conn != nil && v.authConfig != nil {
|
if v.authConfig != nil {
|
||||||
conn = v.authConfig.Server(conn)
|
conn = v.authConfig.Server(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case v.awaitingConns <- &ConnectionWithError{
|
case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())):
|
||||||
conn: conn,
|
case <-time.After(time.Second * 5):
|
||||||
err: err,
|
conn.Close()
|
||||||
}:
|
|
||||||
default:
|
|
||||||
if conn != nil {
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
v.Unlock()
|
v.Unlock()
|
||||||
|
@ -129,8 +106,8 @@ func (v *TCPListener) Put(id internal.ConnectionID, conn net.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case v.awaitingConns <- &ConnectionWithError{conn: conn}:
|
case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())):
|
||||||
default:
|
case <-time.After(time.Second * 5):
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -144,12 +121,6 @@ func (v *TCPListener) Close() error {
|
||||||
defer v.Unlock()
|
defer v.Unlock()
|
||||||
v.acccepting = false
|
v.acccepting = false
|
||||||
v.listener.Close()
|
v.listener.Close()
|
||||||
close(v.awaitingConns)
|
|
||||||
for connErr := range v.awaitingConns {
|
|
||||||
if connErr.conn != nil {
|
|
||||||
connErr.conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
package internet
|
package internet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"v2ray.com/core/app/log"
|
|
||||||
"v2ray.com/core/common/errors"
|
"v2ray.com/core/common/errors"
|
||||||
v2net "v2ray.com/core/common/net"
|
v2net "v2ray.com/core/common/net"
|
||||||
"v2ray.com/core/common/retry"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -23,10 +20,9 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error)
|
type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error)
|
||||||
|
|
||||||
type Listener interface {
|
type Listener interface {
|
||||||
Accept() (Connection, error)
|
|
||||||
Close() error
|
Close() error
|
||||||
Addr() net.Addr
|
Addr() net.Addr
|
||||||
}
|
}
|
||||||
|
@ -37,8 +33,8 @@ type TCPHub struct {
|
||||||
closed chan bool
|
closed chan bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) {
|
func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error) {
|
||||||
ctx := context.Background()
|
settings := StreamSettingsFromContext(ctx)
|
||||||
protocol := settings.GetEffectiveProtocol()
|
protocol := settings.GetEffectiveProtocol()
|
||||||
transportSettings, err := settings.GetEffectiveTransportSettings()
|
transportSettings, err := settings.GetEffectiveTransportSettings()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -56,65 +52,9 @@ func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandle
|
||||||
if listenFunc == nil {
|
if listenFunc == nil {
|
||||||
return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
|
return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
|
||||||
}
|
}
|
||||||
listener, err := listenFunc(ctx, address, port)
|
listener, err := listenFunc(ctx, address, port, conns)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
|
return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
|
||||||
}
|
}
|
||||||
|
return listener, nil
|
||||||
hub := &TCPHub{
|
|
||||||
listener: listener,
|
|
||||||
connCallback: callback,
|
|
||||||
}
|
|
||||||
|
|
||||||
go hub.start()
|
|
||||||
return hub, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *TCPHub) Close() {
|
|
||||||
defer func() {
|
|
||||||
recover()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-v.closed:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
v.listener.Close()
|
|
||||||
close(v.closed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *TCPHub) start() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-v.closed:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
var newConn Connection
|
|
||||||
err := retry.ExponentialBackoff(10, 500).On(func() error {
|
|
||||||
select {
|
|
||||||
case <-v.closed:
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
conn, err := v.listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return errors.Base(err).RequireUserAction().Message("Internet|Listener: Failed to accept new TCP connection.")
|
|
||||||
}
|
|
||||||
newConn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
if errors.IsActionRequired(err) {
|
|
||||||
log.Warning(err)
|
|
||||||
} else {
|
|
||||||
log.Info(err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if newConn != nil {
|
|
||||||
go v.connCallback(newConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,14 +23,9 @@ var (
|
||||||
ErrClosedListener = errors.New("Listener is closed.")
|
ErrClosedListener = errors.New("Listener is closed.")
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionWithError struct {
|
|
||||||
conn net.Conn
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type requestHandler struct {
|
type requestHandler struct {
|
||||||
path string
|
path string
|
||||||
conns chan *ConnectionWithError
|
ln *Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
@ -45,29 +40,29 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case h.conns <- &ConnectionWithError{conn: conn}:
|
case h.ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, h.ln, internal.ReuseConnection(h.ln.config.IsConnectionReuse())):
|
||||||
default:
|
case <-time.After(time.Second * 5):
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Listener struct {
|
type Listener struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
closed chan bool
|
closed chan bool
|
||||||
awaitingConns chan *ConnectionWithError
|
listener net.Listener
|
||||||
listener net.Listener
|
tlsConfig *tls.Config
|
||||||
tlsConfig *tls.Config
|
config *Config
|
||||||
config *Config
|
conns chan<- internet.Connection
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
|
func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- internet.Connection) (internet.Listener, error) {
|
||||||
networkSettings := internet.TransportSettingsFromContext(ctx)
|
networkSettings := internet.TransportSettingsFromContext(ctx)
|
||||||
wsSettings := networkSettings.(*Config)
|
wsSettings := networkSettings.(*Config)
|
||||||
|
|
||||||
l := &Listener{
|
l := &Listener{
|
||||||
closed: make(chan bool),
|
closed: make(chan bool),
|
||||||
awaitingConns: make(chan *ConnectionWithError, 32),
|
config: wsSettings,
|
||||||
config: wsSettings,
|
conns: conns,
|
||||||
}
|
}
|
||||||
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
|
if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
|
||||||
tlsConfig, ok := securitySettings.(*v2tls.Config)
|
tlsConfig, ok := securitySettings.(*v2tls.Config)
|
||||||
|
@ -101,8 +96,8 @@ func (ln *Listener) listenws(address v2net.Address, port v2net.Port) error {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
http.Serve(listener, &requestHandler{
|
http.Serve(listener, &requestHandler{
|
||||||
path: ln.config.GetNormailzedPath(),
|
path: ln.config.GetNormailzedPath(),
|
||||||
conns: ln.awaitingConns,
|
ln: ln,
|
||||||
})
|
})
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -123,24 +118,6 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) {
|
||||||
return &connection{wsc: conn}, nil
|
return &connection{wsc: conn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *Listener) Accept() (internet.Connection, error) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ln.closed:
|
|
||||||
return nil, ErrClosedListener
|
|
||||||
case connErr, open := <-ln.awaitingConns:
|
|
||||||
if !open {
|
|
||||||
return nil, ErrClosedListener
|
|
||||||
}
|
|
||||||
if connErr.err != nil {
|
|
||||||
return nil, connErr.err
|
|
||||||
}
|
|
||||||
return internal.NewConnection(internal.ConnectionID{}, connErr.conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())), nil
|
|
||||||
case <-time.After(time.Second * 2):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
|
func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
|
||||||
ln.Lock()
|
ln.Lock()
|
||||||
defer ln.Unlock()
|
defer ln.Unlock()
|
||||||
|
@ -150,8 +127,8 @@ func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case ln.awaitingConns <- &ConnectionWithError{conn: conn}:
|
case ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())):
|
||||||
default:
|
case <-time.After(time.Second * 5):
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -170,12 +147,6 @@ func (ln *Listener) Close() error {
|
||||||
}
|
}
|
||||||
close(ln.closed)
|
close(ln.closed)
|
||||||
ln.listener.Close()
|
ln.listener.Close()
|
||||||
close(ln.awaitingConns)
|
|
||||||
for connErr := range ln.awaitingConns {
|
|
||||||
if connErr.conn != nil {
|
|
||||||
connErr.conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,31 +16,28 @@ import (
|
||||||
|
|
||||||
func Test_listenWSAndDial(t *testing.T) {
|
func Test_listenWSAndDial(t *testing.T) {
|
||||||
assert := assert.On(t)
|
assert := assert.On(t)
|
||||||
|
conns := make(chan internet.Connection, 16)
|
||||||
listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
|
listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
|
||||||
Path: "ws",
|
Path: "ws",
|
||||||
}), v2net.DomainAddress("localhost"), 13146)
|
}), v2net.DomainAddress("localhost"), 13146, conns)
|
||||||
assert.Error(err).IsNil()
|
assert.Error(err).IsNil()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for conn := range conns {
|
||||||
conn, err := listen.Accept()
|
go func(c internet.Connection) {
|
||||||
if err != nil {
|
defer c.Close()
|
||||||
break
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
var b [1024]byte
|
var b [1024]byte
|
||||||
n, err := conn.Read(b[:])
|
n, err := c.Read(b[:])
|
||||||
//assert.Error(err).IsNil()
|
//assert.Error(err).IsNil()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.SetReusable(false)
|
c.SetReusable(false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
assert.Bool(bytes.HasPrefix(b[:n], []byte("Test connection"))).IsTrue()
|
assert.Bool(bytes.HasPrefix(b[:n], []byte("Test connection"))).IsTrue()
|
||||||
|
|
||||||
_, err = conn.Write([]byte("Response"))
|
_, err = c.Write([]byte("Response"))
|
||||||
assert.Error(err).IsNil()
|
assert.Error(err).IsNil()
|
||||||
}()
|
}(conn)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -77,6 +74,8 @@ func Test_listenWSAndDial(t *testing.T) {
|
||||||
assert.Error(conn.Close()).IsNil()
|
assert.Error(conn.Close()).IsNil()
|
||||||
|
|
||||||
assert.Error(listen.Close()).IsNil()
|
assert.Error(listen.Close()).IsNil()
|
||||||
|
|
||||||
|
close(conns)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_listenWSAndDial_TLS(t *testing.T) {
|
func Test_listenWSAndDial_TLS(t *testing.T) {
|
||||||
|
@ -96,11 +95,11 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
|
||||||
AllowInsecure: true,
|
AllowInsecure: true,
|
||||||
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
|
Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
|
||||||
})
|
})
|
||||||
listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143)
|
conns := make(chan internet.Connection, 16)
|
||||||
|
listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143, conns)
|
||||||
assert.Error(err).IsNil()
|
assert.Error(err).IsNil()
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := listen.Accept()
|
conn := <-conns
|
||||||
assert.Error(err).IsNil()
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
listen.Close()
|
listen.Close()
|
||||||
}()
|
}()
|
||||||
|
|
Loading…
Reference in New Issue