k3s/pkg/agent/tunnel/tunnel.go

201 lines
4.5 KiB
Go
Raw Normal View History

2019-01-01 08:23:01 +00:00
package tunnel
import (
"context"
"crypto/tls"
"fmt"
"net"
2019-07-18 12:00:07 +00:00
"reflect"
"strconv"
2019-01-01 08:23:01 +00:00
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/rancher/k3s/pkg/agent/proxy"
2019-01-09 16:54:15 +00:00
"github.com/rancher/k3s/pkg/daemons/config"
"github.com/rancher/k3s/pkg/version"
2019-05-09 22:05:51 +00:00
"github.com/rancher/remotedialer"
2019-01-01 08:23:01 +00:00
"github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
2019-07-14 07:29:21 +00:00
watchtypes "k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
2019-01-01 08:23:01 +00:00
"k8s.io/client-go/tools/clientcmd"
)
var (
ports = map[string]bool{
"10250": true,
"10010": true,
}
)
func getAddresses(endpoint *v1.Endpoints) []string {
serverAddresses := []string{}
if endpoint == nil {
return serverAddresses
}
for _, subset := range endpoint.Subsets {
var port string
if len(subset.Ports) > 0 {
port = strconv.Itoa(int(subset.Ports[0].Port))
}
if port == "" {
port = "443"
}
for _, address := range subset.Addresses {
serverAddresses = append(serverAddresses, net.JoinHostPort(address.IP, port))
}
}
return serverAddresses
}
func Setup(ctx context.Context, config *config.Node, proxy proxy.Proxy) error {
restConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigK3sController)
2019-01-01 08:23:01 +00:00
if err != nil {
return err
}
client, err := kubernetes.NewForConfig(restConfig)
2019-01-01 08:23:01 +00:00
if err != nil {
return err
}
nodeRestConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigKubelet)
if err != nil {
return err
}
tlsConfig, err := rest.TLSConfigFor(nodeRestConfig)
if err != nil {
return err
}
2020-03-26 21:08:47 +00:00
endpoint, _ := client.CoreV1().Endpoints("default").Get(ctx, "kubernetes", metav1.GetOptions{})
if endpoint != nil {
addresses := getAddresses(endpoint)
if len(addresses) > 0 {
proxy.Update(getAddresses(endpoint))
}
}
disconnect := map[string]context.CancelFunc{}
wg := &sync.WaitGroup{}
for _, address := range proxy.SupervisorAddresses() {
if _, ok := disconnect[address]; !ok {
disconnect[address] = connect(ctx, wg, address, tlsConfig)
}
}
go func() {
connect:
for {
2019-07-14 07:29:21 +00:00
time.Sleep(5 * time.Second)
2020-03-26 21:08:47 +00:00
watch, err := client.CoreV1().Endpoints("default").Watch(ctx, metav1.ListOptions{
FieldSelector: fields.Set{"metadata.name": "kubernetes"}.String(),
2019-07-14 07:29:21 +00:00
ResourceVersion: "0",
})
if err != nil {
logrus.Warnf("Unable to watch for tunnel endpoints: %v", err)
continue connect
}
watching:
for {
select {
case ev, ok := <-watch.ResultChan():
2019-07-14 07:29:21 +00:00
if !ok || ev.Type == watchtypes.Error {
2019-07-18 12:00:07 +00:00
if ok {
logrus.Errorf("Tunnel endpoint watch channel closed: %v", ev)
}
watch.Stop()
continue connect
}
endpoint, ok := ev.Object.(*v1.Endpoints)
if !ok {
2019-07-14 07:29:21 +00:00
logrus.Errorf("Tunnel could not case event object to endpoint: %v", ev)
continue watching
}
2019-07-18 12:00:07 +00:00
newAddresses := getAddresses(endpoint)
if reflect.DeepEqual(newAddresses, proxy.SupervisorAddresses()) {
2019-07-18 12:00:07 +00:00
continue watching
}
proxy.Update(newAddresses)
2019-07-10 21:50:23 +00:00
validEndpoint := map[string]bool{}
for _, address := range proxy.SupervisorAddresses() {
validEndpoint[address] = true
if _, ok := disconnect[address]; !ok {
disconnect[address] = connect(ctx, nil, address, tlsConfig)
}
}
for address, cancel := range disconnect {
if !validEndpoint[address] {
cancel()
delete(disconnect, address)
2019-07-18 12:00:07 +00:00
logrus.Infof("Stopped tunnel to %s", address)
}
}
}
}
}
}()
2019-07-18 01:15:15 +00:00
wait := make(chan int, 1)
go func() {
wg.Wait()
wait <- 0
}()
select {
case <-ctx.Done():
logrus.Error("Tunnel context canceled while waiting for connection")
2019-07-18 01:15:15 +00:00
return ctx.Err()
case <-wait:
}
return nil
}
func connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) context.CancelFunc {
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{
TLSClientConfig: tlsConfig,
2019-01-01 08:23:01 +00:00
}
once := sync.Once{}
if waitGroup != nil {
waitGroup.Add(1)
}
2019-07-18 12:00:07 +00:00
ctx, cancel := context.WithCancel(rootCtx)
2019-01-01 08:23:01 +00:00
go func() {
for {
remotedialer.ClientConnect(ctx, wsURL, nil, ws, func(proto, address string) bool {
2019-01-01 08:23:01 +00:00
host, port, err := net.SplitHostPort(address)
return err == nil && proto == "tcp" && ports[port] && host == "127.0.0.1"
}, func(_ context.Context) error {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
2019-01-01 08:23:01 +00:00
return nil
})
if ctx.Err() != nil {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return
}
2019-01-01 08:23:01 +00:00
}
}()
return cancel
2019-01-01 08:23:01 +00:00
}