Use go tcpproxy

pull/676/head
Erik Wilson 5 years ago
parent 1833b65fcd
commit a17e336993

@ -0,0 +1,38 @@
package loadbalancer
import (
"encoding/json"
"io/ioutil"
"github.com/rancher/k3s/pkg/agent/util"
)
func (lb *LoadBalancer) writeConfig() error {
configOut, err := json.MarshalIndent(lb, "", " ")
if err != nil {
return err
}
if err := util.WriteFile(lb.configFile, string(configOut)); err != nil {
return err
}
return nil
}
func (lb *LoadBalancer) updateConfig() error {
writeConfig := true
if configBytes, err := ioutil.ReadFile(lb.configFile); err == nil {
config := &LoadBalancer{}
if err := json.Unmarshal(configBytes, config); err == nil {
if config.ServerURL == lb.ServerURL {
writeConfig = false
lb.setServers(config.ServerAddresses)
}
}
}
if writeConfig {
if err := lb.writeConfig(); err != nil {
return err
}
}
return nil
}

@ -0,0 +1,142 @@
package loadbalancer
import (
"context"
"errors"
"net"
"path/filepath"
"sync"
"github.com/google/tcpproxy"
"github.com/rancher/k3s/pkg/cli/cmds"
"github.com/sirupsen/logrus"
)
type LoadBalancer struct {
mutex sync.Mutex
dialer *net.Dialer
proxy *tcpproxy.Proxy
configFile string
localAddress string
localServerURL string
originalServerAddress string
ServerURL string
ServerAddresses []string
randomServers []string
currentServerAddress string
nextServerIndex int
}
const (
serviceName = "k3s-agent-load-balancer"
)
func Setup(ctx context.Context, cfg cmds.Agent) (_lb *LoadBalancer, _err error) {
if cfg.DisableLoadBalancer {
return nil, nil
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
defer func() {
if _err != nil {
logrus.Warnf("Error starting load balancer: %s", _err)
if listener != nil {
listener.Close()
}
}
}()
if err != nil {
return nil, err
}
localAddress := listener.Addr().String()
originalServerAddress, localServerURL, err := parseURL(cfg.ServerURL, localAddress)
if err != nil {
return nil, err
}
lb := &LoadBalancer{
dialer: &net.Dialer{},
configFile: filepath.Join(cfg.DataDir, "etc", serviceName+".json"),
localAddress: localAddress,
localServerURL: localServerURL,
originalServerAddress: originalServerAddress,
ServerURL: cfg.ServerURL,
}
lb.setServers([]string{lb.originalServerAddress})
lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) {
return listener, nil
},
}
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
Addr: serviceName,
DialContext: lb.dialContext,
})
if err := lb.updateConfig(); err != nil {
return nil, err
}
if err := lb.proxy.Start(); err != nil {
return nil, err
}
logrus.Infof("Running load balancer %s -> %v", lb.localAddress, lb.randomServers)
return lb, nil
}
func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil {
return
}
if !lb.setServers(serverAddresses) {
return
}
logrus.Infof("Updating load balancer server addresses -> %v", lb.randomServers)
if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer config: %s", err)
}
}
func (lb *LoadBalancer) LoadBalancerServerURL() string {
if lb == nil {
return ""
}
return lb.localServerURL
}
func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress
conn, err := lb.dialer.DialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Warnf("Dial error from load balancer: %s", err)
newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Warnf("Dial context in load balancer failed over to %s", newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
return nil, errors.New("all servers failed")
}
}
}

@ -0,0 +1,183 @@
package loadbalancer
import (
"bufio"
"context"
"errors"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/rancher/k3s/pkg/cli/cmds"
)
type server struct {
listener net.Listener
conns []net.Conn
prefix string
}
func createServer(prefix string) (*server, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
}
s := &server{
prefix: prefix,
listener: listener,
}
go s.serve()
return s, nil
}
func (s *server) serve() {
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
s.conns = append(s.conns, conn)
go s.echo(conn)
}
}
func (s *server) close() {
s.listener.Close()
for _, conn := range s.conns {
conn.Close()
}
}
func (s *server) echo(conn net.Conn) {
for {
result, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
return
}
conn.Write([]byte(s.prefix + ":" + result))
}
}
func ping(conn net.Conn) (string, error) {
fmt.Fprintf(conn, "ping\n")
result, err := bufio.NewReader(conn).ReadString('\n')
if err != nil {
return "", err
}
return strings.TrimSpace(result), nil
}
func assertEqual(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Fatalf("[ %v != %v ]", a, b)
}
}
func assertNotEqual(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Fatalf("[ %v == %v ]", a, b)
}
}
func TestFailOver(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "lb-test")
if err != nil {
assertEqual(t, err, nil)
}
defer os.RemoveAll(tmpDir)
ogServe, err := createServer("og")
if err != nil {
assertEqual(t, err, nil)
}
lbServe, err := createServer("lb")
if err != nil {
assertEqual(t, err, nil)
}
cfg := cmds.Agent{
ServerURL: fmt.Sprintf("http://%s/", ogServe.listener.Addr().String()),
DataDir: tmpDir,
}
lb, err := Setup(context.Background(), cfg)
if err != nil {
assertEqual(t, err, nil)
}
parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
if err != nil {
assertEqual(t, err, nil)
}
localAddress := parsedURL.Host
lb.Update([]string{lbServe.listener.Addr().String()})
conn1, err := net.Dial("tcp", localAddress)
if err != nil {
assertEqual(t, err, nil)
}
result1, err := ping(conn1)
if err != nil {
assertEqual(t, err, nil)
}
assertEqual(t, result1, "lb:ping")
lbServe.close()
_, err = ping(conn1)
assertNotEqual(t, err, nil)
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
assertEqual(t, err, nil)
}
result2, err := ping(conn2)
if err != nil {
assertEqual(t, err, nil)
}
assertEqual(t, result2, "og:ping")
}
func TestFailFast(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "lb-test")
if err != nil {
assertEqual(t, err, nil)
}
defer os.RemoveAll(tmpDir)
cfg := cmds.Agent{
ServerURL: "http://127.0.0.1:-1/",
DataDir: tmpDir,
}
lb, err := Setup(context.Background(), cfg)
if err != nil {
assertEqual(t, err, nil)
}
conn, err := net.Dial("tcp", lb.localAddress)
if err != nil {
assertEqual(t, err, nil)
}
done := make(chan error)
go func() {
_, err = ping(conn)
done <- err
}()
timeout := time.After(10 * time.Millisecond)
select {
case err := <-done:
assertNotEqual(t, err, nil)
case <-timeout:
t.Fatal(errors.New("time out"))
}
}

@ -0,0 +1,57 @@
package loadbalancer
import (
"errors"
"math/rand"
"reflect"
)
func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
serverAddresses, hasOriginalServer := sortServers(serverAddresses, lb.originalServerAddress)
if len(serverAddresses) == 0 {
return false
}
lb.mutex.Lock()
defer lb.mutex.Unlock()
if reflect.DeepEqual(serverAddresses, lb.ServerAddresses) {
return false
}
lb.ServerAddresses = serverAddresses
lb.randomServers = append([]string{}, lb.ServerAddresses...)
rand.Shuffle(len(lb.randomServers), func(i, j int) {
lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i]
})
if !hasOriginalServer {
lb.randomServers = append(lb.randomServers, lb.originalServerAddress)
}
lb.currentServerAddress = lb.randomServers[0]
lb.nextServerIndex = 1
return true
}
func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
if len(lb.randomServers) == 0 {
return "", errors.New("No servers in load balancer proxy list")
}
if len(lb.randomServers) == 1 {
return lb.currentServerAddress, nil
}
if failedServer != lb.currentServerAddress {
return lb.currentServerAddress, nil
}
if lb.nextServerIndex >= len(lb.randomServers) {
lb.nextServerIndex = 0
}
lb.currentServerAddress = lb.randomServers[lb.nextServerIndex]
lb.nextServerIndex++
return lb.currentServerAddress, nil
}

