diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 65e90ce0..1ff991e4 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -80,7 +80,7 @@ func (o *ServerConnection) Id() internal.ConnectionID { // Listener defines a server listening for connections type Listener struct { sync.Mutex - closed chan bool + ctx context.Context sessions map[ConnectionID]*Connection hub *udp.Hub tlsConfig *tls.Config @@ -112,7 +112,7 @@ func NewListener(ctx context.Context, address v2net.Address, port v2net.Port, co Security: security, }, sessions: make(map[ConnectionID]*Connection), - closed: make(chan bool), + ctx: ctx, config: kcpSettings, conns: conns, } @@ -143,20 +143,19 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina return } + v.Lock() + defer v.Unlock() + select { - case <-v.closed: + case <-v.ctx.Done(): return default: } - v.Lock() - defer v.Unlock() if v.hub == nil { return } - if payload.Len() < 4 { - return - } + conv := segments[0].Conversation() cmd := segments[0].Command() @@ -213,7 +212,7 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina func (v *Listener) Remove(id ConnectionID) { select { - case <-v.closed: + case <-v.ctx.Done(): return default: v.Lock() @@ -224,20 +223,14 @@ func (v *Listener) Remove(id ConnectionID) { // Close stops listening on the UDP address. Already Accepted connections are not closed. func (v *Listener) Close() error { + v.hub.Close() v.Lock() defer v.Unlock() - select { - case <-v.closed: - return ErrClosedListener - default: - } - close(v.closed) for _, conn := range v.sessions { go conn.Terminate() } - v.hub.Close() return nil } diff --git a/transport/internet/tcp/hub.go b/transport/internet/tcp/hub.go index 2a43891f..999dfabf 100644 --- a/transport/internet/tcp/hub.go +++ b/transport/internet/tcp/hub.go @@ -22,7 +22,7 @@ var ( type TCPListener struct { sync.Mutex - acccepting bool + ctx context.Context listener *net.TCPListener tlsConfig *tls.Config authConfig internet.ConnectionAuthenticator @@ -43,10 +43,10 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conn tcpSettings := networkSettings.(*Config) l := &TCPListener{ - acccepting: true, - listener: listener, - config: tcpSettings, - conns: conns, + ctx: ctx, + listener: listener, + config: tcpSettings, + conns: conns, } if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil { tlsConfig, ok := securitySettings.(*v2tls.Config) @@ -70,13 +70,14 @@ func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conn } func (v *TCPListener) KeepAccepting() { - for v.acccepting { + for { + select { + case <-v.ctx.Done(): + return + default: + } conn, err := v.listener.Accept() v.Lock() - if !v.acccepting { - v.Unlock() - break - } if err != nil { log.Warning("TCP|Listener: Failed to accepted raw connections: ", err) v.Unlock() @@ -100,12 +101,10 @@ func (v *TCPListener) KeepAccepting() { } func (v *TCPListener) Put(id internal.ConnectionID, conn net.Conn) { - v.Lock() - defer v.Unlock() - if !v.acccepting { - return - } select { + case <-v.ctx.Done(): + conn.Close() + return case v.conns <- internal.NewConnection(internal.ConnectionID{}, conn, v, internal.ReuseConnection(v.config.IsConnectionReuse())): case <-time.After(time.Second * 5): conn.Close() @@ -117,9 +116,6 @@ func (v *TCPListener) Addr() net.Addr { } func (v *TCPListener) Close() error { - v.Lock() - defer v.Unlock() - v.acccepting = false v.listener.Close() return nil } diff --git a/transport/internet/tcp_hub.go b/transport/internet/tcp_hub.go index 4d996b1a..72e1ef53 100644 --- a/transport/internet/tcp_hub.go +++ b/transport/internet/tcp_hub.go @@ -27,12 +27,6 @@ type Listener interface { Addr() net.Addr } -type TCPHub struct { - listener Listener - connCallback ConnectionHandler - closed chan bool -} - func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port, conns chan<- Connection) (Listener, error) { settings := StreamSettingsFromContext(ctx) protocol := settings.GetEffectiveProtocol() diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 0031bb99..db272957 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -40,6 +40,8 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } select { + case <-h.ln.ctx.Done(): + conn.Close() case h.ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, h.ln, internal.ReuseConnection(h.ln.config.IsConnectionReuse())): case <-time.After(time.Second * 5): conn.Close() @@ -48,7 +50,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req type Listener struct { sync.Mutex - closed chan bool + ctx context.Context listener net.Listener tlsConfig *tls.Config config *Config @@ -60,7 +62,7 @@ func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port, conns wsSettings := networkSettings.(*Config) l := &Listener{ - closed: make(chan bool), + ctx: ctx, config: wsSettings, conns: conns, } @@ -119,14 +121,9 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) { } func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) { - ln.Lock() - defer ln.Unlock() - select { - case <-ln.closed: - return - default: - } select { + case <-ln.ctx.Done(): + conn.Close() case ln.conns <- internal.NewConnection(internal.ConnectionID{}, conn, ln, internal.ReuseConnection(ln.config.IsConnectionReuse())): case <-time.After(time.Second * 5): conn.Close() @@ -138,16 +135,7 @@ func (ln *Listener) Addr() net.Addr { } func (ln *Listener) Close() error { - ln.Lock() - defer ln.Unlock() - select { - case <-ln.closed: - return ErrClosedListener - default: - } - close(ln.closed) - ln.listener.Close() - return nil + return ln.listener.Close() } func init() {