mirror of https://github.com/k3s-io/k3s
Refactor load balancer server list and health checking
Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
(cherry picked from commit 911ee19a93
)
Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
pull/11460/head
parent
867ca25412
commit
ba4237aaf7
|
@ -15,8 +15,8 @@ type lbConfig struct {
|
|||
|
||||
func (lb *LoadBalancer) writeConfig() error {
|
||||
config := &lbConfig{
|
||||
ServerURL: lb.serverURL,
|
||||
ServerAddresses: lb.serverAddresses,
|
||||
ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
|
||||
ServerAddresses: lb.servers.getAddresses(),
|
||||
}
|
||||
configOut, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
|
@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error {
|
|||
}
|
||||
|
||||
func (lb *LoadBalancer) updateConfig() error {
|
||||
writeConfig := true
|
||||
if configBytes, err := os.ReadFile(lb.configFile); err == nil {
|
||||
config := &lbConfig{}
|
||||
if err := json.Unmarshal(configBytes, config); err == nil {
|
||||
if config.ServerURL == lb.serverURL {
|
||||
writeConfig = false
|
||||
lb.setServers(config.ServerAddresses)
|
||||
// if the default server from the config matches our current default,
|
||||
// load the rest of the addresses as well.
|
||||
if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
|
||||
lb.Update(config.ServerAddresses)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if writeConfig {
|
||||
if err := lb.writeConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
// config didn't exist or used a different default server, write the current config to disk.
|
||||
return lb.writeConfig()
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error {
|
|||
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
|
||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
// Create a new HTTP proxy dialer
|
||||
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer)))
|
||||
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer)))
|
||||
return httpProxyDialer, nil
|
||||
} else if proxyURL.Scheme == "socks5" {
|
||||
// For SOCKS5 proxies, use the proxy package's FromURL
|
||||
|
|
|
@ -2,15 +2,16 @@ package loadbalancer
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/k3s-io/k3s/pkg/version"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
var originalDialer proxy.Dialer
|
||||
var defaultEnv map[string]string
|
||||
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}
|
||||
|
||||
|
@ -19,7 +20,7 @@ func init() {
|
|||
}
|
||||
|
||||
func prepareEnv(env ...string) {
|
||||
defaultDialer = &net.Dialer{}
|
||||
originalDialer = defaultDialer
|
||||
defaultEnv = map[string]string{}
|
||||
for _, e := range proxyEnvs {
|
||||
if v, ok := os.LookupEnv(e); ok {
|
||||
|
@ -34,6 +35,7 @@ func prepareEnv(env ...string) {
|
|||
}
|
||||
|
||||
func restoreEnv() {
|
||||
defaultDialer = originalDialer
|
||||
for _, e := range proxyEnvs {
|
||||
if v, ok := defaultEnv[e]; ok {
|
||||
os.Setenv(e, v)
|
||||
|
|
|
@ -2,55 +2,29 @@ package loadbalancer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
"strings"
|
||||
|
||||
"github.com/inetaf/tcpproxy"
|
||||
"github.com/k3s-io/k3s/pkg/version"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// server tracks the connections to a server, so that they can be closed when the server is removed.
|
||||
type server struct {
|
||||
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
|
||||
mutex sync.Mutex
|
||||
address string
|
||||
healthCheck func() bool
|
||||
connections map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
|
||||
type serverConn struct {
|
||||
server *server
|
||||
net.Conn
|
||||
}
|
||||
|
||||
// LoadBalancer holds data for a local listener which forwards connections to a
|
||||
// pool of remote servers. It is not a proper load-balancer in that it does not
|
||||
// actually balance connections, but instead fails over to a new server only
|
||||
// when a connection attempt to the currently selected server fails.
|
||||
type LoadBalancer struct {
|
||||
// This mutex protects access to servers map and randomServers list.
|
||||
// All direct access to the servers map/list should be protected by it.
|
||||
mutex sync.RWMutex
|
||||
proxy *tcpproxy.Proxy
|
||||
|
||||
serviceName string
|
||||
configFile string
|
||||
localAddress string
|
||||
localServerURL string
|
||||
defaultServerAddress string
|
||||
serverURL string
|
||||
serverAddresses []string
|
||||
randomServers []string
|
||||
servers map[string]*server
|
||||
currentServerAddress string
|
||||
nextServerIndex int
|
||||
serviceName string
|
||||
configFile string
|
||||
scheme string
|
||||
localAddress string
|
||||
servers serverList
|
||||
proxy *tcpproxy.Proxy
|
||||
}
|
||||
|
||||
const RandomPort = 0
|
||||
|
@ -63,7 +37,7 @@ var (
|
|||
|
||||
// New contstructs a new LoadBalancer instance. The default server URL, and
|
||||
// currently active servers, are stored in a file within the dataDir.
|
||||
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
|
||||
func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
|
||||
config := net.ListenConfig{Control: reusePort}
|
||||
var localAddress string
|
||||
if isIPv6 {
|
||||
|
@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
|
||||
localAddress = listener.Addr().String()
|
||||
|
||||
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
|
||||
serverURL, err := url.Parse(defaultServerURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serverURL == localServerURL {
|
||||
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
|
||||
defaultServerAddress = ""
|
||||
// Set explicit port from scheme
|
||||
if serverURL.Port() == "" {
|
||||
if strings.ToLower(serverURL.Scheme) == "http" {
|
||||
serverURL.Host += ":80"
|
||||
}
|
||||
if strings.ToLower(serverURL.Scheme) == "https" {
|
||||
serverURL.Host += ":443"
|
||||
}
|
||||
}
|
||||
|
||||
lb := &LoadBalancer{
|
||||
serviceName: serviceName,
|
||||
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
|
||||
localAddress: localAddress,
|
||||
localServerURL: localServerURL,
|
||||
defaultServerAddress: defaultServerAddress,
|
||||
servers: make(map[string]*server),
|
||||
serverURL: serverURL,
|
||||
serviceName: serviceName,
|
||||
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
|
||||
scheme: serverURL.Scheme,
|
||||
localAddress: listener.Addr().String(),
|
||||
}
|
||||
|
||||
lb.setServers([]string{lb.defaultServerAddress})
|
||||
// if starting pointing at ourselves, don't set a default server address,
|
||||
// which will cause all dials to fail until servers are added.
|
||||
if serverURL.Host == lb.localAddress {
|
||||
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
|
||||
} else {
|
||||
lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host)
|
||||
}
|
||||
|
||||
lb.proxy = &tcpproxy.Proxy{
|
||||
ListenFunc: func(string, string) (net.Listener, error) {
|
||||
|
@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
|||
}
|
||||
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
|
||||
Addr: serviceName,
|
||||
DialContext: lb.dialContext,
|
||||
DialContext: lb.servers.dialContext,
|
||||
OnDialError: onDialError,
|
||||
})
|
||||
|
||||
|
@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
|||
if err := lb.proxy.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress)
|
||||
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
|
||||
|
||||
go lb.runHealthChecks(ctx)
|
||||
go lb.servers.runHealthChecks(ctx, lb.serviceName)
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
// Update updates the list of server addresses to contain only the listed servers.
|
||||
func (lb *LoadBalancer) Update(serverAddresses []string) {
|
||||
if lb == nil {
|
||||
if !lb.servers.setAddresses(lb.serviceName, serverAddresses) {
|
||||
return
|
||||
}
|
||||
if !lb.setServers(serverAddresses) {
|
||||
return
|
||||
}
|
||||
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)
|
||||
|
||||
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
|
||||
|
||||
if err := lb.writeConfig(); err != nil {
|
||||
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) LoadBalancerServerURL() string {
|
||||
if lb == nil {
|
||||
return ""
|
||||
// SetDefault sets the selected address as the default / fallback address
|
||||
func (lb *LoadBalancer) SetDefault(serverAddress string) {
|
||||
lb.servers.setDefaultAddress(lb.serviceName, serverAddress)
|
||||
|
||||
if err := lb.writeConfig(); err != nil {
|
||||
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
||||
}
|
||||
return lb.localServerURL
|
||||
}
|
||||
|
||||
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
|
||||
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) {
|
||||
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
|
||||
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
|
||||
} else {
|
||||
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) LocalURL() string {
|
||||
return lb.scheme + "://" + lb.localAddress
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) ServerAddresses() []string {
|
||||
if lb == nil {
|
||||
return nil
|
||||
}
|
||||
return lb.serverAddresses
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
lb.mutex.RLock()
|
||||
defer lb.mutex.RUnlock()
|
||||
|
||||
var allChecksFailed bool
|
||||
startIndex := lb.nextServerIndex
|
||||
for {
|
||||
targetServer := lb.currentServerAddress
|
||||
|
||||
server := lb.servers[targetServer]
|
||||
if server == nil || targetServer == "" {
|
||||
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
|
||||
} else if allChecksFailed || server.healthCheck() {
|
||||
dialTime := time.Now()
|
||||
conn, err := server.dialContext(ctx, network, targetServer)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
|
||||
// Don't close connections to the failed server if we're retrying with health checks ignored.
|
||||
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
|
||||
if !allChecksFailed {
|
||||
defer server.closeAll()
|
||||
}
|
||||
} else {
|
||||
logrus.Debugf("Dial health check failed for %s", targetServer)
|
||||
}
|
||||
|
||||
newServer, err := lb.nextServer(targetServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if targetServer != newServer {
|
||||
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
maxIndex := len(lb.randomServers)
|
||||
if startIndex > maxIndex {
|
||||
startIndex = maxIndex
|
||||
}
|
||||
if lb.nextServerIndex == startIndex {
|
||||
if allChecksFailed {
|
||||
return nil, errors.New("all servers failed")
|
||||
}
|
||||
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
|
||||
allChecksFailed = true
|
||||
}
|
||||
}
|
||||
return lb.servers.getAddresses()
|
||||
}
|
||||
|
||||
func onDialError(src net.Conn, dstDialErr error) {
|
||||
|
@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) {
|
|||
}
|
||||
|
||||
// ResetLoadBalancer will delete the local state file for the load balancer on disk
|
||||
func ResetLoadBalancer(dataDir, serviceName string) error {
|
||||
func ResetLoadBalancer(dataDir, serviceName string) {
|
||||
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
|
||||
if err := os.Remove(stateFile); err != nil {
|
||||
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
|
||||
logrus.Warn(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,19 +5,29 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Test_UnitLoadBalancer(t *testing.T) {
|
||||
_, reporterConfig := GinkgoConfiguration()
|
||||
reporterConfig.Verbose = testing.Verbose()
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "LoadBalancer Suite", reporterConfig)
|
||||
}
|
||||
|
||||
func init() {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
address string
|
||||
listener net.Listener
|
||||
conns []net.Conn
|
||||
prefix string
|
||||
|
@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) {
|
|||
s := &testServer{
|
||||
prefix: prefix,
|
||||
listener: listener,
|
||||
address: listener.Addr().String(),
|
||||
}
|
||||
go s.serve()
|
||||
go func() {
|
||||
|
@ -53,6 +64,7 @@ func (s *testServer) serve() {
|
|||
|
||||
func (s *testServer) close() {
|
||||
logrus.Printf("testServer %s closing", s.prefix)
|
||||
s.address = ""
|
||||
s.listener.Close()
|
||||
for _, conn := range s.conns {
|
||||
conn.Close()
|
||||
|
@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *testServer) address() string {
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
|
||||
func ping(conn net.Conn) (string, error) {
|
||||
fmt.Fprintf(conn, "ping\n")
|
||||
result, err := bufio.NewReader(conn).ReadString('\n')
|
||||
|
@ -82,221 +90,340 @@ func ping(conn net.Conn) (string, error) {
|
|||
return strings.TrimSpace(result), nil
|
||||
}
|
||||
|
||||
// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint)
|
||||
// and then adds a new server (a node). The node server is then closed, and it is confirmed
|
||||
// that new connections use the default server.
|
||||
func Test_UnitFailOver(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
var _ = Describe("LoadBalancer", func() {
|
||||
// creates a LB using a default server (ie fixed registration endpoint)
|
||||
// and then adds a new server (a node). The node server is then closed, and it is confirmed
|
||||
// that new connections use the default server.
|
||||
When("loadbalancer is running", Ordered, func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var defaultServer, node1Server, node2Server *testServer
|
||||
var conn1, conn2, conn3, conn4 net.Conn
|
||||
var lb *LoadBalancer
|
||||
var err error
|
||||
|
||||
defaultServer, err := createServer(ctx, "default")
|
||||
if err != nil {
|
||||
t.Fatalf("createServer(default) failed: %v", err)
|
||||
}
|
||||
BeforeAll(func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
|
||||
node1Server, err := createServer(ctx, "node1")
|
||||
if err != nil {
|
||||
t.Fatalf("createServer(node1) failed: %v", err)
|
||||
}
|
||||
defaultServer, err = createServer(ctx, "default")
|
||||
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
|
||||
|
||||
node2Server, err := createServer(ctx, "node2")
|
||||
if err != nil {
|
||||
t.Fatalf("createServer(node2) failed: %v", err)
|
||||
}
|
||||
node1Server, err = createServer(ctx, "node1")
|
||||
Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed")
|
||||
|
||||
// start the loadbalancer with the default server as the only server
|
||||
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false)
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
node2Server, err = createServer(ctx, "node2")
|
||||
Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed")
|
||||
|
||||
parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse failed: %v", err)
|
||||
}
|
||||
localAddress := parsedURL.Host
|
||||
// start the loadbalancer with the default server as the only server
|
||||
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false)
|
||||
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||
})
|
||||
|
||||
// add the node as a new server address.
|
||||
lb.Update([]string{node1Server.address()})
|
||||
AfterAll(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
// make sure connections go to the node
|
||||
conn1, err := net.Dial("tcp", localAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial failed: %v", err)
|
||||
}
|
||||
if result, err := ping(conn1); err != nil {
|
||||
t.Fatalf("ping(conn1) failed: %v", err)
|
||||
} else if result != "node1:ping" {
|
||||
t.Fatalf("Unexpected ping(conn1) result: %v", result)
|
||||
}
|
||||
It("adds node1 as a server", func() {
|
||||
// add the node as a new server address.
|
||||
lb.Update([]string{node1Server.address})
|
||||
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||
|
||||
t.Log("conn1 tested OK")
|
||||
By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers()))
|
||||
|
||||
// set failing health check for node 1
|
||||
lb.SetHealthCheck(node1Server.address(), func() bool { return false })
|
||||
// wait for state to change
|
||||
Eventually(func() state {
|
||||
if s := lb.servers.getServer(node1Server.address); s != nil {
|
||||
return s.state
|
||||
}
|
||||
return stateInvalid
|
||||
}, 5, 1).Should(Equal(statePreferred))
|
||||
})
|
||||
|
||||
// Server connections are checked every second, now that node 1 is failed
|
||||
// the connections to it should be closed.
|
||||
time.Sleep(2 * time.Second)
|
||||
It("connects to node1", func() {
|
||||
// make sure connections go to the node
|
||||
conn1, err = net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result")
|
||||
|
||||
if _, err := ping(conn1); err == nil {
|
||||
t.Fatal("Unexpected successful ping on closed connection conn1")
|
||||
}
|
||||
By("conn1 tested OK")
|
||||
})
|
||||
|
||||
t.Log("conn1 closed on failure OK")
|
||||
It("changes node1 state to failed", func() {
|
||||
// set failing health check for node 1
|
||||
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed })
|
||||
|
||||
// make sure connection still goes to the first node - it is failing health checks but so
|
||||
// is the default endpoint, so it should be tried first with health checks disabled,
|
||||
// before failing back to the default.
|
||||
conn2, err := net.Dial("tcp", localAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial failed: %v", err)
|
||||
// wait for state to change
|
||||
Eventually(func() state {
|
||||
if s := lb.servers.getServer(node1Server.address); s != nil {
|
||||
return s.state
|
||||
}
|
||||
return stateInvalid
|
||||
}, 5, 1).Should(Equal(stateFailed))
|
||||
})
|
||||
|
||||
}
|
||||
if result, err := ping(conn2); err != nil {
|
||||
t.Fatalf("ping(conn2) failed: %v", err)
|
||||
} else if result != "node1:ping" {
|
||||
t.Fatalf("Unexpected ping(conn2) result: %v", result)
|
||||
}
|
||||
It("disconnects from node1", func() {
|
||||
// Server connections are checked every second, now that node 1 is failed
|
||||
// the connections to it should be closed.
|
||||
Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
|
||||
|
||||
t.Log("conn2 tested OK")
|
||||
By("conn1 closed on failure OK")
|
||||
|
||||
// make sure the health checks don't close the connection we just made -
|
||||
// connections should only be closed when it transitions from health to unhealthy.
|
||||
time.Sleep(2 * time.Second)
|
||||
// connections shoould go to the default now that node 1 is failed
|
||||
conn2, err = net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
|
||||
|
||||
if result, err := ping(conn2); err != nil {
|
||||
t.Fatalf("ping(conn2) failed: %v", err)
|
||||
} else if result != "node1:ping" {
|
||||
t.Fatalf("Unexpected ping(conn2) result: %v", result)
|
||||
}
|
||||
By("conn2 tested OK")
|
||||
})
|
||||
|
||||
t.Log("conn2 tested OK again")
|
||||
It("does not close connections unexpectedly", func() {
|
||||
// make sure the health checks don't close the connection we just made -
|
||||
// connections should only be closed when it transitions from health to unhealthy.
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// shut down the first node server to force failover to the default
|
||||
node1Server.close()
|
||||
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
|
||||
|
||||
// make sure new connections go to the default, and existing connections are closed
|
||||
conn3, err := net.Dial("tcp", localAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial failed: %v", err)
|
||||
By("conn2 tested OK again")
|
||||
})
|
||||
|
||||
}
|
||||
if result, err := ping(conn3); err != nil {
|
||||
t.Fatalf("ping(conn3) failed: %v", err)
|
||||
} else if result != "default:ping" {
|
||||
t.Fatalf("Unexpected ping(conn3) result: %v", result)
|
||||
}
|
||||
It("closes connections when dial fails", func() {
|
||||
// shut down the first node server to force failover to the default
|
||||
node1Server.close()
|
||||
|
||||
t.Log("conn3 tested OK")
|
||||
// make sure new connections go to the default, and existing connections are closed
|
||||
conn3, err = net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
|
||||
if _, err := ping(conn2); err == nil {
|
||||
t.Fatal("Unexpected successful ping on closed connection conn2")
|
||||
}
|
||||
Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result")
|
||||
|
||||
t.Log("conn2 closed on failure OK")
|
||||
By("conn3 tested OK")
|
||||
})
|
||||
|
||||
// add the second node as a new server address.
|
||||
lb.Update([]string{node2Server.address()})
|
||||
It("replaces node2 as a server", func() {
|
||||
// add the second node as a new server address.
|
||||
lb.Update([]string{node2Server.address})
|
||||
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||
|
||||
// make sure connection now goes to the second node,
|
||||
// and connections to the default are closed.
|
||||
conn4, err := net.Dial("tcp", localAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial failed: %v", err)
|
||||
By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers()))
|
||||
|
||||
}
|
||||
if result, err := ping(conn4); err != nil {
|
||||
t.Fatalf("ping(conn4) failed: %v", err)
|
||||
} else if result != "node2:ping" {
|
||||
t.Fatalf("Unexpected ping(conn4) result: %v", result)
|
||||
}
|
||||
// wait for state to change
|
||||
Eventually(func() state {
|
||||
if s := lb.servers.getServer(node2Server.address); s != nil {
|
||||
return s.state
|
||||
}
|
||||
return stateInvalid
|
||||
}, 5, 1).Should(Equal(statePreferred))
|
||||
})
|
||||
|
||||
t.Log("conn4 tested OK")
|
||||
It("connects to node2", func() {
|
||||
// make sure connection now goes to the second node,
|
||||
// and connections to the default are closed.
|
||||
conn4, err = net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
|
||||
// Server connections are checked every second, now that we have a healthy
|
||||
// server, connections to the default server should be closed
|
||||
time.Sleep(2 * time.Second)
|
||||
Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result")
|
||||
|
||||
if _, err := ping(conn3); err == nil {
|
||||
t.Fatal("Unexpected successful ping on connection conn3")
|
||||
}
|
||||
By("conn4 tested OK")
|
||||
})
|
||||
|
||||
t.Log("conn3 closed on failure OK")
|
||||
}
|
||||
It("does not close connections unexpectedly", func() {
|
||||
// Server connections are checked every second, now that we have a healthy
|
||||
// server, connections to the default server should be closed
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly
|
||||
func Test_UnitFailFast(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
|
||||
|
||||
serverURL := "http://127.0.0.1:0/"
|
||||
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
By("conn2 closed on failure OK")
|
||||
|
||||
conn, err := net.Dial("tcp", lb.localAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial failed: %v", err)
|
||||
}
|
||||
Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err = ping(conn)
|
||||
done <- err
|
||||
}()
|
||||
timeout := time.After(10 * time.Millisecond)
|
||||
By("conn3 closed on failure OK")
|
||||
})
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
t.Fatal("Unexpected successful ping from invalid address")
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatal("Test timed out")
|
||||
}
|
||||
}
|
||||
It("adds default as a server", func() {
|
||||
// add the default as a full server
|
||||
lb.Update([]string{node2Server.address, defaultServer.address})
|
||||
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||
|
||||
// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail
|
||||
// within the expected duration
|
||||
func Test_UnitFailUnreachable(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping slow test in short mode.")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// wait for state to change
|
||||
Eventually(func() state {
|
||||
if s := lb.servers.getServer(defaultServer.address); s != nil {
|
||||
return s.state
|
||||
}
|
||||
return stateInvalid
|
||||
}, 5, 1).Should(Equal(statePreferred))
|
||||
|
||||
serverAddr := "192.0.2.1:6443"
|
||||
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false)
|
||||
if err != nil {
|
||||
t.Fatalf("New() failed: %v", err)
|
||||
}
|
||||
By(fmt.Sprintf("Default server added: %v", lb.servers.getServers()))
|
||||
})
|
||||
|
||||
// Set failing health check to reduce retries
|
||||
lb.SetHealthCheck(serverAddr, func() bool { return false })
|
||||
It("returns the default server in the address list", func() {
|
||||
// confirm that both servers are listed in the address list
|
||||
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address))
|
||||
|
||||
conn, err := net.Dial("tcp", lb.localAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("net.Dial failed: %v", err)
|
||||
}
|
||||
// confirm that the default is still listed as default
|
||||
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err = ping(conn)
|
||||
done <- err
|
||||
}()
|
||||
timeout := time.After(11 * time.Second)
|
||||
})
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
t.Fatal("Unexpected successful ping from unreachable address")
|
||||
}
|
||||
case <-timeout:
|
||||
t.Fatal("Test timed out")
|
||||
}
|
||||
}
|
||||
It("does not return the default server in the address list after removing it", func() {
|
||||
// remove the default as a server
|
||||
lb.Update([]string{node2Server.address})
|
||||
By(fmt.Sprintf("Default removed: %v", lb.servers.getServers()))
|
||||
|
||||
// confirm that it is not listed as a server
|
||||
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
|
||||
|
||||
// but is still listed as the default
|
||||
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
|
||||
})
|
||||
|
||||
It("removes default server when no longer default", func() {
|
||||
// set node2 as the default
|
||||
lb.SetDefault(node2Server.address)
|
||||
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
|
||||
|
||||
// confirm that it is still listed as a server
|
||||
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
|
||||
|
||||
// and is listed as the default
|
||||
Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default")
|
||||
})
|
||||
|
||||
It("sets all three servers", func() {
|
||||
// set node2 as the default
|
||||
lb.SetDefault(defaultServer.address)
|
||||
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
|
||||
|
||||
lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address})
|
||||
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||
|
||||
// wait for state to change
|
||||
Eventually(func() state {
|
||||
if s := lb.servers.getServer(defaultServer.address); s != nil {
|
||||
return s.state
|
||||
}
|
||||
return stateInvalid
|
||||
}, 5, 1).Should(Equal(statePreferred))
|
||||
|
||||
By(fmt.Sprintf("All servers set: %v", lb.servers.getServers()))
|
||||
|
||||
// confirm that it is still listed as a server
|
||||
Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address))
|
||||
|
||||
// and is listed as the default
|
||||
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
|
||||
})
|
||||
})
|
||||
|
||||
// confirms that the loadbalancer will not dial itself
|
||||
When("the default server is the loadbalancer", Ordered, func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var defaultServer *testServer
|
||||
var lb *LoadBalancer
|
||||
var err error
|
||||
|
||||
BeforeAll(func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
|
||||
defaultServer, err = createServer(ctx, "default")
|
||||
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
|
||||
address := defaultServer.address
|
||||
defaultServer.close()
|
||||
_, port, _ := net.SplitHostPort(address)
|
||||
intPort, _ := strconv.Atoi(port)
|
||||
|
||||
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false)
|
||||
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||
})
|
||||
|
||||
AfterAll(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
It("fails immediately", func() {
|
||||
conn, err := net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
|
||||
_, err = ping(conn)
|
||||
Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection")
|
||||
})
|
||||
})
|
||||
|
||||
// confirms that connnections to invalid addresses fail quickly
|
||||
When("there are no valid addresses", Ordered, func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var lb *LoadBalancer
|
||||
var err error
|
||||
|
||||
BeforeAll(func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false)
|
||||
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||
})
|
||||
|
||||
AfterAll(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
It("fails fast", func() {
|
||||
conn, err := net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err = ping(conn)
|
||||
done <- err
|
||||
}()
|
||||
timeout := time.After(10 * time.Millisecond)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
Fail("Unexpected successful ping from invalid address")
|
||||
}
|
||||
case <-timeout:
|
||||
Fail("Test timed out")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// confirms that connnections to unreachable addresses do fail within the
|
||||
// expected duration
|
||||
When("the server is unreachable", Ordered, func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var lb *LoadBalancer
|
||||
var err error
|
||||
|
||||
BeforeAll(func() {
|
||||
tmpDir := GinkgoT().TempDir()
|
||||
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false)
|
||||
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||
})
|
||||
|
||||
AfterAll(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
It("fails with the correct timeout", func() {
|
||||
conn, err := net.Dial("tcp", lb.localAddress)
|
||||
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err = ping(conn)
|
||||
done <- err
|
||||
}()
|
||||
timeout := time.After(11 * time.Second)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err == nil {
|
||||
Fail("Unexpected successful ping from unreachable address")
|
||||
}
|
||||
case <-timeout:
|
||||
Fail("Test timed out")
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
|
@ -1,118 +1,421 @@
|
|||
package loadbalancer
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"math/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
)
|
||||
|
||||
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
|
||||
serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
|
||||
if len(serverAddresses) == 0 {
|
||||
return false
|
||||
}
|
||||
type HealthCheckFunc func() HealthCheckResult
|
||||
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
// HealthCheckResult indicates the status of a server health check poll.
|
||||
// For health-checks that poll in the background, Unknown should be returned
|
||||
// if a poll has not occurred since the last check.
|
||||
type HealthCheckResult int
|
||||
|
||||
newAddresses := sets.NewString(serverAddresses...)
|
||||
curAddresses := sets.NewString(lb.serverAddresses...)
|
||||
const (
|
||||
HealthCheckResultUnknown HealthCheckResult = iota
|
||||
HealthCheckResultFailed
|
||||
HealthCheckResultOK
|
||||
)
|
||||
|
||||
// serverList tracks potential backend servers for use by a loadbalancer.
|
||||
type serverList struct {
|
||||
// This mutex protects access to the server list. All direct access to the list should be protected by it.
|
||||
mutex sync.Mutex
|
||||
servers []*server
|
||||
}
|
||||
|
||||
// setServers updates the server list to contain only the selected addresses.
|
||||
func (sl *serverList) setAddresses(serviceName string, addresses []string) bool {
|
||||
newAddresses := sets.New(addresses...)
|
||||
curAddresses := sets.New(sl.getAddresses()...)
|
||||
if newAddresses.Equal(curAddresses) {
|
||||
return false
|
||||
}
|
||||
|
||||
for addedServer := range newAddresses.Difference(curAddresses) {
|
||||
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
|
||||
lb.servers[addedServer] = &server{
|
||||
address: addedServer,
|
||||
connections: make(map[net.Conn]struct{}),
|
||||
healthCheck: func() bool { return true },
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
var closeAllFuncs []func()
|
||||
var defaultServer *server
|
||||
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
|
||||
defaultServer = sl.servers[i]
|
||||
}
|
||||
|
||||
// add new servers
|
||||
for addedAddress := range newAddresses.Difference(curAddresses) {
|
||||
if defaultServer != nil && defaultServer.address == addedAddress {
|
||||
// make default server go through the same health check promotions as a new server when added
|
||||
logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName)
|
||||
defaultServer.state = stateUnchecked
|
||||
defaultServer.lastTransition = time.Now()
|
||||
} else {
|
||||
s := newServer(addedAddress, false)
|
||||
logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address)
|
||||
sl.servers = append(sl.servers, s)
|
||||
}
|
||||
}
|
||||
|
||||
for removedServer := range curAddresses.Difference(newAddresses) {
|
||||
server := lb.servers[removedServer]
|
||||
if server != nil {
|
||||
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer)
|
||||
// Defer closing connections until after the new server list has been put into place.
|
||||
// Closing open connections ensures that anything stuck retrying on a stale server is forced
|
||||
// over to a valid endpoint.
|
||||
defer server.closeAll()
|
||||
// Don't delete the default server from the server map, in case we need to fall back to it.
|
||||
if removedServer != lb.defaultServerAddress {
|
||||
delete(lb.servers, removedServer)
|
||||
}
|
||||
// remove old servers
|
||||
for removedAddress := range curAddresses.Difference(newAddresses) {
|
||||
if defaultServer != nil && defaultServer.address == removedAddress {
|
||||
// demote the default server down to standby, instead of deleting it
|
||||
defaultServer.state = stateStandby
|
||||
closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll)
|
||||
} else {
|
||||
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
|
||||
if s.address == removedAddress {
|
||||
logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address)
|
||||
// set state to invalid to prevent server from making additional connections
|
||||
s.state = stateInvalid
|
||||
closeAllFuncs = append(closeAllFuncs, s.closeAll)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
lb.serverAddresses = serverAddresses
|
||||
lb.randomServers = append([]string{}, lb.serverAddresses...)
|
||||
rand.Shuffle(len(lb.randomServers), func(i, j int) {
|
||||
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
|
||||
})
|
||||
// If the current server list does not contain the default server address,
|
||||
// we want to include it in the random server list so that it can be tried if necessary.
|
||||
// However, it should be treated as always failing health checks so that it is only
|
||||
// used if all other endpoints are unavailable.
|
||||
if !hasDefaultServer {
|
||||
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
|
||||
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
|
||||
defaultServer.healthCheck = func() bool { return false }
|
||||
lb.servers[lb.defaultServerAddress] = defaultServer
|
||||
}
|
||||
slices.SortFunc(sl.servers, compareServers)
|
||||
|
||||
// Close all connections to servers that were removed
|
||||
for _, closeAll := range closeAllFuncs {
|
||||
closeAll()
|
||||
}
|
||||
lb.currentServerAddress = lb.randomServers[0]
|
||||
lb.nextServerIndex = 1
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// nextServer attempts to get the next server in the loadbalancer server list.
|
||||
// If another goroutine has already updated the current server address to point at
|
||||
// a different address than just failed, nothing is changed. Otherwise, a new server address
|
||||
// is stored to the currentServerAddress field, and returned for use.
|
||||
// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex.
|
||||
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
|
||||
// note: these fields are not protected by the mutex, so we clamp the index value and update
|
||||
// the index/current address using local variables, to avoid time-of-check vs time-of-use
|
||||
// race conditions caused by goroutine A incrementing it in between the time goroutine B
|
||||
// validates its value, and uses it as a list index.
|
||||
currentServerAddress := lb.currentServerAddress
|
||||
nextServerIndex := lb.nextServerIndex
|
||||
// getAddresses returns the addresses of all servers.
|
||||
// If the default server is in standby state, indicating it is only present
|
||||
// because it is the default, it is not returned in this list.
|
||||
func (sl *serverList) getAddresses() []string {
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
if len(lb.randomServers) == 0 {
|
||||
return "", errors.New("No servers in load balancer proxy list")
|
||||
addresses := make([]string, 0, len(sl.servers))
|
||||
for _, s := range sl.servers {
|
||||
if s.isDefault && s.state == stateStandby {
|
||||
continue
|
||||
}
|
||||
addresses = append(addresses, s.address)
|
||||
}
|
||||
if len(lb.randomServers) == 1 {
|
||||
return currentServerAddress, nil
|
||||
}
|
||||
if failedServer != currentServerAddress {
|
||||
return currentServerAddress, nil
|
||||
}
|
||||
if nextServerIndex >= len(lb.randomServers) {
|
||||
nextServerIndex = 0
|
||||
}
|
||||
|
||||
currentServerAddress = lb.randomServers[nextServerIndex]
|
||||
nextServerIndex++
|
||||
|
||||
lb.currentServerAddress = currentServerAddress
|
||||
lb.nextServerIndex = nextServerIndex
|
||||
|
||||
return currentServerAddress, nil
|
||||
return addresses
|
||||
}
|
||||
|
||||
// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map
|
||||
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := defaultDialer.Dial(network, address)
|
||||
// setDefault sets the server with the provided address as the default server.
|
||||
// The default flag is cleared on all other servers, and if the server was previously
|
||||
// only kept in the list because it was the default, it is removed.
|
||||
func (sl *serverList) setDefaultAddress(serviceName, address string) {
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
// deal with existing default first
|
||||
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
|
||||
if s.isDefault && s.address != address {
|
||||
s.isDefault = false
|
||||
if s.state == stateStandby {
|
||||
s.state = stateInvalid
|
||||
defer s.closeAll()
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
// update or create server with selected address
|
||||
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
|
||||
sl.servers[i].isDefault = true
|
||||
} else {
|
||||
sl.servers = append(sl.servers, newServer(address, true))
|
||||
}
|
||||
|
||||
logrus.Infof("Updated load balancer %s default server: %s", serviceName, address)
|
||||
slices.SortFunc(sl.servers, compareServers)
|
||||
}
|
||||
|
||||
// getDefault returns the address of the default server.
|
||||
func (sl *serverList) getDefaultAddress() string {
|
||||
if s := sl.getDefaultServer(); s != nil {
|
||||
return s.address
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getDefault returns the default server.
|
||||
func (sl *serverList) getDefaultServer() *server {
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
|
||||
return sl.servers[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getServers returns a copy of the servers list that can be safely iterated over without holding a lock
|
||||
func (sl *serverList) getServers() []*server {
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
return slices.Clone(sl.servers)
|
||||
}
|
||||
|
||||
// getServer returns the first server with the specified address
|
||||
func (sl *serverList) getServer(address string) *server {
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
|
||||
return sl.servers[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setHealthCheck updates the health check function for a server, replacing the
|
||||
// current function.
|
||||
func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error {
|
||||
if s := sl.getServer(address); s != nil {
|
||||
s.healthCheck = healthCheck
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("no server found for %s", address)
|
||||
}
|
||||
|
||||
// recordSuccess records a successful check of a server, either via health-check or dial.
|
||||
// The server's state is adjusted accordingly.
|
||||
func (sl *serverList) recordSuccess(srv *server, r reason) {
|
||||
var new_state state
|
||||
switch srv.state {
|
||||
case stateFailed, stateUnchecked:
|
||||
// dialed or health checked OK once, improve to recovering
|
||||
new_state = stateRecovering
|
||||
case stateRecovering:
|
||||
if r == reasonHealthCheck {
|
||||
// was recovering due to successful dial or first health check, can now improve
|
||||
if len(srv.connections) > 0 {
|
||||
// server accepted connections while recovering, attempt to go straight to active
|
||||
new_state = stateActive
|
||||
} else {
|
||||
// no connections, just make it preferred
|
||||
new_state = statePreferred
|
||||
}
|
||||
}
|
||||
case stateHealthy:
|
||||
if r == reasonDial {
|
||||
// improve from healthy to active by being dialed
|
||||
new_state = stateActive
|
||||
}
|
||||
case statePreferred:
|
||||
if r == reasonDial {
|
||||
// improve from healthy to active by being dialed
|
||||
new_state = stateActive
|
||||
} else {
|
||||
if time.Now().Sub(srv.lastTransition) > time.Minute {
|
||||
// has been preferred for a while without being dialed, demote to healthy
|
||||
new_state = stateHealthy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// no-op if state did not change
|
||||
if new_state == stateInvalid {
|
||||
return
|
||||
}
|
||||
|
||||
// handle active transition and sort the server list while holding the lock
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
// handle states of other servers when attempting to make this one active
|
||||
if new_state == stateActive {
|
||||
for _, s := range sl.servers {
|
||||
if srv.address == s.address {
|
||||
continue
|
||||
}
|
||||
switch s.state {
|
||||
case stateFailed, stateStandby, stateRecovering, stateHealthy:
|
||||
// close connections to other non-active servers whenever we have a new active server
|
||||
defer s.closeAll()
|
||||
case stateActive:
|
||||
if len(s.connections) > len(srv.connections) {
|
||||
// if there is a currently active server that has more connections than we do,
|
||||
// close our connections and go to preferred instead
|
||||
new_state = statePreferred
|
||||
defer srv.closeAll()
|
||||
} else {
|
||||
// otherwise, close its connections and demote it to preferred
|
||||
s.state = statePreferred
|
||||
defer s.closeAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensure some other routine didn't already make the transition
|
||||
if srv.state == new_state {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Infof("Server %s->%s from successful %s", srv, new_state, r)
|
||||
srv.state = new_state
|
||||
srv.lastTransition = time.Now()
|
||||
|
||||
slices.SortFunc(sl.servers, compareServers)
|
||||
}
|
||||
|
||||
// recordFailure records a failed check of a server, either via health-check or dial.
|
||||
// The server's state is adjusted accordingly.
|
||||
func (sl *serverList) recordFailure(srv *server, r reason) {
|
||||
var new_state state
|
||||
switch srv.state {
|
||||
case stateUnchecked, stateRecovering:
|
||||
if r == reasonDial {
|
||||
// only demote from unchecked or recovering if a dial fails, health checks may
|
||||
// continue to fail despite it being dialable. just leave it where it is
|
||||
// and don't close any connections.
|
||||
new_state = stateFailed
|
||||
}
|
||||
case stateHealthy, statePreferred, stateActive:
|
||||
// should not have any connections when in any state other than active or
|
||||
// recovering, but close them all anyway to force failover.
|
||||
defer srv.closeAll()
|
||||
new_state = stateFailed
|
||||
}
|
||||
|
||||
// no-op if state did not change
|
||||
if new_state == stateInvalid {
|
||||
return
|
||||
}
|
||||
|
||||
// sort the server list while holding the lock
|
||||
sl.mutex.Lock()
|
||||
defer sl.mutex.Unlock()
|
||||
|
||||
// ensure some other routine didn't already make the transition
|
||||
if srv.state == new_state {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Infof("Server %s->%s from failed %s", srv, new_state, r)
|
||||
srv.state = new_state
|
||||
srv.lastTransition = time.Now()
|
||||
|
||||
slices.SortFunc(sl.servers, compareServers)
|
||||
}
|
||||
|
||||
// state is possible server health states, in increasing order of preference.
|
||||
// The server list is kept sorted in descending order by this state value.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateInvalid state = iota
|
||||
stateFailed // failed a health check or dial
|
||||
stateStandby // reserved for use by default server if not in server list
|
||||
stateUnchecked // just added, has not been health checked
|
||||
stateRecovering // successfully health checked once, or dialed when failed
|
||||
stateHealthy // normal state
|
||||
statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance
|
||||
stateActive // currently active server
|
||||
)
|
||||
|
||||
func (s state) String() string {
|
||||
switch s {
|
||||
case stateInvalid:
|
||||
return "INVALID"
|
||||
case stateFailed:
|
||||
return "FAILED"
|
||||
case stateStandby:
|
||||
return "STANDBY"
|
||||
case stateUnchecked:
|
||||
return "UNCHECKED"
|
||||
case stateRecovering:
|
||||
return "RECOVERING"
|
||||
case stateHealthy:
|
||||
return "HEALTHY"
|
||||
case statePreferred:
|
||||
return "PREFERRED"
|
||||
case stateActive:
|
||||
return "ACTIVE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// reason specifies the reason for a successful or failed health report
|
||||
type reason int
|
||||
|
||||
const (
|
||||
reasonDial reason = iota
|
||||
reasonHealthCheck
|
||||
)
|
||||
|
||||
func (r reason) String() string {
|
||||
switch r {
|
||||
case reasonDial:
|
||||
return "dial"
|
||||
case reasonHealthCheck:
|
||||
return "health check"
|
||||
default:
|
||||
return "unknown reason"
|
||||
}
|
||||
}
|
||||
|
||||
// server tracks the connections to a server, so that they can be closed when the server is removed.
|
||||
type server struct {
|
||||
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
|
||||
mutex sync.Mutex
|
||||
address string
|
||||
isDefault bool
|
||||
state state
|
||||
lastTransition time.Time
|
||||
healthCheck HealthCheckFunc
|
||||
connections map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
// newServer creates a new server, with a default health check
|
||||
// and default/state fields appropriate for whether or not
|
||||
// the server is a full server, or just a fallback default.
|
||||
func newServer(address string, isDefault bool) *server {
|
||||
state := stateUnchecked
|
||||
if isDefault {
|
||||
state = stateStandby
|
||||
}
|
||||
return &server{
|
||||
address: address,
|
||||
isDefault: isDefault,
|
||||
state: state,
|
||||
lastTransition: time.Now(),
|
||||
healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown },
|
||||
connections: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) String() string {
|
||||
format := "%s@%s"
|
||||
if s.isDefault {
|
||||
format += "*"
|
||||
}
|
||||
return fmt.Sprintf(format, s.address, s.state)
|
||||
}
|
||||
|
||||
// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map
|
||||
func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) {
|
||||
if s.state == stateInvalid {
|
||||
return nil, fmt.Errorf("server %s is stopping", s.address)
|
||||
}
|
||||
|
||||
conn, err := defaultDialer.Dial(network, s.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -132,7 +435,7 @@ func (s *server) closeAll() {
|
|||
defer s.mutex.Unlock()
|
||||
|
||||
if l := len(s.connections); l > 0 {
|
||||
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
|
||||
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s)
|
||||
for conn := range s.connections {
|
||||
// Close the connection in a goroutine so that we don't hold the lock while doing so.
|
||||
go conn.Close()
|
||||
|
@ -140,6 +443,12 @@ func (s *server) closeAll() {
|
|||
}
|
||||
}
|
||||
|
||||
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
|
||||
type serverConn struct {
|
||||
server *server
|
||||
net.Conn
|
||||
}
|
||||
|
||||
// Close removes the connection entry from the server's connection map, and
|
||||
// closes the wrapped connection.
|
||||
func (sc *serverConn) Close() error {
|
||||
|
@ -150,73 +459,43 @@ func (sc *serverConn) Close() error {
|
|||
return sc.Conn.Close()
|
||||
}
|
||||
|
||||
// SetDefault sets the selected address as the default / fallback address
|
||||
func (lb *LoadBalancer) SetDefault(serverAddress string) {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
|
||||
// if the old default server is not currently in use, remove it from the server map
|
||||
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
|
||||
defer server.closeAll()
|
||||
delete(lb.servers, lb.defaultServerAddress)
|
||||
}
|
||||
// if the new default server doesn't have an entry in the map, add one - but
|
||||
// with a failing health check so that it is only used as a last resort.
|
||||
if _, ok := lb.servers[serverAddress]; !ok {
|
||||
lb.servers[serverAddress] = &server{
|
||||
address: serverAddress,
|
||||
healthCheck: func() bool { return false },
|
||||
connections: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
lb.defaultServerAddress = serverAddress
|
||||
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
|
||||
}
|
||||
|
||||
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
|
||||
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
if server := lb.servers[address]; server != nil {
|
||||
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
|
||||
server.healthCheck = healthCheck
|
||||
} else {
|
||||
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
|
||||
}
|
||||
}
|
||||
|
||||
// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
|
||||
// connections closed, to force clients to switch over to a healthy server.
|
||||
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
|
||||
previousStatus := map[string]bool{}
|
||||
// runHealthChecks periodically health-checks all servers.
|
||||
func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) {
|
||||
wait.Until(func() {
|
||||
lb.mutex.RLock()
|
||||
defer lb.mutex.RUnlock()
|
||||
var healthyServerExists bool
|
||||
for address, server := range lb.servers {
|
||||
status := server.healthCheck()
|
||||
healthyServerExists = healthyServerExists || status
|
||||
if status == false && previousStatus[address] == true {
|
||||
// Only close connections when the server transitions from healthy to unhealthy;
|
||||
// we don't want to re-close all the connections every time as we might be ignoring
|
||||
// health checks due to all servers being marked unhealthy.
|
||||
defer server.closeAll()
|
||||
}
|
||||
previousStatus[address] = status
|
||||
}
|
||||
|
||||
// If there is at least one healthy server, and the default server is not in the server list,
|
||||
// close all the connections to the default server so that clients reconnect and switch over
|
||||
// to a preferred server.
|
||||
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
|
||||
if healthyServerExists && !hasDefaultServer {
|
||||
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
|
||||
defer server.closeAll()
|
||||
for _, s := range sl.getServers() {
|
||||
switch s.healthCheck() {
|
||||
case HealthCheckResultOK:
|
||||
sl.recordSuccess(s, reasonHealthCheck)
|
||||
case HealthCheckResultFailed:
|
||||
sl.recordFailure(s, reasonHealthCheck)
|
||||
}
|
||||
}
|
||||
}, time.Second, ctx.Done())
|
||||
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
|
||||
logrus.Debugf("Stopped health checking for load balancer %s", serviceName)
|
||||
}
|
||||
|
||||
// dialContext attemps to dial a connection to a server from the server list.
|
||||
// Success or failure is recorded to ensure that server state is updated appropriately.
|
||||
func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
for _, s := range sl.getServers() {
|
||||
dialTime := time.Now()
|
||||
conn, err := s.dialContext(ctx, network)
|
||||
if err == nil {
|
||||
sl.recordSuccess(s, reasonDial)
|
||||
return conn, nil
|
||||
}
|
||||
logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err)
|
||||
sl.recordFailure(s, reasonDial)
|
||||
}
|
||||
return nil, errors.New("all servers failed")
|
||||
}
|
||||
|
||||
// compareServers is a comparison function that can be used to sort the server list
|
||||
// so that servers with a more preferred state, or higher number of connections, are ordered first.
|
||||
func compareServers(a, b *server) int {
|
||||
c := cmp.Compare(b.state, a.state)
|
||||
if c == 0 {
|
||||
return cmp.Compare(len(b.connections), len(a.connections))
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ type Proxy interface {
|
|||
SupervisorAddresses() []string
|
||||
APIServerURL() string
|
||||
IsAPIServerLBEnabled() bool
|
||||
SetHealthCheck(address string, healthCheck func() bool)
|
||||
SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc)
|
||||
}
|
||||
|
||||
// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
|
||||
|
@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
|
|||
return nil, err
|
||||
}
|
||||
p.supervisorLB = lb
|
||||
p.supervisorURL = lb.LoadBalancerServerURL()
|
||||
p.supervisorURL = lb.LocalURL()
|
||||
p.apiServerURL = p.supervisorURL
|
||||
}
|
||||
|
||||
|
@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) {
|
|||
p.supervisorAddresses = supervisorAddresses
|
||||
}
|
||||
|
||||
func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) {
|
||||
func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) {
|
||||
if p.supervisorLB != nil {
|
||||
p.supervisorLB.SetHealthCheck(address, healthCheck)
|
||||
}
|
||||
|
@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
|
|||
return err
|
||||
}
|
||||
p.apiServerLB = lb
|
||||
p.apiServerURL = lb.LoadBalancerServerURL()
|
||||
p.apiServerURL = lb.LocalURL()
|
||||
} else {
|
||||
p.apiServerURL = u.String()
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/gorilla/websocket"
|
||||
agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
|
||||
"github.com/k3s-io/k3s/pkg/agent/loadbalancer"
|
||||
"github.com/k3s-io/k3s/pkg/agent/proxy"
|
||||
daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
|
||||
"github.com/k3s-io/k3s/pkg/util"
|
||||
|
@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
|
|||
if _, ok := disconnect[address]; !ok {
|
||||
conn := a.connect(ctx, wg, address, tlsConfig)
|
||||
disconnect[address] = conn.cancel
|
||||
proxy.SetHealthCheck(address, conn.connected)
|
||||
proxy.SetHealthCheck(address, conn.healthCheck)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
|
|||
if _, ok := disconnect[address]; !ok {
|
||||
conn := a.connect(ctx, nil, address, tlsConfig)
|
||||
disconnect[address] = conn.cancel
|
||||
proxy.SetHealthCheck(address, conn.connected)
|
||||
proxy.SetHealthCheck(address, conn.healthCheck)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -427,22 +428,20 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
|
|||
}
|
||||
|
||||
type agentConnection struct {
|
||||
cancel context.CancelFunc
|
||||
connected func() bool
|
||||
cancel context.CancelFunc
|
||||
healthCheck loadbalancer.HealthCheckFunc
|
||||
}
|
||||
|
||||
// connect initiates a connection to the remotedialer server. Incoming dial requests from
|
||||
// the server will be checked by the authorizer function prior to being fulfilled.
|
||||
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
|
||||
var status loadbalancer.HealthCheckResult
|
||||
|
||||
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
|
||||
ws := &websocket.Dialer{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
|
||||
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
|
||||
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
|
||||
connected := true
|
||||
|
||||
once := sync.Once{}
|
||||
if waitGroup != nil {
|
||||
waitGroup.Add(1)
|
||||
|
@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
|
|||
}
|
||||
|
||||
onConnect := func(_ context.Context, _ *remotedialer.Session) error {
|
||||
connected = true
|
||||
status = loadbalancer.HealthCheckResultOK
|
||||
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
|
||||
if waitGroup != nil {
|
||||
once.Do(waitGroup.Done)
|
||||
|
@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
|
|||
for {
|
||||
// ConnectToProxy blocks until error or context cancellation
|
||||
err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect)
|
||||
connected = false
|
||||
status = loadbalancer.HealthCheckResultFailed
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...")
|
||||
// wait between reconnection attempts to avoid hammering the server
|
||||
|
@ -484,8 +483,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
|
|||
}()
|
||||
|
||||
return agentConnection{
|
||||
cancel: cancel,
|
||||
connected: func() bool { return connected },
|
||||
cancel: cancel,
|
||||
healthCheck: func() loadbalancer.HealthCheckResult {
|
||||
return status
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,6 +48,10 @@ type etcdproxy struct {
|
|||
}
|
||||
|
||||
func (e *etcdproxy) Update(addresses []string) {
|
||||
if e.etcdLB == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.etcdLB.Update(addresses)
|
||||
|
||||
validEndpoint := map[string]bool{}
|
||||
|
@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) {
|
|||
|
||||
// start a polling routine that makes periodic requests to the etcd node's supervisor port.
|
||||
// If the request fails, the node is marked unhealthy.
|
||||
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool {
|
||||
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
|
||||
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
|
||||
connected := true
|
||||
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc {
|
||||
var status loadbalancer.HealthCheckResult
|
||||
|
||||
host, _, _ := net.SplitHostPort(address)
|
||||
url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort)))
|
||||
|
@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func()
|
|||
}
|
||||
if err != nil || statusCode != http.StatusOK {
|
||||
logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode)
|
||||
connected = false
|
||||
status = loadbalancer.HealthCheckResultFailed
|
||||
} else {
|
||||
connected = true
|
||||
status = loadbalancer.HealthCheckResultOK
|
||||
}
|
||||
}, 5*time.Second, 1.0, true)
|
||||
|
||||
return func() bool {
|
||||
return connected
|
||||
return func() loadbalancer.HealthCheckResult {
|
||||
// Reset the status to unknown on reading, until next time it is checked.
|
||||
// This avoids having a health check result alter the server state between active checks.
|
||||
s := status
|
||||
status = loadbalancer.HealthCheckResultUnknown
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue