Refactor load balancer server list and health checking

Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
pull/11430/head
Brad Davidson 2024-11-15 22:11:47 +00:00 committed by Brad Davidson
parent 95797c4a79
commit 911ee19a93
9 changed files with 837 additions and 489 deletions

View File

@ -15,8 +15,8 @@ type lbConfig struct {
func (lb *LoadBalancer) writeConfig() error { func (lb *LoadBalancer) writeConfig() error {
config := &lbConfig{ config := &lbConfig{
ServerURL: lb.serverURL, ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
ServerAddresses: lb.serverAddresses, ServerAddresses: lb.servers.getAddresses(),
} }
configOut, err := json.MarshalIndent(config, "", " ") configOut, err := json.MarshalIndent(config, "", " ")
if err != nil { if err != nil {
@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error {
} }
func (lb *LoadBalancer) updateConfig() error { func (lb *LoadBalancer) updateConfig() error {
writeConfig := true
if configBytes, err := os.ReadFile(lb.configFile); err == nil { if configBytes, err := os.ReadFile(lb.configFile); err == nil {
config := &lbConfig{} config := &lbConfig{}
if err := json.Unmarshal(configBytes, config); err == nil { if err := json.Unmarshal(configBytes, config); err == nil {
if config.ServerURL == lb.serverURL { // if the default server from the config matches our current default,
writeConfig = false // load the rest of the addresses as well.
lb.setServers(config.ServerAddresses) if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
lb.Update(config.ServerAddresses)
return nil
} }
} }
} }
if writeConfig { // config didn't exist or used a different default server, write the current config to disk.
if err := lb.writeConfig(); err != nil { return lb.writeConfig()
return err
}
}
return nil
} }

View File

@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error {
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Create a new HTTP proxy dialer // Create a new HTTP proxy dialer
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer)))
return httpProxyDialer, nil return httpProxyDialer, nil
} else if proxyURL.Scheme == "socks5" { } else if proxyURL.Scheme == "socks5" {
// For SOCKS5 proxies, use the proxy package's FromURL // For SOCKS5 proxies, use the proxy package's FromURL

View File

@ -2,15 +2,16 @@ package loadbalancer
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"strings" "strings"
"testing" "testing"
"github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
) )
var originalDialer proxy.Dialer
var defaultEnv map[string]string var defaultEnv map[string]string
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"} var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}
@ -19,7 +20,7 @@ func init() {
} }
func prepareEnv(env ...string) { func prepareEnv(env ...string) {
defaultDialer = &net.Dialer{} originalDialer = defaultDialer
defaultEnv = map[string]string{} defaultEnv = map[string]string{}
for _, e := range proxyEnvs { for _, e := range proxyEnvs {
if v, ok := os.LookupEnv(e); ok { if v, ok := os.LookupEnv(e); ok {
@ -34,6 +35,7 @@ func prepareEnv(env ...string) {
} }
func restoreEnv() { func restoreEnv() {
defaultDialer = originalDialer
for _, e := range proxyEnvs { for _, e := range proxyEnvs {
if v, ok := defaultEnv[e]; ok { if v, ok := defaultEnv[e]; ok {
os.Setenv(e, v) os.Setenv(e, v)

View File

@ -2,55 +2,29 @@ package loadbalancer
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"sync" "strings"
"time"
"github.com/inetaf/tcpproxy" "github.com/inetaf/tcpproxy"
"github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// LoadBalancer holds data for a local listener which forwards connections to a // LoadBalancer holds data for a local listener which forwards connections to a
// pool of remote servers. It is not a proper load-balancer in that it does not // pool of remote servers. It is not a proper load-balancer in that it does not
// actually balance connections, but instead fails over to a new server only // actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails. // when a connection attempt to the currently selected server fails.
type LoadBalancer struct { type LoadBalancer struct {
// This mutex protects access to servers map and randomServers list. serviceName string
// All direct access to the servers map/list should be protected by it. configFile string
mutex sync.RWMutex scheme string
proxy *tcpproxy.Proxy localAddress string
servers serverList
serviceName string proxy *tcpproxy.Proxy
configFile string
localAddress string
localServerURL string
defaultServerAddress string
serverURL string
serverAddresses []string
randomServers []string
servers map[string]*server
currentServerAddress string
nextServerIndex int
} }
const RandomPort = 0 const RandomPort = 0
@ -63,7 +37,7 @@ var (
// New contstructs a new LoadBalancer instance. The default server URL, and // New contstructs a new LoadBalancer instance. The default server URL, and
// currently active servers, are stored in a file within the dataDir. // currently active servers, are stored in a file within the dataDir.
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
config := net.ListenConfig{Control: reusePort} config := net.ListenConfig{Control: reusePort}
var localAddress string var localAddress string
if isIPv6 { if isIPv6 {
@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
return nil, err return nil, err
} }
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. serverURL, err := url.Parse(defaultServerURL)
localAddress = listener.Addr().String()
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if serverURL == localServerURL { // Set explicit port from scheme
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) if serverURL.Port() == "" {
defaultServerAddress = "" if strings.ToLower(serverURL.Scheme) == "http" {
serverURL.Host += ":80"
}
if strings.ToLower(serverURL.Scheme) == "https" {
serverURL.Host += ":443"
}
} }
lb := &LoadBalancer{ lb := &LoadBalancer{
serviceName: serviceName, serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"), configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
localAddress: localAddress, scheme: serverURL.Scheme,
localServerURL: localServerURL, localAddress: listener.Addr().String(),
defaultServerAddress: defaultServerAddress,
servers: make(map[string]*server),
serverURL: serverURL,
} }
lb.setServers([]string{lb.defaultServerAddress}) // if starting pointing at ourselves, don't set a default server address,
// which will cause all dials to fail until servers are added.
if serverURL.Host == lb.localAddress {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
} else {
lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host)
}
lb.proxy = &tcpproxy.Proxy{ lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) { ListenFunc: func(string, string) (net.Listener, error) {
@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
} }
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
Addr: serviceName, Addr: serviceName,
DialContext: lb.dialContext, DialContext: lb.servers.dialContext,
OnDialError: onDialError, OnDialError: onDialError,
}) })
@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
if err := lb.proxy.Start(); err != nil { if err := lb.proxy.Start(); err != nil {
return nil, err return nil, err
} }
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
go lb.runHealthChecks(ctx) go lb.servers.runHealthChecks(ctx, lb.serviceName)
return lb, nil return lb, nil
} }
// Update updates the list of server addresses to contain only the listed servers.
func (lb *LoadBalancer) Update(serverAddresses []string) { func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil { if !lb.servers.setAddresses(lb.serviceName, serverAddresses) {
return return
} }
if !lb.setServers(serverAddresses) {
return logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
}
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)
if err := lb.writeConfig(); err != nil { if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
} }
} }
func (lb *LoadBalancer) LoadBalancerServerURL() string { // SetDefault sets the selected address as the default / fallback address
if lb == nil { func (lb *LoadBalancer) SetDefault(serverAddress string) {
return "" lb.servers.setDefaultAddress(lb.serviceName, serverAddress)
if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
} }
return lb.localServerURL }
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) {
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
} else {
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
}
}
func (lb *LoadBalancer) LocalURL() string {
return lb.scheme + "://" + lb.localAddress
} }
func (lb *LoadBalancer) ServerAddresses() []string { func (lb *LoadBalancer) ServerAddresses() []string {
if lb == nil { return lb.servers.getAddresses()
return nil
}
return lb.serverAddresses
}
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
var allChecksFailed bool
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress
server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else if allChecksFailed || server.healthCheck() {
dialTime := time.Now()
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
// Don't close connections to the failed server if we're retrying with health checks ignored.
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}
newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
if allChecksFailed {
return nil, errors.New("all servers failed")
}
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
allChecksFailed = true
}
}
} }
func onDialError(src net.Conn, dstDialErr error) { func onDialError(src net.Conn, dstDialErr error) {
@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) {
} }
// ResetLoadBalancer will delete the local state file for the load balancer on disk // ResetLoadBalancer will delete the local state file for the load balancer on disk
func ResetLoadBalancer(dataDir, serviceName string) error { func ResetLoadBalancer(dataDir, serviceName string) {
stateFile := filepath.Join(dataDir, "etc", serviceName+".json") stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
if err := os.Remove(stateFile); err != nil { if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
logrus.Warn(err) logrus.Warn(err)
} }
return nil
} }

