Merge pull request #69371 from pivotal-k8s/fix-well-tested-fake

Add tests for `csiDriverClient`
pull/58/head
k8s-ci-robot 2018-10-22 18:28:37 -07:00 committed by GitHub
commit 91ac9d50fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 188 additions and 60 deletions

View File

@ -20,6 +20,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
@ -70,13 +71,38 @@ type csiClient interface {
// csiClient encapsulates all csi-plugin methods
type csiDriverClient struct {
driverName string
nodeClient csipb.NodeClient
nodeClientCreator nodeClientCreator
}
var _ csiClient = &csiDriverClient{}
type nodeClientCreator func(driverName string) (
nodeClient csipb.NodeClient,
closer io.Closer,
err error,
)
// newNodeClient creates a new NodeClient with the internally used gRPC
// connection set up. It also returns a closer which must to be called to close
// the gRPC connection when the NodeClient is not used anymore.
// This is the default implementation for the nodeClientCreator, used in
// newCsiDriverClient.
func newNodeClient(driverName string) (nodeClient csipb.NodeClient, closer io.Closer, err error) {
var conn *grpc.ClientConn
conn, err = newGrpcConn(driverName)
if err != nil {
return nil, nil, err
}
nodeClient = csipb.NewNodeClient(conn)
return nodeClient, conn, nil
}
func newCsiDriverClient(driverName string) *csiDriverClient {
c := &csiDriverClient{driverName: driverName}
c := &csiDriverClient{
driverName: driverName,
nodeClientCreator: newNodeClient,
}
return c
}
@ -87,12 +113,11 @@ func (c *csiDriverClient) NodeGetInfo(ctx context.Context) (
err error) {
glog.V(4).Info(log("calling NodeGetInfo rpc"))
conn, err := newGrpcConn(c.driverName)
nodeClient, closer, err := c.nodeClientCreator(c.driverName)
if err != nil {
return "", 0, nil, err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)
defer closer.Close()
res, err := nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{})
if err != nil {
@ -122,12 +147,11 @@ func (c *csiDriverClient) NodePublishVolume(
return errors.New("missing target path")
}
conn, err := newGrpcConn(c.driverName)
nodeClient, closer, err := c.nodeClientCreator(c.driverName)
if err != nil {
return err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)
defer closer.Close()
req := &csipb.NodePublishVolumeRequest{
VolumeId: volID,
@ -171,12 +195,11 @@ func (c *csiDriverClient) NodeUnpublishVolume(ctx context.Context, volID string,
return errors.New("missing target path")
}
conn, err := newGrpcConn(c.driverName)
nodeClient, closer, err := c.nodeClientCreator(c.driverName)
if err != nil {
return err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)
defer closer.Close()
req := &csipb.NodeUnpublishVolumeRequest{
VolumeId: volID,
@ -204,12 +227,11 @@ func (c *csiDriverClient) NodeStageVolume(ctx context.Context,
return errors.New("missing staging target path")
}
conn, err := newGrpcConn(c.driverName)
nodeClient, closer, err := c.nodeClientCreator(c.driverName)
if err != nil {
return err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)
defer closer.Close()
req := &csipb.NodeStageVolumeRequest{
VolumeId: volID,
@ -249,12 +271,11 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT
return errors.New("missing staging target path")
}
conn, err := newGrpcConn(c.driverName)
nodeClient, closer, err := c.nodeClientCreator(c.driverName)
if err != nil {
return err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)
defer closer.Close()
req := &csipb.NodeUnstageVolumeRequest{
VolumeId: volID,
@ -267,12 +288,11 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT
func (c *csiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) {
glog.V(4).Info(log("calling NodeGetCapabilities rpc"))
conn, err := newGrpcConn(c.driverName)
nodeClient, closer, err := c.nodeClientCreator(c.driverName)
if err != nil {
return nil, err
}
defer conn.Close()
nodeClient := csipb.NewNodeClient(conn)
defer closer.Close()
req := &csipb.NodeGetCapabilitiesRequest{}
resp, err := nodeClient.NodeGetCapabilities(ctx, req)

View File

@ -19,12 +19,13 @@ package csi
import (
"context"
"errors"
"io"
"reflect"
"testing"
csipb "github.com/container-storage-interface/spec/lib/go/csi/v0"
api "k8s.io/api/core/v1"
"k8s.io/kubernetes/pkg/volume/csi/fake"
"reflect"
)
type fakeCsiDriverClient struct {
@ -151,6 +152,20 @@ func setupClient(t *testing.T, stageUnstageSet bool) csiClient {
return newFakeCsiDriverClient(t, stageUnstageSet)
}
func checkErr(t *testing.T, expectedAnError bool, actualError error) {
t.Helper()
errOccurred := actualError != nil
if expectedAnError && !errOccurred {
t.Error("expected an error")
}
if !expectedAnError && errOccurred {
t.Errorf("expected no error, got: %v", actualError)
}
}
func TestClientNodeGetInfo(t *testing.T) {
testCases := []struct {
name string
@ -168,28 +183,33 @@ func TestClientNodeGetInfo(t *testing.T) {
Segments: map[string]string{"com.example.csi-topology/zone": "zone1"},
},
},
{name: "grpc error", mustFail: true, err: errors.New("grpc error")},
{
name: "grpc error",
mustFail: true,
err: errors.New("grpc error"),
},
}
client := setupClient(t, false /* stageUnstageSet */)
for _, tc := range testCases {
t.Logf("test case: %s", tc.name)
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
client.(*fakeCsiDriverClient).nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{
fakeCloser := fake.NewCloser(t)
client := &csiDriverClient{
driverName: "Fake Driver Name",
nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) {
nodeClient := fake.NewNodeClient(false /* stagingCapable */)
nodeClient.SetNextError(tc.err)
nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{
NodeId: tc.expectedNodeID,
MaxVolumesPerNode: tc.expectedMaxVolumePerNode,
AccessibleTopology: tc.expectedAccessibleTopology,
})
return nodeClient, fakeCloser, nil
},
}
nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background())
if tc.mustFail && err == nil {
t.Error("expected an error but got none")
}
if !tc.mustFail && err != nil {
t.Errorf("expected no errors but got: %v", err)
}
checkErr(t, tc.mustFail, err)
if nodeID != tc.expectedNodeID {
t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID)
@ -202,6 +222,10 @@ func TestClientNodeGetInfo(t *testing.T) {
if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) {
t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology)
}
if !tc.mustFail {
fakeCloser.Check()
}
}
}
@ -221,11 +245,18 @@ func TestClientNodePublishVolume(t *testing.T) {
{name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")},
}
client := setupClient(t, false)
for _, tc := range testCases {
t.Logf("test case: %s", tc.name)
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
fakeCloser := fake.NewCloser(t)
client := &csiDriverClient{
driverName: "Fake Driver Name",
nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) {
nodeClient := fake.NewNodeClient(false /* stagingCapable */)
nodeClient.SetNextError(tc.err)
return nodeClient, fakeCloser, nil
},
}
err := client.NodePublishVolume(
context.Background(),
tc.volID,
@ -238,9 +269,10 @@ func TestClientNodePublishVolume(t *testing.T) {
map[string]string{},
tc.fsType,
)
checkErr(t, tc.mustFail, err)
if tc.mustFail && err == nil {
t.Error("test must fail, but err is nil")
if !tc.mustFail {
fakeCloser.Check()
}
}
}
@ -259,14 +291,23 @@ func TestClientNodeUnpublishVolume(t *testing.T) {
{name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")},
}
client := setupClient(t, false)
for _, tc := range testCases {
t.Logf("test case: %s", tc.name)
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
fakeCloser := fake.NewCloser(t)
client := &csiDriverClient{
driverName: "Fake Driver Name",
nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) {
nodeClient := fake.NewNodeClient(false /* stagingCapable */)
nodeClient.SetNextError(tc.err)
return nodeClient, fakeCloser, nil
},
}
err := client.NodeUnpublishVolume(context.Background(), tc.volID, tc.targetPath)
if tc.mustFail && err == nil {
t.Error("test must fail, but err is nil")
checkErr(t, tc.mustFail, err)
if !tc.mustFail {
fakeCloser.Check()
}
}
}
@ -288,11 +329,18 @@ func TestClientNodeStageVolume(t *testing.T) {
{name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")},
}
client := setupClient(t, false)
for _, tc := range testCases {
t.Logf("Running test case: %s", tc.name)
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
fakeCloser := fake.NewCloser(t)
client := &csiDriverClient{
driverName: "Fake Driver Name",
nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) {
nodeClient := fake.NewNodeClient(false /* stagingCapable */)
nodeClient.SetNextError(tc.err)
return nodeClient, fakeCloser, nil
},
}
err := client.NodeStageVolume(
context.Background(),
tc.volID,
@ -303,9 +351,10 @@ func TestClientNodeStageVolume(t *testing.T) {
tc.secret,
map[string]string{"attr0": "val0"},
)
checkErr(t, tc.mustFail, err)
if tc.mustFail && err == nil {
t.Error("test must fail, but err is nil")
if !tc.mustFail {
fakeCloser.Check()
}
}
}
@ -324,17 +373,26 @@ func TestClientNodeUnstageVolume(t *testing.T) {
{name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")},
}
client := setupClient(t, false)
for _, tc := range testCases {
t.Logf("Running test case: %s", tc.name)
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
fakeCloser := fake.NewCloser(t)
client := &csiDriverClient{
driverName: "Fake Driver Name",
nodeClientCreator: func(driverName string) (csipb.NodeClient, io.Closer, error) {
nodeClient := fake.NewNodeClient(false /* stagingCapable */)
nodeClient.SetNextError(tc.err)
return nodeClient, fakeCloser, nil
},
}
err := client.NodeUnstageVolume(
context.Background(),
tc.volID, tc.stagingTargetPath,
)
if tc.mustFail && err == nil {
t.Error("test must fail, but err is nil")
checkErr(t, tc.mustFail, err)
if !tc.mustFail {
fakeCloser.Check()
}
}
}

View File

@ -2,7 +2,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = ["fake_client.go"],
srcs = [
"fake_client.go",
"fake_closer.go",
],
importpath = "k8s.io/kubernetes/pkg/volume/csi/fake",
visibility = ["//visibility:public"],
deps = [

View File

@ -0,0 +1,47 @@
/*
Copyright 2018 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package fake
import (
"testing"
)
func NewCloser(t *testing.T) *Closer {
return &Closer{
t: t,
}
}
type Closer struct {
wasCalled bool
t *testing.T
}
func (c *Closer) Close() error {
c.wasCalled = true
return nil
}
func (c *Closer) Check() *Closer {
c.t.Helper()
if !c.wasCalled {
c.t.Error("expected closer to have been called")
}
return c
}