mirror of https://github.com/k3s-io/k3s
Refactor load balancer server list and health checking
Signed-off-by: Brad Davidson <brad.davidson@rancher.com>pull/11430/head
parent
95797c4a79
commit
911ee19a93
|
@ -15,8 +15,8 @@ type lbConfig struct {
|
||||||
|
|
||||||
func (lb *LoadBalancer) writeConfig() error {
|
func (lb *LoadBalancer) writeConfig() error {
|
||||||
config := &lbConfig{
|
config := &lbConfig{
|
||||||
ServerURL: lb.serverURL,
|
ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
|
||||||
ServerAddresses: lb.serverAddresses,
|
ServerAddresses: lb.servers.getAddresses(),
|
||||||
}
|
}
|
||||||
configOut, err := json.MarshalIndent(config, "", " ")
|
configOut, err := json.MarshalIndent(config, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lb *LoadBalancer) updateConfig() error {
|
func (lb *LoadBalancer) updateConfig() error {
|
||||||
writeConfig := true
|
|
||||||
if configBytes, err := os.ReadFile(lb.configFile); err == nil {
|
if configBytes, err := os.ReadFile(lb.configFile); err == nil {
|
||||||
config := &lbConfig{}
|
config := &lbConfig{}
|
||||||
if err := json.Unmarshal(configBytes, config); err == nil {
|
if err := json.Unmarshal(configBytes, config); err == nil {
|
||||||
if config.ServerURL == lb.serverURL {
|
// if the default server from the config matches our current default,
|
||||||
writeConfig = false
|
// load the rest of the addresses as well.
|
||||||
lb.setServers(config.ServerAddresses)
|
if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
|
||||||
}
|
lb.Update(config.ServerAddresses)
|
||||||
}
|
|
||||||
}
|
|
||||||
if writeConfig {
|
|
||||||
if err := lb.writeConfig(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// config didn't exist or used a different default server, write the current config to disk.
|
||||||
|
return lb.writeConfig()
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error {
|
||||||
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
|
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
|
||||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||||
// Create a new HTTP proxy dialer
|
// Create a new HTTP proxy dialer
|
||||||
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer)))
|
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer)))
|
||||||
return httpProxyDialer, nil
|
return httpProxyDialer, nil
|
||||||
} else if proxyURL.Scheme == "socks5" {
|
} else if proxyURL.Scheme == "socks5" {
|
||||||
// For SOCKS5 proxies, use the proxy package's FromURL
|
// For SOCKS5 proxies, use the proxy package's FromURL
|
||||||
|
|
|
@ -2,15 +2,16 @@ package loadbalancer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/k3s-io/k3s/pkg/version"
|
"github.com/k3s-io/k3s/pkg/version"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var originalDialer proxy.Dialer
|
||||||
var defaultEnv map[string]string
|
var defaultEnv map[string]string
|
||||||
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}
|
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}
|
||||||
|
|
||||||
|
@ -19,7 +20,7 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareEnv(env ...string) {
|
func prepareEnv(env ...string) {
|
||||||
defaultDialer = &net.Dialer{}
|
originalDialer = defaultDialer
|
||||||
defaultEnv = map[string]string{}
|
defaultEnv = map[string]string{}
|
||||||
for _, e := range proxyEnvs {
|
for _, e := range proxyEnvs {
|
||||||
if v, ok := os.LookupEnv(e); ok {
|
if v, ok := os.LookupEnv(e); ok {
|
||||||
|
@ -34,6 +35,7 @@ func prepareEnv(env ...string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func restoreEnv() {
|
func restoreEnv() {
|
||||||
|
defaultDialer = originalDialer
|
||||||
for _, e := range proxyEnvs {
|
for _, e := range proxyEnvs {
|
||||||
if v, ok := defaultEnv[e]; ok {
|
if v, ok := defaultEnv[e]; ok {
|
||||||
os.Setenv(e, v)
|
os.Setenv(e, v)
|
||||||
|
|
|
@ -2,55 +2,29 @@ package loadbalancer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/inetaf/tcpproxy"
|
"github.com/inetaf/tcpproxy"
|
||||||
"github.com/k3s-io/k3s/pkg/version"
|
"github.com/k3s-io/k3s/pkg/version"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// server tracks the connections to a server, so that they can be closed when the server is removed.
|
|
||||||
type server struct {
|
|
||||||
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
|
|
||||||
mutex sync.Mutex
|
|
||||||
address string
|
|
||||||
healthCheck func() bool
|
|
||||||
connections map[net.Conn]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
|
|
||||||
type serverConn struct {
|
|
||||||
server *server
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadBalancer holds data for a local listener which forwards connections to a
|
// LoadBalancer holds data for a local listener which forwards connections to a
|
||||||
// pool of remote servers. It is not a proper load-balancer in that it does not
|
// pool of remote servers. It is not a proper load-balancer in that it does not
|
||||||
// actually balance connections, but instead fails over to a new server only
|
// actually balance connections, but instead fails over to a new server only
|
||||||
// when a connection attempt to the currently selected server fails.
|
// when a connection attempt to the currently selected server fails.
|
||||||
type LoadBalancer struct {
|
type LoadBalancer struct {
|
||||||
// This mutex protects access to servers map and randomServers list.
|
|
||||||
// All direct access to the servers map/list should be protected by it.
|
|
||||||
mutex sync.RWMutex
|
|
||||||
proxy *tcpproxy.Proxy
|
|
||||||
|
|
||||||
serviceName string
|
serviceName string
|
||||||
configFile string
|
configFile string
|
||||||
|
scheme string
|
||||||
localAddress string
|
localAddress string
|
||||||
localServerURL string
|
servers serverList
|
||||||
defaultServerAddress string
|
proxy *tcpproxy.Proxy
|
||||||
serverURL string
|
|
||||||
serverAddresses []string
|
|
||||||
randomServers []string
|
|
||||||
servers map[string]*server
|
|
||||||
currentServerAddress string
|
|
||||||
nextServerIndex int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const RandomPort = 0
|
const RandomPort = 0
|
||||||
|
@ -63,7 +37,7 @@ var (
|
||||||
|
|
||||||
// New contstructs a new LoadBalancer instance. The default server URL, and
|
// New contstructs a new LoadBalancer instance. The default server URL, and
|
||||||
// currently active servers, are stored in a file within the dataDir.
|
// currently active servers, are stored in a file within the dataDir.
|
||||||
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
|
func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
|
||||||
config := net.ListenConfig{Control: reusePort}
|
config := net.ListenConfig{Control: reusePort}
|
||||||
var localAddress string
|
var localAddress string
|
||||||
if isIPv6 {
|
if isIPv6 {
|
||||||
|
@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
|
serverURL, err := url.Parse(defaultServerURL)
|
||||||
localAddress = listener.Addr().String()
|
|
||||||
|
|
||||||
defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverURL == localServerURL {
|
// Set explicit port from scheme
|
||||||
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
|
if serverURL.Port() == "" {
|
||||||
defaultServerAddress = ""
|
if strings.ToLower(serverURL.Scheme) == "http" {
|
||||||
|
serverURL.Host += ":80"
|
||||||
|
}
|
||||||
|
if strings.ToLower(serverURL.Scheme) == "https" {
|
||||||
|
serverURL.Host += ":443"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lb := &LoadBalancer{
|
lb := &LoadBalancer{
|
||||||
serviceName: serviceName,
|
serviceName: serviceName,
|
||||||
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
|
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
|
||||||
localAddress: localAddress,
|
scheme: serverURL.Scheme,
|
||||||
localServerURL: localServerURL,
|
localAddress: listener.Addr().String(),
|
||||||
defaultServerAddress: defaultServerAddress,
|
|
||||||
servers: make(map[string]*server),
|
|
||||||
serverURL: serverURL,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lb.setServers([]string{lb.defaultServerAddress})
|
// if starting pointing at ourselves, don't set a default server address,
|
||||||
|
// which will cause all dials to fail until servers are added.
|
||||||
|
if serverURL.Host == lb.localAddress {
|
||||||
|
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
|
||||||
|
} else {
|
||||||
|
lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host)
|
||||||
|
}
|
||||||
|
|
||||||
lb.proxy = &tcpproxy.Proxy{
|
lb.proxy = &tcpproxy.Proxy{
|
||||||
ListenFunc: func(string, string) (net.Listener, error) {
|
ListenFunc: func(string, string) (net.Listener, error) {
|
||||||
|
@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
||||||
}
|
}
|
||||||
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
|
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
|
||||||
Addr: serviceName,
|
Addr: serviceName,
|
||||||
DialContext: lb.dialContext,
|
DialContext: lb.servers.dialContext,
|
||||||
OnDialError: onDialError,
|
OnDialError: onDialError,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
|
||||||
if err := lb.proxy.Start(); err != nil {
|
if err := lb.proxy.Start(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress)
|
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
|
||||||
|
|
||||||
go lb.runHealthChecks(ctx)
|
go lb.servers.runHealthChecks(ctx, lb.serviceName)
|
||||||
|
|
||||||
return lb, nil
|
return lb, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update updates the list of server addresses to contain only the listed servers.
|
||||||
func (lb *LoadBalancer) Update(serverAddresses []string) {
|
func (lb *LoadBalancer) Update(serverAddresses []string) {
|
||||||
if lb == nil {
|
if !lb.servers.setAddresses(lb.serviceName, serverAddresses) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !lb.setServers(serverAddresses) {
|
|
||||||
return
|
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())
|
||||||
}
|
|
||||||
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)
|
|
||||||
|
|
||||||
if err := lb.writeConfig(); err != nil {
|
if err := lb.writeConfig(); err != nil {
|
||||||
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lb *LoadBalancer) LoadBalancerServerURL() string {
|
// SetDefault sets the selected address as the default / fallback address
|
||||||
if lb == nil {
|
func (lb *LoadBalancer) SetDefault(serverAddress string) {
|
||||||
return ""
|
lb.servers.setDefaultAddress(lb.serviceName, serverAddress)
|
||||||
|
|
||||||
|
if err := lb.writeConfig(); err != nil {
|
||||||
|
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
|
||||||
}
|
}
|
||||||
return lb.localServerURL
|
}
|
||||||
|
|
||||||
|
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
|
||||||
|
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) {
|
||||||
|
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
|
||||||
|
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
|
||||||
|
} else {
|
||||||
|
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lb *LoadBalancer) LocalURL() string {
|
||||||
|
return lb.scheme + "://" + lb.localAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (lb *LoadBalancer) ServerAddresses() []string {
|
func (lb *LoadBalancer) ServerAddresses() []string {
|
||||||
if lb == nil {
|
return lb.servers.getAddresses()
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return lb.serverAddresses
|
|
||||||
}
|
|
||||||
|
|
||||||
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
|
|
||||||
lb.mutex.RLock()
|
|
||||||
defer lb.mutex.RUnlock()
|
|
||||||
|
|
||||||
var allChecksFailed bool
|
|
||||||
startIndex := lb.nextServerIndex
|
|
||||||
for {
|
|
||||||
targetServer := lb.currentServerAddress
|
|
||||||
|
|
||||||
server := lb.servers[targetServer]
|
|
||||||
if server == nil || targetServer == "" {
|
|
||||||
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
|
|
||||||
} else if allChecksFailed || server.healthCheck() {
|
|
||||||
dialTime := time.Now()
|
|
||||||
conn, err := server.dialContext(ctx, network, targetServer)
|
|
||||||
if err == nil {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
|
|
||||||
// Don't close connections to the failed server if we're retrying with health checks ignored.
|
|
||||||
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
|
|
||||||
if !allChecksFailed {
|
|
||||||
defer server.closeAll()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logrus.Debugf("Dial health check failed for %s", targetServer)
|
|
||||||
}
|
|
||||||
|
|
||||||
newServer, err := lb.nextServer(targetServer)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if targetServer != newServer {
|
|
||||||
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
|
|
||||||
}
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
maxIndex := len(lb.randomServers)
|
|
||||||
if startIndex > maxIndex {
|
|
||||||
startIndex = maxIndex
|
|
||||||
}
|
|
||||||
if lb.nextServerIndex == startIndex {
|
|
||||||
if allChecksFailed {
|
|
||||||
return nil, errors.New("all servers failed")
|
|
||||||
}
|
|
||||||
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
|
|
||||||
allChecksFailed = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func onDialError(src net.Conn, dstDialErr error) {
|
func onDialError(src net.Conn, dstDialErr error) {
|
||||||
|
@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetLoadBalancer will delete the local state file for the load balancer on disk
|
// ResetLoadBalancer will delete the local state file for the load balancer on disk
|
||||||
func ResetLoadBalancer(dataDir, serviceName string) error {
|
func ResetLoadBalancer(dataDir, serviceName string) {
|
||||||
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
|
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
|
||||||
if err := os.Remove(stateFile); err != nil {
|
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
|
||||||
logrus.Warn(err)
|
logrus.Warn(err)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,19 +5,29 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func Test_UnitLoadBalancer(t *testing.T) {
|
||||||
|
_, reporterConfig := GinkgoConfiguration()
|
||||||
|
reporterConfig.Verbose = testing.Verbose()
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "LoadBalancer Suite", reporterConfig)
|
||||||
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
logrus.SetLevel(logrus.DebugLevel)
|
logrus.SetLevel(logrus.DebugLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
type testServer struct {
|
type testServer struct {
|
||||||
|
address string
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
conns []net.Conn
|
conns []net.Conn
|
||||||
prefix string
|
prefix string
|
||||||
|
@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) {
|
||||||
s := &testServer{
|
s := &testServer{
|
||||||
prefix: prefix,
|
prefix: prefix,
|
||||||
listener: listener,
|
listener: listener,
|
||||||
|
address: listener.Addr().String(),
|
||||||
}
|
}
|
||||||
go s.serve()
|
go s.serve()
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -53,6 +64,7 @@ func (s *testServer) serve() {
|
||||||
|
|
||||||
func (s *testServer) close() {
|
func (s *testServer) close() {
|
||||||
logrus.Printf("testServer %s closing", s.prefix)
|
logrus.Printf("testServer %s closing", s.prefix)
|
||||||
|
s.address = ""
|
||||||
s.listener.Close()
|
s.listener.Close()
|
||||||
for _, conn := range s.conns {
|
for _, conn := range s.conns {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testServer) address() string {
|
|
||||||
return s.listener.Addr().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func ping(conn net.Conn) (string, error) {
|
func ping(conn net.Conn) (string, error) {
|
||||||
fmt.Fprintf(conn, "ping\n")
|
fmt.Fprintf(conn, "ping\n")
|
||||||
result, err := bufio.NewReader(conn).ReadString('\n')
|
result, err := bufio.NewReader(conn).ReadString('\n')
|
||||||
|
@ -82,166 +90,285 @@ func ping(conn net.Conn) (string, error) {
|
||||||
return strings.TrimSpace(result), nil
|
return strings.TrimSpace(result), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint)
|
var _ = Describe("LoadBalancer", func() {
|
||||||
// and then adds a new server (a node). The node server is then closed, and it is confirmed
|
// creates a LB using a default server (ie fixed registration endpoint)
|
||||||
// that new connections use the default server.
|
// and then adds a new server (a node). The node server is then closed, and it is confirmed
|
||||||
func Test_UnitFailOver(t *testing.T) {
|
// that new connections use the default server.
|
||||||
tmpDir := t.TempDir()
|
When("loadbalancer is running", Ordered, func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
var defaultServer, node1Server, node2Server *testServer
|
||||||
|
var conn1, conn2, conn3, conn4 net.Conn
|
||||||
|
var lb *LoadBalancer
|
||||||
|
var err error
|
||||||
|
|
||||||
defaultServer, err := createServer(ctx, "default")
|
BeforeAll(func() {
|
||||||
if err != nil {
|
tmpDir := GinkgoT().TempDir()
|
||||||
t.Fatalf("createServer(default) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
node1Server, err := createServer(ctx, "node1")
|
defaultServer, err = createServer(ctx, "default")
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
|
||||||
t.Fatalf("createServer(node1) failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
node2Server, err := createServer(ctx, "node2")
|
node1Server, err = createServer(ctx, "node1")
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed")
|
||||||
t.Fatalf("createServer(node2) failed: %v", err)
|
|
||||||
}
|
node2Server, err = createServer(ctx, "node2")
|
||||||
|
Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed")
|
||||||
|
|
||||||
// start the loadbalancer with the default server as the only server
|
// start the loadbalancer with the default server as the only server
|
||||||
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false)
|
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false)
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||||
t.Fatalf("New() failed: %v", err)
|
})
|
||||||
}
|
|
||||||
|
|
||||||
parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
|
AfterAll(func() {
|
||||||
if err != nil {
|
cancel()
|
||||||
t.Fatalf("url.Parse failed: %v", err)
|
})
|
||||||
}
|
|
||||||
localAddress := parsedURL.Host
|
|
||||||
|
|
||||||
|
It("adds node1 as a server", func() {
|
||||||
// add the node as a new server address.
|
// add the node as a new server address.
|
||||||
lb.Update([]string{node1Server.address()})
|
lb.Update([]string{node1Server.address})
|
||||||
|
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||||
|
|
||||||
|
By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers()))
|
||||||
|
|
||||||
|
// wait for state to change
|
||||||
|
Eventually(func() state {
|
||||||
|
if s := lb.servers.getServer(node1Server.address); s != nil {
|
||||||
|
return s.state
|
||||||
|
}
|
||||||
|
return stateInvalid
|
||||||
|
}, 5, 1).Should(Equal(statePreferred))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("connects to node1", func() {
|
||||||
// make sure connections go to the node
|
// make sure connections go to the node
|
||||||
conn1, err := net.Dial("tcp", localAddress)
|
conn1, err = net.Dial("tcp", lb.localAddress)
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
t.Fatalf("net.Dial failed: %v", err)
|
Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result")
|
||||||
}
|
|
||||||
if result, err := ping(conn1); err != nil {
|
|
||||||
t.Fatalf("ping(conn1) failed: %v", err)
|
|
||||||
} else if result != "node1:ping" {
|
|
||||||
t.Fatalf("Unexpected ping(conn1) result: %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn1 tested OK")
|
By("conn1 tested OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("changes node1 state to failed", func() {
|
||||||
// set failing health check for node 1
|
// set failing health check for node 1
|
||||||
lb.SetHealthCheck(node1Server.address(), func() bool { return false })
|
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed })
|
||||||
|
|
||||||
|
// wait for state to change
|
||||||
|
Eventually(func() state {
|
||||||
|
if s := lb.servers.getServer(node1Server.address); s != nil {
|
||||||
|
return s.state
|
||||||
|
}
|
||||||
|
return stateInvalid
|
||||||
|
}, 5, 1).Should(Equal(stateFailed))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("disconnects from node1", func() {
|
||||||
// Server connections are checked every second, now that node 1 is failed
|
// Server connections are checked every second, now that node 1 is failed
|
||||||
// the connections to it should be closed.
|
// the connections to it should be closed.
|
||||||
time.Sleep(2 * time.Second)
|
Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
|
||||||
|
|
||||||
if _, err := ping(conn1); err == nil {
|
By("conn1 closed on failure OK")
|
||||||
t.Fatal("Unexpected successful ping on closed connection conn1")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn1 closed on failure OK")
|
// connections shoould go to the default now that node 1 is failed
|
||||||
|
conn2, err = net.Dial("tcp", lb.localAddress)
|
||||||
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
|
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
|
||||||
|
|
||||||
// make sure connection still goes to the first node - it is failing health checks but so
|
By("conn2 tested OK")
|
||||||
// is the default endpoint, so it should be tried first with health checks disabled,
|
})
|
||||||
// before failing back to the default.
|
|
||||||
conn2, err := net.Dial("tcp", localAddress)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("net.Dial failed: %v", err)
|
|
||||||
|
|
||||||
}
|
|
||||||
if result, err := ping(conn2); err != nil {
|
|
||||||
t.Fatalf("ping(conn2) failed: %v", err)
|
|
||||||
} else if result != "node1:ping" {
|
|
||||||
t.Fatalf("Unexpected ping(conn2) result: %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn2 tested OK")
|
|
||||||
|
|
||||||
|
It("does not close connections unexpectedly", func() {
|
||||||
// make sure the health checks don't close the connection we just made -
|
// make sure the health checks don't close the connection we just made -
|
||||||
// connections should only be closed when it transitions from health to unhealthy.
|
// connections should only be closed when it transitions from health to unhealthy.
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
if result, err := ping(conn2); err != nil {
|
Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result")
|
||||||
t.Fatalf("ping(conn2) failed: %v", err)
|
|
||||||
} else if result != "node1:ping" {
|
|
||||||
t.Fatalf("Unexpected ping(conn2) result: %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn2 tested OK again")
|
By("conn2 tested OK again")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("closes connections when dial fails", func() {
|
||||||
// shut down the first node server to force failover to the default
|
// shut down the first node server to force failover to the default
|
||||||
node1Server.close()
|
node1Server.close()
|
||||||
|
|
||||||
// make sure new connections go to the default, and existing connections are closed
|
// make sure new connections go to the default, and existing connections are closed
|
||||||
conn3, err := net.Dial("tcp", localAddress)
|
conn3, err = net.Dial("tcp", lb.localAddress)
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
t.Fatalf("net.Dial failed: %v", err)
|
|
||||||
|
|
||||||
}
|
Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result")
|
||||||
if result, err := ping(conn3); err != nil {
|
|
||||||
t.Fatalf("ping(conn3) failed: %v", err)
|
|
||||||
} else if result != "default:ping" {
|
|
||||||
t.Fatalf("Unexpected ping(conn3) result: %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn3 tested OK")
|
By("conn3 tested OK")
|
||||||
|
})
|
||||||
if _, err := ping(conn2); err == nil {
|
|
||||||
t.Fatal("Unexpected successful ping on closed connection conn2")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn2 closed on failure OK")
|
|
||||||
|
|
||||||
|
It("replaces node2 as a server", func() {
|
||||||
// add the second node as a new server address.
|
// add the second node as a new server address.
|
||||||
lb.Update([]string{node2Server.address()})
|
lb.Update([]string{node2Server.address})
|
||||||
|
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||||
|
|
||||||
|
By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers()))
|
||||||
|
|
||||||
|
// wait for state to change
|
||||||
|
Eventually(func() state {
|
||||||
|
if s := lb.servers.getServer(node2Server.address); s != nil {
|
||||||
|
return s.state
|
||||||
|
}
|
||||||
|
return stateInvalid
|
||||||
|
}, 5, 1).Should(Equal(statePreferred))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("connects to node2", func() {
|
||||||
// make sure connection now goes to the second node,
|
// make sure connection now goes to the second node,
|
||||||
// and connections to the default are closed.
|
// and connections to the default are closed.
|
||||||
conn4, err := net.Dial("tcp", localAddress)
|
conn4, err = net.Dial("tcp", lb.localAddress)
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
t.Fatalf("net.Dial failed: %v", err)
|
|
||||||
|
|
||||||
}
|
Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result")
|
||||||
if result, err := ping(conn4); err != nil {
|
|
||||||
t.Fatalf("ping(conn4) failed: %v", err)
|
|
||||||
} else if result != "node2:ping" {
|
|
||||||
t.Fatalf("Unexpected ping(conn4) result: %v", result)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("conn4 tested OK")
|
By("conn4 tested OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not close connections unexpectedly", func() {
|
||||||
// Server connections are checked every second, now that we have a healthy
|
// Server connections are checked every second, now that we have a healthy
|
||||||
// server, connections to the default server should be closed
|
// server, connections to the default server should be closed
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
|
|
||||||
if _, err := ping(conn3); err == nil {
|
Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
|
||||||
t.Fatal("Unexpected successful ping on connection conn3")
|
|
||||||
|
By("conn2 closed on failure OK")
|
||||||
|
|
||||||
|
Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1")
|
||||||
|
|
||||||
|
By("conn3 closed on failure OK")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("adds default as a server", func() {
|
||||||
|
// add the default as a full server
|
||||||
|
lb.Update([]string{node2Server.address, defaultServer.address})
|
||||||
|
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||||
|
|
||||||
|
// wait for state to change
|
||||||
|
Eventually(func() state {
|
||||||
|
if s := lb.servers.getServer(defaultServer.address); s != nil {
|
||||||
|
return s.state
|
||||||
}
|
}
|
||||||
|
return stateInvalid
|
||||||
|
}, 5, 1).Should(Equal(statePreferred))
|
||||||
|
|
||||||
t.Log("conn3 closed on failure OK")
|
By(fmt.Sprintf("Default server added: %v", lb.servers.getServers()))
|
||||||
}
|
})
|
||||||
|
|
||||||
// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly
|
It("returns the default server in the address list", func() {
|
||||||
func Test_UnitFailFast(t *testing.T) {
|
// confirm that both servers are listed in the address list
|
||||||
tmpDir := t.TempDir()
|
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address))
|
||||||
|
|
||||||
|
// confirm that the default is still listed as default
|
||||||
|
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
It("does not return the default server in the address list after removing it", func() {
|
||||||
|
// remove the default as a server
|
||||||
|
lb.Update([]string{node2Server.address})
|
||||||
|
By(fmt.Sprintf("Default removed: %v", lb.servers.getServers()))
|
||||||
|
|
||||||
|
// confirm that it is not listed as a server
|
||||||
|
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
|
||||||
|
|
||||||
|
// but is still listed as the default
|
||||||
|
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("removes default server when no longer default", func() {
|
||||||
|
// set node2 as the default
|
||||||
|
lb.SetDefault(node2Server.address)
|
||||||
|
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
|
||||||
|
|
||||||
|
// confirm that it is still listed as a server
|
||||||
|
Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address))
|
||||||
|
|
||||||
|
// and is listed as the default
|
||||||
|
Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default")
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets all three servers", func() {
|
||||||
|
// set node2 as the default
|
||||||
|
lb.SetDefault(defaultServer.address)
|
||||||
|
By(fmt.Sprintf("Default set: %v", lb.servers.getServers()))
|
||||||
|
|
||||||
|
lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address})
|
||||||
|
lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||||
|
lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||||
|
lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK })
|
||||||
|
|
||||||
|
// wait for state to change
|
||||||
|
Eventually(func() state {
|
||||||
|
if s := lb.servers.getServer(defaultServer.address); s != nil {
|
||||||
|
return s.state
|
||||||
|
}
|
||||||
|
return stateInvalid
|
||||||
|
}, 5, 1).Should(Equal(statePreferred))
|
||||||
|
|
||||||
|
By(fmt.Sprintf("All servers set: %v", lb.servers.getServers()))
|
||||||
|
|
||||||
|
// confirm that it is still listed as a server
|
||||||
|
Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address))
|
||||||
|
|
||||||
|
// and is listed as the default
|
||||||
|
Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// confirms that the loadbalancer will not dial itself
|
||||||
|
When("the default server is the loadbalancer", Ordered, func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
var defaultServer *testServer
|
||||||
|
var lb *LoadBalancer
|
||||||
|
var err error
|
||||||
|
|
||||||
serverURL := "http://127.0.0.1:0/"
|
BeforeAll(func() {
|
||||||
lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false)
|
tmpDir := GinkgoT().TempDir()
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("New() failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
defaultServer, err = createServer(ctx, "default")
|
||||||
|
Expect(err).NotTo(HaveOccurred(), "createServer(default) failed")
|
||||||
|
address := defaultServer.address
|
||||||
|
defaultServer.close()
|
||||||
|
_, port, _ := net.SplitHostPort(address)
|
||||||
|
intPort, _ := strconv.Atoi(port)
|
||||||
|
|
||||||
|
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false)
|
||||||
|
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterAll(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("fails immediately", func() {
|
||||||
conn, err := net.Dial("tcp", lb.localAddress)
|
conn, err := net.Dial("tcp", lb.localAddress)
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
t.Fatalf("net.Dial failed: %v", err)
|
|
||||||
}
|
_, err = ping(conn)
|
||||||
|
Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// confirms that connnections to invalid addresses fail quickly
|
||||||
|
When("there are no valid addresses", Ordered, func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var lb *LoadBalancer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
BeforeAll(func() {
|
||||||
|
tmpDir := GinkgoT().TempDir()
|
||||||
|
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false)
|
||||||
|
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterAll(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("fails fast", func() {
|
||||||
|
conn, err := net.Dial("tcp", lb.localAddress)
|
||||||
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
|
|
||||||
done := make(chan error)
|
done := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -253,36 +380,34 @@ func Test_UnitFailFast(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Unexpected successful ping from invalid address")
|
Fail("Unexpected successful ping from invalid address")
|
||||||
}
|
}
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
t.Fatal("Test timed out")
|
Fail("Test timed out")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail
|
// confirms that connnections to unreachable addresses do fail within the
|
||||||
// within the expected duration
|
// expected duration
|
||||||
func Test_UnitFailUnreachable(t *testing.T) {
|
When("the server is unreachable", Ordered, func() {
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("skipping slow test in short mode.")
|
|
||||||
}
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
var lb *LoadBalancer
|
||||||
|
var err error
|
||||||
|
|
||||||
serverAddr := "192.0.2.1:6443"
|
BeforeAll(func() {
|
||||||
lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false)
|
tmpDir := GinkgoT().TempDir()
|
||||||
if err != nil {
|
lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false)
|
||||||
t.Fatalf("New() failed: %v", err)
|
Expect(err).NotTo(HaveOccurred(), "New() failed")
|
||||||
}
|
})
|
||||||
|
|
||||||
// Set failing health check to reduce retries
|
AfterAll(func() {
|
||||||
lb.SetHealthCheck(serverAddr, func() bool { return false })
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("fails with the correct timeout", func() {
|
||||||
conn, err := net.Dial("tcp", lb.localAddress)
|
conn, err := net.Dial("tcp", lb.localAddress)
|
||||||
if err != nil {
|
Expect(err).NotTo(HaveOccurred(), "net.Dial failed")
|
||||||
t.Fatalf("net.Dial failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan error)
|
done := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -294,9 +419,11 @@ func Test_UnitFailUnreachable(t *testing.T) {
|
||||||
select {
|
select {
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("Unexpected successful ping from unreachable address")
|
Fail("Unexpected successful ping from unreachable address")
|
||||||
}
|
}
|
||||||
case <-timeout:
|
case <-timeout:
|
||||||
t.Fatal("Test timed out")
|
Fail("Test timed out")
|
||||||
}
|
}
|
||||||
}
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
|
@ -1,118 +1,421 @@
|
||||||
package loadbalancer
|
package loadbalancer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"math/rand"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"k8s.io/apimachinery/pkg/util/sets"
|
"k8s.io/apimachinery/pkg/util/sets"
|
||||||
"k8s.io/apimachinery/pkg/util/wait"
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
|
type HealthCheckFunc func() HealthCheckResult
|
||||||
serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
|
|
||||||
if len(serverAddresses) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
lb.mutex.Lock()
|
// HealthCheckResult indicates the status of a server health check poll.
|
||||||
defer lb.mutex.Unlock()
|
// For health-checks that poll in the background, Unknown should be returned
|
||||||
|
// if a poll has not occurred since the last check.
|
||||||
|
type HealthCheckResult int
|
||||||
|
|
||||||
newAddresses := sets.NewString(serverAddresses...)
|
const (
|
||||||
curAddresses := sets.NewString(lb.serverAddresses...)
|
HealthCheckResultUnknown HealthCheckResult = iota
|
||||||
|
HealthCheckResultFailed
|
||||||
|
HealthCheckResultOK
|
||||||
|
)
|
||||||
|
|
||||||
|
// serverList tracks potential backend servers for use by a loadbalancer.
|
||||||
|
type serverList struct {
|
||||||
|
// This mutex protects access to the server list. All direct access to the list should be protected by it.
|
||||||
|
mutex sync.Mutex
|
||||||
|
servers []*server
|
||||||
|
}
|
||||||
|
|
||||||
|
// setServers updates the server list to contain only the selected addresses.
|
||||||
|
func (sl *serverList) setAddresses(serviceName string, addresses []string) bool {
|
||||||
|
newAddresses := sets.New(addresses...)
|
||||||
|
curAddresses := sets.New(sl.getAddresses()...)
|
||||||
if newAddresses.Equal(curAddresses) {
|
if newAddresses.Equal(curAddresses) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for addedServer := range newAddresses.Difference(curAddresses) {
|
sl.mutex.Lock()
|
||||||
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
|
defer sl.mutex.Unlock()
|
||||||
lb.servers[addedServer] = &server{
|
|
||||||
address: addedServer,
|
var closeAllFuncs []func()
|
||||||
connections: make(map[net.Conn]struct{}),
|
var defaultServer *server
|
||||||
healthCheck: func() bool { return true },
|
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
|
||||||
|
defaultServer = sl.servers[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// add new servers
|
||||||
|
for addedAddress := range newAddresses.Difference(curAddresses) {
|
||||||
|
if defaultServer != nil && defaultServer.address == addedAddress {
|
||||||
|
// make default server go through the same health check promotions as a new server when added
|
||||||
|
logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName)
|
||||||
|
defaultServer.state = stateUnchecked
|
||||||
|
defaultServer.lastTransition = time.Now()
|
||||||
|
} else {
|
||||||
|
s := newServer(addedAddress, false)
|
||||||
|
logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address)
|
||||||
|
sl.servers = append(sl.servers, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for removedServer := range curAddresses.Difference(newAddresses) {
|
// remove old servers
|
||||||
server := lb.servers[removedServer]
|
for removedAddress := range curAddresses.Difference(newAddresses) {
|
||||||
if server != nil {
|
if defaultServer != nil && defaultServer.address == removedAddress {
|
||||||
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer)
|
// demote the default server down to standby, instead of deleting it
|
||||||
// Defer closing connections until after the new server list has been put into place.
|
defaultServer.state = stateStandby
|
||||||
// Closing open connections ensures that anything stuck retrying on a stale server is forced
|
closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll)
|
||||||
// over to a valid endpoint.
|
} else {
|
||||||
defer server.closeAll()
|
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
|
||||||
// Don't delete the default server from the server map, in case we need to fall back to it.
|
if s.address == removedAddress {
|
||||||
if removedServer != lb.defaultServerAddress {
|
logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address)
|
||||||
delete(lb.servers, removedServer)
|
// set state to invalid to prevent server from making additional connections
|
||||||
|
s.state = stateInvalid
|
||||||
|
closeAllFuncs = append(closeAllFuncs, s.closeAll)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
}
|
return false
|
||||||
}
|
|
||||||
|
|
||||||
lb.serverAddresses = serverAddresses
|
|
||||||
lb.randomServers = append([]string{}, lb.serverAddresses...)
|
|
||||||
rand.Shuffle(len(lb.randomServers), func(i, j int) {
|
|
||||||
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
|
|
||||||
})
|
})
|
||||||
// If the current server list does not contain the default server address,
|
|
||||||
// we want to include it in the random server list so that it can be tried if necessary.
|
|
||||||
// However, it should be treated as always failing health checks so that it is only
|
|
||||||
// used if all other endpoints are unavailable.
|
|
||||||
if !hasDefaultServer {
|
|
||||||
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
|
|
||||||
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
|
|
||||||
defaultServer.healthCheck = func() bool { return false }
|
|
||||||
lb.servers[lb.defaultServerAddress] = defaultServer
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
lb.currentServerAddress = lb.randomServers[0]
|
|
||||||
lb.nextServerIndex = 1
|
slices.SortFunc(sl.servers, compareServers)
|
||||||
|
|
||||||
|
// Close all connections to servers that were removed
|
||||||
|
for _, closeAll := range closeAllFuncs {
|
||||||
|
closeAll()
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// nextServer attempts to get the next server in the loadbalancer server list.
|
// getAddresses returns the addresses of all servers.
|
||||||
// If another goroutine has already updated the current server address to point at
|
// If the default server is in standby state, indicating it is only present
|
||||||
// a different address than just failed, nothing is changed. Otherwise, a new server address
|
// because it is the default, it is not returned in this list.
|
||||||
// is stored to the currentServerAddress field, and returned for use.
|
func (sl *serverList) getAddresses() []string {
|
||||||
// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex.
|
sl.mutex.Lock()
|
||||||
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
|
defer sl.mutex.Unlock()
|
||||||
// note: these fields are not protected by the mutex, so we clamp the index value and update
|
|
||||||
// the index/current address using local variables, to avoid time-of-check vs time-of-use
|
|
||||||
// race conditions caused by goroutine A incrementing it in between the time goroutine B
|
|
||||||
// validates its value, and uses it as a list index.
|
|
||||||
currentServerAddress := lb.currentServerAddress
|
|
||||||
nextServerIndex := lb.nextServerIndex
|
|
||||||
|
|
||||||
if len(lb.randomServers) == 0 {
|
addresses := make([]string, 0, len(sl.servers))
|
||||||
return "", errors.New("No servers in load balancer proxy list")
|
for _, s := range sl.servers {
|
||||||
|
if s.isDefault && s.state == stateStandby {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if len(lb.randomServers) == 1 {
|
addresses = append(addresses, s.address)
|
||||||
return currentServerAddress, nil
|
|
||||||
}
|
}
|
||||||
if failedServer != currentServerAddress {
|
return addresses
|
||||||
return currentServerAddress, nil
|
|
||||||
}
|
|
||||||
if nextServerIndex >= len(lb.randomServers) {
|
|
||||||
nextServerIndex = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
currentServerAddress = lb.randomServers[nextServerIndex]
|
|
||||||
nextServerIndex++
|
|
||||||
|
|
||||||
lb.currentServerAddress = currentServerAddress
|
|
||||||
lb.nextServerIndex = nextServerIndex
|
|
||||||
|
|
||||||
return currentServerAddress, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map
|
// setDefault sets the server with the provided address as the default server.
|
||||||
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
// The default flag is cleared on all other servers, and if the server was previously
|
||||||
conn, err := defaultDialer.Dial(network, address)
|
// only kept in the list because it was the default, it is removed.
|
||||||
|
func (sl *serverList) setDefaultAddress(serviceName, address string) {
|
||||||
|
sl.mutex.Lock()
|
||||||
|
defer sl.mutex.Unlock()
|
||||||
|
|
||||||
|
// deal with existing default first
|
||||||
|
sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool {
|
||||||
|
if s.isDefault && s.address != address {
|
||||||
|
s.isDefault = false
|
||||||
|
if s.state == stateStandby {
|
||||||
|
s.state = stateInvalid
|
||||||
|
defer s.closeAll()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
// update or create server with selected address
|
||||||
|
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
|
||||||
|
sl.servers[i].isDefault = true
|
||||||
|
} else {
|
||||||
|
sl.servers = append(sl.servers, newServer(address, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Updated load balancer %s default server: %s", serviceName, address)
|
||||||
|
slices.SortFunc(sl.servers, compareServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDefault returns the address of the default server.
|
||||||
|
func (sl *serverList) getDefaultAddress() string {
|
||||||
|
if s := sl.getDefaultServer(); s != nil {
|
||||||
|
return s.address
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDefault returns the default server.
|
||||||
|
func (sl *serverList) getDefaultServer() *server {
|
||||||
|
sl.mutex.Lock()
|
||||||
|
defer sl.mutex.Unlock()
|
||||||
|
|
||||||
|
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 {
|
||||||
|
return sl.servers[i]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getServers returns a copy of the servers list that can be safely iterated over without holding a lock
|
||||||
|
func (sl *serverList) getServers() []*server {
|
||||||
|
sl.mutex.Lock()
|
||||||
|
defer sl.mutex.Unlock()
|
||||||
|
|
||||||
|
return slices.Clone(sl.servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getServer returns the first server with the specified address
|
||||||
|
func (sl *serverList) getServer(address string) *server {
|
||||||
|
sl.mutex.Lock()
|
||||||
|
defer sl.mutex.Unlock()
|
||||||
|
|
||||||
|
if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 {
|
||||||
|
return sl.servers[i]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setHealthCheck updates the health check function for a server, replacing the
|
||||||
|
// current function.
|
||||||
|
func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error {
|
||||||
|
if s := sl.getServer(address); s != nil {
|
||||||
|
s.healthCheck = healthCheck
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("no server found for %s", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordSuccess records a successful check of a server, either via health-check or dial.
|
||||||
|
// The server's state is adjusted accordingly.
|
||||||
|
func (sl *serverList) recordSuccess(srv *server, r reason) {
|
||||||
|
var new_state state
|
||||||
|
switch srv.state {
|
||||||
|
case stateFailed, stateUnchecked:
|
||||||
|
// dialed or health checked OK once, improve to recovering
|
||||||
|
new_state = stateRecovering
|
||||||
|
case stateRecovering:
|
||||||
|
if r == reasonHealthCheck {
|
||||||
|
// was recovering due to successful dial or first health check, can now improve
|
||||||
|
if len(srv.connections) > 0 {
|
||||||
|
// server accepted connections while recovering, attempt to go straight to active
|
||||||
|
new_state = stateActive
|
||||||
|
} else {
|
||||||
|
// no connections, just make it preferred
|
||||||
|
new_state = statePreferred
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case stateHealthy:
|
||||||
|
if r == reasonDial {
|
||||||
|
// improve from healthy to active by being dialed
|
||||||
|
new_state = stateActive
|
||||||
|
}
|
||||||
|
case statePreferred:
|
||||||
|
if r == reasonDial {
|
||||||
|
// improve from healthy to active by being dialed
|
||||||
|
new_state = stateActive
|
||||||
|
} else {
|
||||||
|
if time.Now().Sub(srv.lastTransition) > time.Minute {
|
||||||
|
// has been preferred for a while without being dialed, demote to healthy
|
||||||
|
new_state = stateHealthy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// no-op if state did not change
|
||||||
|
if new_state == stateInvalid {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle active transition and sort the server list while holding the lock
|
||||||
|
sl.mutex.Lock()
|
||||||
|
defer sl.mutex.Unlock()
|
||||||
|
|
||||||
|
// handle states of other servers when attempting to make this one active
|
||||||
|
if new_state == stateActive {
|
||||||
|
for _, s := range sl.servers {
|
||||||
|
if srv.address == s.address {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch s.state {
|
||||||
|
case stateFailed, stateStandby, stateRecovering, stateHealthy:
|
||||||
|
// close connections to other non-active servers whenever we have a new active server
|
||||||
|
defer s.closeAll()
|
||||||
|
case stateActive:
|
||||||
|
if len(s.connections) > len(srv.connections) {
|
||||||
|
// if there is a currently active server that has more connections than we do,
|
||||||
|
// close our connections and go to preferred instead
|
||||||
|
new_state = statePreferred
|
||||||
|
defer srv.closeAll()
|
||||||
|
} else {
|
||||||
|
// otherwise, close its connections and demote it to preferred
|
||||||
|
s.state = statePreferred
|
||||||
|
defer s.closeAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensure some other routine didn't already make the transition
|
||||||
|
if srv.state == new_state {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Server %s->%s from successful %s", srv, new_state, r)
|
||||||
|
srv.state = new_state
|
||||||
|
srv.lastTransition = time.Now()
|
||||||
|
|
||||||
|
slices.SortFunc(sl.servers, compareServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordFailure records a failed check of a server, either via health-check or dial.
|
||||||
|
// The server's state is adjusted accordingly.
|
||||||
|
func (sl *serverList) recordFailure(srv *server, r reason) {
|
||||||
|
var new_state state
|
||||||
|
switch srv.state {
|
||||||
|
case stateUnchecked, stateRecovering:
|
||||||
|
if r == reasonDial {
|
||||||
|
// only demote from unchecked or recovering if a dial fails, health checks may
|
||||||
|
// continue to fail despite it being dialable. just leave it where it is
|
||||||
|
// and don't close any connections.
|
||||||
|
new_state = stateFailed
|
||||||
|
}
|
||||||
|
case stateHealthy, statePreferred, stateActive:
|
||||||
|
// should not have any connections when in any state other than active or
|
||||||
|
// recovering, but close them all anyway to force failover.
|
||||||
|
defer srv.closeAll()
|
||||||
|
new_state = stateFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// no-op if state did not change
|
||||||
|
if new_state == stateInvalid {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort the server list while holding the lock
|
||||||
|
sl.mutex.Lock()
|
||||||
|
defer sl.mutex.Unlock()
|
||||||
|
|
||||||
|
// ensure some other routine didn't already make the transition
|
||||||
|
if srv.state == new_state {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Infof("Server %s->%s from failed %s", srv, new_state, r)
|
||||||
|
srv.state = new_state
|
||||||
|
srv.lastTransition = time.Now()
|
||||||
|
|
||||||
|
slices.SortFunc(sl.servers, compareServers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// state is possible server health states, in increasing order of preference.
|
||||||
|
// The server list is kept sorted in descending order by this state value.
|
||||||
|
type state int
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateInvalid state = iota
|
||||||
|
stateFailed // failed a health check or dial
|
||||||
|
stateStandby // reserved for use by default server if not in server list
|
||||||
|
stateUnchecked // just added, has not been health checked
|
||||||
|
stateRecovering // successfully health checked once, or dialed when failed
|
||||||
|
stateHealthy // normal state
|
||||||
|
statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance
|
||||||
|
stateActive // currently active server
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s state) String() string {
|
||||||
|
switch s {
|
||||||
|
case stateInvalid:
|
||||||
|
return "INVALID"
|
||||||
|
case stateFailed:
|
||||||
|
return "FAILED"
|
||||||
|
case stateStandby:
|
||||||
|
return "STANDBY"
|
||||||
|
case stateUnchecked:
|
||||||
|
return "UNCHECKED"
|
||||||
|
case stateRecovering:
|
||||||
|
return "RECOVERING"
|
||||||
|
case stateHealthy:
|
||||||
|
return "HEALTHY"
|
||||||
|
case statePreferred:
|
||||||
|
return "PREFERRED"
|
||||||
|
case stateActive:
|
||||||
|
return "ACTIVE"
|
||||||
|
default:
|
||||||
|
return "UNKNOWN"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reason specifies the reason for a successful or failed health report
|
||||||
|
type reason int
|
||||||
|
|
||||||
|
const (
|
||||||
|
reasonDial reason = iota
|
||||||
|
reasonHealthCheck
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r reason) String() string {
|
||||||
|
switch r {
|
||||||
|
case reasonDial:
|
||||||
|
return "dial"
|
||||||
|
case reasonHealthCheck:
|
||||||
|
return "health check"
|
||||||
|
default:
|
||||||
|
return "unknown reason"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// server tracks the connections to a server, so that they can be closed when the server is removed.
|
||||||
|
type server struct {
|
||||||
|
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
|
||||||
|
mutex sync.Mutex
|
||||||
|
address string
|
||||||
|
isDefault bool
|
||||||
|
state state
|
||||||
|
lastTransition time.Time
|
||||||
|
healthCheck HealthCheckFunc
|
||||||
|
connections map[net.Conn]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newServer creates a new server, with a default health check
|
||||||
|
// and default/state fields appropriate for whether or not
|
||||||
|
// the server is a full server, or just a fallback default.
|
||||||
|
func newServer(address string, isDefault bool) *server {
|
||||||
|
state := stateUnchecked
|
||||||
|
if isDefault {
|
||||||
|
state = stateStandby
|
||||||
|
}
|
||||||
|
return &server{
|
||||||
|
address: address,
|
||||||
|
isDefault: isDefault,
|
||||||
|
state: state,
|
||||||
|
lastTransition: time.Now(),
|
||||||
|
healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown },
|
||||||
|
connections: make(map[net.Conn]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) String() string {
|
||||||
|
format := "%s@%s"
|
||||||
|
if s.isDefault {
|
||||||
|
format += "*"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(format, s.address, s.state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map
|
||||||
|
func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) {
|
||||||
|
if s.state == stateInvalid {
|
||||||
|
return nil, fmt.Errorf("server %s is stopping", s.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := defaultDialer.Dial(network, s.address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -132,7 +435,7 @@ func (s *server) closeAll() {
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
if l := len(s.connections); l > 0 {
|
if l := len(s.connections); l > 0 {
|
||||||
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
|
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s)
|
||||||
for conn := range s.connections {
|
for conn := range s.connections {
|
||||||
// Close the connection in a goroutine so that we don't hold the lock while doing so.
|
// Close the connection in a goroutine so that we don't hold the lock while doing so.
|
||||||
go conn.Close()
|
go conn.Close()
|
||||||
|
@ -140,6 +443,12 @@ func (s *server) closeAll() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
|
||||||
|
type serverConn struct {
|
||||||
|
server *server
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
// Close removes the connection entry from the server's connection map, and
|
// Close removes the connection entry from the server's connection map, and
|
||||||
// closes the wrapped connection.
|
// closes the wrapped connection.
|
||||||
func (sc *serverConn) Close() error {
|
func (sc *serverConn) Close() error {
|
||||||
|
@ -150,73 +459,43 @@ func (sc *serverConn) Close() error {
|
||||||
return sc.Conn.Close()
|
return sc.Conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetDefault sets the selected address as the default / fallback address
|
// runHealthChecks periodically health-checks all servers.
|
||||||
func (lb *LoadBalancer) SetDefault(serverAddress string) {
|
func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) {
|
||||||
lb.mutex.Lock()
|
|
||||||
defer lb.mutex.Unlock()
|
|
||||||
|
|
||||||
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
|
|
||||||
// if the old default server is not currently in use, remove it from the server map
|
|
||||||
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
|
|
||||||
defer server.closeAll()
|
|
||||||
delete(lb.servers, lb.defaultServerAddress)
|
|
||||||
}
|
|
||||||
// if the new default server doesn't have an entry in the map, add one - but
|
|
||||||
// with a failing health check so that it is only used as a last resort.
|
|
||||||
if _, ok := lb.servers[serverAddress]; !ok {
|
|
||||||
lb.servers[serverAddress] = &server{
|
|
||||||
address: serverAddress,
|
|
||||||
healthCheck: func() bool { return false },
|
|
||||||
connections: make(map[net.Conn]struct{}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
lb.defaultServerAddress = serverAddress
|
|
||||||
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
|
|
||||||
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
|
|
||||||
lb.mutex.Lock()
|
|
||||||
defer lb.mutex.Unlock()
|
|
||||||
|
|
||||||
if server := lb.servers[address]; server != nil {
|
|
||||||
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
|
|
||||||
server.healthCheck = healthCheck
|
|
||||||
} else {
|
|
||||||
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
|
|
||||||
// connections closed, to force clients to switch over to a healthy server.
|
|
||||||
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
|
|
||||||
previousStatus := map[string]bool{}
|
|
||||||
wait.Until(func() {
|
wait.Until(func() {
|
||||||
lb.mutex.RLock()
|
for _, s := range sl.getServers() {
|
||||||
defer lb.mutex.RUnlock()
|
switch s.healthCheck() {
|
||||||
var healthyServerExists bool
|
case HealthCheckResultOK:
|
||||||
for address, server := range lb.servers {
|
sl.recordSuccess(s, reasonHealthCheck)
|
||||||
status := server.healthCheck()
|
case HealthCheckResultFailed:
|
||||||
healthyServerExists = healthyServerExists || status
|
sl.recordFailure(s, reasonHealthCheck)
|
||||||
if status == false && previousStatus[address] == true {
|
|
||||||
// Only close connections when the server transitions from healthy to unhealthy;
|
|
||||||
// we don't want to re-close all the connections every time as we might be ignoring
|
|
||||||
// health checks due to all servers being marked unhealthy.
|
|
||||||
defer server.closeAll()
|
|
||||||
}
|
|
||||||
previousStatus[address] = status
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is at least one healthy server, and the default server is not in the server list,
|
|
||||||
// close all the connections to the default server so that clients reconnect and switch over
|
|
||||||
// to a preferred server.
|
|
||||||
hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress)
|
|
||||||
if healthyServerExists && !hasDefaultServer {
|
|
||||||
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
|
|
||||||
defer server.closeAll()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, time.Second, ctx.Done())
|
}, time.Second, ctx.Done())
|
||||||
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
|
logrus.Debugf("Stopped health checking for load balancer %s", serviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialContext attemps to dial a connection to a server from the server list.
|
||||||
|
// Success or failure is recorded to ensure that server state is updated appropriately.
|
||||||
|
func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
|
for _, s := range sl.getServers() {
|
||||||
|
dialTime := time.Now()
|
||||||
|
conn, err := s.dialContext(ctx, network)
|
||||||
|
if err == nil {
|
||||||
|
sl.recordSuccess(s, reasonDial)
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err)
|
||||||
|
sl.recordFailure(s, reasonDial)
|
||||||
|
}
|
||||||
|
return nil, errors.New("all servers failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareServers is a comparison function that can be used to sort the server list
|
||||||
|
// so that servers with a more preferred state, or higher number of connections, are ordered first.
|
||||||
|
func compareServers(a, b *server) int {
|
||||||
|
c := cmp.Compare(b.state, a.state)
|
||||||
|
if c == 0 {
|
||||||
|
return cmp.Compare(len(b.connections), len(a.connections))
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ type Proxy interface {
|
||||||
SupervisorAddresses() []string
|
SupervisorAddresses() []string
|
||||||
APIServerURL() string
|
APIServerURL() string
|
||||||
IsAPIServerLBEnabled() bool
|
IsAPIServerLBEnabled() bool
|
||||||
SetHealthCheck(address string, healthCheck func() bool)
|
SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
|
// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
|
||||||
|
@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p.supervisorLB = lb
|
p.supervisorLB = lb
|
||||||
p.supervisorURL = lb.LoadBalancerServerURL()
|
p.supervisorURL = lb.LocalURL()
|
||||||
p.apiServerURL = p.supervisorURL
|
p.apiServerURL = p.supervisorURL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) {
|
||||||
p.supervisorAddresses = supervisorAddresses
|
p.supervisorAddresses = supervisorAddresses
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) {
|
func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) {
|
||||||
if p.supervisorLB != nil {
|
if p.supervisorLB != nil {
|
||||||
p.supervisorLB.SetHealthCheck(address, healthCheck)
|
p.supervisorLB.SetHealthCheck(address, healthCheck)
|
||||||
}
|
}
|
||||||
|
@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.apiServerLB = lb
|
p.apiServerLB = lb
|
||||||
p.apiServerURL = lb.LoadBalancerServerURL()
|
p.apiServerURL = lb.LocalURL()
|
||||||
} else {
|
} else {
|
||||||
p.apiServerURL = u.String()
|
p.apiServerURL = u.String()
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
|
agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
|
||||||
|
"github.com/k3s-io/k3s/pkg/agent/loadbalancer"
|
||||||
"github.com/k3s-io/k3s/pkg/agent/proxy"
|
"github.com/k3s-io/k3s/pkg/agent/proxy"
|
||||||
daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
|
daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config"
|
||||||
"github.com/k3s-io/k3s/pkg/util"
|
"github.com/k3s-io/k3s/pkg/util"
|
||||||
|
@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
|
||||||
if _, ok := disconnect[address]; !ok {
|
if _, ok := disconnect[address]; !ok {
|
||||||
conn := a.connect(ctx, wg, address, tlsConfig)
|
conn := a.connect(ctx, wg, address, tlsConfig)
|
||||||
disconnect[address] = conn.cancel
|
disconnect[address] = conn.cancel
|
||||||
proxy.SetHealthCheck(address, conn.connected)
|
proxy.SetHealthCheck(address, conn.healthCheck)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
|
||||||
if _, ok := disconnect[address]; !ok {
|
if _, ok := disconnect[address]; !ok {
|
||||||
conn := a.connect(ctx, nil, address, tlsConfig)
|
conn := a.connect(ctx, nil, address, tlsConfig)
|
||||||
disconnect[address] = conn.cancel
|
disconnect[address] = conn.cancel
|
||||||
proxy.SetHealthCheck(address, conn.connected)
|
proxy.SetHealthCheck(address, conn.healthCheck)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -428,21 +429,19 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
|
||||||
|
|
||||||
type agentConnection struct {
|
type agentConnection struct {
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
connected func() bool
|
healthCheck loadbalancer.HealthCheckFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// connect initiates a connection to the remotedialer server. Incoming dial requests from
|
// connect initiates a connection to the remotedialer server. Incoming dial requests from
|
||||||
// the server will be checked by the authorizer function prior to being fulfilled.
|
// the server will be checked by the authorizer function prior to being fulfilled.
|
||||||
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
|
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
|
||||||
|
var status loadbalancer.HealthCheckResult
|
||||||
|
|
||||||
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
|
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
|
||||||
ws := &websocket.Dialer{
|
ws := &websocket.Dialer{
|
||||||
TLSClientConfig: tlsConfig,
|
TLSClientConfig: tlsConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
|
|
||||||
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
|
|
||||||
connected := true
|
|
||||||
|
|
||||||
once := sync.Once{}
|
once := sync.Once{}
|
||||||
if waitGroup != nil {
|
if waitGroup != nil {
|
||||||
waitGroup.Add(1)
|
waitGroup.Add(1)
|
||||||
|
@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
onConnect := func(_ context.Context, _ *remotedialer.Session) error {
|
onConnect := func(_ context.Context, _ *remotedialer.Session) error {
|
||||||
connected = true
|
status = loadbalancer.HealthCheckResultOK
|
||||||
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
|
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
|
||||||
if waitGroup != nil {
|
if waitGroup != nil {
|
||||||
once.Do(waitGroup.Done)
|
once.Do(waitGroup.Done)
|
||||||
|
@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
|
||||||
for {
|
for {
|
||||||
// ConnectToProxy blocks until error or context cancellation
|
// ConnectToProxy blocks until error or context cancellation
|
||||||
err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect)
|
err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect)
|
||||||
connected = false
|
status = loadbalancer.HealthCheckResultFailed
|
||||||
if err != nil && !errors.Is(err, context.Canceled) {
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...")
|
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...")
|
||||||
// wait between reconnection attempts to avoid hammering the server
|
// wait between reconnection attempts to avoid hammering the server
|
||||||
|
@ -485,7 +484,9 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
|
||||||
|
|
||||||
return agentConnection{
|
return agentConnection{
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
connected: func() bool { return connected },
|
healthCheck: func() loadbalancer.HealthCheckResult {
|
||||||
|
return status
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,10 @@ type etcdproxy struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *etcdproxy) Update(addresses []string) {
|
func (e *etcdproxy) Update(addresses []string) {
|
||||||
|
if e.etcdLB == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
e.etcdLB.Update(addresses)
|
e.etcdLB.Update(addresses)
|
||||||
|
|
||||||
validEndpoint := map[string]bool{}
|
validEndpoint := map[string]bool{}
|
||||||
|
@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) {
|
||||||
|
|
||||||
// start a polling routine that makes periodic requests to the etcd node's supervisor port.
|
// start a polling routine that makes periodic requests to the etcd node's supervisor port.
|
||||||
// If the request fails, the node is marked unhealthy.
|
// If the request fails, the node is marked unhealthy.
|
||||||
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool {
|
func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc {
|
||||||
// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
|
var status loadbalancer.HealthCheckResult
|
||||||
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
|
|
||||||
connected := true
|
|
||||||
|
|
||||||
host, _, _ := net.SplitHostPort(address)
|
host, _, _ := net.SplitHostPort(address)
|
||||||
url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort)))
|
url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort)))
|
||||||
|
@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func()
|
||||||
}
|
}
|
||||||
if err != nil || statusCode != http.StatusOK {
|
if err != nil || statusCode != http.StatusOK {
|
||||||
logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode)
|
logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode)
|
||||||
connected = false
|
status = loadbalancer.HealthCheckResultFailed
|
||||||
} else {
|
} else {
|
||||||
connected = true
|
status = loadbalancer.HealthCheckResultOK
|
||||||
}
|
}
|
||||||
}, 5*time.Second, 1.0, true)
|
}, 5*time.Second, 1.0, true)
|
||||||
|
|
||||||
return func() bool {
|
return func() loadbalancer.HealthCheckResult {
|
||||||
return connected
|
// Reset the status to unknown on reading, until next time it is checked.
|
||||||
|
// This avoids having a health check result alter the server state between active checks.
|
||||||
|
s := status
|
||||||
|
status = loadbalancer.HealthCheckResultUnknown
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue