Refactor load balancer server list and health checking

Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
pull/11430/head
Brad Davidson 2024-11-15 22:11:47 +00:00 committed by Brad Davidson
parent 95797c4a79
commit 911ee19a93
9 changed files with 837 additions and 489 deletions

View File

@ -15,8 +15,8 @@ type lbConfig struct {
func (lb *LoadBalancer) writeConfig() error { func (lb *LoadBalancer) writeConfig() error {
config := &lbConfig{ config := &lbConfig{
ServerURL: lb.serverURL, ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
ServerAddresses: lb.serverAddresses, ServerAddresses: lb.servers.getAddresses(),
} }
configOut, err := json.MarshalIndent(config, "", " ") configOut, err := json.MarshalIndent(config, "", " ")
if err != nil { if err != nil {
@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error {
} }
func (lb *LoadBalancer) updateConfig() error { func (lb *LoadBalancer) updateConfig() error {
writeConfig := true
if configBytes, err := os.ReadFile(lb.configFile); err == nil { if configBytes, err := os.ReadFile(lb.configFile); err == nil {
config := &lbConfig{} config := &lbConfig{}
if err := json.Unmarshal(configBytes, config); err == nil { if err := json.Unmarshal(configBytes, config); err == nil {
if config.ServerURL == lb.serverURL { // if the default server from the config matches our current default,
writeConfig = false // load the rest of the addresses as well.
lb.setServers(config.ServerAddresses) if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
} lb.Update(config.ServerAddresses)
}
}
if writeConfig {
if err := lb.writeConfig(); err != nil {
return err
}
}
return nil return nil
} }
}
}
// config didn't exist or used a different default server, write the current config to disk.
return lb.writeConfig()
}

View File

@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error {
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Create a new HTTP proxy dialer // Create a new HTTP proxy dialer
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer)))
return httpProxyDialer, nil return httpProxyDialer, nil
} else if proxyURL.Scheme == "socks5" { } else if proxyURL.Scheme == "socks5" {
// For SOCKS5 proxies, use the proxy package's FromURL // For SOCKS5 proxies, use the proxy package's FromURL

View File

@ -2,15 +2,16 @@ package loadbalancer
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"strings" "strings"
"testing" "testing"
"github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
) )
var originalDialer proxy.Dialer
var defaultEnv map[string]string var defaultEnv map[string]string
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"} var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}
@ -19,7 +20,7 @@ func init() {
} }
func prepareEnv(env ...string) { func prepareEnv(env ...string) {
defaultDialer = &net.Dialer{} originalDialer = defaultDialer
defaultEnv = map[string]string{} defaultEnv = map[string]string{}
for _, e := range proxyEnvs { for _, e := range proxyEnvs {
if v, ok := os.LookupEnv(e); ok { if v, ok := os.LookupEnv(e); ok {
@ -34,6 +35,7 @@ func prepareEnv(env ...string) {
} }
func restoreEnv() { func restoreEnv() {
defaultDialer = originalDialer
for _, e := range proxyEnvs { for _, e := range proxyEnvs {
if v, ok := defaultEnv[e]; ok { if v, ok := defaultEnv[e]; ok {
os.Setenv(e, v) os.Setenv(e, v)

View File

@ -2,55 +2,29 @@ package loadbalancer
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"sync" "strings"
"time"
"github.com/inetaf/tcpproxy" "github.com/inetaf/tcpproxy"
"github.com/k3s-io/k3s/pkg/version" "github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// LoadBalancer holds data for a local listener which forwards connections to a // LoadBalancer holds data for a local listener which forwards connections to a
// pool of remote servers. It is not a proper load-balancer in that it does not // pool of remote servers. It is not a proper load-balancer in that it does not
// actually balance connections, but instead fails over to a new server only // actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails. // when a connection attempt to the currently selected server fails.
type LoadBalancer struct { type LoadBalancer struct {
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy
serviceName string serviceName string
configFile string configFile string
scheme string
localAddress string localAddress string
localServerURL string servers serverList
defaultServerAddress string proxy *tcpproxy.Proxy
serverURL string
serverAddresses []string
randomServers []string
servers map[string]*server
currentServerAddress string
nextServerIndex int
} }
const RandomPort = 0 const RandomPort = 0
@ -63,7 +37,7 @@ var (
// New contstructs a new LoadBalancer instance. The default server URL, and // New contstructs a new LoadBalancer instance. The default server URL, and
// currently active servers, are stored in a file within the dataDir. // currently active servers, are stored in a file within the dataDir.
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
config := net.ListenConfig{Control: reusePort} config := net.ListenConfig{Control: reusePort}
var localAddress string var localAddress string
if isIPv6 { if isIPv6 {
@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
return nil, err return nil, err
} }
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. serverURL, err := url.Parse(defaultServerURL)
localAddress = listener.Addr().String()
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if serverURL == localServerURL { // Set explicit port from scheme
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) if serverURL.Port() == "" {
defaultServerAddress = "" if strings.ToLower(serverURL.Scheme) == "http" {
serverURL.Host += ":80"
}
if strings.ToLower(serverURL.Scheme) == "https" {
serverURL.Host += ":443"
}
} }
lb := &LoadBalancer{ lb := &LoadBalancer{
serviceName: serviceName, serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"), configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
localAddress: localAddress, scheme: serverURL.Scheme,
localServerURL: localServerURL, localAddress: listener.Addr().String(),
defaultServerAddress: defaultServerAddress,
servers: make(map[string]*server),
serverURL: serverURL,
} }
lb.setServers([]string{lb.defaultServerAddress}) // if starting pointing at ourselves, don't set a default server address,
// which will cause all dials to fail until servers are added.
if serverURL.Host == lb.localAddress {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
} else {
lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host)
}
lb.proxy = &tcpproxy.Proxy{ lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) { ListenFunc: func(string, string) (net.Listener, error) {
@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
} }
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
Addr: serviceName, Addr: serviceName,
DialContext: lb.dialContext, DialContext: lb.servers.dialContext,
OnDialError: onDialError, OnDialError: onDialError,
}) })
@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
if err := lb.proxy.Start(); err != nil { if err := lb.proxy.Start(); err != nil {
return nil, err return nil, err
} }
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
go lb.runHealthChecks(ctx) go lb.servers.runHealthChecks(ctx, lb.serviceName)
return lb, nil return lb, nil
} }
// Update updates the list of server addresses to contain only the listed servers.
func (lb *LoadBalancer) Update(serverAddresses []string) { func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil { if !lb.servers.setAddresses(lb.serviceName, serverAddresses) {
return return
} }
if !lb.setServers(serverAddresses) {
return logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
}
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)
if err := lb.writeConfig(); err != nil { if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
} }
} }
func (lb *LoadBalancer) LoadBalancerServerURL() string { // SetDefault sets the selected address as the default / fallback address
if lb == nil { func (lb *LoadBalancer) SetDefault(serverAddress string) {
return "" lb.servers.setDefaultAddress(lb.serviceName, serverAddress)
if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
} }
return lb.localServerURL }
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) {
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
} else {
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
}
}
func (lb *LoadBalancer) LocalURL() string {
return lb.scheme + "://" + lb.localAddress
} }
func (lb *LoadBalancer) ServerAddresses() []string { func (lb *LoadBalancer) ServerAddresses() []string {
if lb == nil { return lb.servers.getAddresses()
return nil
}
return lb.serverAddresses
}
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
var allChecksFailed bool
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress
server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else if allChecksFailed || server.healthCheck() {
dialTime := time.Now()
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
// Don't close connections to the failed server if we're retrying with health checks ignored.
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}
newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
if allChecksFailed {
return nil, errors.New("all servers failed")
}
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
allChecksFailed = true
}
}
} }
func onDialError(src net.Conn, dstDialErr error) { func onDialError(src net.Conn, dstDialErr error) {
@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) {
} }
// ResetLoadBalancer will delete the local state file for the load balancer on disk // ResetLoadBalancer will delete the local state file for the load balancer on disk
func ResetLoadBalancer(dataDir, serviceName string) error { func ResetLoadBalancer(dataDir, serviceName string) {
stateFile := filepath.Join(dataDir, "etc", serviceName+".json") stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
if err := os.Remove(stateFile); err != nil { if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
logrus.Warn(err) logrus.Warn(err)
} }
return nil
} }

View File

@ -5,19 +5,29 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/url" "strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func Test_UnitLoadBalancer(t *testing.T) {
_, reporterConfig := GinkgoConfiguration()
reporterConfig.Verbose = testing.Verbose()
RegisterFailHandler(Fail)
RunSpecs(t, "LoadBalancer Suite", reporterConfig)
}
func init() { func init() {
logrus.SetLevel(logrus.DebugLevel) logrus.SetLevel(logrus.DebugLevel)
} }
type testServer struct { type testServer struct {
address string
listener net.Listener listener net.Listener
conns []net.Conn conns []net.Conn
prefix string prefix string
@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) {
s := &testServer{ s := &testServer{
prefix: prefix, prefix: prefix,
listener: listener, listener: listener,
address: listener.Addr().String(),
} }
go s.serve() go s.serve()
go func() { go func() {
@ -53,6 +64,7 @@ func (s *testServer) serve() {
func (s *testServer) close() { func (s *testServer) close() {
logrus.Printf("testServer %s closing", s.prefix) logrus.Printf("testServer %s closing", s.prefix)
s.address = ""
s.listener.Close() s.listener.Close()
for _, conn := range s.conns { for _, conn := range s.conns {
conn.Close() conn.Close()
@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) {
} }
} }
func (s *testServer) address() string {
return s.listener.Addr().String()
}
func ping(conn net.Conn) (string, error) { func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n") fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n') result, err := bufio.NewReader(conn).ReadString('\n')
@ -82,166 +90,285 @@ func ping(conn net.Conn) (string, error) {
return strings.TrimSpace(result), nil return strings.TrimSpace(result), nil
} }
// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint) var _ = Describe("LoadBalancer", func() {
// creates a LB using a default server (ie fixed registration endpoint)
// and then adds a new server (a node). The node server is then closed, and it is confirmed // and then adds a new server (a node). The node server is then closed, and it is confirmed
// that new connections use the default server. // that new connections use the default server.
func Test_UnitFailOver(t *testing.T) { When("loadbalancer is running", Ordered, func() {
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() var defaultServer, node1Server, node2Server *testServer
var conn1, conn2, conn3, conn4 net.Conn
var lb *LoadBalancer
var err error
defaultServer, err := createServer(ctx, "default") BeforeAll(func() {
if err != nil { tmpDir := GinkgoT().TempDir()
t.Fatalf("createServer(default) failed: %v", err)
}
node1Server, err := createServer(ctx, "node1") defaultServer, err = createServer(ctx, "default")
if err != nil { Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
t.Fatalf("createServer(node1) failed: %v", err)
}
node2Server, err := createServer(ctx, "node2") node1Server, err = createServer(ctx, "node1")
if err != nil { Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed")
t.Fatalf("createServer(node2) failed: %v", err)
} node2Server, err = createServer(ctx, "node2")
Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed")
// start the loadbalancer with the default server as the only server // start the loadbalancer with the default server as the only server
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false) lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false)
if err != nil { Expect(err).NotTo(HaveOccurred(), "New() failed")
t.Fatalf("New() failed: %v", err) })
}
parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) AfterAll(func() {
if err != nil { cancel()
t.Fatalf("url.Parse failed: %v", err) })
}
localAddress := parsedURL.Host
It("adds node1 as a server", func() {
// add the node as a new server address. // add the node as a new server address.
lb.Update([]string{node1Server.address()}) lb.Update([]string{node1Server.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers()))
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
It("connects to node1", func() {
// make sure connections go to the node // make sure connections go to the node
conn1, err := net.Dial("tcp", localAddress) conn1, err = net.Dial("tcp", lb.localAddress)
if err != nil { Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
t.Fatalf("net.Dial failed: %v", err) Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result")
}
if result, err := ping(conn1); err != nil {
t.Fatalf("ping(conn1) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn1) result: %v", result)
}
t.Log("conn1 tested OK") By("conn1 tested OK")
})
It("changes node1 state to failed", func() {
// set failing health check for node 1 // set failing health check for node 1
lb.SetHealthCheck(node1Server.address(), func() bool { return false }) lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node1Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(stateFailed))
})
It("disconnects from node1", func() {
// Server connections are checked every second, now that node 1 is failed // Server connections are checked every second, now that node 1 is failed
// the connections to it should be closed. // the connections to it should be closed.
time.Sleep(2 * time.Second) Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
if _, err := ping(conn1); err == nil { By("conn1 closed on failure OK")
t.Fatal("Unexpected successful ping on closed connection conn1")
}
t.Log("conn1 closed on failure OK") // connections shoould go to the default now that node 1 is failed
conn2, err = net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
// make sure connection still goes to the first node - it is failing health checks but so By("conn2 tested OK")
// is the default endpoint, so it should be tried first with health checks disabled, })
// before failing back to the default.
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)
}
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
t.Log("conn2 tested OK")
It("does not close connections unexpectedly", func() {
// make sure the health checks don't close the connection we just made - // make sure the health checks don't close the connection we just made -
// connections should only be closed when it transitions from health to unhealthy. // connections should only be closed when it transitions from health to unhealthy.
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if result, err := ping(conn2); err != nil { Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}
t.Log("conn2 tested OK again") By("conn2 tested OK again")
})
It("closes connections when dial fails", func() {
// shut down the first node server to force failover to the default // shut down the first node server to force failover to the default
node1Server.close() node1Server.close()
// make sure new connections go to the default, and existing connections are closed // make sure new connections go to the default, and existing connections are closed
conn3, err := net.Dial("tcp", localAddress) conn3, err = net.Dial("tcp", lb.localAddress)
if err != nil { Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
t.Fatalf("net.Dial failed: %v", err)
} Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result")
if result, err := ping(conn3); err != nil {
t.Fatalf("ping(conn3) failed: %v", err)
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn3) result: %v", result)
}
t.Log("conn3 tested OK") By("conn3 tested OK")
})
if _, err := ping(conn2); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn2")
}
t.Log("conn2 closed on failure OK")
It("replaces node2 as a server", func() {
// add the second node as a new server address. // add the second node as a new server address.
lb.Update([]string{node2Server.address()}) lb.Update([]string{node2Server.address})
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers()))
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(node2Server.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
})
It("connects to node2", func() {
// make sure connection now goes to the second node, // make sure connection now goes to the second node,
// and connections to the default are closed. // and connections to the default are closed.
conn4, err := net.Dial("tcp", localAddress) conn4, err = net.Dial("tcp", lb.localAddress)
if err != nil { Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
t.Fatalf("net.Dial failed: %v", err)
} Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result")
if result, err := ping(conn4); err != nil {
t.Fatalf("ping(conn4) failed: %v", err)
} else if result != "node2:ping" {
t.Fatalf("Unexpected ping(conn4) result: %v", result)
}
t.Log("conn4 tested OK") By("conn4 tested OK")
})
It("does not close connections unexpectedly", func() {
// Server connections are checked every second, now that we have a healthy // Server connections are checked every second, now that we have a healthy
// server, connections to the default server should be closed // server, connections to the default server should be closed
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
if _, err := ping(conn3); err == nil { Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
t.Fatal("Unexpected successful ping on connection conn3")
}
t.Log("conn3 closed on failure OK") By("conn2 closed on failure OK")
}
// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
func Test_UnitFailFast(t *testing.T) {
tmpDir := t.TempDir() By("conn3 closed on failure OK")
})
It("adds default as a server", func() {
// add the default as a full server
lb.Update([]string{node2Server.address, defaultServer.address})
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
By(fmt.Sprintf("Default server added: %v", lb.servers.getServers()))
})
It("returns the default server in the address list", func() {
// confirm that both servers are listed in the address list
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address))
// confirm that the default is still listed as default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
It("does not return the default server in the address list after removing it", func() {
// remove the default as a server
lb.Update([]string{node2Server.address})
By(fmt.Sprintf("Default removed: %v", lb.servers.getServers()))
// confirm that it is not listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// but is still listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
It("removes default server when no longer default", func() {
// set node2 as the default
lb.SetDefault(node2Server.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default")
})
It("sets all three servers", func() {
// set node2 as the default
lb.SetDefault(defaultServer.address)
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address})
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
// wait for state to change
Eventually(func() state {
if s := lb.servers.getServer(defaultServer.address); s != nil {
return s.state
}
return stateInvalid
}, 5, 1).Should(Equal(statePreferred))
By(fmt.Sprintf("All servers set: %v", lb.servers.getServers()))
// confirm that it is still listed as a server
Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address))
// and is listed as the default
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
})
})
// confirms that the loadbalancer will not dial itself
When("the default server is the loadbalancer", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() var defaultServer *testServer
var lb *LoadBalancer
var err error
serverURL := "http://127.0.0.1:0/" BeforeAll(func() {
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false) tmpDir := GinkgoT().TempDir()
if err != nil {
t.Fatalf("New() failed: %v", err)
}
defaultServer, err = createServer(ctx, "default")
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
address := defaultServer.address
defaultServer.close()
_, port, _ := net.SplitHostPort(address)
intPort, _ := strconv.Atoi(port)
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails immediately", func() {
conn, err := net.Dial("tcp", lb.localAddress) conn, err := net.Dial("tcp", lb.localAddress)
if err != nil { Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
t.Fatalf("net.Dial failed: %v", err)
} _, err = ping(conn)
Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection")
})
})
// confirms that connnections to invalid addresses fail quickly
When("there are no valid addresses", Ordered, func() {
ctx, cancel := context.WithCancel(context.Background())
var lb *LoadBalancer
var err error
BeforeAll(func() {
tmpDir := GinkgoT().TempDir()
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false)
Expect(err).NotTo(HaveOccurred(), "New() failed")
})
AfterAll(func() {
cancel()
})
It("fails fast", func() {
conn, err := net.Dial("tcp", lb.localAddress)
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
done := make(chan error) done := make(chan error)
go func() { go func() {
@ -253,36 +380,34 @@ func Test_UnitFailFast(t *testing.T) {
select { select {
case err := <-done: case err := <-done:
if err == nil { if err == nil {
t.Fatal("Unexpected successful ping from invalid address") Fail("Unexpected successful ping from invalid address")
} }
case <-timeout: case <-timeout:
t.Fatal("Test timed out") Fail("Test timed out")
}
} }
})
})
// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail // confirms that connnections to unreachable addresses do fail within the
// within the expected duration // expected duration
func Test_UnitFailUnreachable(t *testing.T) { When("the server is unreachable", Ordered, func() {
if testing.Short() {
t.Skip("skipping slow test in short mode.")
}
tmpDir := t.TempDir()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() var lb *LoadBalancer
var err error
serverAddr := "192.0.2.1:6443" BeforeAll(func() {
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false) tmpDir := GinkgoT().TempDir()
if err != nil { lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false)
t.Fatalf("New() failed: %v", err) Expect(err).NotTo(HaveOccurred(), "New() failed")
} })
// Set failing health check to reduce retries AfterAll(func() {
lb.SetHealthCheck(serverAddr, func() bool { return false }) cancel()
})
It("fails with the correct timeout", func() {
conn, err := net.Dial("tcp", lb.localAddress) conn, err := net.Dial("tcp", lb.localAddress)
if err != nil { Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
t.Fatalf("net.Dial failed: %v", err)
}
done := make(chan error) done := make(chan error)
go func() { go func() {
@ -294,9 +419,11 @@ func Test_UnitFailUnreachable(t *testing.T) {
select { select {
case err := <-done: case err := <-done:
if err == nil { if err == nil {
t.Fatal("Unexpected successful ping from unreachable address") Fail("Unexpected successful ping from unreachable address")
} }
case <-timeout: case <-timeout:
t.Fatal("Test timed out") Fail("Test timed out")
}
} }
})
})
})