View File

@ -5,19 +5,29 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/url" "strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func Test_UnitLoadBalancer(t *testing.T) {
_, reporterConfig := GinkgoConfiguration()
reporterConfig.Verbose = testing.Verbose()
RegisterFailHandler(Fail)
RunSpecs(t, "LoadBalancer Suite", reporterConfig)
}
func init() { func init() {
logrus.SetLevel(logrus.DebugLevel) logrus.SetLevel(logrus.DebugLevel)
} }
type testServer struct { type testServer struct {
address string
listener net.Listener listener net.Listener
conns []net.Conn conns []net.Conn
prefix string prefix string
@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) {
s := &testServer{ s := &testServer{
prefix: prefix, prefix: prefix,
listener: listener, listener: listener,
address: listener.Addr().String(),
} }
go s.serve() go s.serve()
go func() { go func() {
@ -53,6 +64,7 @@ func (s *testServer) serve() {
func (s *testServer) close() { func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix) logrus.Printf("testServer %s closing", s.prefix)
s.address = ""
s.listener.Close() s.listener.Close()
for _, conn := range s.conns { for _, conn := range s.conns {
conn.Close() conn.Close()
@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) {
} }
} }
func (s *testServer) address() string {
return s.listener.Addr().String()
}
func ping(conn net.Conn) (string, error) { func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n") fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n') result, err := bufio.NewReader(conn).ReadString('\n')
@ -82,221 +90,340 @@ func ping(conn net.Conn) (string, error) {
return strings.TrimSpace(result), nil return strings.TrimSpace(result), nil
} }
// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint) var _ = Describe("LoadBalancer", func() {
// and then adds a new server (a node). The node server is then closed, and it is confirmed // creates a LB using a default server (ie fixed registration endpoint)
// that new connections use the default server. // and then adds a new server (a node). The node server is then closed, and it is confirmed
func Test_UnitFailOver(t *testing.T) { // that new connections use the default server.
tmpDir := t.TempDir() When("loadbalancer is running", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() var defaultServer, node1Server, node2Server *testServer
var conn1, conn2, conn3, conn4 net.Conn
var lb *LoadBalancer
var err error
defaultServer, err := createServer(ctx, "default") BeforeAll(func() {
if err != nil { tmpDir := GinkgoT().TempDir()
t.Fatalf("createServer(default) failed: %v", err)
}
node1Server, err := createServer(ctx, "node1") defaultServer, err = createServer(ctx, "default")
if err != nil { Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
t.Fatalf("createServer(node1) failed: %v", err)
}
node2Server, err := createServer(ctx, "node2") node1Server, err = createServer(ctx, "node1")
if err != nil { Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed")
t.Fatalf("createServer(node2) failed: %v", err)
}
// start the loadbalancer with the default server as the only server node2Server, err = createServer(ctx, "node2")
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false) Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed")
if err != nil {
t.Fatalf("New() failed: %v", err)
}
parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) // start the loadbalancer with the default server as the only server
if err != nil { lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false)
t.Fatalf("url.Parse failed: %v", err) Expect(err).NotTo(HaveOccurred(), "New() failed")
} })
localAddress := parsedURL.Host
// add the node as a new server address. AfterAll(func() {
lb.Update([]string{node1Server.address()}) cancel()
})
// make sure connections go to the node It("adds node1 as a server", func() {
conn1, err := net.Dial("tcp", localAddress) // add the node as a new server address.
if err != nil { lb.Update([]string{node1Server.address})
t.Fatalf("net.Dial failed: %v", err) lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
}
if result, err := ping(conn1); err != nil {
t.Fatalf("ping(conn1) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn1) result: %v", result)
}
t.Log("conn1 tested OK") By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers()))
// set failing health check for node 1 // wait for state to change
lb.SetHealthCheck(node1Server.address(), func() bool { return false }) Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
// Server connections are checked every second, now that node 1 is failed It("connects to node1", func() {
// the connections to it should be closed. // make sure connections go to the node
time.Sleep(2 * time.Second) conn1, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result")
if _, err := ping(conn1); err == nil { By("conn1 tested OK")
t.Fatal("Unexpected successful ping on closed connection conn1") })
}
t.Log("conn1 closed on failure OK") It("changes node1 state to failed", func() {
// set failing health check for node 1
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed })
// make sure connection still goes to the first node - it is failing health checks but so // wait for state to change
// is the default endpoint, so it should be tried first with health checks disabled, Eventually(func() state {
// before failing back to the default. if s := lb.servers.getServer(node1Server.address); s != nil {
conn2, err := net.Dial("tcp", localAddress) return s.state
if err != nil { }
t.Fatalf("net.Dial failed: %v", err) return stateInvalid
}, 5, 1).Should(Equal(stateFailed))
})
} It("disconnects from node1", func() {
if result, err := ping(conn2); err != nil { // Server connections are checked every second, now that node 1 is failed
t.Fatalf("ping(conn2) failed: %v", err) // the connections to it should be closed.
} else if result != "node1:ping" { Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
t.Log("conn2 tested OK") By("conn1 closed on failure OK")
// make sure the health checks don't close the connection we just made - // connections shoould go to the default now that node 1 is failed
// connections should only be closed when it transitions from health to unhealthy. conn2, err = net.Dial("tcp", lb.localAddress)
time.Sleep(2 * time.Second) Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
if result, err := ping(conn2); err != nil { By("conn2 tested OK")
t.Fatalf("ping(conn2) failed: %v", err) })
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
t.Log("conn2 tested OK again") It("does not close connections unexpectedly", func() {
// make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second)
// shut down the first node server to force failover to the default Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
node1Server.close()
// make sure new connections go to the default, and existing connections are closed By("conn2 tested OK again")
conn3, err := net.Dial("tcp", localAddress) })
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
} It("closes connections when dial fails", func() {
if result, err := ping(conn3); err != nil { // shut down the first node server to force failover to the default
t.Fatalf("ping(conn3) failed: %v", err) node1Server.close()
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn3) result: %v", result)
}
t.Log("conn3 tested OK") // make sure new connections go to the default, and existing connections are closed
conn3, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
if _, err := ping(conn2); err == nil { Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result")
t.Fatal("Unexpected successful ping on closed connection conn2")
}
t.Log("conn2 closed on failure OK") By("conn3 tested OK")
})
// add the second node as a new server address. It("replaces node2 as a server", func() {
lb.Update([]string{node2Server.address()}) // add the second node as a new server address.
lb.Update([]string{node2Server.address})
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
// make sure connection now goes to the second node, By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers()))
// and connections to the default are closed.
conn4, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
} // wait for state to change
if result, err := ping(conn4); err != nil { Eventually(func() state {
t.Fatalf("ping(conn4) failed: %v", err) if s := lb.servers.getServer(node2Server.address); s != nil {
} else if result != "node2:ping" { return s.state
t.Fatalf("Unexpected ping(conn4) result: %v", result) }
} return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
t.Log("conn4 tested OK") It("connects to node2", func() {
// make sure connection now goes to the second node,
// and connections to the default are closed.
conn4, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
// Server connections are checked every second, now that we have a healthy Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result")
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)
if _, err := ping(conn3); err == nil { By("conn4 tested OK")
t.Fatal("Unexpected successful ping on connection conn3") })
}
t.Log("conn3 closed on failure OK") It("does not close connections unexpectedly", func() {
} // Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed
time.Sleep(2 * time.Second)
// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
func Test_UnitFailFast(t *testing.T) {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverURL := "http://127.0.0.1:0/" By("conn2 closed on failure OK")
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
if err != nil {
t.Fatalf("New() failed: %v", err)
}
conn, err := net.Dial("tcp", lb.localAddress) Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
done := make(chan error) By("conn3 closed on failure OK")
go func() { })
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
select { It("adds default as a server", func() {
case err := <-done: // add the default as a full server
if err == nil { lb.Update([]string{node2Server.address, defaultServer.address})
t.Fatal("Unexpected successful ping from invalid address") lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
}
case <-timeout:
t.Fatal("Test timed out")
}
}
// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail // wait for state to change
// within the expected duration Eventually(func() state {
func Test_UnitFailUnreachable(t *testing.T) { if s := lb.servers.getServer(defaultServer.address); s != nil {
if testing.Short() { return s.state
t.Skip("skipping slow test in short mode.") }
} return stateInvalid
tmpDir := t.TempDir() }, 5, 1).Should(Equal(statePreferred))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverAddr := "192.0.2.1:6443" By(fmt.Sprintf("Default server added: %v", lb.servers.getServers()))
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false) })
if err != nil {
t.Fatalf("New() failed: %v", err)
}
// Set failing health check to reduce retries It("returns the default server in the address list", func() {
lb.SetHealthCheck(serverAddr, func() bool { return false }) // confirm that both servers are listed in the address list
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address))
conn, err := net.Dial("tcp", lb.localAddress) // confirm that the default is still listed as default
if err != nil { Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
t.Fatalf("net.Dial failed: %v", err)
}
done := make(chan error) })
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)
select { It("does not return the default server in the address list after removing it", func() {
case err := <-done: // remove the default as a server
if err == nil { lb.Update([]string{node2Server.address})
t.Fatal("Unexpected successful ping from unreachable address") By(fmt.Sprintf("Default removed: %v", lb.servers.getServers()))
}
case <-timeout: // confirm that it is not listed as a server
t.Fatal("Test timed out") Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
}
} // but is still listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
It("removes default server when no longer default", func() {
// set node2 as the default
lb.SetDefault(node2Server.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default")
})
It("sets all three servers", func() {
// set node2 as the default
lb.SetDefault(defaultServer.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
By(fmt.Sprintf("All servers set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
})
// confirms that the loadbalancer will not dial itself
When("the default server is the loadbalancer", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var defaultServer *testServer
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
defaultServer, err = createServer(ctx, "default")
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
address := defaultServer.address
defaultServer.close()
_, port, _ := net.SplitHostPort(address)
intPort, _ := strconv.Atoi(port)
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails immediately", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
_, err = ping(conn)
Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection")
})
})
// confirms that connnections to invalid addresses fail quickly
When("there are no valid addresses", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails fast", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
select {
case err := <-done:
if err == nil {
Fail("Unexpected successful ping from invalid address")
}
case <-timeout:
Fail("Test timed out")
}
})
})
// confirms that connnections to unreachable addresses do fail within the
// expected duration
When("the server is unreachable", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails with the correct timeout", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(11 * time.Second)
select {
case err := <-done:
if err == nil {
Fail("Unexpected successful ping from unreachable address")
}
case <-timeout:
Fail("Test timed out")
}
})
})
})

