diff --git a/consul/client.go b/consul/client.go index adca9da5b7..612d23c376 100644 --- a/consul/client.go +++ b/consul/client.go @@ -77,7 +77,7 @@ func NewClient(config *Config) (*Client, error) { // Create server c := &Client{ config: config, - connPool: NewPool(8, clientRPCCache), + connPool: NewPool(clientRPCCache), eventCh: make(chan serf.Event, 256), logger: logger, shutdownCh: make(chan struct{}), diff --git a/consul/pool.go b/consul/pool.go index 57ae1523ac..6749f61ac6 100644 --- a/consul/pool.go +++ b/consul/pool.go @@ -2,23 +2,25 @@ package consul import ( "fmt" + "github.com/inconshreveable/muxado" "github.com/ugorji/go/codec" "net" "net/rpc" "sync" + "sync/atomic" "time" ) // Conn is a pooled connection to a Consul server type Conn struct { + refCount int32 addr net.Addr - conn *net.TCPConn - client *rpc.Client + session muxado.Session lastUsed time.Time } func (c *Conn) Close() error { - return c.conn.Close() + return c.session.Close() } // ConnPool is used to maintain a connection pool to other @@ -29,14 +31,11 @@ func (c *Conn) Close() error { type ConnPool struct { sync.Mutex - // The maximum connectsion to maintain per server - maxConns int - // The maximum time to keep a connection open maxTime time.Duration - // Pool maps an address to a list of connections - pool map[string][]*Conn + // Pool maps an address to a open connection + pool map[string]*Conn // Used to indicate the pool is shutdown shutdown bool @@ -44,13 +43,12 @@ type ConnPool struct { } // NewPool is used to make a new connection pool -// Maintain at most maxConns per host, for up to maxTime. +// Maintain at most one connection per host, for up to maxTime. // Set maxTime to 0 to disable reaping. -func NewPool(maxConns int, maxTime time.Duration) *ConnPool { +func NewPool(maxTime time.Duration) *ConnPool { pool := &ConnPool{ - maxConns: maxConns, maxTime: maxTime, - pool: make(map[string][]*Conn), + pool: make(map[string]*Conn), shutdownCh: make(chan struct{}), } if maxTime > 0 { @@ -64,12 +62,10 @@ func (p *ConnPool) Shutdown() error { p.Lock() defer p.Unlock() - for _, conns := range p.pool { - for _, c := range conns { - c.Close() - } + for _, conn := range p.pool { + conn.Close() } - p.pool = make(map[string][]*Conn) + p.pool = make(map[string]*Conn) if p.shutdown { return nil @@ -97,16 +93,12 @@ func (p *ConnPool) getPooled(addr net.Addr) *Conn { defer p.Unlock() // Look for an existing connection - conns := p.pool[addr.String()] - if len(conns) == 0 { - return nil + c := p.pool[addr.String()] + if c != nil { + c.lastUsed = time.Now() + atomic.AddInt32(&c.refCount, 1) } - - // Remove the last conn from the pool - conn := conns[len(conns)-1] - conns = conns[:len(conns)-1] - p.pool[addr.String()] = conns - return conn + return c } // getNewConn is used to return a new connection @@ -124,67 +116,94 @@ func (p *ConnPool) getNewConn(addr net.Addr) (*Conn, error) { conn.SetKeepAlive(true) conn.SetNoDelay(true) - // Write the Consul RPC byte to set the mode - conn.Write([]byte{byte(rpcConsul)}) + // Write the Consul multiplex byte to set the mode + conn.Write([]byte{byte(rpcMultiplex)}) - // Create the RPC client - cc := codec.GoRpc.ClientCodec(conn, &codec.MsgpackHandle{}) - client := rpc.NewClientWithCodec(cc) + // Create a multiplexed session + session := muxado.Client(conn) // Wrap the connection c := &Conn{ - addr: addr, - conn: conn, - client: client, + refCount: 1, + addr: addr, + session: session, + lastUsed: time.Now(), } - return c, nil -} -// Return is used to return a connection once done. Connections -// that are in an error state should not be returned -func (p *ConnPool) returnConn(conn *Conn) { + // Monitor the session + go func() { + session.Wait() + p.Lock() + defer p.Unlock() + if conn, ok := p.pool[addr.String()]; ok && conn.session == session { + delete(p.pool, addr.String()) + } + }() + + // Track this connection, handle potential race condition p.Lock() defer p.Unlock() - - // Set the last used time - conn.lastUsed = time.Now() - - // Look for existing connections - conns := p.pool[conn.addr.String()] - - // Check for limit on connections or shutdown - if p.shutdown || len(conns) >= p.maxConns { - conn.Close() - return + if existing := p.pool[addr.String()]; existing != nil { + session.Close() + return existing, nil + } else { + p.pool[addr.String()] = c + return c, nil } +} - // Retain the connection - conns = append(conns, conn) - p.pool[conn.addr.String()] = conns +// clearConn is used to clear any cached connection, potentially in response to an erro +func (p *ConnPool) clearConn(addr net.Addr) { + p.Lock() + defer p.Unlock() + delete(p.pool, addr.String()) +} + +// releaseConn is invoked when we are done with a conn to reduce the ref count +func (p *ConnPool) releaseConn(conn *Conn) { + atomic.AddInt32(&conn.refCount, -1) } // RPC is used to make an RPC call to a remote host func (p *ConnPool) RPC(addr net.Addr, method string, args interface{}, reply interface{}) error { + retries := 0 +START: // Try to get a conn first conn, err := p.acquire(addr) if err != nil { return fmt.Errorf("failed to get conn: %v", err) } + defer p.releaseConn(conn) + + // Create a new stream + stream, err := conn.session.Open() + if err != nil { + p.clearConn(addr) + + // Try to redial, possible that the TCP session closed due to timeout + if retries == 0 { + retries++ + goto START + } + return fmt.Errorf("failed to start stream: %v", err) + } + defer stream.Close() + + // Create the RPC client + cc := codec.GoRpc.ClientCodec(stream, &codec.MsgpackHandle{}) + client := rpc.NewClientWithCodec(cc) // Make the RPC call - err = conn.client.Call(method, args, reply) + err = client.Call(method, args, reply) // Fast path the non-error case if err == nil { - p.returnConn(conn) return nil } - // If not a network error, save the connection - if _, ok := err.(net.Error); !ok { - p.returnConn(conn) - } else { - conn.Close() + // If its a network error, nuke the connection + if _, ok := err.(net.Error); ok { + p.clearConn(addr) } return fmt.Errorf("rpc error: %v", err) } @@ -201,25 +220,27 @@ func (p *ConnPool) reap() { // Reap all old conns p.Lock() + var removed []string now := time.Now() - for host, conns := range p.pool { - n := len(conns) - for i := 0; i < n; i++ { - // Skip new connections - conn := conns[i] - if now.Sub(conn.lastUsed) < p.maxTime { - continue - } - - // Close the conn - conn.Close() - - // Remove from pool - conns[i], conns[n-1] = conns[n-1], nil - conns = conns[:n-1] - p.pool[host] = conns - n-- + for host, conn := range p.pool { + // Skip recently used connections + if now.Sub(conn.lastUsed) < p.maxTime { + continue } + + // Skip connections with active streams + if atomic.LoadInt32(&conn.refCount) > 0 { + continue + } + + // Close the conn + conn.Close() + + // Remove from pool + removed = append(removed, host) + } + for _, host := range removed { + delete(p.pool, host) } p.Unlock() } diff --git a/consul/server.go b/consul/server.go index 7ad58fee7c..b976e2caa3 100644 --- a/consul/server.go +++ b/consul/server.go @@ -112,7 +112,7 @@ func NewServer(config *Config) (*Server, error) { // Create server s := &Server{ config: config, - connPool: NewPool(16, time.Minute), + connPool: NewPool(time.Minute), eventChLAN: make(chan serf.Event, 256), eventChWAN: make(chan serf.Event, 256), logger: logger, @@ -337,7 +337,9 @@ func (s *Server) Shutdown() error { s.connPool.Shutdown() // Close the fsm - s.fsm.Close() + if s.fsm != nil { + s.fsm.Close() + } return nil }