Extract connection rotating dialer into a package

This will be re-used for exec auth plugin to rotate connections on
credential change.
pull/8/head
Andrew Lytvynov 2018-05-16 10:30:53 -07:00
parent 835afe683f
commit 85a61ff3aa
6 changed files with 201 additions and 65 deletions

View File

@ -26,6 +26,7 @@ go_library(
"//vendor/k8s.io/client-go/kubernetes/typed/certificates/v1beta1:go_default_library",
"//vendor/k8s.io/client-go/rest:go_default_library",
"//vendor/k8s.io/client-go/util/certificate:go_default_library",
"//vendor/k8s.io/client-go/util/connrotation:go_default_library",
],
)

View File

@ -17,12 +17,10 @@ limitations under the License.
package certificate
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/golang/glog"
@ -31,6 +29,7 @@ import (
"k8s.io/apimachinery/pkg/util/wait"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/util/certificate"
"k8s.io/client-go/util/connrotation"
)
// UpdateTransport instruments a restconfig with a transport that dynamically uses
@ -64,11 +63,7 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig
return nil, fmt.Errorf("there is already a transport or dialer configured")
}
// Custom dialer that will track all connections it creates.
t := &connTracker{
dialer: &net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second},
conns: make(map[*closableConn]struct{}),
}
d := connrotation.NewDialer((&net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second}).DialContext)
tlsConfig, err := restclient.TLSConfigFor(clientConfig)
if err != nil {
@ -128,7 +123,7 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig
// to reperform its TLS handshake with new cert.
//
// See: https://github.com/kubernetes-incubator/bootkube/pull/663#issuecomment-318506493
t.closeAllConns()
d.CloseAll()
}, period, stopCh)
}
@ -137,7 +132,7 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: tlsConfig,
MaxIdleConnsPerHost: 25,
DialContext: t.DialContext, // Use custom dialer.
DialContext: d.DialContext, // Use custom dialer.
})
// Zero out all existing TLS options since our new transport enforces them.
@ -149,60 +144,5 @@ func updateTransport(stopCh <-chan struct{}, period time.Duration, clientConfig
clientConfig.CAFile = ""
clientConfig.Insecure = false
return t.closeAllConns, nil
}
// connTracker is a dialer that tracks all open connections it creates.
type connTracker struct {
dialer *net.Dialer
mu sync.Mutex
conns map[*closableConn]struct{}
}
// closeAllConns forcibly closes all tracked connections.
func (c *connTracker) closeAllConns() {
c.mu.Lock()
conns := c.conns
c.conns = make(map[*closableConn]struct{})
c.mu.Unlock()
for conn := range conns {
conn.Close()
}
}
func (c *connTracker) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := c.dialer.DialContext(ctx, network, address)
if err != nil {
return nil, err
}
closable := &closableConn{Conn: conn}
// Start tracking the connection
c.mu.Lock()
c.conns[closable] = struct{}{}
c.mu.Unlock()
// When the connection is closed, remove it from the map. This will
// be no-op if the connection isn't in the map, e.g. if closeAllConns()
// is called.
closable.onClose = func() {
c.mu.Lock()
delete(c.conns, closable)
c.mu.Unlock()
}
return closable, nil
}
type closableConn struct {
onClose func()
net.Conn
}
func (c *closableConn) Close() error {
go c.onClose()
return c.Conn.Close()
return d.CloseAll, nil
}

View File

@ -162,6 +162,7 @@ filegroup(
"//staging/src/k8s.io/client-go/util/buffer:all-srcs",
"//staging/src/k8s.io/client-go/util/cert:all-srcs",
"//staging/src/k8s.io/client-go/util/certificate:all-srcs",
"//staging/src/k8s.io/client-go/util/connrotation:all-srcs",
"//staging/src/k8s.io/client-go/util/exec:all-srcs",
"//staging/src/k8s.io/client-go/util/flowcontrol:all-srcs",
"//staging/src/k8s.io/client-go/util/homedir:all-srcs",

View File

@ -0,0 +1,28 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = ["connrotation.go"],
importpath = "k8s.io/client-go/util/connrotation",
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = ["connrotation_test.go"],
embed = [":go_default_library"],
)
filegroup(
name = "package-srcs",
srcs = glob(["**"]),
tags = ["automanaged"],
visibility = ["//visibility:private"],
)
filegroup(
name = "all-srcs",
srcs = [":package-srcs"],
tags = ["automanaged"],
visibility = ["//visibility:public"],
)

View File

@ -0,0 +1,105 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package connrotation implements a connection dialer that tracks and can close
// all created connections.
//
// This is used for credential rotation of long-lived connections, when there's
// no way to re-authenticate on a live connection.
package connrotation
import (
"context"
"net"
"sync"
)
// DialFunc is a shorthand for signature of net.DialContext.
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
// Dialer opens connections through Dial and tracks them.
type Dialer struct {
dial DialFunc
mu sync.Mutex
conns map[*closableConn]struct{}
}
// NewDialer creates a new Dialer instance.
//
// If dial is not nil, it will be used to create new underlying connections.
// Otherwise net.DialContext is used.
func NewDialer(dial DialFunc) *Dialer {
return &Dialer{
dial: dial,
conns: make(map[*closableConn]struct{}),
}
}
// CloseAll forcibly closes all tracked connections.
//
// Note: new connections may get created before CloseAll returns.
func (d *Dialer) CloseAll() {
d.mu.Lock()
conns := d.conns
d.conns = make(map[*closableConn]struct{})
d.mu.Unlock()
for conn := range conns {
conn.Close()
}
}
// Dial creates a new tracked connection.
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
// DialContext creates a new tracked connection.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.dial(ctx, network, address)
if err != nil {
return nil, err
}
closable := &closableConn{Conn: conn}
// Start tracking the connection
d.mu.Lock()
d.conns[closable] = struct{}{}
d.mu.Unlock()
// When the connection is closed, remove it from the map. This will
// be no-op if the connection isn't in the map, e.g. if CloseAll()
// is called.
closable.onClose = func() {
d.mu.Lock()
delete(d.conns, closable)
d.mu.Unlock()
}
return closable, nil
}
type closableConn struct {
onClose func()
net.Conn
}
func (c *closableConn) Close() error {
go c.onClose()
return c.Conn.Close()
}

View File

@ -0,0 +1,61 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package connrotation
import (
"context"
"net"
"testing"
"time"
)
func TestCloseAll(t *testing.T) {
closed := make(chan struct{})
dialFn := func(ctx context.Context, network, address string) (net.Conn, error) {
return closeOnlyConn{onClose: func() { closed <- struct{}{} }}, nil
}
dialer := NewDialer(dialFn)
const numConns = 10
// Outer loop to ensure Dialer is re-usable after CloseAll.
for i := 0; i < 5; i++ {
for j := 0; j < numConns; j++ {
if _, err := dialer.Dial("", ""); err != nil {
t.Fatal(err)
}
}
dialer.CloseAll()
for j := 0; j < numConns; j++ {
select {
case <-closed:
case <-time.After(time.Second):
t.Fatalf("iteration %d: 1s after CloseAll only %d/%d connections closed", i, j, numConns)
}
}
}
}
type closeOnlyConn struct {
net.Conn
onClose func()
}
func (c closeOnlyConn) Close() error {
go c.onClose()
return nil
}