Refactor load balancer server list and health checking

Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
(cherry picked from commit 911ee19a93)
Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
pull/11460/head
Brad Davidson 2024-11-15 22:11:47 +00:00 committed by Brad Davidson
parent 867ca25412
commit ba4237aaf7
9 changed files with 837 additions and 489 deletions

View File

@ -15,8 +15,8 @@ type lbConfig struct {
func (lb *LoadBalancer) writeConfig() error {
config := &lbConfig{
ServerURL: lb.serverURL,
ServerAddresses: lb.serverAddresses,
ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
ServerAddresses: lb.servers.getAddresses(),
}
configOut, err := json.MarshalIndent(config, "", " ")
if err != nil {
@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error {
}
func (lb *LoadBalancer) updateConfig() error {
writeConfig := true
if configBytes, err := os.ReadFile(lb.configFile); err == nil {
config := &lbConfig{}
if err := json.Unmarshal(configBytes, config); err == nil {
if config.ServerURL == lb.serverURL {
writeConfig = false
lb.setServers(config.ServerAddresses)
// if the default server from the config matches our current default,
// load the rest of the addresses as well.
if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
lb.Update(config.ServerAddresses)
return nil
}
}
}
if writeConfig {
if err := lb.writeConfig(); err != nil {
return err
}
}
return nil
// config didn't exist or used a different default server, write the current config to disk.
return lb.writeConfig()
}

View File

@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error {
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Create a new HTTP proxy dialer
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer)))
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer)))
return httpProxyDialer, nil
} else if proxyURL.Scheme == "socks5" {
// For SOCKS5 proxies, use the proxy package's FromURL

View File

@ -2,15 +2,16 @@ package loadbalancer
import (
"fmt"
"net"
"os"
"strings"
"testing"
"github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
var originalDialer proxy.Dialer
var defaultEnv map[string]string
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}
@ -19,7 +20,7 @@ func init() {
}
func prepareEnv(env ...string) {
defaultDialer = &net.Dialer{}
originalDialer = defaultDialer
defaultEnv = map[string]string{}
for _, e := range proxyEnvs {
if v, ok := os.LookupEnv(e); ok {
@ -34,6 +35,7 @@ func prepareEnv(env ...string) {
}
func restoreEnv() {
defaultDialer = originalDialer
for _, e := range proxyEnvs {
if v, ok := defaultEnv[e]; ok {
os.Setenv(e, v)

View File

@ -2,55 +2,29 @@ package loadbalancer
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"sync"
"time"
"strings"
"github.com/inetaf/tcpproxy"
"github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus"
)
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// LoadBalancer holds data for a local listener which forwards connections to a
// pool of remote servers. It is not a proper load-balancer in that it does not
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy
serviceName string
configFile string
localAddress string
localServerURL string
defaultServerAddress string
serverURL string
serverAddresses []string
randomServers []string
servers map[string]*server
currentServerAddress string
nextServerIndex int
serviceName string
configFile string
scheme string
localAddress string
servers serverList
proxy *tcpproxy.Proxy
}
const RandomPort = 0
@ -63,7 +37,7 @@ var (
// New contstructs a new LoadBalancer instance. The default server URL, and
// currently active servers, are stored in a file within the dataDir.
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
config := net.ListenConfig{Control: reusePort}
var localAddress string
if isIPv6 {
@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
return nil, err
}
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
localAddress = listener.Addr().String()
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
serverURL, err := url.Parse(defaultServerURL)
if err != nil {
return nil, err
}
if serverURL == localServerURL {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
defaultServerAddress = ""
// Set explicit port from scheme
if serverURL.Port() == "" {
if strings.ToLower(serverURL.Scheme) == "http" {
serverURL.Host += ":80"
}
if strings.ToLower(serverURL.Scheme) == "https" {
serverURL.Host += ":443"
}
}
lb := &LoadBalancer{
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
localAddress: localAddress,
localServerURL: localServerURL,
defaultServerAddress: defaultServerAddress,
servers: make(map[string]*server),
serverURL: serverURL,
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
scheme: serverURL.Scheme,
localAddress: listener.Addr().String(),
}
lb.setServers([]string{lb.defaultServerAddress})
// if starting pointing at ourselves, don't set a default server address,
// which will cause all dials to fail until servers are added.
if serverURL.Host == lb.localAddress {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
} else {
lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host)
}
lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) {
@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
}
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
Addr: serviceName,
DialContext: lb.dialContext,
DialContext: lb.servers.dialContext,
OnDialError: onDialError,
})
@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
if err := lb.proxy.Start(); err != nil {
return nil, err
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress)
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
go lb.runHealthChecks(ctx)
go lb.servers.runHealthChecks(ctx, lb.serviceName)
return lb, nil
}
// Update updates the list of server addresses to contain only the listed servers.
func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil {
if !lb.servers.setAddresses(lb.serviceName, serverAddresses) {
return
}
if !lb.setServers(serverAddresses) {
return
}
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
}
func (lb *LoadBalancer) LoadBalancerServerURL() string {
if lb == nil {
return ""
// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.servers.setDefaultAddress(lb.serviceName, serverAddress)
if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
return lb.localServerURL
}
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) {
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
} else {
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
}
}
func (lb *LoadBalancer) LocalURL() string {
return lb.scheme + "://" + lb.localAddress
}
func (lb *LoadBalancer) ServerAddresses() []string {
if lb == nil {
return nil
}
return lb.serverAddresses
}
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
var allChecksFailed bool
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress
server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else if allChecksFailed || server.healthCheck() {
dialTime := time.Now()
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
// Don't close connections to the failed server if we're retrying with health checks ignored.
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}
newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
if allChecksFailed {
return nil, errors.New("all servers failed")
}
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
allChecksFailed = true
}
}
return lb.servers.getAddresses()
}
func onDialError(src net.Conn, dstDialErr error) {
@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) {
}
// ResetLoadBalancer will delete the local state file for the load balancer on disk
func ResetLoadBalancer(dataDir, serviceName string) error {
func ResetLoadBalancer(dataDir, serviceName string) {
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
if err := os.Remove(stateFile); err != nil {
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
logrus.Warn(err)
}
return nil
}

View File

@ -5,19 +5,29 @@ import (
"context"
"fmt"
"net"
"net/url"
"strconv"
"strings"
"testing"
"time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus"
)
func Test_UnitLoadBalancer(t *testing.T) {
_, reporterConfig := GinkgoConfiguration()
reporterConfig.Verbose = testing.Verbose()
RegisterFailHandler(Fail)
RunSpecs(t, "LoadBalancer Suite", reporterConfig)
}
func init() {
logrus.SetLevel(logrus.DebugLevel)
}
type testServer struct {
address string
listener net.Listener
conns []net.Conn
prefix string
@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) {
s := &testServer{
prefix: prefix,
listener: listener,
address: listener.Addr().String(),
}
go s.serve()
go func() {
@ -53,6 +64,7 @@ func (s *testServer) serve() {
func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix)
s.address = ""
s.listener.Close()
for _, conn := range s.conns {
conn.Close()
@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) {
}
}
func (s *testServer) address() string {
return s.listener.Addr().String()
}
func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n')
@ -82,221 +90,340 @@ func ping(conn net.Conn) (string, error) {
return strings.TrimSpace(result), nil
}
// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server.
func Test_UnitFailOver(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var _ = Describe("LoadBalancer", func() {
// creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server.
When("loadbalancer is running", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var defaultServer, node1Server, node2Server *testServer
var conn1, conn2, conn3, conn4 net.Conn
var lb *LoadBalancer
var err error
defaultServer, err := createServer(ctx, "default")
if err != nil {
t.Fatalf("createServer(default) failed: %v", err)
}
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
node1Server, err := createServer(ctx, "node1")
if err != nil {
t.Fatalf("createServer(node1) failed: %v", err)
}
defaultServer, err = createServer(ctx, "default")
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
node2Server, err := createServer(ctx, "node2")
if err != nil {
t.Fatalf("createServer(node2) failed: %v", err)
}
node1Server, err = createServer(ctx, "node1")
Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed")
// start the loadbalancer with the default server as the only server
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
node2Server, err = createServer(ctx, "node2")
Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed")
parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
if err != nil {
t.Fatalf("url.Parse failed: %v", err)
}
localAddress := parsedURL.Host
// start the loadbalancer with the default server as the only server
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
// add the node as a new server address.
lb.Update([]string{node1Server.address()})
AfterAll(func() {
cancel()
})
// make sure connections go to the node
conn1, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
if result, err := ping(conn1); err != nil {
t.Fatalf("ping(conn1) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn1) result: %v", result)
}
It("adds node1 as a server", func() {
// add the node as a new server address.
lb.Update([]string{node1Server.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
t.Log("conn1 tested OK")
By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers()))
// set failing health check for node 1
lb.SetHealthCheck(node1Server.address(), func() bool { return false })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
// Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed.
time.Sleep(2 * time.Second)
It("connects to node1", func() {
// make sure connections go to the node
conn1, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result")
if _, err := ping(conn1); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn1")
}
By("conn1 tested OK")
})
t.Log("conn1 closed on failure OK")
It("changes node1 state to failed", func() {
// set failing health check for node 1
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed })
// make sure connection still goes to the first node - it is failing health checks but so
// is the default endpoint, so it should be tried first with health checks disabled,
// before failing back to the default.
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(stateFailed))
})
}
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
It("disconnects from node1", func() {
// Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed.
Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
t.Log("conn2 tested OK")
By("conn1 closed on failure OK")
// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)
// connections shoould go to the default now that node 1 is failed
conn2, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
By("conn2 tested OK")
})
t.Log("conn2 tested OK again")
It("does not close connections unexpectedly", func() {
// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)
// shut down the first node server to force failover to the default
node1Server.close()
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
// make sure new connections go to the default, and existing connections are closed
conn3, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
By("conn2 tested OK again")
})
}
if result, err := ping(conn3); err != nil {
t.Fatalf("ping(conn3) failed: %v", err)
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn3) result: %v", result)
}
It("closes connections when dial fails", func() {
// shut down the first node server to force failover to the default
node1Server.close()
t.Log("conn3 tested OK")
// make sure new connections go to the default, and existing connections are closed
conn3, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
if _, err := ping(conn2); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn2")
}
Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result")
t.Log("conn2 closed on failure OK")
By("conn3 tested OK")
})
// add the second node as a new server address.
lb.Update([]string{node2Server.address()})
It("replaces node2 as a server", func() {
// add the second node as a new server address.
lb.Update([]string{node2Server.address})
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers()))
}
if result, err := ping(conn4); err != nil {
t.Fatalf("ping(conn4) failed: %v", err)
} else if result != "node2:ping" {
t.Fatalf("Unexpected ping(conn4) result: %v", result)
}
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node2Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
t.Log("conn4 tested OK")
It("connects to node2", func() {
// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
// Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)
Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result")
if _, err := ping(conn3); err == nil {
t.Fatal("Unexpected successful ping on connection conn3")
}
By("conn4 tested OK")
})
t.Log("conn3 closed on failure OK")
}
It("does not close connections unexpectedly", func() {
// Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)
// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly
func Test_UnitFailFast(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
serverURL := "http://127.0.0.1:0/"
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
By("conn2 closed on failure OK")
conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
By("conn3 closed on failure OK")
})
select {
case err := <-done:
if err == nil {
t.Fatal("Unexpected successful ping from invalid address")
}
case <-timeout:
t.Fatal("Test timed out")
}
}
It("adds default as a server", func() {
// add the default as a full server
lb.Update([]string{node2Server.address, defaultServer.address})
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail
// within the expected duration
func Test_UnitFailUnreachable(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test in short mode.")
}
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
serverAddr := "192.0.2.1:6443"
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
By(fmt.Sprintf("Default server added: %v", lb.servers.getServers()))
})
// Set failing health check to reduce retries
lb.SetHealthCheck(serverAddr, func() bool { return false })
It("returns the default server in the address list", func() {
// confirm that both servers are listed in the address list
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address))
conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
// confirm that the default is still listed as default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)
})
select {
case err := <-done:
if err == nil {
t.Fatal("Unexpected successful ping from unreachable address")
}
case <-timeout:
t.Fatal("Test timed out")
}
}
It("does not return the default server in the address list after removing it", func() {
// remove the default as a server
lb.Update([]string{node2Server.address})
By(fmt.Sprintf("Default removed: %v", lb.servers.getServers()))
// confirm that it is not listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// but is still listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
It("removes default server when no longer default", func() {
// set node2 as the default
lb.SetDefault(node2Server.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default")
})
It("sets all three servers", func() {
// set node2 as the default
lb.SetDefault(defaultServer.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
By(fmt.Sprintf("All servers set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
})
// confirms that the loadbalancer will not dial itself
When("the default server is the loadbalancer", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var defaultServer *testServer
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
defaultServer, err = createServer(ctx, "default")
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
address := defaultServer.address
defaultServer.close()
_, port, _ := net.SplitHostPort(address)
intPort, _ := strconv.Atoi(port)
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails immediately", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
_, err = ping(conn)
Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection")
})
})
// confirms that connnections to invalid addresses fail quickly
When("there are no valid addresses", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails fast", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
select {
case err := <-done:
if err == nil {
Fail("Unexpected successful ping from invalid address")
}
case <-timeout:
Fail("Test timed out")
}
})
})
// confirms that connnections to unreachable addresses do fail within the
// expected duration
When("the server is unreachable", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails with the correct timeout", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)
select {
case err := <-done:
if err == nil {
Fail("Unexpected successful ping from unreachable address")
}
case <-timeout:
Fail("Test timed out")
}
})
})
})

