Browse Source

Fix issue with loadbalancer failover to default server

The loadbalancer should only fail over to the default server if all other server have failed, and it should force fail-back to a preferred server as soon as one passes health checks.

The loadbalancer tests have been improved to ensure that this occurs.

Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
pull/8882/merge
Brad Davidson 1 week ago committed by Brad Davidson
parent
commit
cd4ddedbc9
  1. 2
      pkg/agent/loadbalancer/loadbalancer.go
  2. 186
      pkg/agent/loadbalancer/loadbalancer_test.go
  3. 47
      pkg/agent/loadbalancer/servers.go
  4. 3
      pkg/agent/loadbalancer/utility.go
  5. 3
      pkg/util/apierrors.go

2
pkg/agent/loadbalancer/loadbalancer.go

@ -179,6 +179,8 @@ func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net
if !allChecksFailed { if !allChecksFailed {
defer server.closeAll() defer server.closeAll()
} }
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
} }
newServer, err := lb.nextServer(targetServer) newServer, err := lb.nextServer(targetServer)

186
pkg/agent/loadbalancer/loadbalancer_test.go

@ -10,7 +10,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/k3s-io/k3s/pkg/cli/cmds"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -24,7 +23,7 @@ type testServer struct {
prefix string prefix string
} }
func createServer(prefix string) (*testServer, error) { func createServer(ctx context.Context, prefix string) (*testServer, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return nil, err return nil, err
@ -34,6 +33,10 @@ func createServer(prefix string) (*testServer, error) {
listener: listener, listener: listener,
} }
go s.serve() go s.serve()
go func() {
<-ctx.Done()
s.close()
}()
return s, nil return s, nil
} }
@ -49,6 +52,7 @@ func (s *testServer) serve() {
} }
func (s *testServer) close() { func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix)
s.listener.Close() s.listener.Close()
for _, conn := range s.conns { for _, conn := range s.conns {
conn.Close() conn.Close()
@ -65,6 +69,10 @@ func (s *testServer) echo(conn net.Conn) {
} }
} }
func (s *testServer) address() string {
return s.listener.Addr().String()
}
func ping(conn net.Conn) (string, error) { func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n") fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n') result, err := bufio.NewReader(conn).ReadString('\n')
@ -74,25 +82,31 @@ func ping(conn net.Conn) (string, error) {
return strings.TrimSpace(result), nil return strings.TrimSpace(result), nil
} }
// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server.
func Test_UnitFailOver(t *testing.T) { func Test_UnitFailOver(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ogServe, err := createServer("og") defaultServer, err := createServer(ctx, "default")
if err != nil { if err != nil {
t.Fatalf("createServer(og) failed: %v", err) t.Fatalf("createServer(default) failed: %v", err)
} }
lbServe, err := createServer("lb") node1Server, err := createServer(ctx, "node1")
if err != nil { if err != nil {
t.Fatalf("createServer(lb) failed: %v", err) t.Fatalf("createServer(node1) failed: %v", err)
} }
cfg := cmds.Agent{ node2Server, err := createServer(ctx, "node2")
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()), if err != nil {
DataDir: tmpDir, t.Fatalf("createServer(node2) failed: %v", err)
} }
lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false) // start the loadbalancer with the default server as the only server
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false)
if err != nil { if err != nil {
t.Fatalf("New() failed: %v", err) t.Fatalf("New() failed: %v", err)
} }
@ -103,50 +117,123 @@ func Test_UnitFailOver(t *testing.T) {
} }
localAddress := parsedURL.Host localAddress := parsedURL.Host
lb.Update([]string{lbServe.listener.Addr().String()}) // add the node as a new server address.
lb.Update([]string{node1Server.address()})
// make sure connections go to the node
conn1, err := net.Dial("tcp", localAddress) conn1, err := net.Dial("tcp", localAddress)
if err != nil { if err != nil {
t.Fatalf("net.Dial failed: %v", err) t.Fatalf("net.Dial failed: %v", err)
} }
result1, err := ping(conn1) if result, err := ping(conn1); err != nil {
if err != nil {
t.Fatalf("ping(conn1) failed: %v", err) t.Fatalf("ping(conn1) failed: %v", err)
} } else if result != "node1:ping" {
if result1 != "lb:ping" { t.Fatalf("Unexpected ping(conn1) result: %v", result)
t.Fatalf("Unexpected ping result: %v", result1)
} }
lbServe.close() t.Log("conn1 tested OK")
_, err = ping(conn1) // set failing health check for node 1
if err == nil { lb.SetHealthCheck(node1Server.address(), func() bool { return false })
// Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed.
time.Sleep(2 * time.Second)
if _, err := ping(conn1); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn1") t.Fatal("Unexpected successful ping on closed connection conn1")
} }
t.Log("conn1 closed on failure OK")
// make sure connection still goes to the first node - it is failing health checks but so
// is the default endpoint, so it should be tried first with health checks disabled,
// before failing back to the default.
conn2, err := net.Dial("tcp", localAddress) conn2, err := net.Dial("tcp", localAddress)
if err != nil { if err != nil {
t.Fatalf("net.Dial failed: %v", err) t.Fatalf("net.Dial failed: %v", err)
} }
result2, err := ping(conn2) if result, err := ping(conn2); err != nil {
if err != nil {
t.Fatalf("ping(conn2) failed: %v", err) t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
t.Log("conn2 tested OK")
// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
t.Log("conn2 tested OK again")
// shut down the first node server to force failover to the default
node1Server.close()
// make sure new connections go to the default, and existing connections are closed
conn3, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
if result, err := ping(conn3); err != nil {
t.Fatalf("ping(conn3) failed: %v", err)
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn3) result: %v", result)
}
t.Log("conn3 tested OK")
if _, err := ping(conn2); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn2")
}
t.Log("conn2 closed on failure OK")
// add the second node as a new server address.
lb.Update([]string{node2Server.address()})
// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
} }
if result2 != "og:ping" { if result, err := ping(conn4); err != nil {
t.Fatalf("Unexpected ping result: %v", result2) t.Fatalf("ping(conn4) failed: %v", err)
} else if result != "node2:ping" {
t.Fatalf("Unexpected ping(conn4) result: %v", result)
} }
t.Log("conn4 tested OK")
// Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)
if _, err := ping(conn3); err == nil {
t.Fatal("Unexpected successful ping on connection conn3")
}
t.Log("conn3 closed on failure OK")
} }
// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly
func Test_UnitFailFast(t *testing.T) { func Test_UnitFailFast(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cfg := cmds.Agent{ serverURL := "http://127.0.0.1:0/"
ServerURL: "http://127.0.0.1:0/", lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
DataDir: tmpDir,
}
lb, err := New(context.TODO(), cfg.DataDir, SupervisorServiceName, cfg.ServerURL, RandomPort, false)
if err != nil { if err != nil {
t.Fatalf("New() failed: %v", err) t.Fatalf("New() failed: %v", err)
} }
@ -172,3 +259,44 @@ func Test_UnitFailFast(t *testing.T) {
t.Fatal("Test timed out") t.Fatal("Test timed out")
} }
} }
// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail
// within the expected duration
func Test_UnitFailUnreachable(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test in short mode.")
}
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverAddr := "192.0.2.1:6443"
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
// Set failing health check to reduce retries
lb.SetHealthCheck(serverAddr, func() bool { return false })
conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)
select {
case err := <-done:
if err == nil {
t.Fatal("Unexpected successful ping from unreachable address")
}
case <-timeout:
t.Fatal("Test timed out")
}
}

