Lazily dial kms-plugin.

pull/58/head
immutablet 2018-09-12 14:56:44 -07:00
parent 2fb9fc2400
commit 07cbf2545f
4 changed files with 249 additions and 71 deletions

View File

@ -24,6 +24,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"os" "os"
"time"
yaml "github.com/ghodss/yaml" yaml "github.com/ghodss/yaml"
@ -40,6 +41,7 @@ const (
aesGCMTransformerPrefixV1 = "k8s:enc:aesgcm:v1:" aesGCMTransformerPrefixV1 = "k8s:enc:aesgcm:v1:"
secretboxTransformerPrefixV1 = "k8s:enc:secretbox:v1:" secretboxTransformerPrefixV1 = "k8s:enc:secretbox:v1:"
kmsTransformerPrefixV1 = "k8s:enc:kms:v1:" kmsTransformerPrefixV1 = "k8s:enc:kms:v1:"
kmsPluginConnectionTimeout = 3 * time.Second
) )
// GetTransformerOverrides returns the transformer overrides by reading and parsing the encryption provider configuration file // GetTransformerOverrides returns the transformer overrides by reading and parsing the encryption provider configuration file
@ -160,7 +162,7 @@ func GetPrefixTransformers(config *ResourceConfig) ([]value.PrefixTransformer, e
} }
// Get gRPC client service with endpoint. // Get gRPC client service with endpoint.
envelopeService, err := envelopeServiceFactory(provider.KMS.Endpoint) envelopeService, err := envelopeServiceFactory(provider.KMS.Endpoint, kmsPluginConnectionTimeout)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not configure KMS plugin %q, error: %v", provider.KMS.Name, err) return nil, fmt.Errorf("could not configure KMS plugin %q, error: %v", provider.KMS.Name, err)
} }

View File

@ -21,6 +21,7 @@ import (
"encoding/base64" "encoding/base64"
"strings" "strings"
"testing" "testing"
"time"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apiserver/pkg/storage/value" "k8s.io/apiserver/pkg/storage/value"
@ -239,7 +240,7 @@ func (t *testEnvelopeService) Encrypt(data []byte) ([]byte, error) {
} }
// The factory method to create mock envelope service. // The factory method to create mock envelope service.
func newMockEnvelopeService(endpoint string) (envelope.Service, error) { func newMockEnvelopeService(endpoint string, timeout time.Duration) (envelope.Service, error) {
return &testEnvelopeService{}, nil return &testEnvelopeService{}, nil
} }

View File