@ -0,0 +1,49 @@
package loadbalancer
import (
"errors"
"net/url"
"sort"
"strings"
)
func parseURL(serverURL, newHost string) (string, string, error) {
parsedURL, err := url.Parse(serverURL)
if err != nil {
return "", "", err
}
if parsedURL.Host == "" {
return "", "", errors.New("Initial server URL host is not defined for load balancer")
}
address := parsedURL.Host
if parsedURL.Port() == "" {
if strings.ToLower(parsedURL.Scheme) == "http" {
address += ":80"
}
if strings.ToLower(parsedURL.Scheme) == "https" {
address += ":443"
}
}
parsedURL.Host = newHost
return address, parsedURL.String(), nil
}
func sortServers(input []string, search string) ([]string, bool) {
result := []string{}
found := false
skip := map[string]bool{"": true}
for _, entry := range input {
if skip[entry] {
continue
}
if search == entry {
found = true
}
skip[entry] = true
result = append(result, entry)
}
sort.Strings(result)
return result, found
}

@ -12,6 +12,7 @@ import (
"github.com/rancher/k3s/pkg/agent/config"
"github.com/rancher/k3s/pkg/agent/containerd"
"github.com/rancher/k3s/pkg/agent/flannel"
"github.com/rancher/k3s/pkg/agent/loadbalancer"
"github.com/rancher/k3s/pkg/agent/syssetup"
"github.com/rancher/k3s/pkg/agent/tunnel"
"github.com/rancher/k3s/pkg/cli/cmds"
@ -21,7 +22,7 @@ import (
"github.com/sirupsen/logrus"
)
func run(ctx context.Context, cfg cmds.Agent) error {
func run(ctx context.Context, cfg cmds.Agent, lb *loadbalancer.LoadBalancer) error {
nodeConfig := config.Get(ctx, cfg)
if err := config.HostnameCheck(cfg); err != nil {
@ -47,7 +48,7 @@ func run(ctx context.Context, cfg cmds.Agent) error {
return err
}
if err := tunnel.Setup(ctx, nodeConfig); err != nil {
if err := tunnel.Setup(ctx, nodeConfig, lb.Update); err != nil {
return err
}
@ -77,11 +78,20 @@ func Run(ctx context.Context, cfg cmds.Agent) error {
}
cfg.DataDir = filepath.Join(cfg.DataDir, "agent")
os.MkdirAll(cfg.DataDir, 0700)
if cfg.ClusterSecret != "" {
cfg.Token = "K10node:" + cfg.ClusterSecret
}
lb, err := loadbalancer.Setup(ctx, cfg)
if err != nil {
return err
}
if lb != nil {
cfg.ServerURL = lb.LoadBalancerServerURL()
}
for {
tmpFile, err := clientaccess.AgentAccessInfoToTempKubeConfig("", cfg.ServerURL, cfg.Token)
if err != nil {
@ -97,8 +107,7 @@ func Run(ctx context.Context, cfg cmds.Agent) error {
break
}
os.MkdirAll(cfg.DataDir, 0700)
return run(ctx, cfg)
return run(ctx, cfg, lb)
}
func validate() error {

@ -53,7 +53,7 @@ func getAddresses(endpoint *v1.Endpoints) []string {
return serverAddresses
}
func Setup(ctx context.Context, config *config.Node) error {
func Setup(ctx context.Context, config *config.Node, onChange func([]string)) error {
restConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigNode)
if err != nil {
return err
@ -74,6 +74,9 @@ func Setup(ctx context.Context, config *config.Node) error {
endpoint, _ := client.CoreV1().Endpoints("default").Get("kubernetes", metav1.GetOptions{})
if endpoint != nil {
addresses = getAddresses(endpoint)
if onChange != nil {
onChange(addresses)
}
}
disconnect := map[string]context.CancelFunc{}
@ -120,6 +123,9 @@ func Setup(ctx context.Context, config *config.Node) error {
}
addresses = newAddresses
logrus.Infof("Tunnel endpoint watch event: %v", addresses)
if onChange != nil {
onChange(addresses)
}
validEndpoint := map[string]bool{}

@ -11,6 +11,7 @@ type Agent struct {
Token string
TokenFile string
ServerURL string
DisableLoadBalancer bool
ResolvConf string
DataDir string
NodeIP string

@ -214,6 +214,7 @@ func run(app *cli.Context, cfg *cmds.Server) error {
agentConfig.ServerURL = url
agentConfig.Token = token
agentConfig.Labels = append(agentConfig.Labels, "node-role.kubernetes.io/master=true")
agentConfig.DisableLoadBalancer = true
return agent.Run(ctx, agentConfig)
}

Loading…
Cancel
Save