View File

@ -1,118 +1,421 @@
package loadbalancer package loadbalancer
import ( import (
"cmp"
"context" "context"
"math/rand" "errors"
"fmt"
"net" "net"
"slices" "slices"
"sync"
"time" "time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/util/wait"
) )
func (lb *LoadBalancer) setServers(serverAddresses []string) bool { type HealthCheckFunc func() HealthCheckResult
serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
if len(serverAddresses) == 0 { // HealthCheckResult indicates the status of a server health check poll.
return false // For health-checks that poll in the background, Unknown should be returned
// if a poll has not occurred since the last check.
type HealthCheckResult int
const (
HealthCheckResultUnknown HealthCheckResult = iota
HealthCheckResultFailed
HealthCheckResultOK
)
// serverList tracks potential backend servers for use by a loadbalancer.
type serverList struct {
// This mutex protects access to the server list. All direct access to the list should be protected by it.
mutex sync.Mutex
servers []*server
} }
lb.mutex.Lock() // setServers updates the server list to contain only the selected addresses.
defer lb.mutex.Unlock() func (sl *serverList) setAddresses(serviceName string, addresses []string) bool {
newAddresses := sets.New(addresses...)
newAddresses := sets.NewString(serverAddresses...) curAddresses := sets.New(sl.getAddresses()...)
curAddresses := sets.NewString(lb.serverAddresses...)
if newAddresses.Equal(curAddresses) { if newAddresses.Equal(curAddresses) {
return false return false
} }
for addedServer := range newAddresses.Difference(curAddresses) { sl.mutex.Lock()
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) defer sl.mutex.Unlock()
lb.servers[addedServer] = &server{
address: addedServer, var closeAllFuncs []func()
connections: make(map[net.Conn]struct{}), var defaultServer *server
healthCheck: func() bool { return true }, if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
defaultServer = sl.servers[i]
}
// add new servers
for addedAddress := range newAddresses.Difference(curAddresses) {
if defaultServer != nil && defaultServer.address == addedAddress {
// make default server go through the same health check promotions as a new server when added
logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName)
defaultServer.state = stateUnchecked
defaultServer.lastTransition = time.Now()
} else {
s := newServer(addedAddress, false)
logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address)
sl.servers = append(sl.servers, s)
} }
} }
for removedServer := range curAddresses.Difference(newAddresses) { // remove old servers
server := lb.servers[removedServer] for removedAddress := range curAddresses.Difference(newAddresses) {
if server != nil { if defaultServer != nil && defaultServer.address == removedAddress {
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) // demote the default server down to standby, instead of deleting it
// Defer closing connections until after the new server list has been put into place. defaultServer.state = stateStandby
// Closing open connections ensures that anything stuck retrying on a stale server is forced closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll)
// over to a valid endpoint. } else {
defer server.closeAll() sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
// Don't delete the default server from the server map, in case we need to fall back to it. if s.address == removedAddress {
if removedServer != lb.defaultServerAddress { logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address)
delete(lb.servers, removedServer) // set state to invalid to prevent server from making additional connections
s.state = stateInvalid
closeAllFuncs = append(closeAllFuncs, s.closeAll)
return true
} }
} return false
}
lb.serverAddresses = serverAddresses
lb.randomServers = append([]string{}, lb.serverAddresses...)
rand.Shuffle(len(lb.randomServers), func(i, j int) {
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
}) })
// If the current server list does not contain the default server address,
// we want to include it in the random server list so that it can be tried if necessary.
// However, it should be treated as always failing health checks so that it is only
// used if all other endpoints are unavailable.
if !hasDefaultServer {
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
defaultServer.healthCheck = func() bool { return false }
lb.servers[lb.defaultServerAddress] = defaultServer
} }
} }
lb.currentServerAddress = lb.randomServers[0]
lb.nextServerIndex = 1 slices.SortFunc(sl.servers, compareServers)
// Close all connections to servers that were removed
for _, closeAll := range closeAllFuncs {
closeAll()
}
return true return true
} }
// nextServer attempts to get the next server in the loadbalancer server list. // getAddresses returns the addresses of all servers.
// If another goroutine has already updated the current server address to point at // If the default server is in standby state, indicating it is only present
// a different address than just failed, nothing is changed. Otherwise, a new server address // because it is the default, it is not returned in this list.
// is stored to the currentServerAddress field, and returned for use. func (sl *serverList) getAddresses() []string {
// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. sl.mutex.Lock()
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { defer sl.mutex.Unlock()
// note: these fields are not protected by the mutex, so we clamp the index value and update
// the index/current address using local variables, to avoid time-of-check vs time-of-use
// race conditions caused by goroutine A incrementing it in between the time goroutine B
// validates its value, and uses it as a list index.
currentServerAddress := lb.currentServerAddress
nextServerIndex := lb.nextServerIndex
if len(lb.randomServers) == 0 { addresses := make([]string, 0, len(sl.servers))
return "", errors.New("No servers in load balancer proxy list") for _, s := range sl.servers {
if s.isDefault && s.state == stateStandby {
continue
} }
if len(lb.randomServers) == 1 { addresses = append(addresses, s.address)
return currentServerAddress, nil
} }
if failedServer != currentServerAddress { return addresses
return currentServerAddress, nil
}
if nextServerIndex >= len(lb.randomServers) {
nextServerIndex = 0
} }
currentServerAddress = lb.randomServers[nextServerIndex] // setDefault sets the server with the provided address as the default server.
nextServerIndex++ // The default flag is cleared on all other servers, and if the server was previously
// only kept in the list because it was the default, it is removed.
func (sl *serverList) setDefaultAddress(serviceName, address string) {
sl.mutex.Lock()
defer sl.mutex.Unlock()
lb.currentServerAddress = currentServerAddress // deal with existing default first
lb.nextServerIndex = nextServerIndex sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
if s.isDefault && s.address != address {
s.isDefault = false
if s.state == stateStandby {
s.state = stateInvalid
defer s.closeAll()
return true
}
}
return false
})
return currentServerAddress, nil // update or create server with selected address
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
sl.servers[i].isDefault = true
} else {
sl.servers = append(sl.servers, newServer(address, true))
} }
// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map logrus.Infof("Updated load balancer %s default server: %s", serviceName, address)
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { slices.SortFunc(sl.servers, compareServers)
conn, err := defaultDialer.Dial(network, address) }
// getDefault returns the address of the default server.
func (sl *serverList) getDefaultAddress() string {
if s := sl.getDefaultServer(); s != nil {
return s.address
}
return ""
}
// getDefault returns the default server.
func (sl *serverList) getDefaultServer() *server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
return sl.servers[i]
}
return nil
}
// getServers returns a copy of the servers list that can be safely iterated over without holding a lock
func (sl *serverList) getServers() []*server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
return slices.Clone(sl.servers)
}
// getServer returns the first server with the specified address
func (sl *serverList) getServer(address string) *server {
sl.mutex.Lock()
defer sl.mutex.Unlock()
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
return sl.servers[i]
}
return nil
}
// setHealthCheck updates the health check function for a server, replacing the
// current function.
func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error {
if s := sl.getServer(address); s != nil {
s.healthCheck = healthCheck
return nil
}
return fmt.Errorf("no server found for %s", address)
}
// recordSuccess records a successful check of a server, either via health-check or dial.
// The server's state is adjusted accordingly.
func (sl *serverList) recordSuccess(srv *server, r reason) {
var new_state state
switch srv.state {
case stateFailed, stateUnchecked:
// dialed or health checked OK once, improve to recovering
new_state = stateRecovering
case stateRecovering:
if r == reasonHealthCheck {
// was recovering due to successful dial or first health check, can now improve
if len(srv.connections) > 0 {
// server accepted connections while recovering, attempt to go straight to active
new_state = stateActive
} else {
// no connections, just make it preferred
new_state = statePreferred
}
}
case stateHealthy:
if r == reasonDial {
// improve from healthy to active by being dialed
new_state = stateActive
}
case statePreferred:
if r == reasonDial {
// improve from healthy to active by being dialed
new_state = stateActive
} else {
if time.Now().Sub(srv.lastTransition) > time.Minute {
// has been preferred for a while without being dialed, demote to healthy
new_state = stateHealthy
}
}
}
// no-op if state did not change
if new_state == stateInvalid {
return
}
// handle active transition and sort the server list while holding the lock
sl.mutex.Lock()
defer sl.mutex.Unlock()
// handle states of other servers when attempting to make this one active
if new_state == stateActive {
for _, s := range sl.servers {
if srv.address == s.address {
continue
}
switch s.state {
case stateFailed, stateStandby, stateRecovering, stateHealthy:
// close connections to other non-active servers whenever we have a new active server
defer s.closeAll()
case stateActive:
if len(s.connections) > len(srv.connections) {
// if there is a currently active server that has more connections than we do,
// close our connections and go to preferred instead
new_state = statePreferred
defer srv.closeAll()
} else {
// otherwise, close its connections and demote it to preferred
s.state = statePreferred
defer s.closeAll()
}
}
}
}
// ensure some other routine didn't already make the transition
if srv.state == new_state {
return
}
logrus.Infof("Server %s->%s from successful %s", srv, new_state, r)
srv.state = new_state
srv.lastTransition = time.Now()
slices.SortFunc(sl.servers, compareServers)
}
// recordFailure records a failed check of a server, either via health-check or dial.
// The server's state is adjusted accordingly.
func (sl *serverList) recordFailure(srv *server, r reason) {
var new_state state
switch srv.state {
case stateUnchecked, stateRecovering:
if r == reasonDial {
// only demote from unchecked or recovering if a dial fails, health checks may
// continue to fail despite it being dialable. just leave it where it is
// and don't close any connections.
new_state = stateFailed
}
case stateHealthy, statePreferred, stateActive:
// should not have any connections when in any state other than active or
// recovering, but close them all anyway to force failover.
defer srv.closeAll()
new_state = stateFailed
}
// no-op if state did not change
if new_state == stateInvalid {
return
}
// sort the server list while holding the lock
sl.mutex.Lock()
defer sl.mutex.Unlock()
// ensure some other routine didn't already make the transition
if srv.state == new_state {
return
}
logrus.Infof("Server %s->%s from failed %s", srv, new_state, r)
srv.state = new_state
srv.lastTransition = time.Now()
slices.SortFunc(sl.servers, compareServers)
}
// state is possible server health states, in increasing order of preference.
// The server list is kept sorted in descending order by this state value.
type state int
const (
stateInvalid state = iota
stateFailed // failed a health check or dial
stateStandby // reserved for use by default server if not in server list
stateUnchecked // just added, has not been health checked
stateRecovering // successfully health checked once, or dialed when failed
stateHealthy // normal state
statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance
stateActive // currently active server
)
func (s state) String() string {
switch s {
case stateInvalid:
return "INVALID"
case stateFailed:
return "FAILED"
case stateStandby:
return "STANDBY"
case stateUnchecked:
return "UNCHECKED"
case stateRecovering:
return "RECOVERING"
case stateHealthy:
return "HEALTHY"
case statePreferred:
return "PREFERRED"
case stateActive:
return "ACTIVE"
default:
return "UNKNOWN"
}
}
// reason specifies the reason for a successful or failed health report
type reason int
const (
reasonDial reason = iota
reasonHealthCheck
)
func (r reason) String() string {
switch r {
case reasonDial:
return "dial"
case reasonHealthCheck:
return "health check"
default:
return "unknown reason"
}
}
// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
isDefault bool
state state
lastTransition time.Time
healthCheck HealthCheckFunc
connections map[net.Conn]struct{}
}
// newServer creates a new server, with a default health check
// and default/state fields appropriate for whether or not
// the server is a full server, or just a fallback default.
func newServer(address string, isDefault bool) *server {
state := stateUnchecked
if isDefault {
state = stateStandby
}
return &server{
address: address,
isDefault: isDefault,
state: state,
lastTransition: time.Now(),
healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown },
connections: make(map[net.Conn]struct{}),
}
}
func (s *server) String() string {
format := "%s@%s"
if s.isDefault {
format += "*"
}
return fmt.Sprintf(format, s.address, s.state)
}
// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map
func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) {
if s.state == stateInvalid {
return nil, fmt.Errorf("server %s is stopping", s.address)
}
conn, err := defaultDialer.Dial(network, s.address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -132,7 +435,7 @@ func (s *server) closeAll() {
defer s.mutex.Unlock() defer s.mutex.Unlock()
if l := len(s.connections); l > 0 { if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s)
for conn := range s.connections { for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so. // Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close() go conn.Close()
@ -140,6 +443,12 @@ func (s *server) closeAll() {
} }
} }
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}
// Close removes the connection entry from the server's connection map, and // Close removes the connection entry from the server's connection map, and
// closes the wrapped connection. // closes the wrapped connection.
func (sc *serverConn) Close() error { func (sc *serverConn) Close() error {
@ -150,73 +459,43 @@ func (sc *serverConn) Close() error {
return sc.Conn.Close() return sc.Conn.Close()
} }
// SetDefault sets the selected address as the default / fallback address // runHealthChecks periodically health-checks all servers.
func (lb *LoadBalancer) SetDefault(serverAddress string) { func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one - but
// with a failing health check so that it is only used as a last resort.
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return false },
connections: make(map[net.Conn]struct{}),
}
}
lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}
// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
previousStatus := map[string]bool{}
wait.Until(func() { wait.Until(func() {
lb.mutex.RLock() for _, s := range sl.getServers() {
defer lb.mutex.RUnlock() switch s.healthCheck() {
var healthyServerExists bool case HealthCheckResultOK:
for address, server := range lb.servers { sl.recordSuccess(s, reasonHealthCheck)
status := server.healthCheck() case HealthCheckResultFailed:
healthyServerExists = healthyServerExists || status sl.recordFailure(s, reasonHealthCheck)
if status == false && previousStatus[address] == true {
// Only close connections when the server transitions from healthy to unhealthy;
// we don't want to re-close all the connections every time as we might be ignoring
// health checks due to all servers being marked unhealthy.
defer server.closeAll()
}
previousStatus[address] = status
}
// If there is at least one healthy server, and the default server is not in the server list,
// close all the connections to the default server so that clients reconnect and switch over
// to a preferred server.
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
if healthyServerExists && !hasDefaultServer {
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
defer server.closeAll()
} }
} }
}, time.Second, ctx.Done()) }, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) logrus.Debugf("Stopped health checking for load balancer %s", serviceName)
}
// dialContext attemps to dial a connection to a server from the server list.
// Success or failure is recorded to ensure that server state is updated appropriately.
func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
for _, s := range sl.getServers() {
dialTime := time.Now()
conn, err := s.dialContext(ctx, network)
if err == nil {
sl.recordSuccess(s, reasonDial)
return conn, nil
}
logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err)
sl.recordFailure(s, reasonDial)
}
return nil, errors.New("all servers failed")
}
// compareServers is a comparison function that can be used to sort the server list
// so that servers with a more preferred state, or higher number of connections, are ordered first.
func compareServers(a, b *server) int {
c := cmp.Compare(b.state, a.state)
if c == 0 {
return cmp.Compare(len(b.connections), len(a.connections))
}
return c
} }