@ -23,6 +23,7 @@ import (
"net" "net"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
"github.com/golang/glog" "github.com/golang/glog"
@ -39,19 +40,20 @@ const (
// Current version for the protocol interface definition. // Current version for the protocol interface definition.
kmsapiVersion = "v1beta1" kmsapiVersion = "v1beta1"
// The timeout that communicate with KMS server. versionErrorf = "KMS provider api version %s is not supported, only %s is supported now"
timeout = 30 * time.Second
) )
// The gRPC implementation for envelope.Service. // The gRPC implementation for envelope.Service.
type gRPCService struct { type gRPCService struct {
// gRPC client instance kmsClient kmsapi.KeyManagementServiceClient
kmsClient kmsapi.KeyManagementServiceClient connection *grpc.ClientConn
connection *grpc.ClientConn callTimeout time.Duration
mux sync.RWMutex
versionChecked bool
} }
// NewGRPCService returns an envelope.Service which use gRPC to communicate the remote KMS provider. // NewGRPCService returns an envelope.Service which use gRPC to communicate the remote KMS provider.
func NewGRPCService(endpoint string) (Service, error) { func NewGRPCService(endpoint string, callTimeout time.Duration) (Service, error) {
glog.V(4).Infof("Configure KMS provider with endpoint: %s", endpoint) glog.V(4).Infof("Configure KMS provider with endpoint: %s", endpoint)
addr, err := parseEndpoint(endpoint) addr, err := parseEndpoint(endpoint)
@ -59,28 +61,28 @@ func NewGRPCService(endpoint string) (Service, error) {
return nil, err return nil, err
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) connection, err := grpc.Dial(addr, grpc.WithInsecure(), grpc.WithDefaultCallOptions(grpc.FailFast(false)), grpc.WithDialer(
defer cancel() func(string, time.Duration) (net.Conn, error) {
// Ignoring addr and timeout arguments:
// addr - comes from the closure
// timeout - is ignored since we are connecting in a non-blocking configuration
c, err := net.DialTimeout(unixProtocol, addr, 0)
if err != nil {
glog.Errorf("failed to create connection to unix socket: %s, error: %v", addr, err)
}
return c, err
}))
connection, err := grpc.DialContext(ctx, addr, grpc.WithInsecure(), grpc.WithDialer(unixDial))
if err != nil { if err != nil {
return nil, fmt.Errorf("connect remote KMS provider %q failed, error: %v", addr, err) return nil, fmt.Errorf("failed to create connection to %s, error: %v", endpoint, err)
} }
kmsClient := kmsapi.NewKeyManagementServiceClient(connection) kmsClient := kmsapi.NewKeyManagementServiceClient(connection)
return &gRPCService{
err = checkAPIVersion(kmsClient) kmsClient: kmsClient,
if err != nil { connection: connection,
connection.Close() callTimeout: callTimeout,
return nil, fmt.Errorf("failed check version for %q, error: %v", addr, err) }, nil
}
return &gRPCService{kmsClient: kmsClient, connection: connection}, nil
}
// This dialer explicitly ask gRPC to use unix socket as network.
func unixDial(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(unixProtocol, addr, timeout)
} }
// Parse the endpoint to extract schema, host or path. // Parse the endpoint to extract schema, host or path.
@ -109,31 +111,37 @@ func parseEndpoint(endpoint string) (string, error) {
return u.Path, nil return u.Path, nil
} }
// Check the KMS provider API version. func (g *gRPCService) checkAPIVersion(ctx context.Context) error {
// Only matching kmsapiVersion is supported now. g.mux.Lock()
func checkAPIVersion(kmsClient kmsapi.KeyManagementServiceClient) error { defer g.mux.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() if g.versionChecked {
return nil
}
request := &kmsapi.VersionRequest{Version: kmsapiVersion} request := &kmsapi.VersionRequest{Version: kmsapiVersion}
response, err := kmsClient.Version(ctx, request) response, err := g.kmsClient.Version(ctx, request)
if err != nil { if err != nil {
return fmt.Errorf("failed get version from remote KMS provider: %v", err) return fmt.Errorf("failed get version from remote KMS provider: %v", err)
} }
if response.Version != kmsapiVersion { if response.Version != kmsapiVersion {
return fmt.Errorf("KMS provider api version %s is not supported, only %s is supported now", return fmt.Errorf(versionErrorf, response.Version, kmsapiVersion)
response.Version, kmsapiVersion)
} }
g.versionChecked = true
glog.V(4).Infof("KMS provider %s initialized, version: %s", response.RuntimeName, response.RuntimeVersion) glog.V(4).Infof("Version of KMS provider is %s", response.Version)
return nil return nil
} }
// Decrypt a given data string to obtain the original byte data. // Decrypt a given data string to obtain the original byte data.
func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) { func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), g.callTimeout)
defer cancel() defer cancel()
if err := g.checkAPIVersion(ctx); err != nil {
return nil, err
}
request := &kmsapi.DecryptRequest{Cipher: cipher, Version: kmsapiVersion} request := &kmsapi.DecryptRequest{Cipher: cipher, Version: kmsapiVersion}
response, err := g.kmsClient.Decrypt(ctx, request) response, err := g.kmsClient.Decrypt(ctx, request)
if err != nil { if err != nil {
@ -144,8 +152,11 @@ func (g *gRPCService) Decrypt(cipher []byte) ([]byte, error) {
// Encrypt bytes to a string ciphertext. // Encrypt bytes to a string ciphertext.
func (g *gRPCService) Encrypt(plain []byte) ([]byte, error) { func (g *gRPCService) Encrypt(plain []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), g.callTimeout)
defer cancel() defer cancel()
if err := g.checkAPIVersion(ctx); err != nil {
return nil, err
}
request := &kmsapi.EncryptRequest{Plain: plain, Version: kmsapiVersion} request := &kmsapi.EncryptRequest{Plain: plain, Version: kmsapiVersion}
response, err := g.kmsClient.Encrypt(ctx, request) response, err := g.kmsClient.Encrypt(ctx, request)

View File

@ -25,7 +25,9 @@ import (
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
"sync"
"testing" "testing"
"time"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -36,17 +38,143 @@ const (
endpoint = "unix:///@kms-socket.sock" endpoint = "unix:///@kms-socket.sock"
) )
// Normal encryption and decryption operation. // TestKMSPluginLateStart tests the scenario where kms-plugin pod/container starts after kube-apiserver pod/container.
func TestGRPCService(t *testing.T) { // Since the Dial to kms-plugin is non-blocking we expect the construction of gRPC service to succeed even when
// Start a test gRPC server. // kms-plugin is not yet up - dialing happens in the background.
server, err := startTestKMSProvider() func TestKMSPluginLateStart(t *testing.T) {
callTimeout := 3 * time.Second
service, err := NewGRPCService(endpoint, callTimeout)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
defer destroyService(service)
time.Sleep(callTimeout / 2)
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil { if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err) t.Fatalf("failed to start test KMS provider server, error: %v", err)
} }
defer stopTestKMSProvider(server) defer f.server.Stop()
data := []byte("test data")
_, err = service.Encrypt(data)
if err != nil {
t.Fatalf("failed when execute encrypt, error: %v", err)
}
}
// TestIntermittentConnectionLoss tests the scenario where the connection with kms-plugin is intermittently lost.
func TestIntermittentConnectionLoss(t *testing.T) {
var (
wg1 sync.WaitGroup
wg2 sync.WaitGroup
timeout = 30 * time.Second
blackOut = 1 * time.Second
data = []byte("test data")
)
// Start KMS Plugin
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
// connect to kms plugin
service, err := NewGRPCService(endpoint, timeout)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
defer destroyService(service)
_, err = service.Encrypt(data)
if err != nil {
t.Fatalf("failed when execute encrypt, error: %v", err)
}
t.Log("Connected to KMSPlugin")
// Stop KMS Plugin - simulating connection loss
f.server.Stop()
t.Log("KMS Plugin is stopped")
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
// Call service to encrypt data.
t.Log("Sending encrypt request")
wg1.Done()
_, err := service.Encrypt(data)
if err != nil {
t.Fatalf("failed when executing encrypt, error: %v", err)
}
}()
wg1.Wait()
time.Sleep(blackOut)
// Start KMS Plugin
f, err = startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer f.server.Stop()
t.Log("Restarted KMS Plugin")
wg2.Wait()
}
func TestUnsupportedVersion(t *testing.T) {
ver := "invalid"
data := []byte("test data")
wantErr := fmt.Errorf(versionErrorf, ver, kmsapiVersion)
f, err := startFakeKMSProvider(ver)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %ver", err)
}
defer f.server.Stop()
s, err := NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatal(err)
}
defer destroyService(s)
// Encrypt
_, err = s.Encrypt(data)
if err == nil || err.Error() != wantErr.Error() {
t.Errorf("got err: %ver, want: %ver", err, wantErr)
}
destroyService(s)
s, err = NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatal(err)
}
defer destroyService(s)
// Decrypt
_, err = s.Decrypt(data)
if err == nil || err.Error() != wantErr.Error() {
t.Errorf("got err: %ver, want: %ver", err, wantErr)
}
}
func TestConcurrentAccess(t *testing.T) {
}
// Normal encryption and decryption operation.
func TestGRPCService(t *testing.T) {
// Start a test gRPC server.
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer f.server.Stop()
// Create the gRPC client service. // Create the gRPC client service.
service, err := NewGRPCService(endpoint) service, err := NewGRPCService(endpoint, 1*time.Second)
if err != nil { if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err) t.Fatalf("failed to create envelope service, error: %v", err)
} }
@ -70,19 +198,65 @@ func TestGRPCService(t *testing.T) {
} }
} }
// Normal encryption and decryption operation by multiple go-routines.
func TestGRPCServiceConcurrentAccess(t *testing.T) {
// Start a test gRPC server.
f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err)
}
defer f.server.Stop()
// Create the gRPC client service.
service, err := NewGRPCService(endpoint, 1*time.Second)
if err != nil {
t.Fatalf("failed to create envelope service, error: %v", err)
}
defer destroyService(service)
var wg sync.WaitGroup
n := 1000
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
// Call service to encrypt data.
data := []byte("test data")
cipher, err := service.Encrypt(data)
if err != nil {
t.Errorf("failed when execute encrypt, error: %v", err)
}
// Call service to decrypt data.
result, err := service.Decrypt(cipher)
if err != nil {
t.Errorf("failed when execute decrypt, error: %v", err)
}
if !reflect.DeepEqual(data, result) {
t.Errorf("expect: %v, but: %v", data, result)
}
}()
}
wg.Wait()
}
func destroyService(service Service) { func destroyService(service Service) {
s := service.(*gRPCService) if service != nil {
s.connection.Close() s := service.(*gRPCService)
s.connection.Close()
}
} }
// Test all those invalid configuration for KMS provider. // Test all those invalid configuration for KMS provider.
func TestInvalidConfiguration(t *testing.T) { func TestInvalidConfiguration(t *testing.T) {
// Start a test gRPC server. // Start a test gRPC server.
server, err := startTestKMSProvider() f, err := startFakeKMSProvider(kmsapiVersion)
if err != nil { if err != nil {
t.Fatalf("failed to start test KMS provider server, error: %v", err) t.Fatalf("failed to start test KMS provider server, error: %v", err)
} }
defer stopTestKMSProvider(server) defer f.server.Stop()
invalidConfigs := []struct { invalidConfigs := []struct {
name string name string
@ -91,16 +265,12 @@ func TestInvalidConfiguration(t *testing.T) {
}{ }{
{"emptyConfiguration", kmsapiVersion, ""}, {"emptyConfiguration", kmsapiVersion, ""},
{"invalidScheme", kmsapiVersion, "tcp://localhost:6060"}, {"invalidScheme", kmsapiVersion, "tcp://localhost:6060"},
{"unavailableEndpoint", kmsapiVersion, unixProtocol + ":///kms-socket.nonexist"},
{"invalidAPIVersion", "invalidVersion", endpoint},
} }
for _, testCase := range invalidConfigs { for _, testCase := range invalidConfigs {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
setAPIVersion(testCase.apiVersion) f.apiVersion = testCase.apiVersion
defer setAPIVersion(kmsapiVersion) _, err := NewGRPCService(testCase.endpoint, 1*time.Second)
_, err := NewGRPCService(testCase.endpoint)
if err == nil { if err == nil {
t.Fatalf("should fail to create envelope service for %s.", testCase.name) t.Fatalf("should fail to create envelope service for %s.", testCase.name)
} }
@ -109,7 +279,7 @@ func TestInvalidConfiguration(t *testing.T) {
} }
// Start the gRPC server that listens on unix socket. // Start the gRPC server that listens on unix socket.
func startTestKMSProvider() (*grpc.Server, error) { func startFakeKMSProvider(version string) (*fakeKMSPlugin, error) {
sockFile, err := parseEndpoint(endpoint) sockFile, err := parseEndpoint(endpoint)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse endpoint:%q, error %v", endpoint, err) return nil, fmt.Errorf("failed to parse endpoint:%q, error %v", endpoint, err)
@ -119,31 +289,25 @@ func startTestKMSProvider() (*grpc.Server, error) {
return nil, fmt.Errorf("failed to listen on the unix socket, error: %v", err) return nil, fmt.Errorf("failed to listen on the unix socket, error: %v", err)
} }
server := grpc.NewServer() s := grpc.NewServer()
kmsapi.RegisterKeyManagementServiceServer(server, &base64Server{}) f := &fakeKMSPlugin{apiVersion: version, server: s}
go server.Serve(listener) kmsapi.RegisterKeyManagementServiceServer(s, f)
return server, nil go s.Serve(listener)
} return f, nil
func stopTestKMSProvider(server *grpc.Server) {
server.Stop()
} }
// Fake gRPC sever for remote KMS provider. // Fake gRPC sever for remote KMS provider.
// Use base64 to simulate encrypt and decrypt. // Use base64 to simulate encrypt and decrypt.
type base64Server struct{} type fakeKMSPlugin struct {
apiVersion string
var testProviderAPIVersion = kmsapiVersion server *grpc.Server
func setAPIVersion(apiVersion string) {
testProviderAPIVersion = apiVersion
} }
func (s *base64Server) Version(ctx context.Context, request *kmsapi.VersionRequest) (*kmsapi.VersionResponse, error) { func (s *fakeKMSPlugin) Version(ctx context.Context, request *kmsapi.VersionRequest) (*kmsapi.VersionResponse, error) {
return &kmsapi.VersionResponse{Version: testProviderAPIVersion, RuntimeName: "testKMS", RuntimeVersion: "0.0.1"}, nil return &kmsapi.VersionResponse{Version: s.apiVersion, RuntimeName: "testKMS", RuntimeVersion: "0.0.1"}, nil
} }
func (s *base64Server) Decrypt(ctx context.Context, request *kmsapi.DecryptRequest) (*kmsapi.DecryptResponse, error) { func (s *fakeKMSPlugin) Decrypt(ctx context.Context, request *kmsapi.DecryptRequest) (*kmsapi.DecryptResponse, error) {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(request.Cipher))) buf := make([]byte, base64.StdEncoding.DecodedLen(len(request.Cipher)))
n, err := base64.StdEncoding.Decode(buf, request.Cipher) n, err := base64.StdEncoding.Decode(buf, request.Cipher)
if err != nil { if err != nil {
@ -153,7 +317,7 @@ func (s *base64Server) Decrypt(ctx context.Context, request *kmsapi.DecryptReque
return &kmsapi.DecryptResponse{Plain: buf[:n]}, nil return &kmsapi.DecryptResponse{Plain: buf[:n]}, nil
} }
func (s *base64Server) Encrypt(ctx context.Context, request *kmsapi.EncryptRequest) (*kmsapi.EncryptResponse, error) { func (s *fakeKMSPlugin) Encrypt(ctx context.Context, request *kmsapi.EncryptRequest) (*kmsapi.EncryptResponse, error) {
buf := make([]byte, base64.StdEncoding.EncodedLen(len(request.Plain))) buf := make([]byte, base64.StdEncoding.EncodedLen(len(request.Plain)))
base64.StdEncoding.Encode(buf, request.Plain) base64.StdEncoding.Encode(buf, request.Plain)
return &kmsapi.EncryptResponse{Cipher: buf}, nil return &kmsapi.EncryptResponse{Cipher: buf}, nil