Refactor SSH tunneling, fix proxy transport TLS/Dial extraction

pull/6/head
Jordan Liggitt 2015-10-09 01:18:16 -04:00
parent 826459e51e
commit 1043126135
26 changed files with 739 additions and 513 deletions

View File

@ -376,6 +376,30 @@ func (s *APIServer) Run(_ []string) error {
glog.Fatalf("Cloud provider could not be initialized: %v", err)
}
// Setup tunneler if needed
var tunneler master.Tunneler
var proxyDialerFn apiserver.ProxyDialerFunc
if len(s.SSHUser) > 0 {
// Get ssh key distribution func, if supported
var installSSH master.InstallSSHKey
if cloud != nil {
if instances, supported := cloud.Instances(); supported {
installSSH = instances.AddSSHKeyToAllInstances
}
}
// Set up the tunneler
tunneler = master.NewSSHTunneler(s.SSHUser, s.SSHKeyfile, installSSH)
// Use the tunneler's dialer to connect to the kubelet
s.KubeletConfig.Dial = tunneler.Dial
// Use the tunneler's dialer when proxying to pods, services, and nodes
proxyDialerFn = tunneler.Dial
}
// Proxying to pods and services is IP-based... don't expect to be able to verify the hostname
proxyTLSClientConfig := &tls.Config{InsecureSkipVerify: true}
kubeletClient, err := client.NewKubeletClient(&s.KubeletConfig)
if err != nil {
glog.Fatalf("Failure to start kubelet client: %v", err)
@ -508,12 +532,7 @@ func (s *APIServer) Run(_ []string) error {
}
}
}
var installSSH master.InstallSSHKey
if cloud != nil {
if instances, supported := cloud.Instances(); supported {
installSSH = instances.AddSSHKeyToAllInstances
}
}
config := &master.Config{
StorageDestinations: storageDestinations,
StorageVersions: storageVersions,
@ -542,9 +561,9 @@ func (s *APIServer) Run(_ []string) error {
ClusterName: s.ClusterName,
ExternalHost: s.ExternalHost,
MinRequestTimeout: s.MinRequestTimeout,
SSHUser: s.SSHUser,
SSHKeyfile: s.SSHKeyfile,
InstallSSHKey: installSSH,
ProxyDialer: proxyDialerFn,
ProxyTLSClientConfig: proxyTLSClientConfig,
Tunneler: tunneler,
ServiceNodePortRange: s.ServiceNodePortRange,
KubernetesServiceNodePort: s.KubernetesServiceNodePort,
}

View File

@ -41,7 +41,6 @@ type APIInstaller struct {
info *APIRequestInfoResolver
prefix string // Path prefix where API resources are to be registered.
minRequestTimeout time.Duration
proxyDialerFn ProxyDialerFunc
}
// Struct capturing information about an action ("GET", "POST", "WATCH", PROXY", etc).
@ -64,7 +63,7 @@ var errEmptyName = errors.NewBadRequest("name must be provided")
func (a *APIInstaller) Install(ws *restful.WebService) (apiResources []api.APIResource, errors []error) {
errors = make([]error, 0)
proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info, a.proxyDialerFn})
proxyHandler := (&ProxyHandler{a.prefix + "/proxy/", a.group.Storage, a.group.Codec, a.group.Context, a.info})
// Register the paths in a deterministic (sorted) order to get a deterministic swagger spec.
paths := make([]string, len(a.group.Storage))

View File

@ -105,7 +105,6 @@ type APIGroupVersion struct {
Admit admission.Interface
Context api.RequestContextMapper
ProxyDialerFn ProxyDialerFunc
MinRequestTimeout time.Duration
}
@ -164,7 +163,6 @@ func (g *APIGroupVersion) newInstaller() *APIInstaller {
info: g.APIRequestInfoResolver,
prefix: prefix,
minRequestTimeout: g.MinRequestTimeout,
proxyDialerFn: g.ProxyDialerFn,
}
return installer
}

View File