View File

@ -22,7 +22,7 @@ type Proxy interface {
SupervisorAddresses() []string SupervisorAddresses() []string
APIServerURL() string APIServerURL() string
IsAPIServerLBEnabled() bool IsAPIServerLBEnabled() bool
SetHealthCheck(address string, healthCheck func() bool) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc)
} }
// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If // NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
return nil, err return nil, err
} }
p.supervisorLB = lb p.supervisorLB = lb
p.supervisorURL = lb.LoadBalancerServerURL() p.supervisorURL = lb.LocalURL()
p.apiServerURL = p.supervisorURL p.apiServerURL = p.supervisorURL
} }
@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) {
p.supervisorAddresses = supervisorAddresses p.supervisorAddresses = supervisorAddresses
} }
func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) { func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) {
if p.supervisorLB != nil { if p.supervisorLB != nil {
p.supervisorLB.SetHealthCheck(address, healthCheck) p.supervisorLB.SetHealthCheck(address, healthCheck)
} }
@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
return err return err
} }
p.apiServerLB = lb p.apiServerLB = lb
p.apiServerURL = lb.LoadBalancerServerURL() p.apiServerURL = lb.LocalURL()
} else { } else {
p.apiServerURL = u.String() p.apiServerURL = u.String()
} }

View File