47
pkg/agent/loadbalancer/servers.go

@ -7,6 +7,7 @@ import (
"net" "net"
"net/url" "net/url"
"os" "os"
"slices"
"strconv" "strconv"
"time" "time"
@ -21,7 +22,10 @@ import (
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
) )
var defaultDialer proxy.Dialer = &net.Dialer{} var defaultDialer proxy.Dialer = &net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}
// SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, // SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections,
// if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured // if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured
@ -48,7 +52,7 @@ func SetHTTPProxy(address string) error {
return nil return nil
} }
dialer, err := proxyDialer(proxyURL) dialer, err := proxyDialer(proxyURL, defaultDialer)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL)
} }
@ -59,7 +63,7 @@ func SetHTTPProxy(address string) error {
} }
func (lb *LoadBalancer) setServers(serverAddresses []string) bool { func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.defaultServerAddress) serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
if len(serverAddresses) == 0 { if len(serverAddresses) == 0 {
return false return false
} }
@ -102,8 +106,16 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
rand.Shuffle(len(lb.randomServers), func(i, j int) { rand.Shuffle(len(lb.randomServers), func(i, j int) {
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
}) })
if !hasOriginalServer { // If the current server list does not contain the default server address,
// we want to include it in the random server list so that it can be tried if necessary.
// However, it should be treated as always failing health checks so that it is only
// used if all other endpoints are unavailable.
if !hasDefaultServer {
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
defaultServer.healthCheck = func() bool { return false }
lb.servers[lb.defaultServerAddress] = defaultServer
}
} }
lb.currentServerAddress = lb.randomServers[0] lb.currentServerAddress = lb.randomServers[0]
lb.nextServerIndex = 1 lb.nextServerIndex = 1
@ -163,14 +175,14 @@ func (s *server) dialContext(ctx context.Context, network, address string) (net.
} }
// proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy. // proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy.
func proxyDialer(proxyURL *url.URL) (proxy.Dialer, error) { func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Create a new HTTP proxy dialer // Create a new HTTP proxy dialer
httpProxyDialer := http_dialer.New(proxyURL) httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer)))
return httpProxyDialer, nil return httpProxyDialer, nil
} else if proxyURL.Scheme == "socks5" { } else if proxyURL.Scheme == "socks5" {
// For SOCKS5 proxies, use the proxy package's FromURL // For SOCKS5 proxies, use the proxy package's FromURL
return proxy.FromURL(proxyURL, proxy.Direct) return proxy.FromURL(proxyURL, forward)
} }
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
} }
@ -204,17 +216,18 @@ func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock() lb.mutex.Lock()
defer lb.mutex.Unlock() defer lb.mutex.Unlock()
_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress) hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map // if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer { if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
defer server.closeAll() defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress) delete(lb.servers, lb.defaultServerAddress)
} }
// if the new default server doesn't have an entry in the map, add one // if the new default server doesn't have an entry in the map, add one - but
// with a failing health check so that it is only used as a last resort.
if _, ok := lb.servers[serverAddress]; !ok { if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{ lb.servers[serverAddress] = &server{
address: serverAddress, address: serverAddress,
healthCheck: func() bool { return true }, healthCheck: func() bool { return false },
connections: make(map[net.Conn]struct{}), connections: make(map[net.Conn]struct{}),
} }
} }
@ -243,8 +256,10 @@ func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
wait.Until(func() { wait.Until(func() {
lb.mutex.RLock() lb.mutex.RLock()
defer lb.mutex.RUnlock() defer lb.mutex.RUnlock()
var healthyServerExists bool
for address, server := range lb.servers { for address, server := range lb.servers {
status := server.healthCheck() status := server.healthCheck()
healthyServerExists = healthyServerExists || status
if status == false && previousStatus[address] == true { if status == false && previousStatus[address] == true {
// Only close connections when the server transitions from healthy to unhealthy; // Only close connections when the server transitions from healthy to unhealthy;
// we don't want to re-close all the connections every time as we might be ignoring // we don't want to re-close all the connections every time as we might be ignoring
@ -253,6 +268,16 @@ func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
} }
previousStatus[address] = status previousStatus[address] = status
} }
// If there is at least one healthy server, and the default server is not in the server list,
// close all the connections to the default server so that clients reconnect and switch over
// to a preferred server.
hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress)
if healthyServerExists && !hasDefaultServer {
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
defer server.closeAll()
}
}
}, time.Second, ctx.Done()) }, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
} }

