From 6330cee9ea1be4907e0ccfae6195d9420c53e65b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20H=C3=B6rl?= Date: Mon, 15 Oct 2018 20:55:37 +0100 Subject: [PATCH] Add tests for `csiDriverClient` As #69219 outlines the unit tests in `csi_client_test.go` where not testing the actual implementation of the `csiDriverClient` but was testing the fake. To fix this, we changed the `csiDriverClient` to use a `nodeClientCreator` which is responsible for creating a new `NodeClient`, a real one in prod and a fake one in the tests. The setup of the gRPC connection has been pushed into that creator. The node client uses that connection; that's transparent to the driver client. It's the responsibility of the driver client to close the connection when it is done with the node client. To achieve this, we have the node client creator return a closer which handles the connection teardown. In the tests we now also check if the driver client actually calls this closer, thus closing the gRPC connection. Closes: #69219 Co-authored-by: Rosie Bloxsom Co-authored-by: Maria Ntalla --- pkg/volume/csi/csi_client.go | 62 ++++++++----- pkg/volume/csi/csi_client_test.go | 134 +++++++++++++++++++++-------- pkg/volume/csi/fake/BUILD | 5 +- pkg/volume/csi/fake/fake_closer.go | 47 ++++++++++ 4 files changed, 188 insertions(+), 60 deletions(-) create mode 100644 pkg/volume/csi/fake/fake_closer.go diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index 36055db9aa..1cff7b40ec 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "io" "net" "time" @@ -69,14 +70,39 @@ type csiClient interface { // csiClient encapsulates all csi-plugin methods type csiDriverClient struct { - driverName string - nodeClient csipb.NodeClient + driverName string + nodeClientCreator nodeClientCreator } var _ csiClient = &csiDriverClient{} +type nodeClientCreator func(driverName string) ( + nodeClient csipb.NodeClient, + closer io.Closer, + err error, +) + +// newNodeClient creates a new NodeClient with the internally used gRPC +// connection set up. It also returns a closer which must to be called to close +// the gRPC connection when the NodeClient is not used anymore. +// This is the default implementation for the nodeClientCreator, used in +// newCsiDriverClient. +func newNodeClient(driverName string) (nodeClient csipb.NodeClient, closer io.Closer, err error) { + var conn *grpc.ClientConn + conn, err = newGrpcConn(driverName) + if err != nil { + return nil, nil, err + } + + nodeClient = csipb.NewNodeClient(conn) + return nodeClient, conn, nil +} + func newCsiDriverClient(driverName string) *csiDriverClient { - c := &csiDriverClient{driverName: driverName} + c := &csiDriverClient{ + driverName: driverName, + nodeClientCreator: newNodeClient, + } return c } @@ -87,12 +113,11 @@ func (c *csiDriverClient) NodeGetInfo(ctx context.Context) ( err error) { glog.V(4).Info(log("calling NodeGetInfo rpc")) - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return "", 0, nil, err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() res, err := nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{}) if err != nil { @@ -122,12 +147,11 @@ func (c *csiDriverClient) NodePublishVolume( return errors.New("missing target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodePublishVolumeRequest{ VolumeId: volID, @@ -171,12 +195,11 @@ func (c *csiDriverClient) NodeUnpublishVolume(ctx context.Context, volID string, return errors.New("missing target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeUnpublishVolumeRequest{ VolumeId: volID, @@ -204,12 +227,11 @@ func (c *csiDriverClient) NodeStageVolume(ctx context.Context, return errors.New("missing staging target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeStageVolumeRequest{ VolumeId: volID, @@ -249,12 +271,11 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT return errors.New("missing staging target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeUnstageVolumeRequest{ VolumeId: volID, @@ -267,12 +288,11 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT func (c *csiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) { glog.V(4).Info(log("calling NodeGetCapabilities rpc")) - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return nil, err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeGetCapabilitiesRequest{} resp, err := nodeClient.NodeGetCapabilities(ctx, req) diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 026aa36883..f589aeaa9a 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -19,12 +19,13 @@ package csi import ( "context" "errors" + "io" + "reflect" "testing" csipb "github.com/container-storage-interface/spec/lib/go/csi/v0" api "k8s.io/api/core/v1" "k8s.io/kubernetes/pkg/volume/csi/fake" - "reflect" ) type fakeCsiDriverClient struct { @@ -151,6 +152,20 @@ func setupClient(t *testing.T, stageUnstageSet bool) csiClient { return newFakeCsiDriverClient(t, stageUnstageSet) } +func checkErr(t *testing.T, expectedAnError bool, actualError error) { + t.Helper() + + errOccurred := actualError != nil + + if expectedAnError && !errOccurred { + t.Error("expected an error") + } + + if !expectedAnError && errOccurred { + t.Errorf("expected no error, got: %v", actualError) + } +} + func TestClientNodeGetInfo(t *testing.T) { testCases := []struct { name string @@ -168,28 +183,33 @@ func TestClientNodeGetInfo(t *testing.T) { Segments: map[string]string{"com.example.csi-topology/zone": "zone1"}, }, }, - {name: "grpc error", mustFail: true, err: errors.New("grpc error")}, + { + name: "grpc error", + mustFail: true, + err: errors.New("grpc error"), + }, } - client := setupClient(t, false /* stageUnstageSet */) - for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) - client.(*fakeCsiDriverClient).nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{ - NodeId: tc.expectedNodeID, - MaxVolumesPerNode: tc.expectedMaxVolumePerNode, - AccessibleTopology: tc.expectedAccessibleTopology, - }) + + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{ + NodeId: tc.expectedNodeID, + MaxVolumesPerNode: tc.expectedMaxVolumePerNode, + AccessibleTopology: tc.expectedAccessibleTopology, + }) + return nodeClient, fakeCloser, nil + }, + } + nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background()) - - if tc.mustFail && err == nil { - t.Error("expected an error but got none") - } - - if !tc.mustFail && err != nil { - t.Errorf("expected no errors but got: %v", err) - } + checkErr(t, tc.mustFail, err) if nodeID != tc.expectedNodeID { t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID) @@ -202,6 +222,10 @@ func TestClientNodeGetInfo(t *testing.T) { if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) { t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology) } + + if !tc.mustFail { + fakeCloser.Check() + } } } @@ -221,11 +245,18 @@ func TestClientNodePublishVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodePublishVolume( context.Background(), tc.volID, @@ -238,9 +269,10 @@ func TestClientNodePublishVolume(t *testing.T) { map[string]string{}, tc.fsType, ) + checkErr(t, tc.mustFail, err) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + if !tc.mustFail { + fakeCloser.Check() } } } @@ -259,14 +291,23 @@ func TestClientNodeUnpublishVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodeUnpublishVolume(context.Background(), tc.volID, tc.targetPath) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + checkErr(t, tc.mustFail, err) + + if !tc.mustFail { + fakeCloser.Check() } } } @@ -288,11 +329,18 @@ func TestClientNodeStageVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("Running test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodeStageVolume( context.Background(), tc.volID, @@ -303,9 +351,10 @@ func TestClientNodeStageVolume(t *testing.T) { tc.secret, map[string]string{"attr0": "val0"}, ) + checkErr(t, tc.mustFail, err) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + if !tc.mustFail { + fakeCloser.Check() } } } @@ -324,17 +373,26 @@ func TestClientNodeUnstageVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("Running test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodeUnstageVolume( context.Background(), tc.volID, tc.stagingTargetPath, ) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + checkErr(t, tc.mustFail, err) + + if !tc.mustFail { + fakeCloser.Check() } } } diff --git a/pkg/volume/csi/fake/BUILD b/pkg/volume/csi/fake/BUILD index 4c32879373..679150a5b5 100644 --- a/pkg/volume/csi/fake/BUILD +++ b/pkg/volume/csi/fake/BUILD @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "go_default_library", - srcs = ["fake_client.go"], + srcs = [ + "fake_client.go", + "fake_closer.go", + ], importpath = "k8s.io/kubernetes/pkg/volume/csi/fake", visibility = ["//visibility:public"], deps = [ diff --git a/pkg/volume/csi/fake/fake_closer.go b/pkg/volume/csi/fake/fake_closer.go new file mode 100644 index 0000000000..b790e155b5 --- /dev/null +++ b/pkg/volume/csi/fake/fake_closer.go @@ -0,0 +1,47 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fake + +import ( + "testing" +) + +func NewCloser(t *testing.T) *Closer { + return &Closer{ + t: t, + } +} + +type Closer struct { + wasCalled bool + t *testing.T +} + +func (c *Closer) Close() error { + c.wasCalled = true + return nil +} + +func (c *Closer) Check() *Closer { + c.t.Helper() + + if !c.wasCalled { + c.t.Error("expected closer to have been called") + } + + return c +}