mirror of https://github.com/k3s-io/k3s
parent
1833b65fcd
commit
a17e336993
@ -0,0 +1,38 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/rancher/k3s/pkg/agent/util"
|
||||
)
|
||||
|
||||
func (lb *LoadBalancer) writeConfig() error {
|
||||
configOut, err := json.MarshalIndent(lb, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := util.WriteFile(lb.configFile, string(configOut)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) updateConfig() error {
|
||||
writeConfig := true
|
||||
if configBytes, err := ioutil.ReadFile(lb.configFile); err == nil {
|
||||
config := &LoadBalancer{}
|
||||
if err := json.Unmarshal(configBytes, config); err == nil {
|
||||
if config.ServerURL == lb.ServerURL {
|
||||
writeConfig = false
|
||||
lb.setServers(config.ServerAddresses)
|
||||
}
|
||||
}
|
||||
}
|
||||
if writeConfig {
|
||||
if err := lb.writeConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,142 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/google/tcpproxy"
|
||||
"github.com/rancher/k3s/pkg/cli/cmds"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type LoadBalancer struct {
|
||||
mutex sync.Mutex
|
||||
dialer *net.Dialer
|
||||
proxy *tcpproxy.Proxy
|
||||
|
||||
configFile string
|
||||
localAddress string
|
||||
localServerURL string
|
||||
originalServerAddress string
|
||||
ServerURL string
|
||||
ServerAddresses []string
|
||||
randomServers []string
|
||||
currentServerAddress string
|
||||
nextServerIndex int
|
||||
}
|
||||
|
||||
const (
|
||||
serviceName = "k3s-agent-load-balancer"
|
||||
)
|
||||
|
||||
func Setup(ctx context.Context, cfg cmds.Agent) (_lb *LoadBalancer, _err error) {
|
||||
if cfg.DisableLoadBalancer {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer func() {
|
||||
if _err != nil {
|
||||
logrus.Warnf("Error starting load balancer: %s", _err)
|
||||
if listener != nil {
|
||||
listener.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
localAddress := listener.Addr().String()
|
||||
|
||||
originalServerAddress, localServerURL, err := parseURL(cfg.ServerURL, localAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lb := &LoadBalancer{
|
||||
dialer: &net.Dialer{},
|
||||
configFile: filepath.Join(cfg.DataDir, "etc", serviceName+".json"),
|
||||
localAddress: localAddress,
|
||||
localServerURL: localServerURL,
|
||||
originalServerAddress: originalServerAddress,
|
||||
ServerURL: cfg.ServerURL,
|
||||
}
|
||||
|
||||
lb.setServers([]string{lb.originalServerAddress})
|
||||
|
||||
lb.proxy = &tcpproxy.Proxy{
|
||||
ListenFunc: func(string, string) (net.Listener, error) {
|
||||
return listener, nil
|
||||
},
|
||||
}
|
||||
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
|
||||
Addr: serviceName,
|
||||
DialContext: lb.dialContext,
|
||||
})
|
||||
|
||||
if err := lb.updateConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := lb.proxy.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("Running load balancer %s -> %v", lb.localAddress, lb.randomServers)
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) Update(serverAddresses []string) {
|
||||
if lb == nil {
|
||||
return
|
||||
}
|
||||
if !lb.setServers(serverAddresses) {
|
||||
return
|
||||
}
|
||||
logrus.Infof("Updating load balancer server addresses -> %v", lb.randomServers)
|
||||
|
||||
if err := lb.writeConfig(); err != nil {
|
||||
logrus.Warnf("Error updating load balancer config: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) LoadBalancerServerURL() string {
|
||||
if lb == nil {
|
||||
return ""
|
||||
}
|
||||
return lb.localServerURL
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
startIndex := lb.nextServerIndex
|
||||
for {
|
||||
targetServer := lb.currentServerAddress
|
||||
|
||||
conn, err := lb.dialer.DialContext(ctx, network, targetServer)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
logrus.Warnf("Dial error from load balancer: %s", err)
|
||||
|
||||
newServer, err := lb.nextServer(targetServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if targetServer != newServer {
|
||||
logrus.Warnf("Dial context in load balancer failed over to %s", newServer)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
maxIndex := len(lb.randomServers)
|
||||
if startIndex > maxIndex {
|
||||
startIndex = maxIndex
|
||||
}
|
||||
if lb.nextServerIndex == startIndex {
|
||||
return nil, errors.New("all servers failed")
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,183 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rancher/k3s/pkg/cli/cmds"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
listener net.Listener
|
||||
conns []net.Conn
|
||||
prefix string
|
||||
}
|
||||
|
||||
func createServer(prefix string) (*server, error) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &server{
|
||||
prefix: prefix,
|
||||
listener: listener,
|
||||
}
|
||||
go s.serve()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *server) serve() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.conns = append(s.conns, conn)
|
||||
go s.echo(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) close() {
|
||||
s.listener.Close()
|
||||
for _, conn := range s.conns {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) echo(conn net.Conn) {
|
||||
for {
|
||||
result, err := bufio.NewReader(conn).ReadString('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Write([]byte(s.prefix + ":" + result))
|
||||
}
|
||||
}
|
||||
|
||||
func ping(conn net.Conn) (string, error) {
|
||||
fmt.Fprintf(conn, "ping\n")
|
||||
result, err := bufio.NewReader(conn).ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimSpace(result), nil
|
||||
}
|
||||
|
||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
if a != b {
|
||||
t.Fatalf("[ %v != %v ]", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNotEqual(t *testing.T, a interface{}, b interface{}) {
|
||||
if a == b {
|
||||
t.Fatalf("[ %v == %v ]", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailOver(t *testing.T) {
|
||||
tmpDir, err := ioutil.TempDir("", "lb-test")
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
ogServe, err := createServer("og")
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
lbServe, err := createServer("lb")
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
cfg := cmds.Agent{
|
||||
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()),
|
||||
DataDir: tmpDir,
|
||||
}
|
||||
|
||||
lb, err := Setup(context.Background(), cfg)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
localAddress := parsedURL.Host
|
||||
|
||||
lb.Update([]string{lbServe.listener.Addr().String()})
|
||||
|
||||
conn1, err := net.Dial("tcp", localAddress)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
result1, err := ping(conn1)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
assertEqual(t, result1, "lb:ping")
|
||||
|
||||
lbServe.close()
|
||||
|
||||
_, err = ping(conn1)
|
||||
assertNotEqual(t, err, nil)
|
||||
|
||||
conn2, err := net.Dial("tcp", localAddress)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
result2, err := ping(conn2)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
assertEqual(t, result2, "og:ping")
|
||||
}
|
||||
|
||||
func TestFailFast(t *testing.T) {
|
||||
tmpDir, err := ioutil.TempDir("", "lb-test")
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := cmds.Agent{
|
||||
ServerURL: "http://127.0.0.1:-1/",
|
||||
DataDir: tmpDir,
|
||||
}
|
||||
|
||||
lb, err := Setup(context.Background(), cfg)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", lb.localAddress)
|
||||
if err != nil {
|
||||
assertEqual(t, err, nil)
|
||||
}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err = ping(conn)
|
||||
done <- err
|
||||
}()
|
||||
timeout := time.After(10 * time.Millisecond)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assertNotEqual(t, err, nil)
|
||||
case <-timeout:
|
||||
t.Fatal(errors.New("time out"))
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
|
||||
serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.originalServerAddress)
|
||||
if len(serverAddresses) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
if reflect.DeepEqual(serverAddresses, lb.ServerAddresses) {
|
||||
return false
|
||||
}
|
||||
|
||||
lb.ServerAddresses = serverAddresses
|
||||
lb.randomServers = append([]string{}, lb.ServerAddresses...)
|
||||
rand.Shuffle(len(lb.randomServers), func(i, j int) {
|
||||
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
|
||||
})
|
||||
if !hasOriginalServer {
|
||||
lb.randomServers = append(lb.randomServers, lb.originalServerAddress)
|
||||
}
|
||||
lb.currentServerAddress = lb.randomServers[0]
|
||||
lb.nextServerIndex = 1
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
if len(lb.randomServers) == 0 {
|
||||
return "", errors.New("No servers in load balancer proxy list")
|
||||
}
|
||||
if len(lb.randomServers) == 1 {
|
||||
return lb.currentServerAddress, nil
|
||||
}
|
||||
if failedServer != lb.currentServerAddress {
|
||||
return lb.currentServerAddress, nil
|
||||
}
|
||||
if lb.nextServerIndex >= len(lb.randomServers) {
|
||||
lb.nextServerIndex = 0
|
||||
}
|
||||
|
||||
lb.currentServerAddress = lb.randomServers[lb.nextServerIndex]
|
||||
lb.nextServerIndex++
|
||||
|
||||
return lb.currentServerAddress, nil
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package loadbalancer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseURL(serverURL, newHost string) (string, string, error) {
|
||||
parsedURL, err := url.Parse(serverURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if parsedURL.Host == "" {
|
||||
return "", "", errors.New("Initial server URL host is not defined for load balancer")
|
||||
}
|
||||
address := parsedURL.Host
|
||||
if parsedURL.Port() == "" {
|
||||
if strings.ToLower(parsedURL.Scheme) == "http" {
|
||||
address += ":80"
|
||||
}
|
||||
if strings.ToLower(parsedURL.Scheme) == "https" {
|
||||
address += ":443"
|
||||
}
|
||||
}
|
||||
parsedURL.Host = newHost
|
||||
return address, parsedURL.String(), nil
|
||||
}
|
||||
|
||||
func sortServers(input []string, search string) ([]string, bool) {
|
||||
result := []string{}
|
||||
found := false
|
||||
skip := map[string]bool{"": true}
|
||||
|
||||
for _, entry := range input {
|
||||
if skip[entry] {
|
||||
continue
|
||||
}
|
||||
if search == entry {
|
||||
found = true
|
||||
}
|
||||
skip[entry] = true
|
||||
result = append(result, entry)
|
||||
}
|
||||
|
||||
sort.Strings(result)
|
||||
return result, found
|
||||
}
|
Loading…
Reference in new issue