View File

@ -1,118 +1,421 @@
package loadbalancer package loadbalancer
import ( import (
"cmp"
"context" "context"
"math/rand" "errors"
"fmt"
"net" "net"
"slices" "slices"
"sync"
"time" "time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
) )
func (lb *LoadBalancer) setServers(serverAddresses []string) bool { type HealthCheckFunc func() HealthCheckResult
serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
if len(serverAddresses) == 0 {
return false
}
lb.mutex.Lock() // HealthCheckResult indicates the status of a server health check poll.
defer lb.mutex.Unlock() // For health-checks that poll in the background, Unknown should be returned
// if a poll has not occurred since the last check.
type HealthCheckResult int
newAddresses := sets.NewString(serverAddresses...) const (
curAddresses := sets.NewString(lb.serverAddresses...) HealthCheckResultUnknown HealthCheckResult = iota
HealthCheckResultFailed
HealthCheckResultOK
)
// serverList tracks potential backend servers for use by a loadbalancer.
type serverList struct {
// This mutex protects access to the server list. All direct access to the list should be protected by it.
mutex sync.Mutex
servers []*server
}
// setServers updates the server list to contain only the selected addresses.
func (sl *serverList) setAddresses(serviceName string, addresses []string) bool {
newAddresses := sets.New(addresses...)
curAddresses := sets.New(sl.getAddresses()...)
if newAddresses.Equal(curAddresses) { if newAddresses.Equal(curAddresses) {
return false return false
} }
for addedServer := range newAddresses.Difference(curAddresses) { sl.mutex.Lock()
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) defer sl.mutex.Unlock()
lb.servers[addedServer] = &server{
address: addedServer, var closeAllFuncs []func()
connections: make(map[net.Conn]struct{}), var defaultServer *server
healthCheck: func() bool { return true }, if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
defaultServer = sl.servers[i]
}
// add new servers
for addedAddress := range newAddresses.Difference(curAddresses) {
if defaultServer != nil && defaultServer.address == addedAddress {
// make default server go through the same health check promotions as a new server when added
logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName)
defaultServer.state = stateUnchecked
defaultServer.lastTransition = time.Now()
} else {
s := newServer(addedAddress, false)
logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address)
sl.servers = append(sl.servers, s)
} }
} }
for removedServer := range curAddresses.Difference(newAddresses) { // remove old servers
server := lb.servers[removedServer] for removedAddress := range curAddresses.Difference(newAddresses) {
if server != nil { if defaultServer != nil && defaultServer.address == removedAddress {
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) // demote the default server down to standby, instead of deleting it
// Defer closing connections until after the new server list has been put into place. defaultServer.state = stateStandby
// Closing open connections ensures that anything stuck retrying on a stale server is forced closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll)
// over to a valid endpoint. } else {
defer server.closeAll() sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
// Don't delete the default server from the server map, in case we need to fall back to it. if s.address == removedAddress {
if removedServer != lb.defaultServerAddress { logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address)
delete(lb.servers, removedServer) // set state to invalid to prevent server from making additional connections
} s.state = stateInvalid
closeAllFuncs = append(closeAllFuncs, s.closeAll)
return true
}
return false
})
} }
} }
lb.serverAddresses = serverAddresses slices.SortFunc(sl.servers, compareServers)
lb.randomServers = append([]string{}, lb.serverAddresses...)
rand.Shuffle(len(lb.randomServers), func(i, j int) { // Close all connections to servers that were removed
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] for _, closeAll := range closeAllFuncs {
}) closeAll()
// If the current server list does not contain the default server address,
// we want to include it in the random server list so that it can be tried if necessary.
// However, it should be treated as always failing health checks so that it is only
// used if all other endpoints are unavailable.
if !hasDefaultServer {
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
defaultServer.healthCheck = func() bool { return false }
lb.servers[lb.defaultServerAddress] = defaultServer
}
} }
lb.currentServerAddress = lb.randomServers[0]
lb.nextServerIndex = 1
return true return true
} }
// nextServer attempts to get the next server in the loadbalancer server list. // getAddresses returns the addresses of all servers.
// If another goroutine has already updated the current server address to point at // If the default server is in standby state, indicating it is only present
// a different address than just failed, nothing is changed. Otherwise, a new server address // because it is the default, it is not returned in this list.
// is stored to the currentServerAddress field, and returned for use. func (sl *serverList) getAddresses() []string {
// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. sl.mutex.Lock()
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { defer sl.mutex.Unlock()
// note: these fields are not protected by the mutex, so we clamp the index value and update
// the index/current address using local variables, to avoid time-of-check vs time-of-use
// race conditions caused by goroutine A incrementing it in between the time goroutine B
// validates its value, and uses it as a list index.
currentServerAddress := lb.currentServerAddress
nextServerIndex := lb.nextServerIndex
if len(lb.randomServers) == 0 { addresses := make([]string, 0, len(sl.servers))
return "", errors.New("No servers in load balancer proxy list") for _, s := range sl.servers {
if s.isDefault && s.state == stateStandby {
continue
}
addresses = append(addresses, s.address)
} }
if len(lb.randomServers) == 1 { return addresses
return currentServerAddress, nil
}
if failedServer != currentServerAddress {
return currentServerAddress, nil
}
if nextServerIndex >= len(lb.randomServers) {
nextServerIndex = 0
}
currentServerAddress = lb.randomServers[nextServerIndex]
nextServerIndex++
lb.currentServerAddress = currentServerAddress
lb.nextServerIndex = nextServerIndex
return currentServerAddress, nil
} }
// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map // setDefault sets the server with the provided address as the default server.
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { // The default flag is cleared on all other servers, and if the server was previously
conn, err := defaultDialer.Dial(network, address) // only kept in the list because it was the default, it is removed.
func (sl *serverList) setDefaultAddress(serviceName, address string) {
sl.mutex.Lock()
defer sl.mutex.Unlock()
// deal with existing default first
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
if s.isDefault && s.address != address {
s.isDefault = false
if s.state == stateStandby {
s.state = stateInvalid
defer s.closeAll()
return true
}
}
return false
})
// update or create server with selected address
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
sl.servers[i].isDefault = true
} else {
sl.servers = append(sl.servers, newServer(address, true))
}
logrus.Infof("Updated load balancer %s default server: %s", serviceName, address)
slices.SortFunc(sl.servers, compareServers)
}
// getDefault returns the address of the default server.
func (sl *serverList) getDefaultAddress() string {
if s := sl.getDefaultServer(); s != nil {
return s.address
}
return ""
}
// getDefault returns the default server.
func (sl *serverList) getDefaultServer() *server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
return sl.servers[i]
}
return nil
}
// getServers returns a copy of the servers list that can be safely iterated over without holding a lock
func (sl *serverList) getServers() []*server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
return slices.Clone(sl.servers)
}
// getServer returns the first server with the specified address
func (sl *serverList) getServer(address string) *server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
return sl.servers[i]
}
return nil
}
// setHealthCheck updates the health check function for a server, replacing the
// current function.
func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error {
if s := sl.getServer(address); s != nil {
s.healthCheck = healthCheck
return nil
}
return fmt.Errorf("no server found for %s", address)
}
// recordSuccess records a successful check of a server, either via health-check or dial.
// The server's state is adjusted accordingly.
func (sl *serverList) recordSuccess(srv *server, r reason) {
var new_state state
switch srv.state {
case stateFailed, stateUnchecked:
// dialed or health checked OK once, improve to recovering
new_state = stateRecovering
case stateRecovering:
if r == reasonHealthCheck {
// was recovering due to successful dial or first health check, can now improve
if len(srv.connections) > 0 {
// server accepted connections while recovering, attempt to go straight to active
new_state = stateActive
} else {
// no connections, just make it preferred
new_state = statePreferred
}
}
case stateHealthy:
if r == reasonDial {
// improve from healthy to active by being dialed
new_state = stateActive
}
case statePreferred:
if r == reasonDial {
// improve from healthy to active by being dialed
new_state = stateActive
} else {
if time.Now().Sub(srv.lastTransition) > time.Minute {
// has been preferred for a while without being dialed, demote to healthy
new_state = stateHealthy
}
}
}
// no-op if state did not change
if new_state == stateInvalid {
return
}
// handle active transition and sort the server list while holding the lock
sl.mutex.Lock()
defer sl.mutex.Unlock()
// handle states of other servers when attempting to make this one active
if new_state == stateActive {
for _, s := range sl.servers {
if srv.address == s.address {
continue
}
switch s.state {
case stateFailed, stateStandby, stateRecovering, stateHealthy:
// close connections to other non-active servers whenever we have a new active server
defer s.closeAll()
case stateActive:
if len(s.connections) > len(srv.connections) {
// if there is a currently active server that has more connections than we do,
// close our connections and go to preferred instead
new_state = statePreferred
defer srv.closeAll()
} else {
// otherwise, close its connections and demote it to preferred
s.state = statePreferred
defer s.closeAll()
}
}
}
}
// ensure some other routine didn't already make the transition
if srv.state == new_state {
return
}
logrus.Infof("Server %s->%s from successful %s", srv, new_state, r)
srv.state = new_state
srv.lastTransition = time.Now()
slices.SortFunc(sl.servers, compareServers)
}
// recordFailure records a failed check of a server, either via health-check or dial.
// The server's state is adjusted accordingly.
func (sl *serverList) recordFailure(srv *server, r reason) {
var new_state state
switch srv.state {
case stateUnchecked, stateRecovering:
if r == reasonDial {
// only demote from unchecked or recovering if a dial fails, health checks may
// continue to fail despite it being dialable. just leave it where it is
// and don't close any connections.
new_state = stateFailed
}
case stateHealthy, statePreferred, stateActive:
// should not have any connections when in any state other than active or
// recovering, but close them all anyway to force failover.
defer srv.closeAll()
new_state = stateFailed
}
// no-op if state did not change
if new_state == stateInvalid {
return
}
// sort the server list while holding the lock
sl.mutex.Lock()
defer sl.mutex.Unlock()
// ensure some other routine didn't already make the transition
if srv.state == new_state {
return
}
logrus.Infof("Server %s->%s from failed %s", srv, new_state, r)
srv.state = new_state
srv.lastTransition = time.Now()
slices.SortFunc(sl.servers, compareServers)
}
// state is possible server health states, in increasing order of preference.
// The server list is kept sorted in descending order by this state value.
type state int
const (
stateInvalid state = iota
stateFailed // failed a health check or dial
stateStandby // reserved for use by default server if not in server list
stateUnchecked // just added, has not been health checked
stateRecovering // successfully health checked once, or dialed when failed
stateHealthy // normal state
statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance
stateActive // currently active server
)
func (s state) String() string {
switch s {
case stateInvalid:
return "INVALID"
case stateFailed:
return "FAILED"
case stateStandby:
return "STANDBY"
case stateUnchecked:
return "UNCHECKED"
case stateRecovering:
return "RECOVERING"
case stateHealthy:
return "HEALTHY"
case statePreferred:
return "PREFERRED"
case stateActive:
return "ACTIVE"
default:
return "UNKNOWN"
}
}
// reason specifies the reason for a successful or failed health report
type reason int
const (
reasonDial reason = iota
reasonHealthCheck
)
func (r reason) String() string {
switch r {
case reasonDial:
return "dial"
case reasonHealthCheck:
return "health check"
default:
return "unknown reason"
}
}
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
isDefault bool
state state
lastTransition time.Time
healthCheck HealthCheckFunc
connections map[net.Conn]struct{}
}
// newServer creates a new server, with a default health check
// and default/state fields appropriate for whether or not
// the server is a full server, or just a fallback default.
func newServer(address string, isDefault bool) *server {
state := stateUnchecked
if isDefault {
state = stateStandby
}
return &server{
address: address,
isDefault: isDefault,
state: state,
lastTransition: time.Now(),
healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown },
connections: make(map[net.Conn]struct{}),
}
}
func (s *server) String() string {
format := "%s@%s"
if s.isDefault {
format += "*"
}
return fmt.Sprintf(format, s.address, s.state)
}
// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map
func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) {
if s.state == stateInvalid {
return nil, fmt.Errorf("server %s is stopping", s.address)
}
conn, err := defaultDialer.Dial(network, s.address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,7 +435,7 @@ func (s *server) closeAll() {
defer s.mutex.Unlock() defer s.mutex.Unlock()
if l := len(s.connections); l > 0 { if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s)
for conn := range s.connections { for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so. // Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close() go conn.Close()
@ -140,6 +443,12 @@ func (s *server) closeAll() {
} }
} }
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// Close removes the connection entry from the server's connection map, and // Close removes the connection entry from the server's connection map, and
// closes the wrapped connection. // closes the wrapped connection.
func (sc *serverConn) Close() error { func (sc *serverConn) Close() error {
@ -150,73 +459,43 @@ func (sc *serverConn) Close() error {
return sc.Conn.Close() return sc.Conn.Close()
} }
// SetDefault sets the selected address as the default / fallback address // runHealthChecks periodically health-checks all servers.
func (lb *LoadBalancer) SetDefault(serverAddress string) { func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one - but
// with a failing health check so that it is only used as a last resort.
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return false },
connections: make(map[net.Conn]struct{}),
}
}
lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}
// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
previousStatus := map[string]bool{}
wait.Until(func() { wait.Until(func() {
lb.mutex.RLock() for _, s := range sl.getServers() {
defer lb.mutex.RUnlock() switch s.healthCheck() {
var healthyServerExists bool case HealthCheckResultOK:
for address, server := range lb.servers { sl.recordSuccess(s, reasonHealthCheck)
status := server.healthCheck() case HealthCheckResultFailed:
healthyServerExists = healthyServerExists || status sl.recordFailure(s, reasonHealthCheck)
if status == false && previousStatus[address] == true {
// Only close connections when the server transitions from healthy to unhealthy;
// we don't want to re-close all the connections every time as we might be ignoring
// health checks due to all servers being marked unhealthy.
defer server.closeAll()
}
previousStatus[address] = status
}
// If there is at least one healthy server, and the default server is not in the server list,
// close all the connections to the default server so that clients reconnect and switch over
// to a preferred server.
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
if healthyServerExists && !hasDefaultServer {
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
defer server.closeAll()
} }
} }
}, time.Second, ctx.Done()) }, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) logrus.Debugf("Stopped health checking for load balancer %s", serviceName)
}
// dialContext attemps to dial a connection to a server from the server list.
// Success or failure is recorded to ensure that server state is updated appropriately.
func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
for _, s := range sl.getServers() {
dialTime := time.Now()
conn, err := s.dialContext(ctx, network)
if err == nil {
sl.recordSuccess(s, reasonDial)
return conn, nil
}
logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err)
sl.recordFailure(s, reasonDial)
}
return nil, errors.New("all servers failed")
}
// compareServers is a comparison function that can be used to sort the server list
// so that servers with a more preferred state, or higher number of connections, are ordered first.
func compareServers(a, b *server) int {
c := cmp.Compare(b.state, a.state)
if c == 0 {
return cmp.Compare(len(b.connections), len(a.connections))
}
return c
} }

