From 7fa120c18c361fe7374a21497f70293d454e5d55 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Tue, 7 Aug 2018 12:28:19 -0700 Subject: [PATCH] CSI plugin now calls NodeGetInfo() to get driver's node ID --- pkg/volume/csi/csi_client.go | 23 ++++++++ pkg/volume/csi/csi_client_test.go | 64 +++++++++++++++++++++ pkg/volume/csi/csi_plugin.go | 28 +++++++-- pkg/volume/csi/fake/fake_client.go | 13 +++++ pkg/volume/csi/labelmanager/labelmanager.go | 7 +-- 5 files changed, 126 insertions(+), 9 deletions(-) diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index e12772d7a9..4a4cb4176b 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -32,6 +32,11 @@ import ( ) type csiClient interface { + NodeGetInfo(ctx context.Context) ( + nodeID string, + maxVolumePerNode int64, + accessibleTopology *csipb.Topology, + err error) NodePublishVolume( ctx context.Context, volumeid string, @@ -75,6 +80,24 @@ func newCsiDriverClient(driverName string) *csiDriverClient { return c } +func (c *csiDriverClient) NodeGetInfo(ctx context.Context) ( + nodeID string, + maxVolumePerNode int64, + accessibleTopology *csipb.Topology, + err error) { + glog.V(4).Info(log("calling NodeGetInfo rpc")) + + conn, err := newGrpcConn(c.driverName) + if err != nil { + return "", 0, nil, err + } + defer conn.Close() + nodeClient := csipb.NewNodeClient(conn) + + res, err := nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{}) + return res.GetNodeId(), res.GetMaxVolumesPerNode(), res.GetAccessibleTopology(), nil +} + func (c *csiDriverClient) NodePublishVolume( ctx context.Context, volID string, diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index c5b8cf01f7..026aa36883 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -24,6 +24,7 @@ import ( csipb "github.com/container-storage-interface/spec/lib/go/csi/v0" api "k8s.io/api/core/v1" "k8s.io/kubernetes/pkg/volume/csi/fake" + "reflect" ) type fakeCsiDriverClient struct { @@ -38,6 +39,15 @@ func newFakeCsiDriverClient(t *testing.T, stagingCapable bool) *fakeCsiDriverCli } } +func (c *fakeCsiDriverClient) NodeGetInfo(ctx context.Context) ( + nodeID string, + maxVolumePerNode int64, + accessibleTopology *csipb.Topology, + err error) { + resp, err := c.nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{}) + return resp.GetNodeId(), resp.GetMaxVolumesPerNode(), resp.GetAccessibleTopology(), err +} + func (c *fakeCsiDriverClient) NodePublishVolume( ctx context.Context, volID string, @@ -141,6 +151,60 @@ func setupClient(t *testing.T, stageUnstageSet bool) csiClient { return newFakeCsiDriverClient(t, stageUnstageSet) } +func TestClientNodeGetInfo(t *testing.T) { + testCases := []struct { + name string + expectedNodeID string + expectedMaxVolumePerNode int64 + expectedAccessibleTopology *csipb.Topology + mustFail bool + err error + }{ + { + name: "test ok", + expectedNodeID: "node1", + expectedMaxVolumePerNode: 16, + expectedAccessibleTopology: &csipb.Topology{ + Segments: map[string]string{"com.example.csi-topology/zone": "zone1"}, + }, + }, + {name: "grpc error", mustFail: true, err: errors.New("grpc error")}, + } + + client := setupClient(t, false /* stageUnstageSet */) + + for _, tc := range testCases { + t.Logf("test case: %s", tc.name) + client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err) + client.(*fakeCsiDriverClient).nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{ + NodeId: tc.expectedNodeID, + MaxVolumesPerNode: tc.expectedMaxVolumePerNode, + AccessibleTopology: tc.expectedAccessibleTopology, + }) + nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background()) + + if tc.mustFail && err == nil { + t.Error("expected an error but got none") + } + + if !tc.mustFail && err != nil { + t.Errorf("expected no errors but got: %v", err) + } + + if nodeID != tc.expectedNodeID { + t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID) + } + + if maxVolumePerNode != tc.expectedMaxVolumePerNode { + t.Errorf("expected maxVolumePerNode: %v; got: %v", tc.expectedMaxVolumePerNode, maxVolumePerNode) + } + + if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) { + t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology) + } + } +} + func TestClientNodePublishVolume(t *testing.T) { testCases := []struct { name string diff --git a/pkg/volume/csi/csi_plugin.go b/pkg/volume/csi/csi_plugin.go index 5eba4f79c4..c485d8815c 100644 --- a/pkg/volume/csi/csi_plugin.go +++ b/pkg/volume/csi/csi_plugin.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "context" "github.com/golang/glog" api "k8s.io/api/core/v1" meta "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -76,6 +77,7 @@ type csiDriversStore struct { sync.RWMutex } +// TODO (verult) consider using a struct instead of global variables // csiDrivers map keep track of all registered CSI drivers on the node and their // corresponding sockets var csiDrivers csiDriversStore @@ -92,17 +94,33 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string, if endpoint == "" { endpoint = socketPath } - // Calling nodeLabelManager to update label for newly registered CSI driver - err := lm.AddLabels(pluginName) - if err != nil { - return nil, err - } + // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key // all other CSI components will be able to get the actual socket of CSI drivers by its name. csiDrivers.Lock() defer csiDrivers.Unlock() csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint} + // Get node info from the driver. + csi := newCsiDriverClient(pluginName) + // TODO (verult) retry with exponential backoff, possibly added in csi client library. + ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) + defer cancel() + driverNodeID, _, _, err := csi.NodeGetInfo(ctx) + if err != nil { + return nil, fmt.Errorf("error during CSI NodeGetInfo() call: %v", err) + } + + // Calling nodeLabelManager to update annotations and labels for newly registered CSI driver + err = lm.AddLabels(pluginName, driverNodeID) + if err != nil { + // Unregister the driver and return error + csiDrivers.Lock() + defer csiDrivers.Unlock() + delete(csiDrivers.driversMap, pluginName) + return nil, err + } + return nil, nil } diff --git a/pkg/volume/csi/fake/fake_client.go b/pkg/volume/csi/fake/fake_client.go index b4341e7151..346c994917 100644 --- a/pkg/volume/csi/fake/fake_client.go +++ b/pkg/volume/csi/fake/fake_client.go @@ -61,6 +61,7 @@ type NodeClient struct { nodePublishedVolumes map[string]string nodeStagedVolumes map[string]string stageUnstageSet bool + nodeGetInfoResp *csipb.NodeGetInfoResponse nextErr error } @@ -78,6 +79,10 @@ func (f *NodeClient) SetNextError(err error) { f.nextErr = err } +func (f *NodeClient) SetNodeGetInfoResp(resp *csipb.NodeGetInfoResponse) { + f.nodeGetInfoResp = resp +} + // GetNodePublishedVolumes returns node published volumes func (f *NodeClient) GetNodePublishedVolumes() map[string]string { return f.nodePublishedVolumes @@ -179,6 +184,14 @@ func (f *NodeClient) NodeGetId(ctx context.Context, in *csipb.NodeGetIdRequest, return nil, nil } +// NodeGetId implements csi method +func (f *NodeClient) NodeGetInfo(ctx context.Context, in *csipb.NodeGetInfoRequest, opts ...grpc.CallOption) (*csipb.NodeGetInfoResponse, error) { + if f.nextErr != nil { + return nil, f.nextErr + } + return f.nodeGetInfoResp, nil +} + // NodeGetCapabilities implements csi method func (f *NodeClient) NodeGetCapabilities(ctx context.Context, in *csipb.NodeGetCapabilitiesRequest, opts ...grpc.CallOption) (*csipb.NodeGetCapabilitiesResponse, error) { resp := &csipb.NodeGetCapabilitiesResponse{ diff --git a/pkg/volume/csi/labelmanager/labelmanager.go b/pkg/volume/csi/labelmanager/labelmanager.go index 02508a20c2..79fd531145 100644 --- a/pkg/volume/csi/labelmanager/labelmanager.go +++ b/pkg/volume/csi/labelmanager/labelmanager.go @@ -34,7 +34,6 @@ const ( // Name of node annotation that contains JSON map of driver names to node // names annotationKey = "csi.volume.kubernetes.io/nodeid" - csiPluginName = "kubernetes.io/csi" ) // labelManagementStruct is struct of channels used for communication between the driver registration @@ -46,7 +45,7 @@ type labelManagerStruct struct { // Interface implements an interface for managing labels of a node type Interface interface { - AddLabels(driverName string) error + AddLabels(driverName string, driverNodeId string) error } // NewLabelManager initializes labelManagerStruct and returns available interfaces @@ -59,8 +58,8 @@ func NewLabelManager(nodeName types.NodeName, kubeClient kubernetes.Interface) I // nodeLabelManager waits for labeling requests initiated by the driver's registration // process. -func (lm labelManagerStruct) AddLabels(driverName string) error { - err := verifyAndAddNodeId(string(lm.nodeName), lm.k8s.CoreV1().Nodes(), driverName, string(lm.nodeName)) +func (lm labelManagerStruct) AddLabels(driverName string, driverNodeId string) error { + err := verifyAndAddNodeId(string(lm.nodeName), lm.k8s.CoreV1().Nodes(), driverName, driverNodeId) if err != nil { return fmt.Errorf("failed to update node %s's annotation with error: %+v", lm.nodeName, err) }