diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 8dfb648bab..116cc246c4 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -2,17 +2,17 @@ package loadbalancer import ( "context" - "errors" "fmt" "math/rand" "net" - "net/http" "net/url" "os" "strconv" "github.com/k3s-io/k3s/pkg/version" http_dialer "github.com/mwitkow/go-http-dialer" + "github.com/pkg/errors" + "golang.org/x/net/http/httpproxy" "golang.org/x/net/proxy" "github.com/sirupsen/logrus" @@ -21,28 +21,39 @@ import ( var defaultDialer proxy.Dialer = &net.Dialer{} -func init() { +// SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, +// if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured +// to indicate use of a HTTP proxy for the server URL. +func SetHTTPProxy(address string) error { // Check if env variable for proxy is set - address := os.Getenv(version.ProgramUpper + "_URL") - - if useProxy, _ := strconv.ParseBool(os.Getenv(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED")); !useProxy { - return + if useProxy, _ := strconv.ParseBool(os.Getenv(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED")); !useProxy || address == "" { + return nil } - req, err := http.NewRequest("GET", "https://"+address, nil) + serverURL, err := url.Parse(address) if err != nil { - logrus.Errorf("Error creating request for address %s: %v", address, err) + return errors.Wrapf(err, "failed to parse address %s", address) } - proxyURL, err := http.ProxyFromEnvironment(req) + + // Call this directly instead of using the cached environment used by http.ProxyFromEnvironment to allow for testing + proxyFromEnvironment := httpproxy.FromEnvironment().ProxyFunc() + proxyURL, err := proxyFromEnvironment(serverURL) if err != nil { - logrus.Errorf("Error getting the proxy for address %s: %v", address, err) + return errors.Wrapf(err, "failed to get proxy for address %s", address) + } + if proxyURL == nil { + logrus.Debug(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED is true but no proxy is configured for URL " + serverURL.String()) + return nil } - if dialer, err := proxyDialer(proxyURL); err != nil { - logrus.Errorf("Error creating the proxyDialer for %s: %v", address, err) - } else { - defaultDialer = dialer + dialer, err := proxyDialer(proxyURL) + if err != nil { + return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) } + + defaultDialer = dialer + logrus.Debugf("Using proxy %s for agent connection to %s", proxyURL, serverURL) + return nil } func (lb *LoadBalancer) setServers(serverAddresses []string) bool { diff --git a/pkg/agent/loadbalancer/servers_test.go b/pkg/agent/loadbalancer/servers_test.go new file mode 100644 index 0000000000..c8b8b5b924 --- /dev/null +++ b/pkg/agent/loadbalancer/servers_test.go @@ -0,0 +1,156 @@ +package loadbalancer + +import ( + "fmt" + "net" + "os" + "strings" + "testing" + + "github.com/k3s-io/k3s/pkg/version" + "github.com/sirupsen/logrus" +) + +var defaultEnv map[string]string +var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"} + +func init() { + logrus.SetLevel(logrus.DebugLevel) +} + +func prepareEnv(env ...string) { + defaultDialer = &net.Dialer{} + defaultEnv = map[string]string{} + for _, e := range proxyEnvs { + if v, ok := os.LookupEnv(e); ok { + defaultEnv[e] = v + os.Unsetenv(e) + } + } + for _, e := range env { + k, v, _ := strings.Cut(e, "=") + os.Setenv(k, v) + } +} + +func restoreEnv() { + for _, e := range proxyEnvs { + if v, ok := defaultEnv[e]; ok { + os.Setenv(e, v) + } else { + os.Unsetenv(e) + } + } +} + +func Test_UnitSetHTTPProxy(t *testing.T) { + type args struct { + address string + } + tests := []struct { + name string + args args + setup func() error + teardown func() error + wantErr bool + wantDialer string + }{ + { + name: "Default Proxy", + args: args{address: "https://1.2.3.4:6443"}, + wantDialer: "*net.Dialer", + setup: func() error { + prepareEnv(version.ProgramUpper+"_AGENT_HTTP_PROXY_ALLOWED=", "HTTP_PROXY=", "HTTPS_PROXY=", "NO_PROXY=") + return nil + }, + teardown: func() error { + restoreEnv() + return nil + }, + }, + { + name: "Agent Proxy Enabled", + args: args{address: "https://1.2.3.4:6443"}, + wantDialer: "*http_dialer.HttpTunnel", + setup: func() error { + prepareEnv(version.ProgramUpper+"_AGENT_HTTP_PROXY_ALLOWED=true", "HTTP_PROXY=http://proxy:8080", "HTTPS_PROXY=http://proxy:8080", "NO_PROXY=") + return nil + }, + teardown: func() error { + restoreEnv() + return nil + }, + }, + { + name: "Agent Proxy Enabled with Bogus Proxy", + args: args{address: "https://1.2.3.4:6443"}, + wantDialer: "*net.Dialer", + wantErr: true, + setup: func() error { + prepareEnv(version.ProgramUpper+"_AGENT_HTTP_PROXY_ALLOWED=true", "HTTP_PROXY=proxy proxy", "HTTPS_PROXY=proxy proxy", "NO_PROXY=") + return nil + }, + teardown: func() error { + restoreEnv() + return nil + }, + }, + { + name: "Agent Proxy Enabled with Bogus Server", + args: args{address: "https://1.2.3.4:k3s"}, + wantDialer: "*net.Dialer", + wantErr: true, + setup: func() error { + prepareEnv(version.ProgramUpper+"_AGENT_HTTP_PROXY_ALLOWED=true", "HTTP_PROXY=http://proxy:8080", "HTTPS_PROXY=http://proxy:8080", "NO_PROXY=") + return nil + }, + teardown: func() error { + restoreEnv() + return nil + }, + }, + { + name: "Agent Proxy Enabled but IP Excluded", + args: args{address: "https://1.2.3.4:6443"}, + wantDialer: "*net.Dialer", + setup: func() error { + prepareEnv(version.ProgramUpper+"_AGENT_HTTP_PROXY_ALLOWED=true", "HTTP_PROXY=http://proxy:8080", "HTTPS_PROXY=http://proxy:8080", "NO_PROXY=1.2.0.0/16") + return nil + }, + teardown: func() error { + restoreEnv() + return nil + }, + }, + { + name: "Agent Proxy Enabled but Domain Excluded", + args: args{address: "https://server.example.com:6443"}, + wantDialer: "*net.Dialer", + setup: func() error { + prepareEnv(version.ProgramUpper+"_AGENT_HTTP_PROXY_ALLOWED=true", "HTTP_PROXY=http://proxy:8080", "HTTPS_PROXY=http://proxy:8080", "NO_PROXY=*.example.com") + return nil + }, + teardown: func() error { + restoreEnv() + return nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer tt.teardown() + if err := tt.setup(); err != nil { + t.Errorf("Setup for SetHTTPProxy() failed = %v", err) + return + } + err := SetHTTPProxy(tt.args.address) + t.Logf("SetHTTPProxy() error = %v", err) + if (err != nil) != tt.wantErr { + t.Errorf("SetHTTPProxy() error = %v, wantErr %v", err, tt.wantErr) + } + if dialerType := fmt.Sprintf("%T", defaultDialer); dialerType != tt.wantDialer { + t.Errorf("Got wrong dialer type %s, wanted %s", dialerType, tt.wantDialer) + } + }) + } +} diff --git a/pkg/agent/proxy/apiproxy.go b/pkg/agent/proxy/apiproxy.go index eef4d0e634..0cdc583d26 100644 --- a/pkg/agent/proxy/apiproxy.go +++ b/pkg/agent/proxy/apiproxy.go @@ -41,6 +41,9 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor } if lbEnabled { + if err := loadbalancer.SetHTTPProxy(supervisorURL); err != nil { + return nil, err + } lb, err := loadbalancer.New(ctx, dataDir, loadbalancer.SupervisorServiceName, supervisorURL, p.lbServerPort, isIPv6) if err != nil { return nil, err