diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go index 9a2de3214f..b7d8f63f9d 100644 --- a/pkg/agent/loadbalancer/config.go +++ b/pkg/agent/loadbalancer/config.go @@ -15,8 +15,8 @@ type lbConfig struct { func (lb *LoadBalancer) writeConfig() error { config := &lbConfig{ - ServerURL: lb.serverURL, - ServerAddresses: lb.serverAddresses, + ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(), + ServerAddresses: lb.servers.getAddresses(), } configOut, err := json.MarshalIndent(config, "", " ") if err != nil { @@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error { } func (lb *LoadBalancer) updateConfig() error { - writeConfig := true if configBytes, err := os.ReadFile(lb.configFile); err == nil { config := &lbConfig{} if err := json.Unmarshal(configBytes, config); err == nil { - if config.ServerURL == lb.serverURL { - writeConfig = false - lb.setServers(config.ServerAddresses) + // if the default server from the config matches our current default, + // load the rest of the addresses as well. + if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() { + lb.Update(config.ServerAddresses) + return nil } } } - if writeConfig { - if err := lb.writeConfig(); err != nil { - return err - } - } - return nil + // config didn't exist or used a different default server, write the current config to disk. + return lb.writeConfig() } diff --git a/pkg/agent/loadbalancer/httpproxy.go b/pkg/agent/loadbalancer/httpproxy.go index f14859bfe7..ea97118249 100644 --- a/pkg/agent/loadbalancer/httpproxy.go +++ b/pkg/agent/loadbalancer/httpproxy.go @@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error { func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { // Create a new HTTP proxy dialer - httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) + httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer))) return httpProxyDialer, nil } else if proxyURL.Scheme == "socks5" { // For SOCKS5 proxies, use the proxy package's FromURL diff --git a/pkg/agent/loadbalancer/httpproxy_test.go b/pkg/agent/loadbalancer/httpproxy_test.go index c8b8b5b924..07f72e927e 100644 --- a/pkg/agent/loadbalancer/httpproxy_test.go +++ b/pkg/agent/loadbalancer/httpproxy_test.go @@ -2,15 +2,16 @@ package loadbalancer import ( "fmt" - "net" "os" "strings" "testing" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" + "golang.org/x/net/proxy" ) +var originalDialer proxy.Dialer var defaultEnv map[string]string var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"} @@ -19,7 +20,7 @@ func init() { } func prepareEnv(env ...string) { - defaultDialer = &net.Dialer{} + originalDialer = defaultDialer defaultEnv = map[string]string{} for _, e := range proxyEnvs { if v, ok := os.LookupEnv(e); ok { @@ -34,6 +35,7 @@ func prepareEnv(env ...string) { } func restoreEnv() { + defaultDialer = originalDialer for _, e := range proxyEnvs { if v, ok := defaultEnv[e]; ok { os.Setenv(e, v) diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index db9fa6f16f..2f6d33fbf4 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -2,55 +2,29 @@ package loadbalancer import ( "context" - "errors" "fmt" "net" + "net/url" "os" "path/filepath" - "sync" - "time" + "strings" "github.com/inetaf/tcpproxy" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" ) -// server tracks the connections to a server, so that they can be closed when the server is removed. -type server struct { - // This mutex protects access to the connections map. All direct access to the map should be protected by it. - mutex sync.Mutex - address string - healthCheck func() bool - connections map[net.Conn]struct{} -} - -// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. -type serverConn struct { - server *server - net.Conn -} - // LoadBalancer holds data for a local listener which forwards connections to a // pool of remote servers. It is not a proper load-balancer in that it does not // actually balance connections, but instead fails over to a new server only // when a connection attempt to the currently selected server fails. type LoadBalancer struct { - // This mutex protects access to servers map and randomServers list. - // All direct access to the servers map/list should be protected by it. - mutex sync.RWMutex - proxy *tcpproxy.Proxy - - serviceName string - configFile string - localAddress string - localServerURL string - defaultServerAddress string - serverURL string - serverAddresses []string - randomServers []string - servers map[string]*server - currentServerAddress string - nextServerIndex int + serviceName string + configFile string + scheme string + localAddress string + servers serverList + proxy *tcpproxy.Proxy } const RandomPort = 0 @@ -63,7 +37,7 @@ var ( // New contstructs a new LoadBalancer instance. The default server URL, and // currently active servers, are stored in a file within the dataDir. -func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { +func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { config := net.ListenConfig{Control: reusePort} var localAddress string if isIPv6 { @@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo return nil, err } - // if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. - localAddress = listener.Addr().String() - - defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress) + serverURL, err := url.Parse(defaultServerURL) if err != nil { return nil, err } - if serverURL == localServerURL { - logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) - defaultServerAddress = "" + // Set explicit port from scheme + if serverURL.Port() == "" { + if strings.ToLower(serverURL.Scheme) == "http" { + serverURL.Host += ":80" + } + if strings.ToLower(serverURL.Scheme) == "https" { + serverURL.Host += ":443" + } } lb := &LoadBalancer{ - serviceName: serviceName, - configFile: filepath.Join(dataDir, "etc", serviceName+".json"), - localAddress: localAddress, - localServerURL: localServerURL, - defaultServerAddress: defaultServerAddress, - servers: make(map[string]*server), - serverURL: serverURL, + serviceName: serviceName, + configFile: filepath.Join(dataDir, "etc", serviceName+".json"), + scheme: serverURL.Scheme, + localAddress: listener.Addr().String(), } - lb.setServers([]string{lb.defaultServerAddress}) + // if starting pointing at ourselves, don't set a default server address, + // which will cause all dials to fail until servers are added. + if serverURL.Host == lb.localAddress { + logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) + } else { + lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host) + } lb.proxy = &tcpproxy.Proxy{ ListenFunc: func(string, string) (net.Listener, error) { @@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo } lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ Addr: serviceName, - DialContext: lb.dialContext, + DialContext: lb.servers.dialContext, OnDialError: onDialError, }) @@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) - go lb.runHealthChecks(ctx) + go lb.servers.runHealthChecks(ctx, lb.serviceName) return lb, nil } +// Update updates the list of server addresses to contain only the listed servers. func (lb *LoadBalancer) Update(serverAddresses []string) { - if lb == nil { + if !lb.servers.setAddresses(lb.serviceName, serverAddresses) { return } - if !lb.setServers(serverAddresses) { - return - } - logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress) + + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } } -func (lb *LoadBalancer) LoadBalancerServerURL() string { - if lb == nil { - return "" +// SetDefault sets the selected address as the default / fallback address +func (lb *LoadBalancer) SetDefault(serverAddress string) { + lb.servers.setDefaultAddress(lb.serviceName, serverAddress) + + if err := lb.writeConfig(); err != nil { + logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } - return lb.localServerURL +} + +// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. +func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) { + if err := lb.servers.setHealthCheck(address, healthCheck); err != nil { + logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err) + } else { + logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address) + } +} + +func (lb *LoadBalancer) LocalURL() string { + return lb.scheme + "://" + lb.localAddress } func (lb *LoadBalancer) ServerAddresses() []string { - if lb == nil { - return nil - } - return lb.serverAddresses -} - -func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - - var allChecksFailed bool - startIndex := lb.nextServerIndex - for { - targetServer := lb.currentServerAddress - - server := lb.servers[targetServer] - if server == nil || targetServer == "" { - logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer) - } else if allChecksFailed || server.healthCheck() { - dialTime := time.Now() - conn, err := server.dialContext(ctx, network, targetServer) - if err == nil { - return conn, nil - } - logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err) - // Don't close connections to the failed server if we're retrying with health checks ignored. - // We don't want to disrupt active connections if it is unlikely they will have anywhere to go. - if !allChecksFailed { - defer server.closeAll() - } - } else { - logrus.Debugf("Dial health check failed for %s", targetServer) - } - - newServer, err := lb.nextServer(targetServer) - if err != nil { - return nil, err - } - if targetServer != newServer { - logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer) - } - if ctx.Err() != nil { - return nil, ctx.Err() - } - - maxIndex := len(lb.randomServers) - if startIndex > maxIndex { - startIndex = maxIndex - } - if lb.nextServerIndex == startIndex { - if allChecksFailed { - return nil, errors.New("all servers failed") - } - logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName) - allChecksFailed = true - } - } + return lb.servers.getAddresses() } func onDialError(src net.Conn, dstDialErr error) { @@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) { } // ResetLoadBalancer will delete the local state file for the load balancer on disk -func ResetLoadBalancer(dataDir, serviceName string) error { +func ResetLoadBalancer(dataDir, serviceName string) { stateFile := filepath.Join(dataDir, "etc", serviceName+".json") - if err := os.Remove(stateFile); err != nil { + if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) { logrus.Warn(err) } - return nil } diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index cbfdf982c6..69b4fca10c 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -5,19 +5,29 @@ import ( "context" "fmt" "net" - "net/url" + "strconv" "strings" "testing" "time" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" "github.com/sirupsen/logrus" ) +func Test_UnitLoadBalancer(t *testing.T) { + _, reporterConfig := GinkgoConfiguration() + reporterConfig.Verbose = testing.Verbose() + RegisterFailHandler(Fail) + RunSpecs(t, "LoadBalancer Suite", reporterConfig) +} + func init() { logrus.SetLevel(logrus.DebugLevel) } type testServer struct { + address string listener net.Listener conns []net.Conn prefix string @@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) { s := &testServer{ prefix: prefix, listener: listener, + address: listener.Addr().String(), } go s.serve() go func() { @@ -53,6 +64,7 @@ func (s *testServer) serve() { func (s *testServer) close() { logrus.Printf("testServer %s closing", s.prefix) + s.address = "" s.listener.Close() for _, conn := range s.conns { conn.Close() @@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) { } } -func (s *testServer) address() string { - return s.listener.Addr().String() -} - func ping(conn net.Conn) (string, error) { fmt.Fprintf(conn, "ping\n") result, err := bufio.NewReader(conn).ReadString('\n') @@ -82,221 +90,340 @@ func ping(conn net.Conn) (string, error) { return strings.TrimSpace(result), nil } -// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint) -// and then adds a new server (a node). The node server is then closed, and it is confirmed -// that new connections use the default server. -func Test_UnitFailOver(t *testing.T) { - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +var _ = Describe("LoadBalancer", func() { + // creates a LB using a default server (ie fixed registration endpoint) + // and then adds a new server (a node). The node server is then closed, and it is confirmed + // that new connections use the default server. + When("loadbalancer is running", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var defaultServer, node1Server, node2Server *testServer + var conn1, conn2, conn3, conn4 net.Conn + var lb *LoadBalancer + var err error - defaultServer, err := createServer(ctx, "default") - if err != nil { - t.Fatalf("createServer(default) failed: %v", err) - } + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() - node1Server, err := createServer(ctx, "node1") - if err != nil { - t.Fatalf("createServer(node1) failed: %v", err) - } + defaultServer, err = createServer(ctx, "default") + Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") - node2Server, err := createServer(ctx, "node2") - if err != nil { - t.Fatalf("createServer(node2) failed: %v", err) - } + node1Server, err = createServer(ctx, "node1") + Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed") - // start the loadbalancer with the default server as the only server - lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } + node2Server, err = createServer(ctx, "node2") + Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed") - parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) - if err != nil { - t.Fatalf("url.Parse failed: %v", err) - } - localAddress := parsedURL.Host + // start the loadbalancer with the default server as the only server + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) - // add the node as a new server address. - lb.Update([]string{node1Server.address()}) + AfterAll(func() { + cancel() + }) - // make sure connections go to the node - conn1, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - if result, err := ping(conn1); err != nil { - t.Fatalf("ping(conn1) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn1) result: %v", result) - } + It("adds node1 as a server", func() { + // add the node as a new server address. + lb.Update([]string{node1Server.address}) + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) - t.Log("conn1 tested OK") + By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers())) - // set failing health check for node 1 - lb.SetHealthCheck(node1Server.address(), func() bool { return false }) + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node1Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + }) - // Server connections are checked every second, now that node 1 is failed - // the connections to it should be closed. - time.Sleep(2 * time.Second) + It("connects to node1", func() { + // make sure connections go to the node + conn1, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result") - if _, err := ping(conn1); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn1") - } + By("conn1 tested OK") + }) - t.Log("conn1 closed on failure OK") + It("changes node1 state to failed", func() { + // set failing health check for node 1 + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed }) - // make sure connection still goes to the first node - it is failing health checks but so - // is the default endpoint, so it should be tried first with health checks disabled, - // before failing back to the default. - conn2, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node1Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(stateFailed)) + }) - } - if result, err := ping(conn2); err != nil { - t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn2) result: %v", result) - } + It("disconnects from node1", func() { + // Server connections are checked every second, now that node 1 is failed + // the connections to it should be closed. + Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") - t.Log("conn2 tested OK") + By("conn1 closed on failure OK") - // make sure the health checks don't close the connection we just made - - // connections should only be closed when it transitions from health to unhealthy. - time.Sleep(2 * time.Second) + // connections shoould go to the default now that node 1 is failed + conn2, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") - if result, err := ping(conn2); err != nil { - t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn2) result: %v", result) - } + By("conn2 tested OK") + }) - t.Log("conn2 tested OK again") + It("does not close connections unexpectedly", func() { + // make sure the health checks don't close the connection we just made - + // connections should only be closed when it transitions from health to unhealthy. + time.Sleep(2 * time.Second) - // shut down the first node server to force failover to the default - node1Server.close() + Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") - // make sure new connections go to the default, and existing connections are closed - conn3, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) + By("conn2 tested OK again") + }) - } - if result, err := ping(conn3); err != nil { - t.Fatalf("ping(conn3) failed: %v", err) - } else if result != "default:ping" { - t.Fatalf("Unexpected ping(conn3) result: %v", result) - } + It("closes connections when dial fails", func() { + // shut down the first node server to force failover to the default + node1Server.close() - t.Log("conn3 tested OK") + // make sure new connections go to the default, and existing connections are closed + conn3, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") - if _, err := ping(conn2); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn2") - } + Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result") - t.Log("conn2 closed on failure OK") + By("conn3 tested OK") + }) - // add the second node as a new server address. - lb.Update([]string{node2Server.address()}) + It("replaces node2 as a server", func() { + // add the second node as a new server address. + lb.Update([]string{node2Server.address}) + lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) - // make sure connection now goes to the second node, - // and connections to the default are closed. - conn4, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) + By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers())) - } - if result, err := ping(conn4); err != nil { - t.Fatalf("ping(conn4) failed: %v", err) - } else if result != "node2:ping" { - t.Fatalf("Unexpected ping(conn4) result: %v", result) - } + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node2Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + }) - t.Log("conn4 tested OK") + It("connects to node2", func() { + // make sure connection now goes to the second node, + // and connections to the default are closed. + conn4, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") - // Server connections are checked every second, now that we have a healthy - // server, connections to the default server should be closed - time.Sleep(2 * time.Second) + Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result") - if _, err := ping(conn3); err == nil { - t.Fatal("Unexpected successful ping on connection conn3") - } + By("conn4 tested OK") + }) - t.Log("conn3 closed on failure OK") -} + It("does not close connections unexpectedly", func() { + // Server connections are checked every second, now that we have a healthy + // server, connections to the default server should be closed + time.Sleep(2 * time.Second) -// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly -func Test_UnitFailFast(t *testing.T) { - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") - serverURL := "http://127.0.0.1:0/" - lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } + By("conn2 closed on failure OK") - conn, err := net.Dial("tcp", lb.localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } + Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") - done := make(chan error) - go func() { - _, err = ping(conn) - done <- err - }() - timeout := time.After(10 * time.Millisecond) + By("conn3 closed on failure OK") + }) - select { - case err := <-done: - if err == nil { - t.Fatal("Unexpected successful ping from invalid address") - } - case <-timeout: - t.Fatal("Test timed out") - } -} + It("adds default as a server", func() { + // add the default as a full server + lb.Update([]string{node2Server.address, defaultServer.address}) + lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) -// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail -// within the expected duration -func Test_UnitFailUnreachable(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test in short mode.") - } - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(defaultServer.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) - serverAddr := "192.0.2.1:6443" - lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } + By(fmt.Sprintf("Default server added: %v", lb.servers.getServers())) + }) - // Set failing health check to reduce retries - lb.SetHealthCheck(serverAddr, func() bool { return false }) + It("returns the default server in the address list", func() { + // confirm that both servers are listed in the address list + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address)) - conn, err := net.Dial("tcp", lb.localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } + // confirm that the default is still listed as default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") - done := make(chan error) - go func() { - _, err = ping(conn) - done <- err - }() - timeout := time.After(11 * time.Second) + }) - select { - case err := <-done: - if err == nil { - t.Fatal("Unexpected successful ping from unreachable address") - } - case <-timeout: - t.Fatal("Test timed out") - } -} + It("does not return the default server in the address list after removing it", func() { + // remove the default as a server + lb.Update([]string{node2Server.address}) + By(fmt.Sprintf("Default removed: %v", lb.servers.getServers())) + + // confirm that it is not listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) + + // but is still listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + }) + + It("removes default server when no longer default", func() { + // set node2 as the default + lb.SetDefault(node2Server.address) + By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) + + // confirm that it is still listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) + + // and is listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default") + }) + + It("sets all three servers", func() { + // set node2 as the default + lb.SetDefault(defaultServer.address) + By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) + + lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address}) + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(defaultServer.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + + By(fmt.Sprintf("All servers set: %v", lb.servers.getServers())) + + // confirm that it is still listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address)) + + // and is listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + }) + }) + + // confirms that the loadbalancer will not dial itself + When("the default server is the loadbalancer", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var defaultServer *testServer + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + + defaultServer, err = createServer(ctx, "default") + Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") + address := defaultServer.address + defaultServer.close() + _, port, _ := net.SplitHostPort(address) + intPort, _ := strconv.Atoi(port) + + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails immediately", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + _, err = ping(conn) + Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection") + }) + }) + + // confirms that connnections to invalid addresses fail quickly + When("there are no valid addresses", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails fast", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(10 * time.Millisecond) + + select { + case err := <-done: + if err == nil { + Fail("Unexpected successful ping from invalid address") + } + case <-timeout: + Fail("Test timed out") + } + }) + }) + + // confirms that connnections to unreachable addresses do fail within the + // expected duration + When("the server is unreachable", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails with the correct timeout", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(11 * time.Second) + + select { + case err := <-done: + if err == nil { + Fail("Unexpected successful ping from unreachable address") + } + case <-timeout: + Fail("Test timed out") + } + }) + }) +}) diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 675bee5c5c..7cdf8466ed 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -1,118 +1,421 @@ package loadbalancer import ( + "cmp" "context" - "math/rand" + "errors" + "fmt" "net" "slices" + "sync" "time" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" ) -func (lb *LoadBalancer) setServers(serverAddresses []string) bool { - serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) - if len(serverAddresses) == 0 { - return false - } +type HealthCheckFunc func() HealthCheckResult - lb.mutex.Lock() - defer lb.mutex.Unlock() +// HealthCheckResult indicates the status of a server health check poll. +// For health-checks that poll in the background, Unknown should be returned +// if a poll has not occurred since the last check. +type HealthCheckResult int - newAddresses := sets.NewString(serverAddresses...) - curAddresses := sets.NewString(lb.serverAddresses...) +const ( + HealthCheckResultUnknown HealthCheckResult = iota + HealthCheckResultFailed + HealthCheckResultOK +) + +// serverList tracks potential backend servers for use by a loadbalancer. +type serverList struct { + // This mutex protects access to the server list. All direct access to the list should be protected by it. + mutex sync.Mutex + servers []*server +} + +// setServers updates the server list to contain only the selected addresses. +func (sl *serverList) setAddresses(serviceName string, addresses []string) bool { + newAddresses := sets.New(addresses...) + curAddresses := sets.New(sl.getAddresses()...) if newAddresses.Equal(curAddresses) { return false } - for addedServer := range newAddresses.Difference(curAddresses) { - logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) - lb.servers[addedServer] = &server{ - address: addedServer, - connections: make(map[net.Conn]struct{}), - healthCheck: func() bool { return true }, + sl.mutex.Lock() + defer sl.mutex.Unlock() + + var closeAllFuncs []func() + var defaultServer *server + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + defaultServer = sl.servers[i] + } + + // add new servers + for addedAddress := range newAddresses.Difference(curAddresses) { + if defaultServer != nil && defaultServer.address == addedAddress { + // make default server go through the same health check promotions as a new server when added + logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName) + defaultServer.state = stateUnchecked + defaultServer.lastTransition = time.Now() + } else { + s := newServer(addedAddress, false) + logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address) + sl.servers = append(sl.servers, s) } } - for removedServer := range curAddresses.Difference(newAddresses) { - server := lb.servers[removedServer] - if server != nil { - logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) - // Defer closing connections until after the new server list has been put into place. - // Closing open connections ensures that anything stuck retrying on a stale server is forced - // over to a valid endpoint. - defer server.closeAll() - // Don't delete the default server from the server map, in case we need to fall back to it. - if removedServer != lb.defaultServerAddress { - delete(lb.servers, removedServer) - } + // remove old servers + for removedAddress := range curAddresses.Difference(newAddresses) { + if defaultServer != nil && defaultServer.address == removedAddress { + // demote the default server down to standby, instead of deleting it + defaultServer.state = stateStandby + closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll) + } else { + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.address == removedAddress { + logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address) + // set state to invalid to prevent server from making additional connections + s.state = stateInvalid + closeAllFuncs = append(closeAllFuncs, s.closeAll) + return true + } + return false + }) } } - lb.serverAddresses = serverAddresses - lb.randomServers = append([]string{}, lb.serverAddresses...) - rand.Shuffle(len(lb.randomServers), func(i, j int) { - lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] - }) - // If the current server list does not contain the default server address, - // we want to include it in the random server list so that it can be tried if necessary. - // However, it should be treated as always failing health checks so that it is only - // used if all other endpoints are unavailable. - if !hasDefaultServer { - lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) - if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { - defaultServer.healthCheck = func() bool { return false } - lb.servers[lb.defaultServerAddress] = defaultServer - } + slices.SortFunc(sl.servers, compareServers) + + // Close all connections to servers that were removed + for _, closeAll := range closeAllFuncs { + closeAll() } - lb.currentServerAddress = lb.randomServers[0] - lb.nextServerIndex = 1 return true } -// nextServer attempts to get the next server in the loadbalancer server list. -// If another goroutine has already updated the current server address to point at -// a different address than just failed, nothing is changed. Otherwise, a new server address -// is stored to the currentServerAddress field, and returned for use. -// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. -func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { - // note: these fields are not protected by the mutex, so we clamp the index value and update - // the index/current address using local variables, to avoid time-of-check vs time-of-use - // race conditions caused by goroutine A incrementing it in between the time goroutine B - // validates its value, and uses it as a list index. - currentServerAddress := lb.currentServerAddress - nextServerIndex := lb.nextServerIndex +// getAddresses returns the addresses of all servers. +// If the default server is in standby state, indicating it is only present +// because it is the default, it is not returned in this list. +func (sl *serverList) getAddresses() []string { + sl.mutex.Lock() + defer sl.mutex.Unlock() - if len(lb.randomServers) == 0 { - return "", errors.New("No servers in load balancer proxy list") + addresses := make([]string, 0, len(sl.servers)) + for _, s := range sl.servers { + if s.isDefault && s.state == stateStandby { + continue + } + addresses = append(addresses, s.address) } - if len(lb.randomServers) == 1 { - return currentServerAddress, nil - } - if failedServer != currentServerAddress { - return currentServerAddress, nil - } - if nextServerIndex >= len(lb.randomServers) { - nextServerIndex = 0 - } - - currentServerAddress = lb.randomServers[nextServerIndex] - nextServerIndex++ - - lb.currentServerAddress = currentServerAddress - lb.nextServerIndex = nextServerIndex - - return currentServerAddress, nil + return addresses } -// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map -func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := defaultDialer.Dial(network, address) +// setDefault sets the server with the provided address as the default server. +// The default flag is cleared on all other servers, and if the server was previously +// only kept in the list because it was the default, it is removed. +func (sl *serverList) setDefaultAddress(serviceName, address string) { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // deal with existing default first + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.isDefault && s.address != address { + s.isDefault = false + if s.state == stateStandby { + s.state = stateInvalid + defer s.closeAll() + return true + } + } + return false + }) + + // update or create server with selected address + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + sl.servers[i].isDefault = true + } else { + sl.servers = append(sl.servers, newServer(address, true)) + } + + logrus.Infof("Updated load balancer %s default server: %s", serviceName, address) + slices.SortFunc(sl.servers, compareServers) +} + +// getDefault returns the address of the default server. +func (sl *serverList) getDefaultAddress() string { + if s := sl.getDefaultServer(); s != nil { + return s.address + } + return "" +} + +// getDefault returns the default server. +func (sl *serverList) getDefaultServer() *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + return sl.servers[i] + } + return nil +} + +// getServers returns a copy of the servers list that can be safely iterated over without holding a lock +func (sl *serverList) getServers() []*server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + return slices.Clone(sl.servers) +} + +// getServer returns the first server with the specified address +func (sl *serverList) getServer(address string) *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + return sl.servers[i] + } + return nil +} + +// setHealthCheck updates the health check function for a server, replacing the +// current function. +func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error { + if s := sl.getServer(address); s != nil { + s.healthCheck = healthCheck + return nil + } + return fmt.Errorf("no server found for %s", address) +} + +// recordSuccess records a successful check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordSuccess(srv *server, r reason) { + var new_state state + switch srv.state { + case stateFailed, stateUnchecked: + // dialed or health checked OK once, improve to recovering + new_state = stateRecovering + case stateRecovering: + if r == reasonHealthCheck { + // was recovering due to successful dial or first health check, can now improve + if len(srv.connections) > 0 { + // server accepted connections while recovering, attempt to go straight to active + new_state = stateActive + } else { + // no connections, just make it preferred + new_state = statePreferred + } + } + case stateHealthy: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } + case statePreferred: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } else { + if time.Now().Sub(srv.lastTransition) > time.Minute { + // has been preferred for a while without being dialed, demote to healthy + new_state = stateHealthy + } + } + } + + // no-op if state did not change + if new_state == stateInvalid { + return + } + + // handle active transition and sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // handle states of other servers when attempting to make this one active + if new_state == stateActive { + for _, s := range sl.servers { + if srv.address == s.address { + continue + } + switch s.state { + case stateFailed, stateStandby, stateRecovering, stateHealthy: + // close connections to other non-active servers whenever we have a new active server + defer s.closeAll() + case stateActive: + if len(s.connections) > len(srv.connections) { + // if there is a currently active server that has more connections than we do, + // close our connections and go to preferred instead + new_state = statePreferred + defer srv.closeAll() + } else { + // otherwise, close its connections and demote it to preferred + s.state = statePreferred + defer s.closeAll() + } + } + } + } + + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return + } + + logrus.Infof("Server %s->%s from successful %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) +} + +// recordFailure records a failed check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordFailure(srv *server, r reason) { + var new_state state + switch srv.state { + case stateUnchecked, stateRecovering: + if r == reasonDial { + // only demote from unchecked or recovering if a dial fails, health checks may + // continue to fail despite it being dialable. just leave it where it is + // and don't close any connections. + new_state = stateFailed + } + case stateHealthy, statePreferred, stateActive: + // should not have any connections when in any state other than active or + // recovering, but close them all anyway to force failover. + defer srv.closeAll() + new_state = stateFailed + } + + // no-op if state did not change + if new_state == stateInvalid { + return + } + + // sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return + } + + logrus.Infof("Server %s->%s from failed %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) +} + +// state is possible server health states, in increasing order of preference. +// The server list is kept sorted in descending order by this state value. +type state int + +const ( + stateInvalid state = iota + stateFailed // failed a health check or dial + stateStandby // reserved for use by default server if not in server list + stateUnchecked // just added, has not been health checked + stateRecovering // successfully health checked once, or dialed when failed + stateHealthy // normal state + statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance + stateActive // currently active server +) + +func (s state) String() string { + switch s { + case stateInvalid: + return "INVALID" + case stateFailed: + return "FAILED" + case stateStandby: + return "STANDBY" + case stateUnchecked: + return "UNCHECKED" + case stateRecovering: + return "RECOVERING" + case stateHealthy: + return "HEALTHY" + case statePreferred: + return "PREFERRED" + case stateActive: + return "ACTIVE" + default: + return "UNKNOWN" + } +} + +// reason specifies the reason for a successful or failed health report +type reason int + +const ( + reasonDial reason = iota + reasonHealthCheck +) + +func (r reason) String() string { + switch r { + case reasonDial: + return "dial" + case reasonHealthCheck: + return "health check" + default: + return "unknown reason" + } +} + +// server tracks the connections to a server, so that they can be closed when the server is removed. +type server struct { + // This mutex protects access to the connections map. All direct access to the map should be protected by it. + mutex sync.Mutex + address string + isDefault bool + state state + lastTransition time.Time + healthCheck HealthCheckFunc + connections map[net.Conn]struct{} +} + +// newServer creates a new server, with a default health check +// and default/state fields appropriate for whether or not +// the server is a full server, or just a fallback default. +func newServer(address string, isDefault bool) *server { + state := stateUnchecked + if isDefault { + state = stateStandby + } + return &server{ + address: address, + isDefault: isDefault, + state: state, + lastTransition: time.Now(), + healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown }, + connections: make(map[net.Conn]struct{}), + } +} + +func (s *server) String() string { + format := "%s@%s" + if s.isDefault { + format += "*" + } + return fmt.Sprintf(format, s.address, s.state) +} + +// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map +func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) { + if s.state == stateInvalid { + return nil, fmt.Errorf("server %s is stopping", s.address) + } + + conn, err := defaultDialer.Dial(network, s.address) if err != nil { return nil, err } @@ -132,7 +435,7 @@ func (s *server) closeAll() { defer s.mutex.Unlock() if l := len(s.connections); l > 0 { - logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) + logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s) for conn := range s.connections { // Close the connection in a goroutine so that we don't hold the lock while doing so. go conn.Close() @@ -140,6 +443,12 @@ func (s *server) closeAll() { } } +// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. +type serverConn struct { + server *server + net.Conn +} + // Close removes the connection entry from the server's connection map, and // closes the wrapped connection. func (sc *serverConn) Close() error { @@ -150,73 +459,43 @@ func (sc *serverConn) Close() error { return sc.Conn.Close() } -// SetDefault sets the selected address as the default / fallback address -func (lb *LoadBalancer) SetDefault(serverAddress string) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - // if the old default server is not currently in use, remove it from the server map - if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { - defer server.closeAll() - delete(lb.servers, lb.defaultServerAddress) - } - // if the new default server doesn't have an entry in the map, add one - but - // with a failing health check so that it is only used as a last resort. - if _, ok := lb.servers[serverAddress]; !ok { - lb.servers[serverAddress] = &server{ - address: serverAddress, - healthCheck: func() bool { return false }, - connections: make(map[net.Conn]struct{}), - } - } - - lb.defaultServerAddress = serverAddress - logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) -} - -// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. -func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - if server := lb.servers[address]; server != nil { - logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address) - server.healthCheck = healthCheck - } else { - logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address) - } -} - -// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their -// connections closed, to force clients to switch over to a healthy server. -func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { - previousStatus := map[string]bool{} +// runHealthChecks periodically health-checks all servers. +func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) { wait.Until(func() { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - var healthyServerExists bool - for address, server := range lb.servers { - status := server.healthCheck() - healthyServerExists = healthyServerExists || status - if status == false && previousStatus[address] == true { - // Only close connections when the server transitions from healthy to unhealthy; - // we don't want to re-close all the connections every time as we might be ignoring - // health checks due to all servers being marked unhealthy. - defer server.closeAll() - } - previousStatus[address] = status - } - - // If there is at least one healthy server, and the default server is not in the server list, - // close all the connections to the default server so that clients reconnect and switch over - // to a preferred server. - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - if healthyServerExists && !hasDefaultServer { - if server, ok := lb.servers[lb.defaultServerAddress]; ok { - defer server.closeAll() + for _, s := range sl.getServers() { + switch s.healthCheck() { + case HealthCheckResultOK: + sl.recordSuccess(s, reasonHealthCheck) + case HealthCheckResultFailed: + sl.recordFailure(s, reasonHealthCheck) } } }, time.Second, ctx.Done()) - logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) + logrus.Debugf("Stopped health checking for load balancer %s", serviceName) +} + +// dialContext attemps to dial a connection to a server from the server list. +// Success or failure is recorded to ensure that server state is updated appropriately. +func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { + for _, s := range sl.getServers() { + dialTime := time.Now() + conn, err := s.dialContext(ctx, network) + if err == nil { + sl.recordSuccess(s, reasonDial) + return conn, nil + } + logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err) + sl.recordFailure(s, reasonDial) + } + return nil, errors.New("all servers failed") +} + +// compareServers is a comparison function that can be used to sort the server list +// so that servers with a more preferred state, or higher number of connections, are ordered first. +func compareServers(a, b *server) int { + c := cmp.Compare(b.state, a.state) + if c == 0 { + return cmp.Compare(len(b.connections), len(a.connections)) + } + return c } diff --git a/pkg/agent/proxy/apiproxy.go b/pkg/agent/proxy/apiproxy.go index e711623e46..56d86a0313 100644 --- a/pkg/agent/proxy/apiproxy.go +++ b/pkg/agent/proxy/apiproxy.go @@ -22,7 +22,7 @@ type Proxy interface { SupervisorAddresses() []string APIServerURL() string IsAPIServerLBEnabled() bool - SetHealthCheck(address string, healthCheck func() bool) + SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) } // NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If @@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor return nil, err } p.supervisorLB = lb - p.supervisorURL = lb.LoadBalancerServerURL() + p.supervisorURL = lb.LocalURL() p.apiServerURL = p.supervisorURL } @@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) { p.supervisorAddresses = supervisorAddresses } -func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) { +func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) { if p.supervisorLB != nil { p.supervisorLB.SetHealthCheck(address, healthCheck) } @@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error { return err } p.apiServerLB = lb - p.apiServerURL = lb.LoadBalancerServerURL() + p.apiServerURL = lb.LocalURL() } else { p.apiServerURL = u.String() } diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index a5df415c73..d04f9fdc0b 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/websocket" agentconfig "github.com/k3s-io/k3s/pkg/agent/config" + "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/agent/proxy" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" @@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan if _, ok := disconnect[address]; !ok { conn := a.connect(ctx, wg, address, tlsConfig) disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) + proxy.SetHealthCheck(address, conn.healthCheck) } } @@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan if _, ok := disconnect[address]; !ok { conn := a.connect(ctx, nil, address, tlsConfig) disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) + proxy.SetHealthCheck(address, conn.healthCheck) } } @@ -427,22 +428,20 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo } type agentConnection struct { - cancel context.CancelFunc - connected func() bool + cancel context.CancelFunc + healthCheck loadbalancer.HealthCheckFunc } // connect initiates a connection to the remotedialer server. Incoming dial requests from // the server will be checked by the authorizer function prior to being fulfilled. func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection { + var status loadbalancer.HealthCheckResult + wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address) ws := &websocket.Dialer{ TLSClientConfig: tlsConfig, } - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true - once := sync.Once{} if waitGroup != nil { waitGroup.Add(1) @@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup } onConnect := func(_ context.Context, _ *remotedialer.Session) error { - connected = true + status = loadbalancer.HealthCheckResultOK logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy") if waitGroup != nil { once.Do(waitGroup.Done) @@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup for { // ConnectToProxy blocks until error or context cancellation err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect) - connected = false + status = loadbalancer.HealthCheckResultFailed if err != nil && !errors.Is(err, context.Canceled) { logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...") // wait between reconnection attempts to avoid hammering the server @@ -484,8 +483,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup }() return agentConnection{ - cancel: cancel, - connected: func() bool { return connected }, + cancel: cancel, + healthCheck: func() loadbalancer.HealthCheckResult { + return status + }, } } diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index ec781e11a3..156834440c 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -48,6 +48,10 @@ type etcdproxy struct { } func (e *etcdproxy) Update(addresses []string) { + if e.etcdLB == nil { + return + } + e.etcdLB.Update(addresses) validEndpoint := map[string]bool{} @@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) { // start a polling routine that makes periodic requests to the etcd node's supervisor port. // If the request fails, the node is marked unhealthy. -func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true +func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc { + var status loadbalancer.HealthCheckResult host, _, _ := net.SplitHostPort(address) url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort))) @@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() } if err != nil || statusCode != http.StatusOK { logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode) - connected = false + status = loadbalancer.HealthCheckResultFailed } else { - connected = true + status = loadbalancer.HealthCheckResultOK } }, 5*time.Second, 1.0, true) - return func() bool { - return connected + return func() loadbalancer.HealthCheckResult { + // Reset the status to unknown on reading, until next time it is checked. + // This avoids having a health check result alter the server state between active checks. + s := status + status = loadbalancer.HealthCheckResultUnknown + return s } }