@ -14,6 +14,7 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
agentconfig "github.com/k3s-io/k3s/pkg/agent/config" agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
"github.com/k3s-io/k3s/pkg/agent/loadbalancer"
"github.com/k3s-io/k3s/pkg/agent/proxy" "github.com/k3s-io/k3s/pkg/agent/proxy"
daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
"github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/util"
@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
if _, ok := disconnect[address]; !ok { if _, ok := disconnect[address]; !ok {
conn := a.connect(ctx, wg, address, tlsConfig) conn := a.connect(ctx, wg, address, tlsConfig)
disconnect[address] = conn.cancel disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected) proxy.SetHealthCheck(address, conn.healthCheck)
} }
} }
@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
if _, ok := disconnect[address]; !ok { if _, ok := disconnect[address]; !ok {
conn := a.connect(ctx, nil, address, tlsConfig) conn := a.connect(ctx, nil, address, tlsConfig)
disconnect[address] = conn.cancel disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected) proxy.SetHealthCheck(address, conn.healthCheck)
} }
} }
@ -428,21 +429,19 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
type agentConnection struct { type agentConnection struct {
cancel context.CancelFunc cancel context.CancelFunc
connected func() bool healthCheck loadbalancer.HealthCheckFunc
} }
// connect initiates a connection to the remotedialer server. Incoming dial requests from // connect initiates a connection to the remotedialer server. Incoming dial requests from
// the server will be checked by the authorizer function prior to being fulfilled. // the server will be checked by the authorizer function prior to being fulfilled.
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection { func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
var status loadbalancer.HealthCheckResult
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address) wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{ ws := &websocket.Dialer{
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
} }
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true
once := sync.Once{} once := sync.Once{}
if waitGroup != nil { if waitGroup != nil {
waitGroup.Add(1) waitGroup.Add(1)
@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
} }
onConnect := func(_ context.Context, _ *remotedialer.Session) error { onConnect := func(_ context.Context, _ *remotedialer.Session) error {
connected = true status = loadbalancer.HealthCheckResultOK
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy") logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
if waitGroup != nil { if waitGroup != nil {
once.Do(waitGroup.Done) once.Do(waitGroup.Done)
@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
for { for {
// ConnectToProxy blocks until error or context cancellation // ConnectToProxy blocks until error or context cancellation
err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect) err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect)
connected = false status = loadbalancer.HealthCheckResultFailed
if err != nil && !errors.Is(err, context.Canceled) { if err != nil && !errors.Is(err, context.Canceled) {
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...") logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...")
// wait between reconnection attempts to avoid hammering the server // wait between reconnection attempts to avoid hammering the server
@ -485,7 +484,9 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
return agentConnection{ return agentConnection{
cancel: cancel, cancel: cancel,
connected: func() bool { return connected }, healthCheck: func() loadbalancer.HealthCheckResult {
return status
},
} }
} }

View File

@ -48,6 +48,10 @@ type etcdproxy struct {
} }
func (e *etcdproxy) Update(addresses []string) { func (e *etcdproxy) Update(addresses []string) {
if e.etcdLB == nil {
return
}
e.etcdLB.Update(addresses) e.etcdLB.Update(addresses)
validEndpoint := map[string]bool{} validEndpoint := map[string]bool{}
@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) {
// start a polling routine that makes periodic requests to the etcd node's supervisor port. // start a polling routine that makes periodic requests to the etcd node's supervisor port.
// If the request fails, the node is marked unhealthy. // If the request fails, the node is marked unhealthy.
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc {
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. var status loadbalancer.HealthCheckResult
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true
host, _, _ := net.SplitHostPort(address) host, _, _ := net.SplitHostPort(address)
url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort))) url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort)))
@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func()
} }
if err != nil || statusCode != http.StatusOK { if err != nil || statusCode != http.StatusOK {
logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode) logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode)
connected = false status = loadbalancer.HealthCheckResultFailed
} else { } else {
connected = true status = loadbalancer.HealthCheckResultOK
} }
}, 5*time.Second, 1.0, true) }, 5*time.Second, 1.0, true)
return func() bool { return func() loadbalancer.HealthCheckResult {
return connected // Reset the status to unknown on reading, until next time it is checked.
// This avoids having a health check result alter the server state between active checks.
s := status
status = loadbalancer.HealthCheckResultUnknown
return s
} }
} }