k3s/pkg/agent/tunnel/tunnel.go

206 lines
5.7 KiB
Go

package tunnel
import (
"context"
"crypto/tls"
"fmt"
"net"
"reflect"
"sync"
"github.com/gorilla/websocket"
agentconfig "github.com/k3s-io/k3s/pkg/agent/config"
"github.com/k3s-io/k3s/pkg/agent/proxy"
"github.com/k3s-io/k3s/pkg/daemons/config"
"github.com/k3s-io/k3s/pkg/util"
"github.com/k3s-io/k3s/pkg/version"
"github.com/rancher/remotedialer"
"github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/tools/clientcmd"
toolswatch "k8s.io/client-go/tools/watch"
)
var (
ports = map[string]bool{
"10250": true,
"10010": true,
}
)
func Setup(ctx context.Context, config *config.Node, proxy proxy.Proxy) error {
restConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigK3sController)
if err != nil {
return err
}
client, err := kubernetes.NewForConfig(restConfig)
if err != nil {
return err
}
nodeRestConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigKubelet)
if err != nil {
return err
}
tlsConfig, err := rest.TLSConfigFor(nodeRestConfig)
if err != nil {
return err
}
// The loadbalancer is only disabled when there is a local apiserver. Servers without a local
// apiserver load-balance to themselves initially, then switch over to an apiserver node as soon
// as we get some addresses from the code below.
if proxy.IsSupervisorLBEnabled() && proxy.SupervisorURL() != "" {
logrus.Info("Getting list of apiserver endpoints from server")
// If not running an apiserver locally, try to get a list of apiservers from the server we're
// connecting to. If that fails, fall back to querying the endpoints list from Kubernetes. This
// fallback requires that the server we're joining be running an apiserver, but is the only safe
// thing to do if its supervisor is down-level and can't provide us with an endpoint list.
if addresses := agentconfig.APIServers(ctx, config, proxy); len(addresses) > 0 {
proxy.SetSupervisorDefault(addresses[0])
proxy.Update(addresses)
} else {
if endpoint, _ := client.CoreV1().Endpoints("default").Get(ctx, "kubernetes", metav1.GetOptions{}); endpoint != nil {
if addresses := util.GetAddresses(endpoint); len(addresses) > 0 {
proxy.Update(addresses)
}
}
}
}
// Attempt to connect to supervisors, storing their cancellation function for later when we
// need to disconnect.
disconnect := map[string]context.CancelFunc{}
wg := &sync.WaitGroup{}
for _, address := range proxy.SupervisorAddresses() {
if _, ok := disconnect[address]; !ok {
disconnect[address] = connect(ctx, wg, address, tlsConfig)
}
}
// Once the apiserver is up, go into a watch loop, adding and removing tunnels as endpoints come
// and go from the cluster.
go func() {
if err := util.WaitForAPIServerReady(ctx, client, util.DefaultAPIServerReadyTimeout); err != nil {
logrus.Warnf("Tunnel endpoint watch failed to wait for apiserver ready: %v", err)
}
endpoints := client.CoreV1().Endpoints(metav1.NamespaceDefault)
fieldSelector := fields.Set{metav1.ObjectNameField: "kubernetes"}.String()
lw := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (object runtime.Object, e error) {
options.FieldSelector = fieldSelector
return endpoints.List(ctx, options)
},
WatchFunc: func(options metav1.ListOptions) (i watch.Interface, e error) {
options.FieldSelector = fieldSelector
return endpoints.Watch(ctx, options)
},
}
_, _, watch, done := toolswatch.NewIndexerInformerWatcher(lw, &v1.Endpoints{})
defer func() {
watch.Stop()
<-done
}()
for {
select {
case <-ctx.Done():
return
case ev, ok := <-watch.ResultChan():
endpoint, ok := ev.Object.(*v1.Endpoints)
if !ok {
logrus.Errorf("Tunnel watch failed: event object not of type v1.Endpoints")
continue
}
newAddresses := util.GetAddresses(endpoint)
if reflect.DeepEqual(newAddresses, proxy.SupervisorAddresses()) {
continue
}
proxy.Update(newAddresses)
validEndpoint := map[string]bool{}
for _, address := range proxy.SupervisorAddresses() {
validEndpoint[address] = true
if _, ok := disconnect[address]; !ok {
disconnect[address] = connect(ctx, nil, address, tlsConfig)
}
}
for address, cancel := range disconnect {
if !validEndpoint[address] {
cancel()
delete(disconnect, address)
logrus.Infof("Stopped tunnel to %s", address)
}
}
}
}
}()
wait := make(chan int, 1)
go func() {
wg.Wait()
wait <- 0
}()
select {
case <-ctx.Done():
logrus.Error("Tunnel context canceled while waiting for connection")
return ctx.Err()
case <-wait:
}
return nil
}
func connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) context.CancelFunc {
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{
TLSClientConfig: tlsConfig,
}
once := sync.Once{}
if waitGroup != nil {
waitGroup.Add(1)
}
ctx, cancel := context.WithCancel(rootCtx)
go func() {
for {
remotedialer.ClientConnect(ctx, wsURL, nil, ws, func(proto, address string) bool {
host, port, err := net.SplitHostPort(address)
return err == nil && proto == "tcp" && ports[port] && (host == "127.0.0.1" || host == "::1")
}, func(_ context.Context) error {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
})
if ctx.Err() != nil {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return
}
}
}()
return cancel
}