@ -17,11 +17,8 @@ limitations under the License.
package apiserver
import (
"crypto/tls"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
@ -40,7 +37,6 @@ import (
proxyutil "k8s.io/kubernetes/pkg/util/proxy"
"github.com/golang/glog"
"k8s.io/kubernetes/third_party/golang/netutil"
)
// ProxyHandler provides a http.Handler which will proxy traffic to locations
@ -51,8 +47,6 @@ type ProxyHandler struct {
codec runtime.Codec
context api.RequestContextMapper
apiRequestInfoResolver *APIRequestInfoResolver
dial func(network, addr string) (net.Conn, error)
}
func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@ -125,11 +119,8 @@ func (r *ProxyHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
httpCode = http.StatusNotFound
return
}
// If we have a custom dialer, and no pre-existing transport, initialize it to use the dialer.
if roundTripper == nil && r.dial != nil {
glog.V(5).Infof("[%x: %v] making a dial-only transport...", proxyHandlerTraceID, req.URL)
roundTripper = &http.Transport{Dial: r.dial}
} else if roundTripper != nil {
if roundTripper != nil {
glog.V(5).Infof("[%x: %v] using transport %T...", proxyHandlerTraceID, req.URL, roundTripper)
}
@ -217,7 +208,7 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque
if !httpstream.IsUpgradeRequest(req) {
return false
}
backendConn, err := dialURL(location, transport)
backendConn, err := proxyutil.DialURL(location, transport)
if err != nil {
status := errToAPIStatus(err)
writeJSON(status.Code, r.codec, status, w, true)
@ -264,46 +255,6 @@ func (r *ProxyHandler) tryUpgrade(w http.ResponseWriter, req, newReq *http.Reque
return true
}
func dialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
switch url.Scheme {
case "http":
return net.Dial("tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
if transport != nil {
httpTransport, ok := transport.(*http.Transport)
if ok {
tlsConfig = httpTransport.TLSClientConfig
}
}
// Dial
tlsConn, err := tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
}
// Return if we were configured to skip validation
if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
return tlsConn, nil
}
// Verify
host, _, _ := net.SplitHostPort(dialAddr)
if err := tlsConn.VerifyHostname(host); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
default:
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
}
}
// borrowed from net/http/httputil/reverseproxy.go
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")

View File

@ -23,6 +23,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -171,6 +172,21 @@ func TestProxyUpgrade(t *testing.T) {
},
ProxyTransport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}},
},
"https (valid hostname + RootCAs + custom dialer)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Errorf("https (valid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
},
ProxyTransport: &http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}},
},
}
for k, tc := range testcases {

View File

@ -28,6 +28,8 @@ import (
"net/http"
"reflect"
"runtime"
"k8s.io/kubernetes/pkg/util"
)
// chaosrt provides the ability to perform simulations of HTTP client failures
@ -86,6 +88,12 @@ func (rt *chaosrt) RoundTrip(req *http.Request) (*http.Response, error) {
return rt.rt.RoundTrip(req)
}
var _ = util.RoundTripperWrapper(&chaosrt{})
func (rt *chaosrt) WrappedRoundTripper() http.RoundTripper {
return rt.rt
}
// Seed represents a consistent stream of chaos.
type Seed struct {
*rand.Rand

View File

@ -23,6 +23,7 @@ import (
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/sets"
)
@ -133,3 +134,9 @@ func (rt *DebuggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return response, err
}
var _ = util.RoundTripperWrapper(&DebuggingRoundTripper{})
func (rt *DebuggingRoundTripper) WrappedRoundTripper() http.RoundTripper {
return rt.delegatedRoundTripper
}

View File

@ -22,6 +22,8 @@ import (
"fmt"
"io/ioutil"
"net/http"
"k8s.io/kubernetes/pkg/util"
)
type userAgentRoundTripper struct {
@ -42,6 +44,12 @@ func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return rt.rt.RoundTrip(req)
}
var _ = util.RoundTripperWrapper(&userAgentRoundTripper{})
func (rt *userAgentRoundTripper) WrappedRoundTripper() http.RoundTripper {
return rt.rt
}
type basicAuthRoundTripper struct {
username string
password string
@ -63,6 +71,12 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return rt.rt.RoundTrip(req)
}
var _ = util.RoundTripperWrapper(&basicAuthRoundTripper{})
func (rt *basicAuthRoundTripper) WrappedRoundTripper() http.RoundTripper {
return rt.rt
}
type bearerAuthRoundTripper struct {
bearer string
rt http.RoundTripper
@ -84,6 +98,12 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response,
return rt.rt.RoundTrip(req)
}
var _ = util.RoundTripperWrapper(&bearerAuthRoundTripper{})
func (rt *bearerAuthRoundTripper) WrappedRoundTripper() http.RoundTripper {
return rt.rt
}
// TLSConfigFor returns a tls.Config that will provide the transport level security defined
// by the provided Config. Will return nil if no transport level security is requested.
func TLSConfigFor(config *Config) (*tls.Config, error) {

View File

@ -17,18 +17,15 @@ limitations under the License.
package master
import (
"crypto/tls"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"net/http/pprof"
"net/url"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"k8s.io/kubernetes/pkg/admission"
@ -240,10 +237,12 @@ type Config struct {
// The range of ports to be assigned to services with type=NodePort or greater
ServiceNodePortRange util.PortRange
// Used for secure proxy. If empty, don't use secure proxy.
SSHUser string
SSHKeyfile string
InstallSSHKey InstallSSHKey
// Used to customize default proxy dial/tls options
ProxyDialer apiserver.ProxyDialerFunc
ProxyTLSClientConfig *tls.Config
// Used to start and monitor tunneling
Tunneler Tunneler
KubernetesServiceNodePort int
}
@ -305,14 +304,11 @@ type Master struct {
Handler http.Handler
InsecureHandler http.Handler
// Used for secure proxy
dialer apiserver.ProxyDialerFunc
tunnels *util.SSHTunnelList
tunnelsLock sync.Mutex
installSSHKey InstallSSHKey
lastSync int64 // Seconds since Epoch
lastSyncMetric prometheus.GaugeFunc
clock util.Clock
// Used for custom proxy dialing, and proxy TLS options
proxyTransport http.RoundTripper
// Used to start and monitor tunneling
tunneler Tunneler
// storage for third party objects
thirdPartyStorage storage.Interface
@ -453,7 +449,8 @@ func New(c *Config) *Master {
// TODO: serviceReadWritePort should be passed in as an argument, it may not always be 443
serviceReadWritePort: 443,
installSSHKey: c.InstallSSHKey,
tunneler: c.Tunneler,
KubernetesServiceNodePort: c.KubernetesServiceNodePort,
}
@ -505,10 +502,18 @@ func NewHandlerContainer(mux *http.ServeMux) *restful.Container {
// init initializes master.
func (m *Master) init(c *Config) {
if c.ProxyDialer != nil || c.ProxyTLSClientConfig != nil {
m.proxyTransport = util.SetTransportDefaults(&http.Transport{
Dial: c.ProxyDialer,
TLSClientConfig: c.ProxyTLSClientConfig,
})
}
healthzChecks := []healthz.HealthzChecker{}
m.clock = util.RealClock{}
dbClient := func(resource string) storage.Interface { return c.StorageDestinations.get("", resource) }
podStorage := podetcd.NewStorage(dbClient("pods"), c.EnableWatchCache, c.KubeletClient)
podStorage := podetcd.NewStorage(dbClient("pods"), c.EnableWatchCache, c.KubeletClient, m.proxyTransport)
podTemplateStorage := podtemplateetcd.NewREST(dbClient("podTemplates"))
@ -527,7 +532,7 @@ func (m *Master) init(c *Config) {
endpointsStorage := endpointsetcd.NewREST(dbClient("endpoints"), c.EnableWatchCache)
m.endpointRegistry = endpoint.NewRegistry(endpointsStorage)
nodeStorage, nodeStatusStorage := nodeetcd.NewREST(dbClient("nodes"), c.EnableWatchCache, c.KubeletClient)
nodeStorage, nodeStatusStorage := nodeetcd.NewREST(dbClient("nodes"), c.EnableWatchCache, c.KubeletClient, m.proxyTransport)
m.nodeRegistry = node.NewRegistry(nodeStorage)
serviceStorage := serviceetcd.NewREST(dbClient("services"))
@ -569,7 +574,7 @@ func (m *Master) init(c *Config) {
"replicationControllers": controllerStorage,
"replicationControllers/status": controllerStatusStorage,
"services": service.NewStorage(m.serviceRegistry, m.endpointRegistry, serviceClusterIPAllocator, serviceNodePortAllocator),
"services": service.NewStorage(m.serviceRegistry, m.endpointRegistry, serviceClusterIPAllocator, serviceNodePortAllocator, m.proxyTransport),
"endpoints": endpointsStorage,
"nodes": nodeStorage,
"nodes/status": nodeStatusStorage,
@ -591,51 +596,13 @@ func (m *Master) init(c *Config) {
"componentStatuses": componentstatus.NewStorage(func() map[string]apiserver.Server { return m.getServersToValidate(c) }),
}
// establish the node proxy dialer
if len(c.SSHUser) > 0 {
// Usernames are capped @ 32
if len(c.SSHUser) > 32 {
glog.Warning("SSH User is too long, truncating to 32 chars")
c.SSHUser = c.SSHUser[0:32]
}
glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile)
// public keyfile is written last, so check for that.
publicKeyFile := c.SSHKeyfile + ".pub"
exists, err := util.FileExists(publicKeyFile)
if err != nil {
glog.Errorf("Error detecting if key exists: %v", err)
} else if !exists {
glog.Infof("Key doesn't exist, attempting to create")
err := m.generateSSHKey(c.SSHUser, c.SSHKeyfile, publicKeyFile)
if err != nil {
glog.Errorf("Failed to create key pair: %v", err)
}
}
m.tunnels = &util.SSHTunnelList{}
m.dialer = m.Dial
m.setupSecureProxy(c.SSHUser, c.SSHKeyfile, publicKeyFile)
m.lastSync = m.clock.Now().Unix()
// This is pretty ugly. A better solution would be to pull this all the way up into the
// server.go file.
httpKubeletClient, ok := c.KubeletClient.(*client.HTTPKubeletClient)
if ok {
httpKubeletClient.Config.Dial = m.dialer
transport, err := client.MakeTransport(httpKubeletClient.Config)
if err != nil {
glog.Errorf("Error setting up transport over SSH: %v", err)
} else {
httpKubeletClient.Client.Transport = transport
}
} else {
glog.Errorf("Failed to cast %v to HTTPKubeletClient, skipping SSH tunnel.", c.KubeletClient)
}
if m.tunneler != nil {
m.tunneler.Run(m.getNodeAddresses)
healthzChecks = append(healthzChecks, healthz.NamedCheck("SSH Tunnel Check", m.IsTunnelSyncHealthy))
m.lastSyncMetric = prometheus.NewGaugeFunc(prometheus.GaugeOpts{
prometheus.NewGaugeFunc(prometheus.GaugeOpts{
Name: "apiserver_proxy_tunnel_sync_latency_secs",
Help: "The time since the last successful synchronization of the SSH tunnels for proxy requests.",
}, func() float64 { return float64(m.secondsSinceSync()) })
}, func() float64 { return float64(m.tunneler.SecondsSinceSync()) })
}
apiVersions := []string{}
@ -875,7 +842,6 @@ func (m *Master) defaultAPIGroupVersion() *apiserver.APIGroupVersion {
Admit: m.admissionControl,
Context: m.requestContextMapper,
ProxyDialerFn: m.dialer,
MinRequestTimeout: m.minRequestTimeout,
}
}
@ -1031,7 +997,6 @@ func (m *Master) thirdpartyapi(group, kind, version string) *apiserver.APIGroupV
Context: m.requestContextMapper,
ProxyDialerFn: m.dialer,
MinRequestTimeout: m.minRequestTimeout,
}
}
@ -1094,7 +1059,6 @@ func (m *Master) experimental(c *Config) *apiserver.APIGroupVersion {
Admit: m.admissionControl,
Context: m.requestContextMapper,
ProxyDialerFn: m.dialer,
MinRequestTimeout: m.minRequestTimeout,
}
}
@ -1117,41 +1081,6 @@ func findExternalAddress(node *api.Node) (string, error) {
return "", fmt.Errorf("Couldn't find external address: %v", node)
}
func (m *Master) Dial(net, addr string) (net.Conn, error) {
// Only lock while picking a tunnel.
tunnel, err := func() (util.SSHTunnelEntry, error) {
m.tunnelsLock.Lock()
defer m.tunnelsLock.Unlock()
return m.tunnels.PickRandomTunnel()
}()
if err != nil {
return nil, err
}
start := time.Now()
id := rand.Int63() // So you can match begins/ends in the log.
glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address)
defer func() {
glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start))
}()
return tunnel.Tunnel.Dial(net, addr)
}
func (m *Master) needToReplaceTunnels(addrs []string) bool {
m.tunnelsLock.Lock()
defer m.tunnelsLock.Unlock()
if m.tunnels == nil || m.tunnels.Len() != len(addrs) {
return true
}
// TODO (cjcullen): This doesn't need to be n^2
for ix := range addrs {
if !m.tunnels.Has(addrs[ix]) {
return true
}
}
return false
}
func (m *Master) getNodeAddresses() ([]string, error) {
nodes, err := m.nodeRegistry.ListNodes(api.NewDefaultContext(), labels.Everything(), fields.Everything())
if err != nil {
@ -1170,126 +1099,12 @@ func (m *Master) getNodeAddresses() ([]string, error) {
}
func (m *Master) IsTunnelSyncHealthy(req *http.Request) error {
lag := m.secondsSinceSync()
if m.tunneler == nil {
return nil
}
lag := m.tunneler.SecondsSinceSync()
if lag > 600 {
return fmt.Errorf("Tunnel sync is taking to long: %d", lag)
}
return nil
}
func (m *Master) secondsSinceSync() int64 {
now := m.clock.Now().Unix()
then := atomic.LoadInt64(&m.lastSync)
return now - then
}
func (m *Master) replaceTunnels(user, keyfile string, newAddrs []string) error {
glog.Infof("replacing tunnels. New addrs: %v", newAddrs)
tunnels := util.MakeSSHTunnels(user, keyfile, newAddrs)
if err := tunnels.Open(); err != nil {
return err
}
m.tunnelsLock.Lock()
defer m.tunnelsLock.Unlock()
if m.tunnels != nil {
m.tunnels.Close()
}
m.tunnels = tunnels
atomic.StoreInt64(&m.lastSync, m.clock.Now().Unix())
return nil
}
func (m *Master) loadTunnels(user, keyfile string) error {
addrs, err := m.getNodeAddresses()
if err != nil {
return err
}
if !m.needToReplaceTunnels(addrs) {
return nil
}
// TODO: This is going to unnecessarily close connections to unchanged nodes.
// See comment about using Watch above.
glog.Info("found different nodes. Need to replace tunnels")
return m.replaceTunnels(user, keyfile, addrs)
}
func (m *Master) refreshTunnels(user, keyfile string) error {
addrs, err := m.getNodeAddresses()
if err != nil {
return err
}
return m.replaceTunnels(user, keyfile, addrs)
}
func (m *Master) setupSecureProxy(user, privateKeyfile, publicKeyfile string) {
// Sync loop to ensure that the SSH key has been installed.
go util.Until(func() {
if m.installSSHKey == nil {
glog.Error("Won't attempt to install ssh key: installSSHKey function is nil")
return
}
key, err := util.ParsePublicKeyFromFile(publicKeyfile)
if err != nil {
glog.Errorf("Failed to load public key: %v", err)
return
}
keyData, err := util.EncodeSSHKey(key)
if err != nil {
glog.Errorf("Failed to encode public key: %v", err)
return
}
if err := m.installSSHKey(user, keyData); err != nil {
glog.Errorf("Failed to install ssh key: %v", err)
}
}, 5*time.Minute, util.NeverStop)
// Sync loop for tunnels
// TODO: switch this to watch.
go util.Until(func() {
if err := m.loadTunnels(user, privateKeyfile); err != nil {
glog.Errorf("Failed to load SSH Tunnels: %v", err)
}
if m.tunnels != nil && m.tunnels.Len() != 0 {
// Sleep for 10 seconds if we have some tunnels.
// TODO (cjcullen): tunnels can lag behind actually existing nodes.
time.Sleep(9 * time.Second)
}
}, 1*time.Second, util.NeverStop)
// Refresh loop for tunnels
// TODO: could make this more controller-ish
go util.Until(func() {
time.Sleep(5 * time.Minute)
if err := m.refreshTunnels(user, privateKeyfile); err != nil {
glog.Errorf("Failed to refresh SSH Tunnels: %v", err)
}
}, 0*time.Second, util.NeverStop)
}
func (m *Master) generateSSHKey(user, privateKeyfile, publicKeyfile string) error {
// TODO: user is not used. Consider removing it as an input to the function.
private, public, err := util.GenerateKey(2048)
if err != nil {
return err
}
// If private keyfile already exists, we must have only made it halfway
// through last time, so delete it.
exists, err := util.FileExists(privateKeyfile)
if err != nil {
glog.Errorf("Error detecting if private key exists: %v", err)
} else if exists {
glog.Infof("Private key exists, but public key does not")
if err := os.Remove(privateKeyfile); err != nil {
glog.Errorf("Failed to remove stale private key: %v", err)
}
}
if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil {
return err
}
publicKeyBytes, err := util.EncodePublicKey(public)
if err != nil {
return err
}
if err := ioutil.WriteFile(publicKeyfile+".tmp", publicKeyBytes, 0600); err != nil {
return err
}
return os.Rename(publicKeyfile+".tmp", publicKeyfile)
}

View File

@ -18,6 +18,7 @@ package master
import (
"bytes"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
@ -25,12 +26,9 @@ import (
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/api/latest"
@ -81,7 +79,12 @@ func setUp(t *testing.T) (Master, Config, *assert.Assertions) {
// using the configuration properly.
func TestNew(t *testing.T) {
_, config, assert := setUp(t)
config.KubeletClient = client.FakeKubeletClient{}
config.ProxyDialer = func(network, addr string) (net.Conn, error) { return nil, nil }
config.ProxyTLSClientConfig = &tls.Config{}
master := New(&config)
// Verify many of the variables match their config counterparts
@ -106,7 +109,15 @@ func TestNew(t *testing.T) {
assert.Equal(master.clusterIP, config.PublicAddress)
assert.Equal(master.publicReadWritePort, config.ReadWritePort)
assert.Equal(master.serviceReadWriteIP, config.ServiceReadWriteIP)
assert.Equal(master.installSSHKey, config.InstallSSHKey)
assert.Equal(master.tunneler, config.Tunneler)
// These functions should point to the same memory location
masterDialer, _ := util.Dialer(master.proxyTransport)
masterDialerFunc := fmt.Sprintf("%p", masterDialer)
configDialerFunc := fmt.Sprintf("%p", config.ProxyDialer)
assert.Equal(masterDialerFunc, configDialerFunc)
assert.Equal(master.proxyTransport.(*http.Transport).TLSClientConfig, config.ProxyTLSClientConfig)
}
// TestNewEtcdStorage verifies that the usage of NewEtcdStorage reacts properly when
@ -271,7 +282,6 @@ func TestInstallSwaggerAPI(t *testing.T) {
// creates the expected APIGroupVersion based off of master.
func TestDefaultAPIGroupVersion(t *testing.T) {
master, _, assert := setUp(t)
master.dialer = func(network, addr string) (net.Conn, error) { return nil, nil }
apiGroup := master.defaultAPIGroupVersion()
@ -279,11 +289,6 @@ func TestDefaultAPIGroupVersion(t *testing.T) {
assert.Equal(apiGroup.Admit, master.admissionControl)
assert.Equal(apiGroup.Context, master.requestContextMapper)
assert.Equal(apiGroup.MinRequestTimeout, master.minRequestTimeout)
// These functions should be different instances of the same function
groupDialerFunc := fmt.Sprintf("%+v", apiGroup.ProxyDialerFn)
masterDialerFunc := fmt.Sprintf("%+v", master.dialer)
assert.Equal(groupDialerFunc, masterDialerFunc)
}
// TestExpapi verifies that the unexported exapi creates
@ -299,42 +304,6 @@ func TestExpapi(t *testing.T) {
assert.Equal(expAPIGroup.Version, latest.GroupOrDie("extensions").GroupVersion)
}
// TestSecondsSinceSync verifies that proper results are returned
// when checking the time between syncs
func TestSecondsSinceSync(t *testing.T) {
master, _, assert := setUp(t)
master.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix()
// Nano Second. No difference.
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 2, time.UTC)}
assert.Equal(int64(0), master.secondsSinceSync())
// Second
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 2, 1, time.UTC)}
assert.Equal(int64(1), master.secondsSinceSync())
// Minute
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 2, 1, 1, time.UTC)}
assert.Equal(int64(60), master.secondsSinceSync())
// Hour
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 2, 1, 1, 1, time.UTC)}
assert.Equal(int64(3600), master.secondsSinceSync())
// Day
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 2, 1, 1, 1, 1, time.UTC)}
assert.Equal(int64(86400), master.secondsSinceSync())
// Month
master.clock = &util.FakeClock{Time: time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC)}
assert.Equal(int64(2678400), master.secondsSinceSync())
// Future Month. Should be -Month.
master.lastSync = time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC).Unix()
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC)}
assert.Equal(int64(-2678400), master.secondsSinceSync())
}
// TestGetNodeAddresses verifies that proper results are returned
// when requesting node addresses.
func TestGetNodeAddresses(t *testing.T) {
@ -366,73 +335,6 @@ func TestGetNodeAddresses(t *testing.T) {
assert.Equal([]string{"127.0.0.2", "127.0.0.2"}, addrs)
}
// TestRefreshTunnels verifies that the function errors when no addresses
// are associated with nodes
func TestRefreshTunnels(t *testing.T) {
master, _, assert := setUp(t)
// Fail case (no addresses associated with nodes)
assert.Error(master.refreshTunnels("test", "/tmp/undefined"))
// TODO: pass case without needing actual connections?
}
// TestIsTunnelSyncHealthy verifies that the 600 second lag test
// is honored.
func TestIsTunnelSyncHealthy(t *testing.T) {
master, _, assert := setUp(t)
// Pass case: 540 second lag
master.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix()
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 9, 1, 1, time.UTC)}
err := master.IsTunnelSyncHealthy(nil)
assert.NoError(err, "IsTunnelSyncHealthy() should not have returned an error.")
// Fail case: 720 second lag
master.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 12, 1, 1, time.UTC)}
err = master.IsTunnelSyncHealthy(nil)
assert.Error(err, "IsTunnelSyncHealthy() should have returned an error.")
}
// generateTempFile creates a temporary file path
func generateTempFilePath(prefix string) string {
tmpPath, _ := filepath.Abs(fmt.Sprintf("%s/%s-%d", os.TempDir(), prefix, time.Now().Unix()))
return tmpPath
}
// TestGenerateSSHKey verifies that SSH key generation does indeed
// generate keys even with keys already exist.
func TestGenerateSSHKey(t *testing.T) {
master, _, assert := setUp(t)
privateKey := generateTempFilePath("private")
publicKey := generateTempFilePath("public")
// Make sure we have no test keys laying around
os.Remove(privateKey)
os.Remove(publicKey)
// Pass case: Sunny day case
err := master.generateSSHKey("unused", privateKey, publicKey)
assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err)
// Pass case: PrivateKey exists test case
os.Remove(publicKey)
err = master.generateSSHKey("unused", privateKey, publicKey)
assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err)
// Pass case: PublicKey exists test case
os.Remove(privateKey)
err = master.generateSSHKey("unused", privateKey, publicKey)
assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err)
// Make sure we have no test keys laying around
os.Remove(privateKey)
os.Remove(publicKey)
// TODO: testing error cases where the file can not be removed?
}
func TestDiscoveryAtAPIS(t *testing.T) {
master, config, assert := setUp(t)
master.exp = true

262
pkg/master/tunneler.go Normal file
View File

@ -0,0 +1,262 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package master
import (
"io/ioutil"
"math/rand"
"net"
"os"
"sync"
"sync/atomic"
"time"
"k8s.io/kubernetes/pkg/util"
"github.com/golang/glog"
"github.com/prometheus/client_golang/prometheus"
)
type AddressFunc func() (addresses []string, err error)
type Tunneler interface {
Run(AddressFunc)
Stop()
Dial(net, addr string) (net.Conn, error)
SecondsSinceSync() int64
}
type SSHTunneler struct {
SSHUser string
SSHKeyfile string
InstallSSHKey InstallSSHKey
tunnels *util.SSHTunnelList
tunnelsLock sync.Mutex
lastSync int64 // Seconds since Epoch
lastSyncMetric prometheus.GaugeFunc
clock util.Clock
getAddresses AddressFunc
stopChan chan struct{}
}
func NewSSHTunneler(sshUser string, sshKeyfile string, installSSHKey InstallSSHKey) Tunneler {
return &SSHTunneler{
SSHUser: sshUser,
SSHKeyfile: sshKeyfile,
InstallSSHKey: installSSHKey,
clock: util.RealClock{},
}
}
// Run establishes tunnel loops and returns
func (c *SSHTunneler) Run(getAddresses AddressFunc) {
if c.stopChan != nil {
return
}
c.stopChan = make(chan struct{})
// Save the address getter
if getAddresses != nil {
c.getAddresses = getAddresses
}
// Usernames are capped @ 32
if len(c.SSHUser) > 32 {
glog.Warning("SSH User is too long, truncating to 32 chars")
c.SSHUser = c.SSHUser[0:32]
}
glog.Infof("Setting up proxy: %s %s", c.SSHUser, c.SSHKeyfile)
// public keyfile is written last, so check for that.
publicKeyFile := c.SSHKeyfile + ".pub"
exists, err := util.FileExists(publicKeyFile)
if err != nil {
glog.Errorf("Error detecting if key exists: %v", err)
} else if !exists {
glog.Infof("Key doesn't exist, attempting to create")
err := c.generateSSHKey(c.SSHUser, c.SSHKeyfile, publicKeyFile)
if err != nil {
glog.Errorf("Failed to create key pair: %v", err)
}
}
c.tunnels = &util.SSHTunnelList{}
c.setupSecureProxy(c.SSHUser, c.SSHKeyfile, publicKeyFile)
c.lastSync = c.clock.Now().Unix()
}
// Stop gracefully shuts down the tunneler
func (c *SSHTunneler) Stop() {
if c.stopChan != nil {
close(c.stopChan)
c.stopChan = nil
}
}
func (c *SSHTunneler) Dial(net, addr string) (net.Conn, error) {
// Only lock while picking a tunnel.
tunnel, err := func() (util.SSHTunnelEntry, error) {
c.tunnelsLock.Lock()
defer c.tunnelsLock.Unlock()
return c.tunnels.PickRandomTunnel()
}()
if err != nil {
return nil, err
}
start := time.Now()
id := rand.Int63() // So you can match begins/ends in the log.
glog.V(3).Infof("[%x: %v] Dialing...", id, tunnel.Address)
defer func() {
glog.V(3).Infof("[%x: %v] Dialed in %v.", id, tunnel.Address, time.Now().Sub(start))
}()
return tunnel.Tunnel.Dial(net, addr)
}
func (c *SSHTunneler) SecondsSinceSync() int64 {
now := c.clock.Now().Unix()
then := atomic.LoadInt64(&c.lastSync)
return now - then
}
func (c *SSHTunneler) needToReplaceTunnels(addrs []string) bool {
c.tunnelsLock.Lock()
defer c.tunnelsLock.Unlock()
if c.tunnels == nil || c.tunnels.Len() != len(addrs) {
return true
}
// TODO (cjcullen): This doesn't need to be n^2
for ix := range addrs {
if !c.tunnels.Has(addrs[ix]) {
return true
}
}
return false
}
func (c *SSHTunneler) replaceTunnels(user, keyfile string, newAddrs []string) error {
glog.Infof("replacing tunnels. New addrs: %v", newAddrs)
tunnels := util.MakeSSHTunnels(user, keyfile, newAddrs)
if err := tunnels.Open(); err != nil {
return err
}
c.tunnelsLock.Lock()
defer c.tunnelsLock.Unlock()
if c.tunnels != nil {
c.tunnels.Close()
}
c.tunnels = tunnels
atomic.StoreInt64(&c.lastSync, c.clock.Now().Unix())
return nil
}
func (c *SSHTunneler) loadTunnels(user, keyfile string) error {
addrs, err := c.getAddresses()
if err != nil {
return err
}
if !c.needToReplaceTunnels(addrs) {
return nil
}
// TODO: This is going to unnecessarily close connections to unchanged nodes.
// See comment about using Watch above.
glog.Info("found different nodes. Need to replace tunnels")
return c.replaceTunnels(user, keyfile, addrs)
}
func (c *SSHTunneler) refreshTunnels(user, keyfile string) error {
addrs, err := c.getAddresses()
if err != nil {
return err
}
return c.replaceTunnels(user, keyfile, addrs)
}
func (c *SSHTunneler) setupSecureProxy(user, privateKeyfile, publicKeyfile string) {
// Sync loop to ensure that the SSH key has been installed.
go util.Until(func() {
if c.InstallSSHKey == nil {
glog.Error("Won't attempt to install ssh key: InstallSSHKey function is nil")
return
}
key, err := util.ParsePublicKeyFromFile(publicKeyfile)
if err != nil {
glog.Errorf("Failed to load public key: %v", err)
return
}
keyData, err := util.EncodeSSHKey(key)
if err != nil {
glog.Errorf("Failed to encode public key: %v", err)
return
}
if err := c.InstallSSHKey(user, keyData); err != nil {
glog.Errorf("Failed to install ssh key: %v", err)
}
}, 5*time.Minute, c.stopChan)
// Sync loop for tunnels
// TODO: switch this to watch.
go util.Until(func() {
if err := c.loadTunnels(user, privateKeyfile); err != nil {
glog.Errorf("Failed to load SSH Tunnels: %v", err)
}
if c.tunnels != nil && c.tunnels.Len() != 0 {
// Sleep for 10 seconds if we have some tunnels.
// TODO (cjcullen): tunnels can lag behind actually existing nodes.
time.Sleep(9 * time.Second)
}
}, 1*time.Second, c.stopChan)
// Refresh loop for tunnels
// TODO: could make this more controller-ish
go util.Until(func() {
time.Sleep(5 * time.Minute)
if err := c.refreshTunnels(user, privateKeyfile); err != nil {
glog.Errorf("Failed to refresh SSH Tunnels: %v", err)
}
}, 0*time.Second, c.stopChan)
}
func (c *SSHTunneler) generateSSHKey(user, privateKeyfile, publicKeyfile string) error {
// TODO: user is not used. Consider removing it as an input to the function.
private, public, err := util.GenerateKey(2048)
if err != nil {
return err
}
// If private keyfile already exists, we must have only made it halfway
// through last time, so delete it.
exists, err := util.FileExists(privateKeyfile)
if err != nil {
glog.Errorf("Error detecting if private key exists: %v", err)
} else if exists {
glog.Infof("Private key exists, but public key does not")
if err := os.Remove(privateKeyfile); err != nil {
glog.Errorf("Failed to remove stale private key: %v", err)
}
}
if err := ioutil.WriteFile(privateKeyfile, util.EncodePrivateKey(private), 0600); err != nil {
return err
}
publicKeyBytes, err := util.EncodePublicKey(public)
if err != nil {
return err
}
if err := ioutil.WriteFile(publicKeyfile+".tmp", publicKeyBytes, 0600); err != nil {
return err
}
return os.Rename(publicKeyfile+".tmp", publicKeyfile)
}

139
pkg/master/tunneler_test.go Normal file
View File

@ -0,0 +1,139 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package master
import (
"fmt"
"os"
"path/filepath"
"testing"
"time"
"k8s.io/kubernetes/pkg/util"
"github.com/stretchr/testify/assert"
)
// TestSecondsSinceSync verifies that proper results are returned
// when checking the time between syncs
func TestSecondsSinceSync(t *testing.T) {
tunneler := &SSHTunneler{}
assert := assert.New(t)
tunneler.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix()
// Nano Second. No difference.
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 2, time.UTC)}
assert.Equal(int64(0), tunneler.SecondsSinceSync())
// Second
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 2, 1, time.UTC)}
assert.Equal(int64(1), tunneler.SecondsSinceSync())
// Minute
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 2, 1, 1, time.UTC)}
assert.Equal(int64(60), tunneler.SecondsSinceSync())
// Hour
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 2, 1, 1, 1, time.UTC)}
assert.Equal(int64(3600), tunneler.SecondsSinceSync())
// Day
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 2, 1, 1, 1, 1, time.UTC)}
assert.Equal(int64(86400), tunneler.SecondsSinceSync())
// Month
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC)}
assert.Equal(int64(2678400), tunneler.SecondsSinceSync())
// Future Month. Should be -Month.
tunneler.lastSync = time.Date(2015, time.February, 1, 1, 1, 1, 1, time.UTC).Unix()
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC)}
assert.Equal(int64(-2678400), tunneler.SecondsSinceSync())
}
// TestRefreshTunnels verifies that the function errors when no addresses
// are associated with nodes
func TestRefreshTunnels(t *testing.T) {
tunneler := &SSHTunneler{}
tunneler.getAddresses = func() ([]string, error) { return []string{}, nil }
assert := assert.New(t)
// Fail case (no addresses associated with nodes)
assert.Error(tunneler.refreshTunnels("test", "/tmp/undefined"))
// TODO: pass case without needing actual connections?
}
// TestIsTunnelSyncHealthy verifies that the 600 second lag test
// is honored.
func TestIsTunnelSyncHealthy(t *testing.T) {
tunneler := &SSHTunneler{}
master, _, assert := setUp(t)
master.tunneler = tunneler
// Pass case: 540 second lag
tunneler.lastSync = time.Date(2015, time.January, 1, 1, 1, 1, 1, time.UTC).Unix()
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 9, 1, 1, time.UTC)}
err := master.IsTunnelSyncHealthy(nil)
assert.NoError(err, "IsTunnelSyncHealthy() should not have returned an error.")
// Fail case: 720 second lag
tunneler.clock = &util.FakeClock{Time: time.Date(2015, time.January, 1, 1, 12, 1, 1, time.UTC)}
err = master.IsTunnelSyncHealthy(nil)
assert.Error(err, "IsTunnelSyncHealthy() should have returned an error.")
}
// generateTempFile creates a temporary file path
func generateTempFilePath(prefix string) string {
tmpPath, _ := filepath.Abs(fmt.Sprintf("%s/%s-%d", os.TempDir(), prefix, time.Now().Unix()))
return tmpPath
}
// TestGenerateSSHKey verifies that SSH key generation does indeed
// generate keys even with keys already exist.
func TestGenerateSSHKey(t *testing.T) {
tunneler := &SSHTunneler{}
assert := assert.New(t)
privateKey := generateTempFilePath("private")
publicKey := generateTempFilePath("public")
// Make sure we have no test keys laying around
os.Remove(privateKey)
os.Remove(publicKey)
// Pass case: Sunny day case
err := tunneler.generateSSHKey("unused", privateKey, publicKey)
assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err)
// Pass case: PrivateKey exists test case
os.Remove(publicKey)
err = tunneler.generateSSHKey("unused", privateKey, publicKey)
assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err)
// Pass case: PublicKey exists test case
os.Remove(privateKey)
err = tunneler.generateSSHKey("unused", privateKey, publicKey)
assert.NoError(err, "generateSSHKey should not have retuend an error: %s", err)
// Make sure we have no test keys laying around
os.Remove(privateKey)
os.Remove(publicKey)
// TODO: testing error cases where the file can not be removed?
}

