k3s/pkg/agent/loadbalancer/servers.go

223 lines
8.0 KiB
Go

package loadbalancer
import (
"context"
"math/rand"
"net"
"slices"
"time"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
)
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress)
if len(serverAddresses) == 0 {
return false
}
lb.mutex.Lock()
defer lb.mutex.Unlock()
newAddresses := sets.NewString(serverAddresses...)
curAddresses := sets.NewString(lb.ServerAddresses...)
if newAddresses.Equal(curAddresses) {
return false
}
for addedServer := range newAddresses.Difference(curAddresses) {
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
lb.servers[addedServer] = &server{
address: addedServer,
connections: make(map[net.Conn]struct{}),
healthCheck: func() bool { return true },
}
}
for removedServer := range curAddresses.Difference(newAddresses) {
server := lb.servers[removedServer]
if server != nil {
logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer)
// Defer closing connections until after the new server list has been put into place.
// Closing open connections ensures that anything stuck retrying on a stale server is forced
// over to a valid endpoint.
defer server.closeAll()
// Don't delete the default server from the server map, in case we need to fall back to it.
if removedServer != lb.defaultServerAddress {
delete(lb.servers, removedServer)
}
}
}
lb.ServerAddresses = serverAddresses
lb.randomServers = append([]string{}, lb.ServerAddresses...)
rand.Shuffle(len(lb.randomServers), func(i, j int) {
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
})
// If the current server list does not contain the default server address,
// we want to include it in the random server list so that it can be tried if necessary.
// However, it should be treated as always failing health checks so that it is only
// used if all other endpoints are unavailable.
if !hasDefaultServer {
lb.randomServers = append(lb.randomServers, lb.defaultServerAddress)
if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok {
defaultServer.healthCheck = func() bool { return false }
lb.servers[lb.defaultServerAddress] = defaultServer
}
}
lb.currentServerAddress = lb.randomServers[0]
lb.nextServerIndex = 1
return true
}
// nextServer attempts to get the next server in the loadbalancer server list.
// If another goroutine has already updated the current server address to point at
// a different address than just failed, nothing is changed. Otherwise, a new server address
// is stored to the currentServerAddress field, and returned for use.
// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex.
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
// note: these fields are not protected by the mutex, so we clamp the index value and update
// the index/current address using local variables, to avoid time-of-check vs time-of-use
// race conditions caused by goroutine A incrementing it in between the time goroutine B
// validates its value, and uses it as a list index.
currentServerAddress := lb.currentServerAddress
nextServerIndex := lb.nextServerIndex
if len(lb.randomServers) == 0 {
return "", errors.New("No servers in load balancer proxy list")
}
if len(lb.randomServers) == 1 {
return currentServerAddress, nil
}
if failedServer != currentServerAddress {
return currentServerAddress, nil
}
if nextServerIndex >= len(lb.randomServers) {
nextServerIndex = 0
}
currentServerAddress = lb.randomServers[nextServerIndex]
nextServerIndex++
lb.currentServerAddress = currentServerAddress
lb.nextServerIndex = nextServerIndex
return currentServerAddress, nil
}
// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map
func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := defaultDialer.Dial(network, address)
if err != nil {
return nil, err
}
// Wrap the connection and add it to the server's connection map
s.mutex.Lock()
defer s.mutex.Unlock()
wrappedConn := &serverConn{server: s, Conn: conn}
s.connections[wrappedConn] = struct{}{}
return wrappedConn, nil
}
// closeAll closes all connections to the server, and removes their entries from the map
func (s *server) closeAll() {
s.mutex.Lock()
defer s.mutex.Unlock()
if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
}
}
}
// Close removes the connection entry from the server's connection map, and
// closes the wrapped connection.
func (sc *serverConn) Close() error {
sc.server.mutex.Lock()
defer sc.server.mutex.Unlock()
delete(sc.server.connections, sc)
return sc.Conn.Close()
}
// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one - but
// with a failing health check so that it is only used as a last resort.
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return false },
connections: make(map[net.Conn]struct{}),
}
}
lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}
// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
previousStatus := map[string]bool{}
wait.Until(func() {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
var healthyServerExists bool
for address, server := range lb.servers {
status := server.healthCheck()
healthyServerExists = healthyServerExists || status
if status == false && previousStatus[address] == true {
// Only close connections when the server transitions from healthy to unhealthy;
// we don't want to re-close all the connections every time as we might be ignoring
// health checks due to all servers being marked unhealthy.
defer server.closeAll()
}
previousStatus[address] = status
}
// If there is at least one healthy server, and the default server is not in the server list,
// close all the connections to the default server so that clients reconnect and switch over
// to a preferred server.
hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress)
if healthyServerExists && !hasDefaultServer {
if server, ok := lb.servers[lb.defaultServerAddress]; ok {
defer server.closeAll()
}
}
}, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
}