From 7a97d737378015a1548ba1553a3360398ebde0bb Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 18 Feb 2017 00:28:50 +0100 Subject: [PATCH] fix race condition in transport --- transport/internet/internal/pool.go | 11 +++++++++-- transport/internet/kcp/listener.go | 5 +++-- transport/internet/websocket/hub.go | 22 ++++++++++++++-------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/transport/internet/internal/pool.go b/transport/internet/internal/pool.go index 4269400d..5c84ec83 100644 --- a/transport/internet/internal/pool.go +++ b/transport/internet/internal/pool.go @@ -31,7 +31,7 @@ func (ec *ExpiringConnection) Expired() bool { // Pool is a connection pool. type Pool struct { - sync.Mutex + sync.RWMutex connsByDest map[ConnectionID][]*ExpiringConnection cleanupToken *signal.Semaphore } @@ -74,10 +74,17 @@ func (p *Pool) Get(id ConnectionID) net.Conn { return conn.conn } +func (p *Pool) isEmpty() bool { + p.RLock() + defer p.RUnlock() + + return len(p.connsByDest) == 0 +} + func (p *Pool) cleanup() { defer p.cleanupToken.Signal() - for len(p.connsByDest) > 0 { + for !p.isEmpty() { time.Sleep(time.Second * 5) expiredConns := make([]net.Conn, 0, 16) p.Lock() diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 4b3ebed9..440efa54 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -246,13 +246,14 @@ func (v *Listener) Accept() (internet.Connection, error) { // Close stops listening on the UDP address. Already Accepted connections are not closed. func (v *Listener) Close() error { + + v.Lock() + defer v.Unlock() select { case <-v.closed: return ErrClosedListener default: } - v.Lock() - defer v.Unlock() close(v.closed) close(v.awaitingConns) diff --git a/transport/internet/websocket/hub.go b/transport/internet/websocket/hub.go index 4c1a543c..12b1037b 100644 --- a/transport/internet/websocket/hub.go +++ b/transport/internet/websocket/hub.go @@ -52,7 +52,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req type Listener struct { sync.Mutex - acccepting bool + closed chan bool awaitingConns chan *ConnectionWithError listener net.Listener tlsConfig *tls.Config @@ -67,7 +67,7 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt wsSettings := networkSettings.(*Config) l := &Listener{ - acccepting: true, + closed: make(chan bool), awaitingConns: make(chan *ConnectionWithError, 32), config: wsSettings, } @@ -130,8 +130,10 @@ func converttovws(w http.ResponseWriter, r *http.Request) (*connection, error) { } func (ln *Listener) Accept() (internet.Connection, error) { - for ln.acccepting { + for { select { + case <-ln.closed: + return nil, ErrClosedListener case connErr, open := <-ln.awaitingConns: if !open { return nil, ErrClosedListener @@ -143,14 +145,15 @@ func (ln *Listener) Accept() (internet.Connection, error) { case <-time.After(time.Second * 2): } } - return nil, ErrClosedListener } func (ln *Listener) Put(id internal.ConnectionID, conn net.Conn) { ln.Lock() defer ln.Unlock() - if !ln.acccepting { + select { + case <-ln.closed: return + default: } select { case ln.awaitingConns <- &ConnectionWithError{conn: conn}: @@ -166,10 +169,13 @@ func (ln *Listener) Addr() net.Addr { func (ln *Listener) Close() error { ln.Lock() defer ln.Unlock() - ln.acccepting = false - + select { + case <-ln.closed: + return ErrClosedListener + default: + } + close(ln.closed) ln.listener.Close() - close(ln.awaitingConns) for connErr := range ln.awaitingConns { if connErr.conn != nil {