View File

@ -1,118 +1,421 @@
package loadbalancer
import (
"cmp"
"context"
"math/rand"
"errors"
"fmt"
"net"
"slices"
"sync"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
)
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
if len(serverAddresses) == 0 {
return false
}
type HealthCheckFunc func() HealthCheckResult
lb.mutex.Lock()
defer lb.mutex.Unlock()
// HealthCheckResult indicates the status of a server health check poll.
// For health-checks that poll in the background, Unknown should be returned
// if a poll has not occurred since the last check.
type HealthCheckResult int
newAddresses := sets.NewString(serverAddresses...)
curAddresses := sets.NewString(lb.serverAddresses...)
const (
HealthCheckResultUnknown HealthCheckResult = iota
HealthCheckResultFailed
HealthCheckResultOK
)
// serverList tracks potential backend servers for use by a loadbalancer.
type serverList struct {
// This mutex protects access to the server list. All direct access to the list should be protected by it.
mutex sync.Mutex
servers []*server
}
// setServers updates the server list to contain only the selected addresses.
func (sl *serverList) setAddresses(serviceName string, addresses []string) bool {
newAddresses := sets.New(addresses...)
curAddresses := sets.New(sl.getAddresses()...)
if newAddresses.Equal(curAddresses) {
return false
}
for addedServer := range newAddresses.Difference(curAddresses) {
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
lb.servers[addedServer] = &server{
address: addedServer,
connections: make(map[net.Conn]struct{}),
healthCheck: func() bool { return true },
sl.mutex.Lock()
defer sl.mutex.Unlock()
var closeAllFuncs []func()
var defaultServer *server
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
defaultServer = sl.servers[i]
}
// add new servers
for addedAddress := range newAddresses.Difference(curAddresses) {
if defaultServer != nil && defaultServer.address == addedAddress {
// make default server go through the same health check promotions as a new server when added
logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName)
defaultServer.state = stateUnchecked
defaultServer.lastTransition = time.Now()
} else {
s := newServer(addedAddress, false)
logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address)
sl.servers = append(sl.servers, s)
}
}
for removedServer := range curAddresses.Difference(newAddresses) {
server := lb.servers[removedServer]
if server != nil {
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer)
// Defer closing connections until after the new server list has been put into place.
// Closing open connections ensures that anything stuck retrying on a stale server is forced
// over to a valid endpoint.
defer server.closeAll()
// Don't delete the default server from the server map, in case we need to fall back to it.
if removedServer != lb.defaultServerAddress {
delete(lb.servers, removedServer)
}
// remove old servers
for removedAddress := range curAddresses.Difference(newAddresses) {
if defaultServer != nil && defaultServer.address == removedAddress {
// demote the default server down to standby, instead of deleting it
defaultServer.state = stateStandby
closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll)
} else {
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
if s.address == removedAddress {
logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address)
// set state to invalid to prevent server from making additional connections
s.state = stateInvalid
closeAllFuncs = append(closeAllFuncs, s.closeAll)
return true
}
return false
})
}
}
lb.serverAddresses = serverAddresses
lb.randomServers = append([]string{}, lb.serverAddresses...)
rand.Shuffle(len(lb.randomServers), func(i, j int) {
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
})
// If the current server list does not contain the default server address,
// we want to include it in the random server list so that it can be tried if necessary.
// However, it should be treated as always failing health checks so that it is only
// used if all other endpoints are unavailable.
if !hasDefaultServer {
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
defaultServer.healthCheck = func() bool { return false }
lb.servers[lb.defaultServerAddress] = defaultServer
}
slices.SortFunc(sl.servers, compareServers)
// Close all connections to servers that were removed
for _, closeAll := range closeAllFuncs {
closeAll()
}
lb.currentServerAddress = lb.randomServers[0]
lb.nextServerIndex = 1
return true
}
// nextServer attempts to get the next server in the loadbalancer server list.
// If another goroutine has already updated the current server address to point at
// a different address than just failed, nothing is changed. Otherwise, a new server address
// is stored to the currentServerAddress field, and returned for use.
// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex.
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
// note: these fields are not protected by the mutex, so we clamp the index value and update
// the index/current address using local variables, to avoid time-of-check vs time-of-use
// race conditions caused by goroutine A incrementing it in between the time goroutine B
// validates its value, and uses it as a list index.
currentServerAddress := lb.currentServerAddress
nextServerIndex := lb.nextServerIndex
// getAddresses returns the addresses of all servers.
// If the default server is in standby state, indicating it is only present
// because it is the default, it is not returned in this list.
func (sl *serverList) getAddresses() []string {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if len(lb.randomServers) == 0 {
return "", errors.New("No servers in load balancer proxy list")
addresses := make([]string, 0, len(sl.servers))
for _, s := range sl.servers {
if s.isDefault && s.state == stateStandby {
continue
}
addresses = append(addresses, s.address)
}
if len(lb.randomServers) == 1 {
return currentServerAddress, nil
}
if failedServer != currentServerAddress {
return currentServerAddress, nil
}
if nextServerIndex >= len(lb.randomServers) {
nextServerIndex = 0
}
currentServerAddress = lb.randomServers[nextServerIndex]
nextServerIndex++
lb.currentServerAddress = currentServerAddress
lb.nextServerIndex = nextServerIndex
return currentServerAddress, nil
return addresses
}
// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := defaultDialer.Dial(network, address)
// setDefault sets the server with the provided address as the default server.
// The default flag is cleared on all other servers, and if the server was previously
// only kept in the list because it was the default, it is removed.
func (sl *serverList) setDefaultAddress(serviceName, address string) {
sl.mutex.Lock()
defer sl.mutex.Unlock()
// deal with existing default first
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
if s.isDefault && s.address != address {
s.isDefault = false
if s.state == stateStandby {
s.state = stateInvalid
defer s.closeAll()
return true
}
}
return false
})
// update or create server with selected address
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
sl.servers[i].isDefault = true
} else {
sl.servers = append(sl.servers, newServer(address, true))
}
logrus.Infof("Updated load balancer %s default server: %s", serviceName, address)
slices.SortFunc(sl.servers, compareServers)
}
// getDefault returns the address of the default server.
func (sl *serverList) getDefaultAddress() string {
if s := sl.getDefaultServer(); s != nil {
return s.address
}
return ""
}
// getDefault returns the default server.
func (sl *serverList) getDefaultServer() *server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
return sl.servers[i]
}
return nil
}
// getServers returns a copy of the servers list that can be safely iterated over without holding a lock
func (sl *serverList) getServers() []*server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
return slices.Clone(sl.servers)
}
// getServer returns the first server with the specified address
func (sl *serverList) getServer(address string) *server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
return sl.servers[i]
}
return nil
}
// setHealthCheck updates the health check function for a server, replacing the
// current function.
func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error {
if s := sl.getServer(address); s != nil {
s.healthCheck = healthCheck
return nil
}
return fmt.Errorf("no server found for %s", address)
}
// recordSuccess records a successful check of a server, either via health-check or dial.
// The server's state is adjusted accordingly.
func (sl *serverList) recordSuccess(srv *server, r reason) {
var new_state state
switch srv.state {
case stateFailed, stateUnchecked:
// dialed or health checked OK once, improve to recovering
new_state = stateRecovering
case stateRecovering:
if r == reasonHealthCheck {
// was recovering due to successful dial or first health check, can now improve
if len(srv.connections) > 0 {
// server accepted connections while recovering, attempt to go straight to active
new_state = stateActive
} else {
// no connections, just make it preferred
new_state = statePreferred
}
}
case stateHealthy:
if r == reasonDial {
// improve from healthy to active by being dialed
new_state = stateActive
}
case statePreferred:
if r == reasonDial {
// improve from healthy to active by being dialed
new_state = stateActive
} else {
if time.Now().Sub(srv.lastTransition) > time.Minute {
// has been preferred for a while without being dialed, demote to healthy
new_state = stateHealthy
}
}
}
// no-op if state did not change
if new_state == stateInvalid {
return
}
// handle active transition and sort the server list while holding the lock
sl.mutex.Lock()
defer sl.mutex.Unlock()
// handle states of other servers when attempting to make this one active
if new_state == stateActive {
for _, s := range sl.servers {
if srv.address == s.address {
continue
}
switch s.state {
case stateFailed, stateStandby, stateRecovering, stateHealthy:
// close connections to other non-active servers whenever we have a new active server
defer s.closeAll()
case stateActive:
if len(s.connections) > len(srv.connections) {
// if there is a currently active server that has more connections than we do,
// close our connections and go to preferred instead
new_state = statePreferred
defer srv.closeAll()
} else {
// otherwise, close its connections and demote it to preferred
s.state = statePreferred
defer s.closeAll()
}
}
}
}
// ensure some other routine didn't already make the transition
if srv.state == new_state {
return
}
logrus.Infof("Server %s->%s from successful %s", srv, new_state, r)
srv.state = new_state
srv.lastTransition = time.Now()
slices.SortFunc(sl.servers, compareServers)
}
// recordFailure records a failed check of a server, either via health-check or dial.
// The server's state is adjusted accordingly.
func (sl *serverList) recordFailure(srv *server, r reason) {
var new_state state
switch srv.state {
case stateUnchecked, stateRecovering:
if r == reasonDial {
// only demote from unchecked or recovering if a dial fails, health checks may
// continue to fail despite it being dialable. just leave it where it is
// and don't close any connections.
new_state = stateFailed
}
case stateHealthy, statePreferred, stateActive:
// should not have any connections when in any state other than active or
// recovering, but close them all anyway to force failover.
defer srv.closeAll()
new_state = stateFailed
}
// no-op if state did not change
if new_state == stateInvalid {
return
}
// sort the server list while holding the lock
sl.mutex.Lock()
defer sl.mutex.Unlock()
// ensure some other routine didn't already make the transition
if srv.state == new_state {
return
}
logrus.Infof("Server %s->%s from failed %s", srv, new_state, r)
srv.state = new_state
srv.lastTransition = time.Now()
slices.SortFunc(sl.servers, compareServers)
}
// state is possible server health states, in increasing order of preference.
// The server list is kept sorted in descending order by this state value.
type state int
const (
stateInvalid state = iota
stateFailed // failed a health check or dial
stateStandby // reserved for use by default server if not in server list
stateUnchecked // just added, has not been health checked
stateRecovering // successfully health checked once, or dialed when failed
stateHealthy // normal state
statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance
stateActive // currently active server
)
func (s state) String() string {
switch s {
case stateInvalid:
return "INVALID"
case stateFailed:
return "FAILED"
case stateStandby:
return "STANDBY"
case stateUnchecked:
return "UNCHECKED"
case stateRecovering:
return "RECOVERING"
case stateHealthy:
return "HEALTHY"
case statePreferred:
return "PREFERRED"
case stateActive:
return "ACTIVE"
default:
return "UNKNOWN"
}
}
// reason specifies the reason for a successful or failed health report
type reason int
const (
reasonDial reason = iota
reasonHealthCheck
)
func (r reason) String() string {
switch r {
case reasonDial:
return "dial"
case reasonHealthCheck:
return "health check"
default:
return "unknown reason"
}
}
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
isDefault bool
state state
lastTransition time.Time
healthCheck HealthCheckFunc
connections map[net.Conn]struct{}
}
// newServer creates a new server, with a default health check
// and default/state fields appropriate for whether or not
// the server is a full server, or just a fallback default.
func newServer(address string, isDefault bool) *server {
state := stateUnchecked
if isDefault {
state = stateStandby
}
return &server{
address: address,
isDefault: isDefault,
state: state,
lastTransition: time.Now(),
healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown },
connections: make(map[net.Conn]struct{}),
}
}
func (s *server) String() string {
format := "%s@%s"
if s.isDefault {
format += "*"
}
return fmt.Sprintf(format, s.address, s.state)
}
// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map
func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) {
if s.state == stateInvalid {
return nil, fmt.Errorf("server %s is stopping", s.address)
}
conn, err := defaultDialer.Dial(network, s.address)
if err != nil {
return nil, err
}
@ -132,7 +435,7 @@ func (s *server) closeAll() {
defer s.mutex.Unlock()
if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s)
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
@ -140,6 +443,12 @@ func (s *server) closeAll() {
}
}
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// Close removes the connection entry from the server's connection map, and
// closes the wrapped connection.
func (sc *serverConn) Close() error {
@ -150,73 +459,43 @@ func (sc *serverConn) Close() error {
return sc.Conn.Close()
}
// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one - but
// with a failing health check so that it is only used as a last resort.
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return false },
connections: make(map[net.Conn]struct{}),
}
}
lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}
// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
previousStatus := map[string]bool{}
// runHealthChecks periodically health-checks all servers.
func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) {
wait.Until(func() {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
var healthyServerExists bool
for address, server := range lb.servers {
status := server.healthCheck()
healthyServerExists = healthyServerExists || status
if status == false && previousStatus[address] == true {
// Only close connections when the server transitions from healthy to unhealthy;
// we don't want to re-close all the connections every time as we might be ignoring
// health checks due to all servers being marked unhealthy.
defer server.closeAll()
}
previousStatus[address] = status
}
// If there is at least one healthy server, and the default server is not in the server list,
// close all the connections to the default server so that clients reconnect and switch over
// to a preferred server.
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
if healthyServerExists && !hasDefaultServer {
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
defer server.closeAll()
for _, s := range sl.getServers() {
switch s.healthCheck() {
case HealthCheckResultOK:
sl.recordSuccess(s, reasonHealthCheck)
case HealthCheckResultFailed:
sl.recordFailure(s, reasonHealthCheck)
}
}
}, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
logrus.Debugf("Stopped health checking for load balancer %s", serviceName)
}
// dialContext attemps to dial a connection to a server from the server list.
// Success or failure is recorded to ensure that server state is updated appropriately.
func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
for _, s := range sl.getServers() {
dialTime := time.Now()
conn, err := s.dialContext(ctx, network)
if err == nil {
sl.recordSuccess(s, reasonDial)
return conn, nil
}
logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err)
sl.recordFailure(s, reasonDial)
}
return nil, errors.New("all servers failed")
}
// compareServers is a comparison function that can be used to sort the server list
// so that servers with a more preferred state, or higher number of connections, are ordered first.
func compareServers(a, b *server) int {
c := cmp.Compare(b.state, a.state)
if c == 0 {
return cmp.Compare(len(b.connections), len(a.connections))
}
return c
}