View File

@ -22,7 +22,7 @@ type Proxy interface {
SupervisorAddresses() []string SupervisorAddresses() []string
APIServerURL() string APIServerURL() string
IsAPIServerLBEnabled() bool IsAPIServerLBEnabled() bool
SetHealthCheck(address string, healthCheck func() bool) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc)
} }
// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If // NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
return nil, err return nil, err
} }
p.supervisorLB = lb p.supervisorLB = lb
p.supervisorURL = lb.LoadBalancerServerURL() p.supervisorURL = lb.LocalURL()
p.apiServerURL = p.supervisorURL p.apiServerURL = p.supervisorURL
} }
@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) {
p.supervisorAddresses = supervisorAddresses p.supervisorAddresses = supervisorAddresses
} }
func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) { func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) {
if p.supervisorLB != nil { if p.supervisorLB != nil {
p.supervisorLB.SetHealthCheck(address, healthCheck) p.supervisorLB.SetHealthCheck(address, healthCheck)
} }
@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
return err return err
} }
p.apiServerLB = lb p.apiServerLB = lb
p.apiServerURL = lb.LoadBalancerServerURL() p.apiServerURL = lb.LocalURL()
} else { } else {
p.apiServerURL = u.String() p.apiServerURL = u.String()
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
agentconfig "github.com/k3s-io/k3s/pkg/agent/config" agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
"github.com/k3s-io/k3s/pkg/agent/loadbalancer"
"github.com/k3s-io/k3s/pkg/agent/proxy" "github.com/k3s-io/k3s/pkg/agent/proxy"
daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
"github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util"
@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
if _, ok := disconnect[address]; !ok { if _, ok := disconnect[address]; !ok {
conn := a.connect(ctx, wg, address, tlsConfig) conn := a.connect(ctx, wg, address, tlsConfig)
disconnect[address] = conn.cancel disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected) proxy.SetHealthCheck(address, conn.healthCheck)
} }
} }
@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
if _, ok := disconnect[address]; !ok { if _, ok := disconnect[address]; !ok {
conn := a.connect(ctx, nil, address, tlsConfig) conn := a.connect(ctx, nil, address, tlsConfig)
disconnect[address] = conn.cancel disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected) proxy.SetHealthCheck(address, conn.healthCheck)
} }
} }
@ -427,22 +428,20 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
} }
type agentConnection struct { type agentConnection struct {
cancel context.CancelFunc cancel context.CancelFunc
connected func() bool healthCheck loadbalancer.HealthCheckFunc
} }
// connect initiates a connection to the remotedialer server. Incoming dial requests from // connect initiates a connection to the remotedialer server. Incoming dial requests from
// the server will be checked by the authorizer function prior to being fulfilled. // the server will be checked by the authorizer function prior to being fulfilled.
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection { func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
var status loadbalancer.HealthCheckResult
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address) wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{ ws := &websocket.Dialer{
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
} }
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true
once := sync.Once{} once := sync.Once{}
if waitGroup != nil { if waitGroup != nil {
waitGroup.Add(1) waitGroup.Add(1)
@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
} }
onConnect := func(_ context.Context, _ *remotedialer.Session) error { onConnect := func(_ context.Context, _ *remotedialer.Session) error {
connected = true status = loadbalancer.HealthCheckResultOK
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy") logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
if waitGroup != nil { if waitGroup != nil {
once.Do(waitGroup.Done) once.Do(waitGroup.Done)
@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
for { for {
// ConnectToProxy blocks until error or context cancellation // ConnectToProxy blocks until error or context cancellation
err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect) err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect)
connected = false status = loadbalancer.HealthCheckResultFailed
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...") logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...")
// wait between reconnection attempts to avoid hammering the server // wait between reconnection attempts to avoid hammering the server
@ -484,8 +483,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
}() }()
return agentConnection{ return agentConnection{
cancel: cancel, cancel: cancel,
connected: func() bool { return connected }, healthCheck: func() loadbalancer.HealthCheckResult {
return status
},
} }
} }