3
pkg/agent/loadbalancer/utility.go

@ -28,6 +28,9 @@ func parseURL(serverURL, newHost string) (string, string, error) {
return address, parsedURL.String(), nil return address, parsedURL.String(), nil
} }
// sortServers returns a sorted, unique list of strings, with any
// empty values removed. The returned bool is true if the list
// contains the search string.
func sortServers(input []string, search string) ([]string, bool) { func sortServers(input []string, search string) ([]string, bool) {
result := []string{} result := []string{}
found := false found := false

3
pkg/util/apierrors.go

@ -40,7 +40,7 @@ func SendError(err error, resp http.ResponseWriter, req *http.Request, status ..
// Don't log "apiserver not ready" or "apiserver disabled" errors, they are frequent during startup // Don't log "apiserver not ready" or "apiserver disabled" errors, they are frequent during startup
if !errors.Is(err, ErrAPINotReady) && !errors.Is(err, ErrAPIDisabled) { if !errors.Is(err, ErrAPINotReady) && !errors.Is(err, ErrAPIDisabled) {
logrus.Errorf("Sending HTTP %d response to %s: %v", code, req.RemoteAddr, err) logrus.Errorf("Sending %s %d response to %s: %v", req.Proto, code, req.RemoteAddr, err)
} }
var serr *apierrors.StatusError var serr *apierrors.StatusError
@ -61,6 +61,7 @@ func SendError(err error, resp http.ResponseWriter, req *http.Request, status ..
serr = apierrors.NewGenericServerResponse(code, req.Method, schema.GroupResource{}, req.URL.Path, err.Error(), 0, true) serr = apierrors.NewGenericServerResponse(code, req.Method, schema.GroupResource{}, req.URL.Path, err.Error(), 0, true)
} }
resp.Header().Add("Connection", "close")
responsewriters.ErrorNegotiated(serr, scheme.Codecs.WithoutConversion(), schema.GroupVersion{}, resp, req) responsewriters.ErrorNegotiated(serr, scheme.Codecs.WithoutConversion(), schema.GroupVersion{}, resp, req)
} }

Loading…
Cancel
Save