View File

@ -22,7 +22,7 @@ type Proxy interface {
SupervisorAddresses() []string
APIServerURL() string
IsAPIServerLBEnabled() bool
SetHealthCheck(address string, healthCheck func() bool)
SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc)
}
// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
return nil, err
}
p.supervisorLB = lb
p.supervisorURL = lb.LoadBalancerServerURL()
p.supervisorURL = lb.LocalURL()
p.apiServerURL = p.supervisorURL
}
@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) {
p.supervisorAddresses = supervisorAddresses
}
func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) {
func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) {
if p.supervisorLB != nil {
p.supervisorLB.SetHealthCheck(address, healthCheck)
}
@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
return err
}
p.apiServerLB = lb
p.apiServerURL = lb.LoadBalancerServerURL()
p.apiServerURL = lb.LocalURL()
} else {
p.apiServerURL = u.String()
}

View File

@ -14,6 +14,7 @@ import (
"github.com/gorilla/websocket"
agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
"github.com/k3s-io/k3s/pkg/agent/loadbalancer"
"github.com/k3s-io/k3s/pkg/agent/proxy"
daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
"github.com/k3s-io/k3s/pkg/util"
@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
if _, ok := disconnect[address]; !ok {
conn := a.connect(ctx, wg, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
proxy.SetHealthCheck(address, conn.healthCheck)
}
}
@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
if _, ok := disconnect[address]; !ok {
conn := a.connect(ctx, nil, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
proxy.SetHealthCheck(address, conn.healthCheck)
}
}
@ -427,22 +428,20 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
}
type agentConnection struct {
cancel context.CancelFunc
connected func() bool
cancel context.CancelFunc
healthCheck loadbalancer.HealthCheckFunc
}
// connect initiates a connection to the remotedialer server. Incoming dial requests from
// the server will be checked by the authorizer function prior to being fulfilled.
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
var status loadbalancer.HealthCheckResult
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{
TLSClientConfig: tlsConfig,
}
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true
once := sync.Once{}
if waitGroup != nil {
waitGroup.Add(1)
@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
}
onConnect := func(_ context.Context, _ *remotedialer.Session) error {
connected = true
status = loadbalancer.HealthCheckResultOK
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
if waitGroup != nil {
once.Do(waitGroup.Done)
@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
for {
// ConnectToProxy blocks until error or context cancellation
err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect)
connected = false
status = loadbalancer.HealthCheckResultFailed
if err != nil && !errors.Is(err, context.Canceled) {
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...")
// wait between reconnection attempts to avoid hammering the server
@ -484,8 +483,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
}()
return agentConnection{
cancel: cancel,
connected: func() bool { return connected },
cancel: cancel,
healthCheck: func() loadbalancer.HealthCheckResult {
return status
},
}
}

View File

@ -48,6 +48,10 @@ type etcdproxy struct {
}
func (e *etcdproxy) Update(addresses []string) {
if e.etcdLB == nil {
return
}
e.etcdLB.Update(addresses)
validEndpoint := map[string]bool{}
@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) {
// start a polling routine that makes periodic requests to the etcd node's supervisor port.
// If the request fails, the node is marked unhealthy.
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool {
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc {
var status loadbalancer.HealthCheckResult
host, _, _ := net.SplitHostPort(address)
url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort)))
@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func()
}
if err != nil || statusCode != http.StatusOK {
logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode)
connected = false
status = loadbalancer.HealthCheckResultFailed
} else {
connected = true
status = loadbalancer.HealthCheckResultOK
}
}, 5*time.Second, 1.0, true)
return func() bool {
return connected
return func() loadbalancer.HealthCheckResult {
// Reset the status to unknown on reading, until next time it is checked.
// This avoids having a health check result alter the server state between active checks.
s := status
status = loadbalancer.HealthCheckResultUnknown
return s
}
}