View File

@ -48,6 +48,10 @@ type etcdproxy struct {
} }
func (e *etcdproxy) Update(addresses []string) { func (e *etcdproxy) Update(addresses []string) {
if e.etcdLB == nil {
return
}
e.etcdLB.Update(addresses) e.etcdLB.Update(addresses)
validEndpoint := map[string]bool{} validEndpoint := map[string]bool{}
@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) {
// start a polling routine that makes periodic requests to the etcd node's supervisor port. // start a polling routine that makes periodic requests to the etcd node's supervisor port.
// If the request fails, the node is marked unhealthy. // If the request fails, the node is marked unhealthy.
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc {
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. var status loadbalancer.HealthCheckResult
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true
host, _, _ := net.SplitHostPort(address) host, _, _ := net.SplitHostPort(address)
url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort))) url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort)))
@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func()
} }
if err != nil || statusCode != http.StatusOK { if err != nil || statusCode != http.StatusOK {
logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode) logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode)
connected = false status = loadbalancer.HealthCheckResultFailed
} else { } else {
connected = true status = loadbalancer.HealthCheckResultOK
} }
}, 5*time.Second, 1.0, true) }, 5*time.Second, 1.0, true)
return func() bool { return func() loadbalancer.HealthCheckResult {
return connected // Reset the status to unknown on reading, until next time it is checked.
// This avoids having a health check result alter the server state between active checks.
s := status
status = loadbalancer.HealthCheckResultUnknown
return s
} }
} }