Use Dial with context

pull/8/head
Mikhail Mazurskiy 2018-05-19 08:14:37 +10:00
parent 77a08ee2d7
commit 5e8e570dbd
No known key found for this signature in database
GPG Key ID: 93551ECC96E2F568
25 changed files with 111 additions and 110 deletions

View File

@ -261,7 +261,7 @@ func CreateNodeDialer(s completedServerRunOptions) (tunneler.Tunneler, *http.Tra
// Proxying to pods and services is IP-based... don't expect to be able to verify the hostname
proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true}
proxyTransport := utilnet.SetTransportDefaults(&http.Transport{
Dial: proxyDialerFn,
DialContext: proxyDialerFn,
TLSClientConfig: proxyTLSClientConfig,
})
return nodeTunneler, proxyTransport, nil
@ -522,8 +522,8 @@ func BuildGenericConfig(
if err != nil {
return nil, err
}
if proxyTransport != nil && proxyTransport.Dial != nil {
ret.Dial = proxyTransport.Dial
if proxyTransport != nil && proxyTransport.DialContext != nil {
ret.Dial = proxyTransport.DialContext
}
return ret, err
},

View File

@ -74,7 +74,7 @@ func MakeTransport(config *KubeletClientConfig) (http.RoundTripper, error) {
rt := http.DefaultTransport
if config.Dial != nil || tlsConfig != nil {
rt = utilnet.SetOldTransportDefaults(&http.Transport{
Dial: config.Dial,
DialContext: config.Dial,
TLSClientConfig: tlsConfig,
})
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package master
import (
"context"
"crypto/tls"
"encoding/json"
"io/ioutil"
@ -108,7 +109,7 @@ func setUp(t *testing.T) (*etcdtesting.EtcdTestServer, Config, informers.SharedI
config.GenericConfig.LoopbackClientConfig = &restclient.Config{APIPath: "/api", ContentConfig: restclient.ContentConfig{NegotiatedSerializer: legacyscheme.Codecs}}
config.ExtraConfig.KubeletClientConfig = kubeletclient.KubeletClientConfig{Port: 10250}
config.ExtraConfig.ProxyTransport = utilnet.SetTransportDefaults(&http.Transport{
Dial: func(network, addr string) (net.Conn, error) { return nil, nil },
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, nil },
TLSClientConfig: &tls.Config{},
})

View File

@ -43,7 +43,7 @@ type AddressFunc func() (addresses []string, err error)
type Tunneler interface {
Run(AddressFunc)
Stop()
Dial(net, addr string) (net.Conn, error)
Dial(ctx context.Context, net, addr string) (net.Conn, error)
SecondsSinceSync() int64
SecondsSinceSSHKeySync() int64
}
@ -149,8 +149,8 @@ func (c *SSHTunneler) Stop() {
}
}
func (c *SSHTunneler) Dial(net, addr string) (net.Conn, error) {
return c.tunnels.Dial(net, addr)
func (c *SSHTunneler) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
return c.tunnels.Dial(ctx, net, addr)
}
func (c *SSHTunneler) SecondsSinceSync() int64 {

View File

@ -17,6 +17,7 @@ limitations under the License.
package tunneler
import (
"context"
"fmt"
"net"
"os"
@ -111,11 +112,11 @@ type FakeTunneler struct {
SecondsSinceSSHKeySyncValue int64
}
func (t *FakeTunneler) Run(AddressFunc) {}
func (t *FakeTunneler) Stop() {}
func (t *FakeTunneler) Dial(net, addr string) (net.Conn, error) { return nil, nil }
func (t *FakeTunneler) SecondsSinceSync() int64 { return t.SecondsSinceSyncValue }
func (t *FakeTunneler) SecondsSinceSSHKeySync() int64 { return t.SecondsSinceSSHKeySyncValue }
func (t *FakeTunneler) Run(AddressFunc) {}
func (t *FakeTunneler) Stop() {}
func (t *FakeTunneler) Dial(ctx context.Context, net, addr string) (net.Conn, error) { return nil, nil }
func (t *FakeTunneler) SecondsSinceSync() int64 { return t.SecondsSinceSyncValue }
func (t *FakeTunneler) SecondsSinceSSHKeySync() int64 { return t.SecondsSinceSSHKeySyncValue }
// TestIsTunnelSyncHealthy verifies that the 600 second lag test
// is honored.

View File

@ -18,6 +18,7 @@ package ssh
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
@ -121,10 +122,11 @@ func (s *SSHTunnel) Open() error {
return err
}
func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) {
func (s *SSHTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if s.client == nil {
return nil, errors.New("tunnel is not opened.")
}
// This Dial method does not allow to pass a context unfortunately
return s.client.Dial(network, address)
}
@ -294,7 +296,7 @@ func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
type tunnel interface {
Open() error
Close() error
Dial(network, address string) (net.Conn, error)
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
type sshTunnelEntry struct {
@ -361,7 +363,7 @@ func (l *SSHTunnelList) delayedHealthCheck(e sshTunnelEntry, delay time.Duration
func (l *SSHTunnelList) healthCheck(e sshTunnelEntry) error {
// GET the healthcheck path using the provided tunnel's dial function.
transport := utilnet.SetTransportDefaults(&http.Transport{
Dial: e.Tunnel.Dial,
DialContext: e.Tunnel.Dial,
// TODO(cjcullen): Plumb real TLS options through.
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
// We don't reuse the clients, so disable the keep-alive to properly
@ -394,7 +396,7 @@ func (l *SSHTunnelList) removeAndReAdd(e sshTunnelEntry) {
go l.createAndAddTunnel(e.Address)
}
func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
func (l *SSHTunnelList) Dial(ctx context.Context, net, addr string) (net.Conn, error) {
start := time.Now()
id := mathrand.Int63() // So you can match begins/ends in the log.
glog.Infof("[%x: %v] Dialing...", id, addr)
@ -405,7 +407,7 @@ func (l *SSHTunnelList) Dial(net, addr string) (net.Conn, error) {
if err != nil {
return nil, err
}
return tunnel.Dial(net, addr)
return tunnel.Dial(ctx, net, addr)
}
func (l *SSHTunnelList) pickTunnel(addr string) (tunnel, error) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package ssh
import (
"context"
"fmt"
"io"
"net"
@ -145,7 +146,7 @@ func TestSSHTunnel(t *testing.T) {
t.FailNow()
}
_, err = tunnel.Dial("tcp", "127.0.0.1:8080")
_, err = tunnel.Dial(context.Background(), "tcp", "127.0.0.1:8080")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
@ -176,7 +177,7 @@ func (*fakeTunnel) Close() error {
return nil
}
func (*fakeTunnel) Dial(network, address string) (net.Conn, error) {
func (*fakeTunnel) Dial(ctx context.Context, network, address string) (net.Conn, error) {
return nil, nil
}

View File

@ -19,6 +19,7 @@ package spdy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
@ -118,7 +119,7 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
}
if proxyURL == nil {
return s.dialWithoutProxy(req.URL)
return s.dialWithoutProxy(req.Context(), req.URL)
}
// ensure we use a canonical host with proxyReq
@ -136,7 +137,7 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
proxyReq.Header.Set("Proxy-Authorization", pa)
}
proxyDialConn, err := s.dialWithoutProxy(proxyURL)
proxyDialConn, err := s.dialWithoutProxy(req.Context(), proxyURL)
if err != nil {
return nil, err
}
@ -187,14 +188,15 @@ func (s *SpdyRoundTripper) dial(req *http.Request) (net.Conn, error) {
}
// dialWithoutProxy dials the host specified by url, using TLS if appropriate.
func (s *SpdyRoundTripper) dialWithoutProxy(url *url.URL) (net.Conn, error) {
func (s *SpdyRoundTripper) dialWithoutProxy(ctx context.Context, url *url.URL) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
if url.Scheme == "http" {
if s.Dialer == nil {
return net.Dial("tcp", dialAddr)
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
} else {
return s.Dialer.Dial("tcp", dialAddr)
return s.Dialer.DialContext(ctx, "tcp", dialAddr)
}
}

View File

@ -19,6 +19,7 @@ package net
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
@ -90,8 +91,8 @@ func SetOldTransportDefaults(t *http.Transport) *http.Transport {
// ProxierWithNoProxyCIDR allows CIDR rules in NO_PROXY
t.Proxy = NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}
if t.Dial == nil {
t.Dial = defaultTransport.Dial
if t.DialContext == nil {
t.DialContext = defaultTransport.DialContext
}
if t.TLSHandshakeTimeout == 0 {
t.TLSHandshakeTimeout = defaultTransport.TLSHandshakeTimeout
@ -119,7 +120,7 @@ type RoundTripperWrapper interface {
WrappedRoundTripper() http.RoundTripper
}
type DialFunc func(net, addr string) (net.Conn, error)
type DialFunc func(ctx context.Context, net, addr string) (net.Conn, error)
func DialerFor(transport http.RoundTripper) (DialFunc, error) {
if transport == nil {
@ -128,7 +129,7 @@ func DialerFor(transport http.RoundTripper) (DialFunc, error) {
switch transport := transport.(type) {
case *http.Transport:
return transport.Dial, nil
return transport.DialContext, nil
case RoundTripperWrapper:
return DialerFor(transport.WrappedRoundTripper())
default:

View File

@ -17,6 +17,7 @@ limitations under the License.
package proxy
import (
"context"
"crypto/tls"
"fmt"
"net"
@ -29,7 +30,7 @@ import (
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
func DialURL(ctx context.Context, url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer, err := utilnet.DialerFor(transport)
@ -40,9 +41,10 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
switch url.Scheme {
case "http":
if dialer != nil {
return dialer("tcp", dialAddr)
return dialer(ctx, "tcp", dialAddr)
}
return net.Dial("tcp", dialAddr)
var d net.Dialer
return d.DialContext(ctx, "tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
@ -56,7 +58,7 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
if dialer != nil {
// We have a dialer; use it to open the connection, then
// create a tls client using the connection.
netConn, err := dialer("tcp", dialAddr)
netConn, err := dialer(ctx, "tcp", dialAddr)
if err != nil {
return nil, err
}
@ -86,7 +88,7 @@ func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
}
} else {
// Dial
// Dial. This Dial method does not allow to pass a context unfortunately
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err

View File

@ -17,6 +17,7 @@ limitations under the License.
package proxy
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
@ -42,6 +43,7 @@ func TestDialURL(t *testing.T) {
if err != nil {
t.Fatal(err)
}
var d net.Dialer
testcases := map[string]struct {
TLSConfig *tls.Config
@ -68,25 +70,25 @@ func TestDialURL(t *testing.T) {
"insecure, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: true},
Dial: net.Dial,
Dial: d.DialContext,
},
"secure, no roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false},
Dial: net.Dial,
Dial: d.DialContext,
ExpectError: "unknown authority",
},
"secure with roots, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots},
Dial: net.Dial,
Dial: d.DialContext,
},
"secure with mismatched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "bogus.com"},
Dial: net.Dial,
Dial: d.DialContext,
ExpectError: "not bogus.com",
},
"secure with matched server, custom dial": {
TLSConfig: &tls.Config{InsecureSkipVerify: false, RootCAs: roots, ServerName: "example.com"},
Dial: net.Dial,
Dial: d.DialContext,
},
}
@ -102,7 +104,7 @@ func TestDialURL(t *testing.T) {
// Clone() mutates the receiver (!), so also call it on the copy
tlsConfigCopy.Clone()
transport := &http.Transport{
Dial: tc.Dial,
DialContext: tc.Dial,
TLSClientConfig: tlsConfigCopy,
}
@ -125,7 +127,7 @@ func TestDialURL(t *testing.T) {
u, _ := url.Parse(ts.URL)
_, p, _ := net.SplitHostPort(u.Host)
u.Host = net.JoinHostPort("127.0.0.1", p)
conn, err := DialURL(u, transport)
conn, err := DialURL(context.Background(), u, transport)
// Make sure dialing doesn't mutate the transport's TLSConfig
if !reflect.DeepEqual(tc.TLSConfig, tlsConfigCopy) {

View File

@ -347,7 +347,7 @@ func (h *UpgradeAwareHandler) DialForUpgrade(req *http.Request) (net.Conn, error
// dial dials the backend at req.URL and writes req to it.
func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
conn, err := DialURL(req.URL, transport)
conn, err := DialURL(req.Context(), req.URL, transport)
if err != nil {
return nil, fmt.Errorf("error dialing backend: %v", err)
}

View File

@ -19,6 +19,7 @@ package proxy
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"crypto/x509"
"errors"
@ -341,6 +342,7 @@ func TestProxyUpgrade(t *testing.T) {
if !localhostPool.AppendCertsFromPEM(localhostCert) {
t.Errorf("error setting up localhostCert pool")
}
var d net.Dialer
testcases := map[string]struct {
ServerFunc func(http.Handler) *httptest.Server
@ -395,7 +397,7 @@ func TestProxyUpgrade(t *testing.T) {
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
},
"https (valid hostname + RootCAs + custom dialer + bearer token)": {
ServerFunc: func(h http.Handler) *httptest.Server {
@ -410,9 +412,9 @@ func TestProxyUpgrade(t *testing.T) {
ts.StartTLS()
return ts
},
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
ProxyTransport: utilnet.SetTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
UpgradeTransport: NewUpgradeRequestRoundTripper(
utilnet.SetOldTransportDefaults(&http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
utilnet.SetOldTransportDefaults(&http.Transport{DialContext: d.DialContext, TLSClientConfig: &tls.Config{RootCAs: localhostPool}}),
RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = utilnet.CloneRequest(req)
req.Header.Set("Authorization", "Bearer 1234")
@ -496,9 +498,15 @@ func TestProxyUpgradeErrorResponse(t *testing.T) {
expectedErr = errors.New("EXPECTED")
)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
transport := http.DefaultTransport.(*http.Transport)
transport.Dial = func(network, addr string) (net.Conn, error) {
return &fakeConn{err: expectedErr}, nil
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return &fakeConn{err: expectedErr}, nil
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
responder = &fakeResponder{t: t, w: w}
proxyHandler := NewUpgradeAwareHandler(

View File

@ -17,6 +17,7 @@ limitations under the License.
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
@ -147,9 +148,10 @@ func (cm *ClientManager) HookClient(h *v1beta1.Webhook) (*rest.RESTClient, error
delegateDialer := cfg.Dial
if delegateDialer == nil {
delegateDialer = net.Dial
var d net.Dialer
delegateDialer = d.DialContext
}
cfg.Dial = func(network, addr string) (net.Conn, error) {
cfg.Dial = func(ctx context.Context, network, addr string) (net.Conn, error) {
if addr == host {
u, err := cm.serviceResolver.ResolveEndpoint(svc.Namespace, svc.Name)
if err != nil {
@ -157,7 +159,7 @@ func (cm *ClientManager) HookClient(h *v1beta1.Webhook) (*rest.RESTClient, error
}
addr = u.Host
}
return delegateDialer(network, addr)
return delegateDialer(ctx, network, addr)
}
return complete(cfg)

View File

@ -69,10 +69,10 @@ func newTransportForETCD2(certFile, keyFile, caFile string) (*http.Transport, er
// TODO: Determine if transport needs optimization
tr := utilnet.SetTransportDefaults(&http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
MaxIdleConnsPerHost: 500,
TLSClientConfig: cfg,

View File

@ -44,12 +44,8 @@ const (
defaultRetries = 2
// protobuf mime type
mimePb = "application/com.github.proto-openapi.spec.v2@v1.0+protobuf"
)
var (
// defaultTimeout is the maximum amount of time per request when no timeout has been set on a RESTClient.
// Defaults to 32s in order to have a distinguishable length of time, relative to other timeouts that exist.
// It's a variable to be able to change it in tests.
defaultTimeout = 32 * time.Second
)

View File

@ -23,12 +23,11 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gogo/protobuf/proto"
"github.com/googleapis/gnostic/OpenAPIv2"
"github.com/stretchr/testify/assert"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
@ -131,31 +130,11 @@ func TestGetServerGroupsWithBrokenServer(t *testing.T) {
}
}
}
func TestGetServerGroupsWithTimeout(t *testing.T) {
done := make(chan bool)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// first we need to write headers, otherwise http client will complain about
// exceeding timeout awaiting headers, only after we can block the call
w.Header().Set("Connection", "keep-alive")
if wf, ok := w.(http.Flusher); ok {
wf.Flush()
}
<-done
}))
defer server.Close()
defer close(done)
client := NewDiscoveryClientForConfigOrDie(&restclient.Config{Host: server.URL, Timeout: 2 * time.Second})
_, err := client.ServerGroups()
// the error we're getting here is wrapped in errors.errorString which makes
// it impossible to unwrap and check it's attributes, so instead we're checking
// the textual output which is presenting http.httpError with timeout set to true
if err == nil {
t.Fatal("missing error")
}
if !strings.Contains(err.Error(), "timeout:true") &&
!strings.Contains(err.Error(), "context.deadlineExceededError") {
t.Fatalf("unexpected error: %v", err)
}
func TestTimeoutIsSet(t *testing.T) {
cfg := &restclient.Config{}
setDiscoveryDefaults(cfg)
assert.Equal(t, defaultTimeout, cfg.Timeout)
}
func TestGetServerResourcesWithV1Server(t *testing.T) {

View File

@ -17,6 +17,7 @@ limitations under the License.
package rest
import (
"context"
"fmt"
"io/ioutil"
"net"
@ -110,7 +111,7 @@ type Config struct {
Timeout time.Duration
// Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, address string) (net.Conn, error)
// Version forces a specific version to be used (if registered)
// Do we need this?

View File

@ -17,6 +17,8 @@ limitations under the License.
package rest
import (
"context"
"errors"
"io"
"net"
"net/http"
@ -25,8 +27,6 @@ import (
"strings"
"testing"
fuzz "github.com/google/gofuzz"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
@ -35,8 +35,7 @@ import (
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/util/flowcontrol"
"errors"
fuzz "github.com/google/gofuzz"
"github.com/stretchr/testify/assert"
)
@ -208,7 +207,7 @@ func (n *fakeNegotiatedSerializer) DecoderToVersion(serializer runtime.Decoder,
return &fakeCodec{}
}
var fakeDialFunc = func(network, addr string) (net.Conn, error) {
var fakeDialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, fakeDialerError
}
var fakeDialerError = errors.New("fakedialer")
@ -253,7 +252,7 @@ func TestAnonymousConfig(t *testing.T) {
r.Config = map[string]string{}
},
// Dial does not require fuzzer
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {},
func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {},
)
for i := 0; i < 20; i++ {
original := &Config{}
@ -284,10 +283,10 @@ func TestAnonymousConfig(t *testing.T) {
expected.WrapTransport = nil
}
if actual.Dial != nil {
_, actualError := actual.Dial("", "")
_, expectedError := actual.Dial("", "")
_, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field")
t.Fatalf("CopyConfig dropped the Dial field")
}
} else {
actual.Dial = nil
@ -329,7 +328,7 @@ func TestCopyConfig(t *testing.T) {
func(r *AuthProviderConfigPersister, f fuzz.Continue) {
*r = fakeAuthProviderConfigPersister{}
},
func(r *func(network, addr string) (net.Conn, error), f fuzz.Continue) {
func(r *func(ctx context.Context, network, addr string) (net.Conn, error), f fuzz.Continue) {
*r = fakeDialFunc
},
)
@ -351,8 +350,8 @@ func TestCopyConfig(t *testing.T) {
expected.WrapTransport = nil
}
if actual.Dial != nil {
_, actualError := actual.Dial("", "")
_, expectedError := actual.Dial("", "")
_, actualError := actual.Dial(context.Background(), "", "")
_, expectedError := expected.Dial(context.Background(), "", "")
if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field")
}
@ -361,7 +360,7 @@ func TestCopyConfig(t *testing.T) {
expected.Dial = nil
if actual.AuthConfigPersister != nil {
actualError := actual.AuthConfigPersister.Persist(nil)
expectedError := actual.AuthConfigPersister.Persist(nil)
expectedError := expected.AuthConfigPersister.Persist(nil)
if !reflect.DeepEqual(expectedError, actualError) {
t.Fatalf("CopyConfig dropped the Dial field")
}

View File

@ -85,7 +85,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
dial = (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial
}).DialContext
}
// Cache a single transport for these options
c.transports[key] = utilnet.SetTransportDefaults(&http.Transport{
@ -93,7 +93,7 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) {
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: idleConnsPerHost,
Dial: dial,
DialContext: dial,
})
return c.transports[key], nil
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"net"
"net/http"
"testing"
@ -52,10 +53,11 @@ func TestTLSConfigKey(t *testing.T) {
}
// Make sure config fields that affect the tls config affect the cache key
dialer := net.Dialer{}
uniqueConfigurations := map[string]*Config{
"no tls": {},
"dialer": {Dial: net.Dial},
"dialer2": {Dial: func(network, address string) (net.Conn, error) { return nil, nil }},
"dialer": {Dial: dialer.DialContext},
"dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }},
"insecure": {TLS: TLSConfig{Insecure: true}},
"cadata 1": {TLS: TLSConfig{CAData: []byte{1}}},
"cadata 2": {TLS: TLSConfig{CAData: []byte{2}}},

View File

@ -17,6 +17,7 @@ limitations under the License.
package transport
import (
"context"
"net"
"net/http"
)
@ -53,7 +54,7 @@ type Config struct {
WrapTransport func(rt http.RoundTripper) http.RoundTripper
// Dial specifies the dial function for creating unencrypted TCP connections.
Dial func(network, addr string) (net.Conn, error)
Dial func(ctx context.Context, network, address string) (net.Conn, error)
}
// ImpersonationConfig has all the available impersonation options

View File

@ -209,10 +209,10 @@ func (r *proxyHandler) updateAPIService(apiService *apiregistrationapi.APIServic
serviceAvailable: apiregistrationapi.IsAPIServiceConditionTrue(apiService, apiregistrationapi.Available),
}
newInfo.proxyRoundTripper, newInfo.transportBuildingError = restclient.TransportFor(newInfo.restConfig)
if newInfo.transportBuildingError == nil && r.proxyTransport != nil && r.proxyTransport.Dial != nil {
if newInfo.transportBuildingError == nil && r.proxyTransport != nil && r.proxyTransport.DialContext != nil {
switch transport := newInfo.proxyRoundTripper.(type) {
case *http.Transport:
transport.Dial = r.proxyTransport.Dial
transport.DialContext = r.proxyTransport.DialContext
default:
newInfo.transportBuildingError = fmt.Errorf("unable to set dialer for %s/%s as rest transport is of type %T", apiService.Spec.Service.Namespace, apiService.Spec.Service.Name, newInfo.proxyRoundTripper)
glog.Warning(newInfo.transportBuildingError.Error())

View File

@ -1868,11 +1868,12 @@ func startProxyServer() (int, *exec.Cmd, error) {
}
func curlUnix(url string, path string) (string, error) {
dial := func(proto, addr string) (net.Conn, error) {
return net.Dial("unix", path)
dial := func(ctx context.Context, proto, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", path)
}
transport := utilnet.SetTransportDefaults(&http.Transport{
Dial: dial,
DialContext: dial,
})
return curlTransport(url, transport)
}

View File

@ -373,10 +373,10 @@ func createClients(numberOfClients int) ([]clientset.Interface, []internalclient
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: 100,
Dial: (&net.Dialer{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
}).DialContext,
})
// Overwrite TLS-related fields from config to avoid collision with
// Transport field.