Lazily dial kms-plugin.

pull/58/head
immutablet 2018-09-12 14:56:44 -07:00
parent 2fb9fc2400
commit 07cbf2545f
4 changed files with 249 additions and 71 deletions

View File

@ -24,6 +24,7 @@ import (
"io"
"io/ioutil"
"os"
"time"
yaml "github.com/ghodss/yaml"
@ -40,6 +41,7 @@ const (
aesGCMTransformerPrefixV1 = "k8s:enc:aesgcm:v1:"
secretboxTransformerPrefixV1 = "k8s:enc:secretbox:v1:"
kmsTransformerPrefixV1 = "k8s:enc:kms:v1:"
kmsPluginConnectionTimeout = 3 * time.Second
)
// GetTransformerOverrides returns the transformer overrides by reading and parsing the encryption provider configuration file
@ -160,7 +162,7 @@ func GetPrefixTransformers(config *ResourceConfig) ([]value.PrefixTransformer, e
}
// Get gRPC client service with endpoint.
envelopeService, err := envelopeServiceFactory(provider.KMS.Endpoint)
envelopeService, err := envelopeServiceFactory(provider.KMS.Endpoint, kmsPluginConnectionTimeout)
if err != nil {
return nil, fmt.Errorf("could not configure KMS plugin %q, error: %v", provider.KMS.Name, err)
}

View File

@ -21,6 +21,7 @@ import (
"encoding/base64"
"strings"
"testing"
"time"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apiserver/pkg/storage/value"
@ -239,7 +240,7 @@ func (t *testEnvelopeService) Encrypt(data []byte) ([]byte, error) {
}
// The factory method to create mock envelope service.
func newMockEnvelopeService(endpoint string) (envelope.Service, error) {
func newMockEnvelopeService(endpoint string, timeout time.Duration) (envelope.Service, error) {
return &testEnvelopeService{}, nil
}

View File

@ -23,6 +23,7 @@ import (
"net"
"net/url"
"strings"
"sync"
"time"
"github.com/golang/glog"
@ -39,19 +40,20 @@ const (
// Current version for the protocol interface definition.
kmsapiVersion = "v1beta1"
// The timeout that communicate with KMS server.
timeout = 30 * time.Second
versionErrorf = "KMS provider api version %s is not supported, only %s is supported now"
)
// The gRPC implementation for envelope.Service.
type gRPCService struct {
// gRPC client instance
kmsClient kmsapi.KeyManagementServiceClient
connection *grpc.ClientConn
kmsClient kmsapi.KeyManagementServiceClient
connection *grpc.ClientConn
callTimeout time.Duration
mux sync.RWMutex
versionChecked bool
}
// NewGRPCService returns an envelope.Service which use gRPC to communicate the remote KMS provider.
func NewGRPCService(endpoint string) (Service, error) {
func NewGRPCService(endpoint string, callTimeout time.Duration) (Service, error) {
glog.V(4).Infof("Configure KMS provider with endpoint: %s", endpoint)
addr, err := parseEndpoint(endpoint)
@ -59,28 +61,28 @@ func NewGRPCService(endpoint string) (Service, error) {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
connection, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.FailFast(false)), grpc.WithDialer(
func(string, time.Duration) (net.Conn, error) {
// Ignoring addr and timeout arguments:
// addr - comes from the closure
// timeout - is ignored since we are connecting in a non-blocking configuration
c, err := net.DialTimeout(unixProtocol, addr, 0)
if err != nil {
glog.Errorf("failed to create connection to unix socket: %s, error: %v", addr, err)
}
return c, err
}))
connection, err := grpc.DialContext(ctx, addr, grpc.WithInsecure(), grpc.WithDialer(unixDial))
if err != nil {
return nil, fmt.Errorf("connect remote KMS provider %q failed, error: %v", addr, err)
return nil, fmt.Errorf("failed to create connection to %s, error: %v", endpoint, err)
}
kmsClient := kmsapi.NewKeyManagementServiceClient(connection)
err = checkAPIVersion(kmsClient)
if err != nil {
connection.Close()
return nil, fmt.Errorf("failed check version for %q, error: %v", addr, err)
}
return &gRPCService{kmsClient: kmsClient, connection: connection}, nil
}
// This dialer explicitly ask gRPC to use unix socket as network.
func unixDial(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(unixProtocol, addr, timeout)
return &gRPCService{
kmsClient: kmsClient,
connection: connection,
callTimeout: callTimeout,
}, nil
}
// Parse the endpoint to extract schema, host or path.
@ -109,31 +111,37 @@ func parseEndpoint(endpoint string) (string, error) {
return u.Path, nil
}
// Check the KMS provider API version.
// Only matching kmsapiVersion is supported now.
func checkAPIVersion(kmsClient kmsapi.KeyManagementServiceClient) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
func (g *gRPCService) checkAPIVersion(ctx context.Context) error {
g.mux.Lock()
defer g.mux.Unlock()
if g.versionChecked {
return nil
}
request := &kmsapi.VersionRequest{Version: kmsapiVersion}
response, err := kmsClient.Version(ctx, request)
response, err := g.kmsClient.Version(ctx, request)
if err != nil {
return fmt.Errorf("failed get version from remote KMS provider: %v", err)
}
if response.Version != kmsapiVersion {
return fmt.Errorf("KMS provider api version %s is not supported, only %s is supported now",
response.Version, kmsapiVersion)
return fmt.Errorf(versionErrorf, response.Version, kmsapiVersion)
}
g.versionChecked = true
glog.V(4).Infof("KMS provider %s initialized, version: %s", response.RuntimeName, response.RuntimeVersion)
glog.V(4).Infof("Version of KMS provider is %s", response.Version)
return nil
}
// Decrypt a given data string to obtain the original byte data.
func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(context.Background(), g.callTimeout)
defer cancel()
if err := g.checkAPIVersion(ctx); err != nil {
return nil, err
}
request := &kmsapi.DecryptRequest{Cipher: cipher, Version: kmsapiVersion}
response, err := g.kmsClient.Decrypt(ctx, request)
if err != nil {
@ -144,8 +152,11 @@ func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) {
// Encrypt bytes to a string ciphertext.
func (g *gRPCService) Encrypt(plain []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
ctx, cancel := context.WithTimeout(context.Background(), g.callTimeout)
defer cancel()
if err := g.checkAPIVersion(ctx); err != nil {
return nil, err
}
request := &kmsapi.EncryptRequest{Plain: plain, Version: kmsapiVersion}
response, err := g.kmsClient.Encrypt(ctx, request)

View File

@ -25,7 +25,9 @@ import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"time"
"google.golang.org/grpc"
@ -36,17 +38,143 @@ const (
endpoint = "unix:///@kms-socket.sock"
)
// Normal encryption and decryption operation.
func TestGRPCService(t *testing.T) {
// Start a test gRPC server.
server, err := startTestKMSProvider()
// TestKMSPluginLateStart tests the scenario where kms-plugin pod/container starts after kube-apiserver pod/container.
// Since the Dial to kms-plugin is non-blocking we expect the construction of gRPC service to succeed even when
// kms-plugin is not yet up - dialing happens in the background.
func TestKMSPluginLateStart(t *testing.T) {
callTimeout := 3 * time.Second
service, err := NewGRPCService(endpoint, callTimeout)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
defer destroyService(service)
time.Sleep(callTimeout / 2)
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer stopTestKMSProvider(server)
defer f.server.Stop()
data := []byte("test data")
_, err = service.Encrypt(data)
if err != nil {
t.Fatalf("failed when execute encrypt, error: %v", err)
}
}
// TestIntermittentConnectionLoss tests the scenario where the connection with kms-plugin is intermittently lost.
func TestIntermittentConnectionLoss(t *testing.T) {
var (
wg1 sync.WaitGroup
wg2 sync.WaitGroup
timeout = 30 * time.Second
blackOut = 1 * time.Second
data = []byte("test data")
)
// Start KMS Plugin
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
// connect to kms plugin
service, err := NewGRPCService(endpoint, timeout)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
defer destroyService(service)
_, err = service.Encrypt(data)
if err != nil {
t.Fatalf("failed when execute encrypt, error: %v", err)
}
t.Log("Connected to KMSPlugin")
// Stop KMS Plugin - simulating connection loss
f.server.Stop()
t.Log("KMS Plugin is stopped")
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
// Call service to encrypt data.
t.Log("Sending encrypt request")
wg1.Done()
_, err := service.Encrypt(data)
if err != nil {
t.Fatalf("failed when executing encrypt, error: %v", err)
}
}()
wg1.Wait()
time.Sleep(blackOut)
// Start KMS Plugin
f, err = startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer f.server.Stop()
t.Log("Restarted KMS Plugin")
wg2.Wait()
}
func TestUnsupportedVersion(t *testing.T) {
ver := "invalid"
data := []byte("test data")
wantErr := fmt.Errorf(versionErrorf, ver, kmsapiVersion)
f, err := startFakeKMSProvider(ver)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %ver", err)
}
defer f.server.Stop()
s, err := NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatal(err)
}
defer destroyService(s)
// Encrypt
_, err = s.Encrypt(data)
if err == nil || err.Error() != wantErr.Error() {
t.Errorf("got err: %ver, want: %ver", err, wantErr)
}
destroyService(s)
s, err = NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatal(err)
}
defer destroyService(s)
// Decrypt
_, err = s.Decrypt(data)
if err == nil || err.Error() != wantErr.Error() {
t.Errorf("got err: %ver, want: %ver", err, wantErr)
}
}
func TestConcurrentAccess(t *testing.T) {
}
// Normal encryption and decryption operation.
func TestGRPCService(t *testing.T) {
// Start a test gRPC server.
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer f.server.Stop()
// Create the gRPC client service.
service, err := NewGRPCService(endpoint)
service, err := NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
@ -70,19 +198,65 @@ func TestGRPCService(t *testing.T) {
}
}
// Normal encryption and decryption operation by multiple go-routines.
func TestGRPCServiceConcurrentAccess(t *testing.T) {
// Start a test gRPC server.
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer f.server.Stop()
// Create the gRPC client service.
service, err := NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
defer destroyService(service)
var wg sync.WaitGroup
n := 1000
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
// Call service to encrypt data.
data := []byte("test data")
cipher, err := service.Encrypt(data)
if err != nil {
t.Errorf("failed when execute encrypt, error: %v", err)
}
// Call service to decrypt data.
result, err := service.Decrypt(cipher)
if err != nil {
t.Errorf("failed when execute decrypt, error: %v", err)
}
if !reflect.DeepEqual(data, result) {
t.Errorf("expect: %v, but: %v", data, result)
}
}()
}
wg.Wait()
}
func destroyService(service Service) {
s := service.(*gRPCService)
s.connection.Close()
if service != nil {
s := service.(*gRPCService)
s.connection.Close()
}
}
// Test all those invalid configuration for KMS provider.
func TestInvalidConfiguration(t *testing.T) {
// Start a test gRPC server.
server, err := startTestKMSProvider()
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer stopTestKMSProvider(server)
defer f.server.Stop()
invalidConfigs := []struct {
name string
@ -91,16 +265,12 @@ func TestInvalidConfiguration(t *testing.T) {
}{
{"emptyConfiguration", kmsapiVersion, ""},
{"invalidScheme", kmsapiVersion, "tcp://localhost:6060"},
{"unavailableEndpoint", kmsapiVersion, unixProtocol + ":///kms-socket.nonexist"},
{"invalidAPIVersion", "invalidVersion", endpoint},
}
for _, testCase := range invalidConfigs {
t.Run(testCase.name, func(t *testing.T) {
setAPIVersion(testCase.apiVersion)
defer setAPIVersion(kmsapiVersion)
_, err := NewGRPCService(testCase.endpoint)
f.apiVersion = testCase.apiVersion
_, err := NewGRPCService(testCase.endpoint, 1*time.Second)
if err == nil {
t.Fatalf("should fail to create envelope service for %s.", testCase.name)
}
@ -109,7 +279,7 @@ func TestInvalidConfiguration(t *testing.T) {
}
// Start the gRPC server that listens on unix socket.
func startTestKMSProvider() (*grpc.Server, error) {
func startFakeKMSProvider(version string) (*fakeKMSPlugin, error) {
sockFile, err := parseEndpoint(endpoint)
if err != nil {
return nil, fmt.Errorf("failed to parse endpoint:%q, error %v", endpoint, err)
@ -119,31 +289,25 @@ func startTestKMSProvider() (*grpc.Server, error) {
return nil, fmt.Errorf("failed to listen on the unix socket, error: %v", err)
}
server := grpc.NewServer()
kmsapi.RegisterKeyManagementServiceServer(server, &base64Server{})
go server.Serve(listener)
return server, nil
}
func stopTestKMSProvider(server *grpc.Server) {
server.Stop()
s := grpc.NewServer()
f := &fakeKMSPlugin{apiVersion: version, server: s}
kmsapi.RegisterKeyManagementServiceServer(s, f)
go s.Serve(listener)
return f, nil
}
// Fake gRPC sever for remote KMS provider.
// Use base64 to simulate encrypt and decrypt.
type base64Server struct{}
var testProviderAPIVersion = kmsapiVersion
func setAPIVersion(apiVersion string) {
testProviderAPIVersion = apiVersion
type fakeKMSPlugin struct {
apiVersion string
server *grpc.Server
}
func (s *base64Server) Version(ctx context.Context, request *kmsapi.VersionRequest) (*kmsapi.VersionResponse, error) {
return &kmsapi.VersionResponse{Version: testProviderAPIVersion, RuntimeName: "testKMS", RuntimeVersion: "0.0.1"}, nil
func (s *fakeKMSPlugin) Version(ctx context.Context, request *kmsapi.VersionRequest) (*kmsapi.VersionResponse, error) {
return &kmsapi.VersionResponse{Version: s.apiVersion, RuntimeName: "testKMS", RuntimeVersion: "0.0.1"}, nil
}
func (s *base64Server) Decrypt(ctx context.Context, request *kmsapi.DecryptRequest) (*kmsapi.DecryptResponse, error) {
func (s *fakeKMSPlugin) Decrypt(ctx context.Context, request *kmsapi.DecryptRequest) (*kmsapi.DecryptResponse, error) {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(request.Cipher)))
n, err := base64.StdEncoding.Decode(buf, request.Cipher)
if err != nil {
@ -153,7 +317,7 @@ func (s *base64Server) Decrypt(ctx context.Context, request *kmsapi.DecryptReque
return &kmsapi.DecryptResponse{Plain: buf[:n]}, nil
}
func (s *base64Server) Encrypt(ctx context.Context, request *kmsapi.EncryptRequest) (*kmsapi.EncryptResponse, error) {
func (s *fakeKMSPlugin) Encrypt(ctx context.Context, request *kmsapi.EncryptRequest) (*kmsapi.EncryptResponse, error) {
buf := make([]byte, base64.StdEncoding.EncodedLen(len(request.Plain)))
base64.StdEncoding.Encode(buf, request.Plain)
return &kmsapi.EncryptResponse{Cipher: buf}, nil