mirror of https://github.com/k3s-io/k3s
Fix issue with stale connections to removed LB server
Track LB connections through each server so that they can be closed when it is removed. Signed-off-by: Brad Davidson <brad.davidson@rancher.com>pull/7215/head
parent
5dece799df
commit
e54ceaa497
|
@ -14,10 +14,25 @@ import (
|
|||
"inet.af/tcpproxy"
|
||||
)
|
||||
|
||||
// server tracks the connections to a server, so that they can be closed when the server is removed.
|
||||
type server struct {
|
||||
mutex sync.Mutex
|
||||
connections map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
|
||||
type serverConn struct {
|
||||
server *server
|
||||
net.Conn
|
||||
}
|
||||
|
||||
// LoadBalancer holds data for a local listener which forwards connections to a
|
||||
// pool of remote servers. It is not a proper load-balancer in that it does not
|
||||
// actually balance connections, but instead fails over to a new server only
|
||||
// when a connection attempt to the currently selected server fails.
|
||||
type LoadBalancer struct {
|
||||
mutex sync.Mutex
|
||||
dialer *net.Dialer
|
||||
proxy *tcpproxy.Proxy
|
||||
mutex sync.Mutex
|
||||
proxy *tcpproxy.Proxy
|
||||
|
||||
serviceName string
|
||||
configFile string
|
||||
|
@ -27,6 +42,7 @@ type LoadBalancer struct {
|
|||
ServerURL string
|
||||
ServerAddresses []string
|
||||
randomServers []string
|
||||
servers map[string]*server
|
||||
currentServerAddress string
|
||||
nextServerIndex int
|
||||
Listener net.Listener
|
||||
|
@ -40,6 +56,8 @@ var (
|
|||
ETCDServerServiceName = version.Program + "-etcd-server-load-balancer"
|
||||
)
|
||||
|
||||
// New contstructs a new LoadBalancer instance. The default server URL, and
|
||||
// currently active servers, are stored in a file within the dataDir.
|
||||
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
|
||||
config := net.ListenConfig{Control: reusePort}
|
||||
var localAddress string
|
||||
|
@ -76,11 +94,11 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
|||
|
||||
lb := &LoadBalancer{
|
||||
serviceName: serviceName,
|
||||
dialer: &net.Dialer{},
|
||||
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
|
||||
localAddress: localAddress,
|
||||
localServerURL: localServerURL,
|
||||
defaultServerAddress: defaultServerAddress,
|
||||
servers: make(map[string]*server),
|
||||
ServerURL: serverURL,
|
||||
}
|
||||
|
||||
|
@ -103,14 +121,28 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
|||
if err := lb.proxy.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("Running load balancer %s %s -> %v", serviceName, lb.localAddress, lb.randomServers)
|
||||
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress)
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) SetDefault(serverAddress string) {
|
||||
logrus.Infof("Updating load balancer %s default server address -> %s", lb.serviceName, serverAddress)
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
|
||||
// if the old default server is not currently in use, remove it from the server map
|
||||
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
|
||||
defer server.closeAll()
|
||||
delete(lb.servers, lb.defaultServerAddress)
|
||||
}
|
||||
// if the new default server doesn't have an entry in the map, add one
|
||||
if _, ok := lb.servers[serverAddress]; !ok {
|
||||
lb.servers[serverAddress] = &server{connections: make(map[net.Conn]struct{})}
|
||||
}
|
||||
|
||||
lb.defaultServerAddress = serverAddress
|
||||
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Update(serverAddresses []string) {
|
||||
|
@ -120,7 +152,7 @@ func (lb *LoadBalancer) Update(serverAddresses []string) {
|
|||
if !lb.setServers(serverAddresses) {
|
||||
return
|
||||
}
|
||||
logrus.Infof("Updating load balancer %s server addresses -> %v", lb.serviceName, lb.randomServers)
|
||||
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.ServerAddresses, lb.defaultServerAddress)
|
||||
|
||||
if err := lb.writeConfig(); err != nil {
|
||||
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
||||
|
@ -139,18 +171,23 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string
|
|||
for {
|
||||
targetServer := lb.currentServerAddress
|
||||
|
||||
conn, err := lb.dialer.DialContext(ctx, network, targetServer)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
server := lb.servers[targetServer]
|
||||
if server == nil || targetServer == "" {
|
||||
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
|
||||
} else {
|
||||
conn, err := server.dialContext(ctx, network, targetServer)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
logrus.Debugf("Dial error from load balancer %s: %s", lb.serviceName, err)
|
||||
}
|
||||
logrus.Debugf("Dial error from load balancer %s: %s", lb.serviceName, err)
|
||||
|
||||
newServer, err := lb.nextServer(targetServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if targetServer != newServer {
|
||||
logrus.Debugf("Dial server in load balancer %s failed over to %s", lb.serviceName, newServer)
|
||||
logrus.Debugf("Failed over to new server for load balancer %s: %s", lb.serviceName, newServer)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
|
@ -167,7 +204,7 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string
|
|||
}
|
||||
|
||||
func onDialError(src net.Conn, dstDialErr error) {
|
||||
logrus.Debugf("Incoming conn %v, error dialing load balancer servers: %v", src.RemoteAddr().String(), dstDialErr)
|
||||
logrus.Debugf("Incoming conn %s, error dialing load balancer servers: %v", src.RemoteAddr(), dstDialErr)
|
||||
src.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -15,18 +15,18 @@ import (
|
|||
"github.com/k3s-io/k3s/pkg/cli/cmds"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
type testServer struct {
|
||||
listener net.Listener
|
||||
conns []net.Conn
|
||||
prefix string
|
||||
}
|
||||
|
||||
func createServer(prefix string) (*server, error) {
|
||||
func createServer(prefix string) (*testServer, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &server{
|
||||
s := &testServer{
|
||||
prefix: prefix,
|
||||
listener: listener,
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ func createServer(prefix string) (*server, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) serve() {
|
||||
func (s *testServer) serve() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
|
@ -45,14 +45,14 @@ func (s *server) serve() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *server) close() {
|
||||
func (s *testServer) close() {
|
||||
s.listener.Close()
|
||||
for _, conn := range s.conns {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) echo(conn net.Conn) {
|
||||
func (s *testServer) echo(conn net.Conn) {
|
||||
for {
|
||||
result, err := bufio.NewReader(conn).ReadString('\n')
|
||||
if err != nil {
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"net"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
)
|
||||
|
||||
var defaultDialer = &net.Dialer{}
|
||||
|
||||
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
|
||||
serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.defaultServerAddress)
|
||||
if len(serverAddresses) == 0 {
|
||||
|
@ -15,10 +21,32 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
|
|||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
if reflect.DeepEqual(serverAddresses, lb.ServerAddresses) {
|
||||
newAddresses := sets.NewString(serverAddresses...)
|
||||
curAddresses := sets.NewString(lb.ServerAddresses...)
|
||||
if newAddresses.Equal(curAddresses) {
|
||||
return false
|
||||
}
|
||||
|
||||
for addedServer := range newAddresses.Difference(curAddresses) {
|
||||
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
|
||||
lb.servers[addedServer] = &server{connections: make(map[net.Conn]struct{})}
|
||||
}
|
||||
|
||||
for removedServer := range curAddresses.Difference(newAddresses) {
|
||||
server := lb.servers[removedServer]
|
||||
if server != nil {
|
||||
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer)
|
||||
// Defer closing connections until after the new server list has been put into place.
|
||||
// Closing open connections ensures that anything stuck retrying on a stale server is forced
|
||||
// over to a valid endpoint.
|
||||
defer server.closeAll()
|
||||
// Don't delete the default server from the server map, in case we need to fall back to it.
|
||||
if removedServer != lb.defaultServerAddress {
|
||||
delete(lb.servers, removedServer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lb.ServerAddresses = serverAddresses
|
||||
lb.randomServers = append([]string{}, lb.ServerAddresses...)
|
||||
rand.Shuffle(len(lb.randomServers), func(i, j int) {
|
||||
|
@ -55,3 +83,41 @@ func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
|
|||
|
||||
return lb.currentServerAddress, nil
|
||||
}
|
||||
|
||||
// dialContext dials a new connection, and adds its wrapped connection to the map
|
||||
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := defaultDialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// don't lock until adding the connection to the map, otherwise we may block
|
||||
// while waiting for the dial to time out
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
conn = &serverConn{server: s, Conn: conn}
|
||||
s.connections[conn] = struct{}{}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// closeAll closes all connections to the server, and removes their entries from the map
|
||||
func (s *server) closeAll() {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
logrus.Debugf("Closing %d connections to load balancer server", len(s.connections))
|
||||
for conn := range s.connections {
|
||||
// Close the connection in a goroutine so that we don't hold the lock while doing so.
|
||||
go conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Close removes the connection entry from the server's connection map, and
|
||||
// closes the wrapped connection.
|
||||
func (sc *serverConn) Close() error {
|
||||
sc.server.mutex.Lock()
|
||||
defer sc.server.mutex.Unlock()
|
||||
|
||||
delete(sc.server.connections, sc)
|
||||
return sc.Conn.Close()
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue