package loadbalancer import ( "context" "math/rand" "net" "slices" "time" "github.com/pkg/errors" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" ) func (lb *LoadBalancer) setServers(serverAddresses []string) bool { serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) if len(serverAddresses) == 0 { return false } lb.mutex.Lock() defer lb.mutex.Unlock() newAddresses := sets.NewString(serverAddresses...) curAddresses := sets.NewString(lb.serverAddresses...) if newAddresses.Equal(curAddresses) { return false } for addedServer := range newAddresses.Difference(curAddresses) { logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) lb.servers[addedServer] = &server{ address: addedServer, connections: make(map[net.Conn]struct{}), healthCheck: func() bool { return true }, } } for removedServer := range curAddresses.Difference(newAddresses) { server := lb.servers[removedServer] if server != nil { logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) // Defer closing connections until after the new server list has been put into place. // Closing open connections ensures that anything stuck retrying on a stale server is forced // over to a valid endpoint. defer server.closeAll() // Don't delete the default server from the server map, in case we need to fall back to it. if removedServer != lb.defaultServerAddress { delete(lb.servers, removedServer) } } } lb.serverAddresses = serverAddresses lb.randomServers = append([]string{}, lb.serverAddresses...) rand.Shuffle(len(lb.randomServers), func(i, j int) { lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] }) // If the current server list does not contain the default server address, // we want to include it in the random server list so that it can be tried if necessary. // However, it should be treated as always failing health checks so that it is only // used if all other endpoints are unavailable. if !hasDefaultServer { lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { defaultServer.healthCheck = func() bool { return false } lb.servers[lb.defaultServerAddress] = defaultServer } } lb.currentServerAddress = lb.randomServers[0] lb.nextServerIndex = 1 return true } // nextServer attempts to get the next server in the loadbalancer server list. // If another goroutine has already updated the current server address to point at // a different address than just failed, nothing is changed. Otherwise, a new server address // is stored to the currentServerAddress field, and returned for use. // This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { // note: these fields are not protected by the mutex, so we clamp the index value and update // the index/current address using local variables, to avoid time-of-check vs time-of-use // race conditions caused by goroutine A incrementing it in between the time goroutine B // validates its value, and uses it as a list index. currentServerAddress := lb.currentServerAddress nextServerIndex := lb.nextServerIndex if len(lb.randomServers) == 0 { return "", errors.New("No servers in load balancer proxy list") } if len(lb.randomServers) == 1 { return currentServerAddress, nil } if failedServer != currentServerAddress { return currentServerAddress, nil } if nextServerIndex >= len(lb.randomServers) { nextServerIndex = 0 } currentServerAddress = lb.randomServers[nextServerIndex] nextServerIndex++ lb.currentServerAddress = currentServerAddress lb.nextServerIndex = nextServerIndex return currentServerAddress, nil } // dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { conn, err := defaultDialer.Dial(network, address) if err != nil { return nil, err } // Wrap the connection and add it to the server's connection map s.mutex.Lock() defer s.mutex.Unlock() wrappedConn := &serverConn{server: s, Conn: conn} s.connections[wrappedConn] = struct{}{} return wrappedConn, nil } // closeAll closes all connections to the server, and removes their entries from the map func (s *server) closeAll() { s.mutex.Lock() defer s.mutex.Unlock() if l := len(s.connections); l > 0 { logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) for conn := range s.connections { // Close the connection in a goroutine so that we don't hold the lock while doing so. go conn.Close() } } } // Close removes the connection entry from the server's connection map, and // closes the wrapped connection. func (sc *serverConn) Close() error { sc.server.mutex.Lock() defer sc.server.mutex.Unlock() delete(sc.server.connections, sc) return sc.Conn.Close() } // SetDefault sets the selected address as the default / fallback address func (lb *LoadBalancer) SetDefault(serverAddress string) { lb.mutex.Lock() defer lb.mutex.Unlock() hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) // if the old default server is not currently in use, remove it from the server map if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { defer server.closeAll() delete(lb.servers, lb.defaultServerAddress) } // if the new default server doesn't have an entry in the map, add one - but // with a failing health check so that it is only used as a last resort. if _, ok := lb.servers[serverAddress]; !ok { lb.servers[serverAddress] = &server{ address: serverAddress, healthCheck: func() bool { return false }, connections: make(map[net.Conn]struct{}), } } lb.defaultServerAddress = serverAddress logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) } // SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { lb.mutex.Lock() defer lb.mutex.Unlock() if server := lb.servers[address]; server != nil { logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address) server.healthCheck = healthCheck } else { logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address) } } // runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their // connections closed, to force clients to switch over to a healthy server. func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { previousStatus := map[string]bool{} wait.Until(func() { lb.mutex.RLock() defer lb.mutex.RUnlock() var healthyServerExists bool for address, server := range lb.servers { status := server.healthCheck() healthyServerExists = healthyServerExists || status if status == false && previousStatus[address] == true { // Only close connections when the server transitions from healthy to unhealthy; // we don't want to re-close all the connections every time as we might be ignoring // health checks due to all servers being marked unhealthy. defer server.closeAll() } previousStatus[address] = status } // If there is at least one healthy server, and the default server is not in the server list, // close all the connections to the default server so that clients reconnect and switch over // to a preferred server. hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) if healthyServerExists && !hasDefaultServer { if server, ok := lb.servers[lb.defaultServerAddress]; ok { defer server.closeAll() } } }, time.Second, ctx.Done()) logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) }