diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index 36055db9aa..1cff7b40ec 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "io" "net" "time" @@ -69,14 +70,39 @@ type csiClient interface { // csiClient encapsulates all csi-plugin methods type csiDriverClient struct { - driverName string - nodeClient csipb.NodeClient + driverName string + nodeClientCreator nodeClientCreator } var _ csiClient = &csiDriverClient{} +type nodeClientCreator func(driverName string) ( + nodeClient csipb.NodeClient, + closer io.Closer, + err error, +) + +// newNodeClient creates a new NodeClient with the internally used gRPC +// connection set up. It also returns a closer which must to be called to close +// the gRPC connection when the NodeClient is not used anymore. +// This is the default implementation for the nodeClientCreator, used in +// newCsiDriverClient. +func newNodeClient(driverName string) (nodeClient csipb.NodeClient, closer io.Closer, err error) { + var conn *grpc.ClientConn + conn, err = newGrpcConn(driverName) + if err != nil { + return nil, nil, err + } + + nodeClient = csipb.NewNodeClient(conn) + return nodeClient, conn, nil +} + func newCsiDriverClient(driverName string) *csiDriverClient { - c := &csiDriverClient{driverName: driverName} + c := &csiDriverClient{ + driverName: driverName, + nodeClientCreator: newNodeClient, + } return c } @@ -87,12 +113,11 @@ func (c *csiDriverClient) NodeGetInfo(ctx context.Context) ( err error) { glog.V(4).Info(log("calling NodeGetInfo rpc")) - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return "", 0, nil, err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() res, err := nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{}) if err != nil { @@ -122,12 +147,11 @@ func (c *csiDriverClient) NodePublishVolume( return errors.New("missing target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodePublishVolumeRequest{ VolumeId: volID, @@ -171,12 +195,11 @@ func (c *csiDriverClient) NodeUnpublishVolume(ctx context.Context, volID string, return errors.New("missing target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeUnpublishVolumeRequest{ VolumeId: volID, @@ -204,12 +227,11 @@ func (c *csiDriverClient) NodeStageVolume(ctx context.Context, return errors.New("missing staging target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeStageVolumeRequest{ VolumeId: volID, @@ -249,12 +271,11 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT return errors.New("missing staging target path") } - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeUnstageVolumeRequest{ VolumeId: volID, @@ -267,12 +288,11 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT func (c *csiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) { glog.V(4).Info(log("calling NodeGetCapabilities rpc")) - conn, err := newGrpcConn(c.driverName) + nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { return nil, err } - defer conn.Close() - nodeClient := csipb.NewNodeClient(conn) + defer closer.Close() req := &csipb.NodeGetCapabilitiesRequest{} resp, err := nodeClient.NodeGetCapabilities(ctx, req) diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 026aa36883..f589aeaa9a 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -19,12 +19,13 @@ package csi import ( "context" "errors" + "io" + "reflect" "testing" csipb "github.com/container-storage-interface/spec/lib/go/csi/v0" api "k8s.io/api/core/v1" "k8s.io/kubernetes/pkg/volume/csi/fake" - "reflect" ) type fakeCsiDriverClient struct { @@ -151,6 +152,20 @@ func setupClient(t *testing.T, stageUnstageSet bool) csiClient { return newFakeCsiDriverClient(t, stageUnstageSet) } +func checkErr(t *testing.T, expectedAnError bool, actualError error) { + t.Helper() + + errOccurred := actualError != nil + + if expectedAnError && !errOccurred { + t.Error("expected an error") + } + + if !expectedAnError && errOccurred { + t.Errorf("expected no error, got: %v", actualError) + } +} + func TestClientNodeGetInfo(t *testing.T) { testCases := []struct { name string @@ -168,28 +183,33 @@ func TestClientNodeGetInfo(t *testing.T) { Segments: map[string]string{"com.example.csi-topology/zone": "zone1"}, }, }, - {name: "grpc error", mustFail: true, err: errors.New("grpc error")}, + { + name: "grpc error", + mustFail: true, + err: errors.New("grpc error"), + }, } - client := setupClient(t, false /* stageUnstageSet */) - for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) - client.(*fakeCsiDriverClient).nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{ - NodeId: tc.expectedNodeID, - MaxVolumesPerNode: tc.expectedMaxVolumePerNode, - AccessibleTopology: tc.expectedAccessibleTopology, - }) + + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{ + NodeId: tc.expectedNodeID, + MaxVolumesPerNode: tc.expectedMaxVolumePerNode, + AccessibleTopology: tc.expectedAccessibleTopology, + }) + return nodeClient, fakeCloser, nil + }, + } + nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background()) - - if tc.mustFail && err == nil { - t.Error("expected an error but got none") - } - - if !tc.mustFail && err != nil { - t.Errorf("expected no errors but got: %v", err) - } + checkErr(t, tc.mustFail, err) if nodeID != tc.expectedNodeID { t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID) @@ -202,6 +222,10 @@ func TestClientNodeGetInfo(t *testing.T) { if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) { t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology) } + + if !tc.mustFail { + fakeCloser.Check() + } } } @@ -221,11 +245,18 @@ func TestClientNodePublishVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodePublishVolume( context.Background(), tc.volID, @@ -238,9 +269,10 @@ func TestClientNodePublishVolume(t *testing.T) { map[string]string{}, tc.fsType, ) + checkErr(t, tc.mustFail, err) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + if !tc.mustFail { + fakeCloser.Check() } } } @@ -259,14 +291,23 @@ func TestClientNodeUnpublishVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodeUnpublishVolume(context.Background(), tc.volID, tc.targetPath) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + checkErr(t, tc.mustFail, err) + + if !tc.mustFail { + fakeCloser.Check() } } } @@ -288,11 +329,18 @@ func TestClientNodeStageVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("Running test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodeStageVolume( context.Background(), tc.volID, @@ -303,9 +351,10 @@ func TestClientNodeStageVolume(t *testing.T) { tc.secret, map[string]string{"attr0": "val0"}, ) + checkErr(t, tc.mustFail, err) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + if !tc.mustFail { + fakeCloser.Check() } } } @@ -324,17 +373,26 @@ func TestClientNodeUnstageVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t, false) - for _, tc := range testCases { t.Logf("Running test case: %s", tc.name) - client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) { + nodeClient := fake.NewNodeClient(false /* stagingCapable */) + nodeClient.SetNextError(tc.err) + return nodeClient, fakeCloser, nil + }, + } + err := client.NodeUnstageVolume( context.Background(), tc.volID, tc.stagingTargetPath, ) - if tc.mustFail && err == nil { - t.Error("test must fail, but err is nil") + checkErr(t, tc.mustFail, err) + + if !tc.mustFail { + fakeCloser.Check() } } } diff --git a/pkg/volume/csi/fake/BUILD b/pkg/volume/csi/fake/BUILD index 4c32879373..679150a5b5 100644 --- a/pkg/volume/csi/fake/BUILD +++ b/pkg/volume/csi/fake/BUILD @@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "go_default_library", - srcs = ["fake_client.go"], + srcs = [ + "fake_client.go", + "fake_closer.go", + ], importpath = "k8s.io/kubernetes/pkg/volume/csi/fake", visibility = ["//visibility:public"], deps = [ diff --git a/pkg/volume/csi/fake/fake_closer.go b/pkg/volume/csi/fake/fake_closer.go new file mode 100644 index 0000000000..b790e155b5 --- /dev/null +++ b/pkg/volume/csi/fake/fake_closer.go @@ -0,0 +1,47 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fake + +import ( + "testing" +) + +func NewCloser(t *testing.T) *Closer { + return &Closer{ + t: t, + } +} + +type Closer struct { + wasCalled bool + t *testing.T +} + +func (c *Closer) Close() error { + c.wasCalled = true + return nil +} + +func (c *Closer) Check() *Closer { + c.t.Helper() + + if !c.wasCalled { + c.t.Error("expected closer to have been called") + } + + return c +}