View File

@ -17,10 +17,7 @@ limitations under the License.
package rest
import (
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
@ -29,12 +26,12 @@ import (
"time"
"k8s.io/kubernetes/pkg/api/errors"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/proxy"
"github.com/golang/glog"
"github.com/mxk/go-flowrate/flowrate"
"k8s.io/kubernetes/third_party/golang/netutil"
)
// UpgradeAwareProxyHandler is a handler for proxy requests that may require an upgrade
@ -128,7 +125,7 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R
return false
}
backendConn, err := h.dialURL()
backendConn, err := proxy.DialURL(h.Location, h.Transport)
if err != nil {
h.err = err
return true
@ -189,79 +186,6 @@ func (h *UpgradeAwareProxyHandler) tryUpgrade(w http.ResponseWriter, req *http.R
return true
}
func (h *UpgradeAwareProxyHandler) dialURL() (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(h.Location)
var dialer func(network, addr string) (net.Conn, error)
if httpTransport, ok := h.Transport.(*http.Transport); ok && httpTransport.Dial != nil {
dialer = httpTransport.Dial
}
switch h.Location.Scheme {
case "http":
if dialer != nil {
return dialer("tcp", dialAddr)
}
return net.Dial("tcp", dialAddr)
case "https":
// TODO: this TLS logic can probably be cleaned up; it's messy in an attempt
// to preserve behavior that we don't know for sure is exercised.
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
var tlsConn *tls.Conn
var err error
if h.Transport != nil {
httpTransport, ok := h.Transport.(*http.Transport)
if ok {
tlsConfig = httpTransport.TLSClientConfig
}
}
if dialer != nil {
// We have a dialer; use it to open the connection, then
// create a tls client using the connection.
netConn, err := dialer("tcp", dialAddr)
if err != nil {
return nil, err
}
// tls.Client requires non-nil config
if tlsConfig == nil {
glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
tlsConn = tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return nil, err
}
} else {
// Dial
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
}
}
// Return if we were configured to skip validation
if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
return tlsConn, nil
}
// Verify
host, _, _ := net.SplitHostPort(dialAddr)
if err := tlsConn.VerifyHostname(host); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
default:
return nil, fmt.Errorf("unknown scheme: %s", h.Location.Scheme)
}
}
func (h *UpgradeAwareProxyHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper {
scheme := url.Scheme
host := url.Host
@ -294,7 +218,12 @@ func (p *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, er
}
removeCORSHeaders(resp)
return resp, nil
}
var _ = util.RoundTripperWrapper(&corsRemovingTransport{})
func (rt *corsRemovingTransport) WrappedRoundTripper() http.RoundTripper {
return rt.RoundTripper
}
// removeCORSHeaders strip CORS headers sent from the backend

View File

@ -22,6 +22,7 @@ import (
"crypto/x509"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
@ -305,6 +306,21 @@ func TestProxyUpgrade(t *testing.T) {
},
ProxyTransport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: localhostPool}},
},
"https (valid hostname + RootCAs + custom dialer)": {
ServerFunc: func(h http.Handler) *httptest.Server {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
t.Errorf("https (valid hostname): proxy_test: %v", err)
}
ts := httptest.NewUnstartedServer(h)
ts.TLS = &tls.Config{
Certificates: []tls.Certificate{cert},
}
ts.StartTLS()
return ts
},
ProxyTransport: &http.Transport{Dial: net.Dial, TLSClientConfig: &tls.Config{RootCAs: localhostPool}},
},
}
for k, tc := range testcases {

View File

@ -31,7 +31,8 @@ import (
type REST struct {
*etcdgeneric.Etcd
connection client.ConnectionInfoGetter
connection client.ConnectionInfoGetter
proxyTransport http.RoundTripper
}
// StatusREST implements the REST endpoint for changing the status of a pod.
@ -49,7 +50,7 @@ func (r *StatusREST) Update(ctx api.Context, obj runtime.Object) (runtime.Object
}
// NewREST returns a RESTStorage object that will work against nodes.
func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionInfoGetter) (*REST, *StatusREST) {
func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionInfoGetter, proxyTransport http.RoundTripper) (*REST, *StatusREST) {
prefix := "/minions"
storageInterface := s
@ -91,7 +92,7 @@ func NewREST(s storage.Interface, useCacher bool, connection client.ConnectionIn
statusStore := *store
statusStore.UpdateStrategy = node.StatusStrategy
return &REST{store, connection}, &StatusREST{store: &statusStore}
return &REST{store, connection, proxyTransport}, &StatusREST{store: &statusStore}
}
// Implement Redirector.
@ -99,5 +100,5 @@ var _ = rest.Redirector(&REST{})
// ResourceLocation returns a URL to which one can send traffic for the specified node.
func (r *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.RoundTripper, error) {
return node.ResourceLocation(r, r.connection, ctx, id)
return node.ResourceLocation(r, r.connection, r.proxyTransport, ctx, id)
}

View File

@ -38,7 +38,7 @@ func (fakeConnectionInfoGetter) GetConnectionInfo(host string) (string, uint, ht
func newStorage(t *testing.T) (*REST, *tools.FakeEtcdClient) {
etcdStorage, fakeClient := registrytest.NewEtcdStorage(t, "")
storage, _ := NewREST(etcdStorage, false, fakeConnectionInfoGetter{})
storage, _ := NewREST(etcdStorage, false, fakeConnectionInfoGetter{}, nil)
return storage, fakeClient
}

View File

@ -136,7 +136,7 @@ func MatchNode(label labels.Selector, field fields.Selector) generic.Matcher {
}
// ResourceLocation returns an URL and transport which one can use to send traffic for the specified node.
func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) {
func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGetter, proxyTransport http.RoundTripper, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) {
schemeReq, name, portReq, valid := util.SplitSchemeNamePort(id)
if !valid {
return nil, nil, errors.NewBadRequest(fmt.Sprintf("invalid node request %q", id))
@ -155,7 +155,7 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet
if portReq == "" || strconv.Itoa(ports.KubeletPort) == portReq {
// Ignore requested scheme, use scheme provided by GetConnectionInfo
scheme, port, transport, err := connection.GetConnectionInfo(host)
scheme, port, kubeletTransport, err := connection.GetConnectionInfo(host)
if err != nil {
return nil, nil, err
}
@ -166,8 +166,8 @@ func ResourceLocation(getter ResourceGetter, connection client.ConnectionInfoGet
strconv.FormatUint(uint64(port), 10),
),
},
transport,
kubeletTransport,
nil
}
return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, nil, nil
return &url.URL{Scheme: schemeReq, Host: net.JoinHostPort(host, portReq)}, proxyTransport, nil
}

View File

@ -56,10 +56,11 @@ type PodStorage struct {
// REST implements a RESTStorage for pods against etcd
type REST struct {
*etcdgeneric.Etcd
proxyTransport http.RoundTripper
}
// NewStorage returns a RESTStorage object that will work against pods.
func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGetter) PodStorage {
func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGetter, proxyTransport http.RoundTripper) PodStorage {
prefix := "/pods"
storageInterface := s
@ -106,11 +107,11 @@ func NewStorage(s storage.Interface, useCacher bool, k client.ConnectionInfoGett
statusStore.UpdateStrategy = pod.StatusStrategy
return PodStorage{
Pod: &REST{store},
Pod: &REST{store, proxyTransport},
Binding: &BindingREST{store: store},
Status: &StatusREST{store: &statusStore},
Log: &LogREST{store: store, kubeletConn: k},
Proxy: &ProxyREST{store: store},
Proxy: &ProxyREST{store: store, proxyTransport: proxyTransport},
Exec: &ExecREST{store: store, kubeletConn: k},
Attach: &AttachREST{store: store, kubeletConn: k},
PortForward: &PortForwardREST{store: store, kubeletConn: k},
@ -122,7 +123,7 @@ var _ = rest.Redirector(&REST{})
// ResourceLocation returns a pods location from its HostIP
func (r *REST) ResourceLocation(ctx api.Context, name string) (*url.URL, http.RoundTripper, error) {
return pod.ResourceLocation(r, ctx, name)
return pod.ResourceLocation(r, r.proxyTransport, ctx, name)
}
// BindingREST implements the REST endpoint for binding pods to nodes when etcd is in use.
@ -256,7 +257,8 @@ func (r *LogREST) NewGetOptions() (runtime.Object, bool, string) {
// ProxyREST implements the proxy subresource for a Pod
// TODO: move me into pod/rest - I'm generic to store type via ResourceGetter
type ProxyREST struct {
store *etcdgeneric.Etcd
store *etcdgeneric.Etcd
proxyTransport http.RoundTripper
}
// Implement Connecter
@ -285,7 +287,7 @@ func (r *ProxyREST) Connect(ctx api.Context, id string, opts runtime.Object) (re
if !ok {
return nil, fmt.Errorf("Invalid options object: %#v", opts)
}
location, transport, err := pod.ResourceLocation(r.store, ctx, id)
location, transport, err := pod.ResourceLocation(r.store, r.proxyTransport, ctx, id)
if err != nil {
return nil, err
}

View File

@ -38,7 +38,7 @@ import (
func newStorage(t *testing.T) (*REST, *BindingREST, *StatusREST, *tools.FakeEtcdClient) {
etcdStorage, fakeClient := registrytest.NewEtcdStorage(t, "")
storage := NewStorage(etcdStorage, false, nil)
storage := NewStorage(etcdStorage, false, nil, nil)
return storage.Pod, storage.Binding, storage.Status, fakeClient
}
@ -740,7 +740,7 @@ func TestEtcdUpdateStatus(t *testing.T) {
func TestPodLogValidates(t *testing.T) {
etcdStorage, _ := registrytest.NewEtcdStorage(t, "")
storage := NewStorage(etcdStorage, false, nil)
storage := NewStorage(etcdStorage, false, nil, nil)
negativeOne := int64(-1)
testCases := []*api.PodLogOptions{

View File

@ -17,7 +17,6 @@ limitations under the License.
package pod
import (
"crypto/tls"
"fmt"
"net"
"net/http"
@ -47,13 +46,6 @@ type podStrategy struct {
// objects via the REST API.
var Strategy = podStrategy{api.Scheme, api.SimpleNameGenerator}
// PodProxyTransport is used by the API proxy to connect to pods
// Exported to allow overriding TLS options (like adding a client certificate)
var PodProxyTransport = util.SetTransportDefaults(&http.Transport{
// Turn off hostname verification, because connections are to assigned IPs, not deterministic
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
})
// NamespaceScoped is true for pods.
func (podStrategy) NamespaceScoped() bool {
return true
@ -195,7 +187,7 @@ func getPod(getter ResourceGetter, ctx api.Context, name string) (*api.Pod, erro
}
// ResourceLocation returns a URL to which one can send traffic for the specified pod.
func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) {
func ResourceLocation(getter ResourceGetter, rt http.RoundTripper, ctx api.Context, id string) (*url.URL, http.RoundTripper, error) {
// Allow ID as "podname" or "podname:port" or "scheme:podname:port".
// If port is not specified, try to use the first defined port on the pod.
scheme, name, port, valid := util.SplitSchemeNamePort(id)
@ -227,7 +219,7 @@ func ResourceLocation(getter ResourceGetter, ctx api.Context, id string) (*url.U
} else {
loc.Host = net.JoinHostPort(pod.Status.PodIP, port)
}
return loc, PodProxyTransport, nil
return loc, rt, nil
}
// LogLocation returns the log URL for a pod container. If opts.Container is blank

View File

@ -17,7 +17,6 @@ limitations under the License.
package service
import (
"crypto/tls"
"fmt"
"math/rand"
"net"
@ -48,23 +47,18 @@ type REST struct {
endpoints endpoint.Registry
serviceIPs ipallocator.Interface
serviceNodePorts portallocator.Interface
proxyTransport http.RoundTripper
}
// ServiceProxyTransport is used by the API proxy to connect to services
// Exported to allow overriding TLS options (like adding a client certificate)
var ServiceProxyTransport = util.SetTransportDefaults(&http.Transport{
// Turn off hostname verification, because connections are to assigned IPs, not deterministic
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
})
// NewStorage returns a new REST.
func NewStorage(registry Registry, endpoints endpoint.Registry, serviceIPs ipallocator.Interface,
serviceNodePorts portallocator.Interface) *REST {
serviceNodePorts portallocator.Interface, proxyTransport http.RoundTripper) *REST {
return &REST{
registry: registry,
endpoints: endpoints,
serviceIPs: serviceIPs,
serviceNodePorts: serviceNodePorts,
proxyTransport: proxyTransport,
}
}
@ -314,7 +308,7 @@ func (rs *REST) ResourceLocation(ctx api.Context, id string) (*url.URL, http.Rou
return &url.URL{
Scheme: svcScheme,
Host: net.JoinHostPort(ip, strconv.Itoa(port)),
}, ServiceProxyTransport, nil
}, rs.proxyTransport, nil
}
}
}

View File

@ -46,7 +46,7 @@ func NewTestREST(t *testing.T, endpoints *api.EndpointsList) (*REST, *registryte
portRange := util.PortRange{Base: 30000, Size: 1000}
portAllocator := portallocator.NewPortAllocator(portRange)
storage := NewStorage(registry, endpointRegistry, r, portAllocator)
storage := NewStorage(registry, endpointRegistry, r, portAllocator, nil)
return storage, registry
}

View File

@ -17,7 +17,10 @@ limitations under the License.
package util
import (
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
@ -62,3 +65,40 @@ func SetTransportDefaults(t *http.Transport) *http.Transport {
}
return t
}
type RoundTripperWrapper interface {
http.RoundTripper
WrappedRoundTripper() http.RoundTripper
}
type DialFunc func(net, addr string) (net.Conn, error)
func Dialer(transport http.RoundTripper) (DialFunc, error) {
if transport == nil {
return nil, nil
}
switch transport := transport.(type) {
case *http.Transport:
return transport.Dial, nil
case RoundTripperWrapper:
return Dialer(transport.WrappedRoundTripper())
default:
return nil, fmt.Errorf("unknown transport type: %v", transport)
}
}
func TLSClientConfig(transport http.RoundTripper) (*tls.Config, error) {
if transport == nil {
return nil, nil
}
switch transport := transport.(type) {
case *http.Transport:
return transport.TLSClientConfig, nil
case RoundTripperWrapper:
return TLSClientConfig(transport.WrappedRoundTripper())
default:
return nil, fmt.Errorf("unknown transport type: %v", transport)
}
}

106
pkg/util/proxy/dial.go Normal file
View File

@ -0,0 +1,106 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/third_party/golang/netutil"
)
func DialURL(url *url.URL, transport http.RoundTripper) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)
dialer, _ := util.Dialer(transport)
switch url.Scheme {
case "http":
if dialer != nil {
return dialer("tcp", dialAddr)
}
return net.Dial("tcp", dialAddr)
case "https":
// Get the tls config from the transport if we recognize it
var tlsConfig *tls.Config
var tlsConn *tls.Conn
var err error
tlsConfig, _ = util.TLSClientConfig(transport)
if dialer != nil {
// We have a dialer; use it to open the connection, then
// create a tls client using the connection.
netConn, err := dialer("tcp", dialAddr)
if err != nil {
return nil, err
}
if tlsConfig == nil {
// tls.Client requires non-nil config
glog.Warningf("using custom dialer with no TLSClientConfig. Defaulting to InsecureSkipVerify")
// tls.Handshake() requires ServerName or InsecureSkipVerify
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
} else if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
// tls.Handshake() requires ServerName or InsecureSkipVerify
// infer the ServerName from the hostname we're connecting to.
inferredHost := dialAddr
if host, _, err := net.SplitHostPort(dialAddr); err == nil {
inferredHost = host
}
// Make a copy to avoid polluting the provided config
tlsConfigCopy := *tlsConfig
tlsConfigCopy.ServerName = inferredHost
tlsConfig = &tlsConfigCopy
}
tlsConn = tls.Client(netConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
netConn.Close()
return nil, err
}
} else {
// Dial
tlsConn, err = tls.Dial("tcp", dialAddr, tlsConfig)
if err != nil {
return nil, err
}
}
// Return if we were configured to skip validation
if tlsConfig != nil && tlsConfig.InsecureSkipVerify {
return tlsConn, nil
}
// Verify
host, _, _ := net.SplitHostPort(dialAddr)
if err := tlsConn.VerifyHostname(host); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
default:
return nil, fmt.Errorf("Unknown scheme: %s", url.Scheme)
}
}

View File

@ -31,6 +31,7 @@ import (
"golang.org/x/net/html"
"golang.org/x/net/html/atom"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/sets"
)
@ -118,6 +119,12 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.rewriteResponse(req, resp)
}
var _ = util.RoundTripperWrapper(&Transport{})
func (rt *Transport) WrappedRoundTripper() http.RoundTripper {
return rt.RoundTripper
}
// rewriteURL rewrites a single URL to go through the proxy, if the URL refers
// to the same host as sourceURL, which is the page on which the target URL
// occurred. If any error occurs (e.g. parsing), it returns targetURL.

View File

@ -669,6 +669,9 @@ var _ = Describe("Kubectl client", func() {
By("curling proxy /api/ output")
localAddr := fmt.Sprintf("http://localhost:%d/api/", port)
apiVersions, err := getAPIVersions(localAddr)
if err != nil {
Failf("Expected at least one supported apiversion, got error %v", err)
}
if len(apiVersions.Versions) < 1 {
Failf("Expected at least one supported apiversion, got %v", apiVersions)
}