diff --git a/pkg/volume/csi/BUILD b/pkg/volume/csi/BUILD index 2bc3929fe0..9e0a5ef3fa 100644 --- a/pkg/volume/csi/BUILD +++ b/pkg/volume/csi/BUILD @@ -23,6 +23,7 @@ go_library( "//staging/src/k8s.io/apimachinery/pkg/api/errors:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/types:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/version:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/wait:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/watch:go_default_library", "//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library", @@ -86,6 +87,7 @@ filegroup( name = "all-srcs", srcs = [ ":package-srcs", + "//pkg/volume/csi/csiv0:all-srcs", "//pkg/volume/csi/fake:all-srcs", "//pkg/volume/csi/nodeinfomanager:all-srcs", ], diff --git a/pkg/volume/csi/csi_attacher.go b/pkg/volume/csi/csi_attacher.go index c66792af0f..26a58ca707 100644 --- a/pkg/volume/csi/csi_attacher.go +++ b/pkg/volume/csi/csi_attacher.go @@ -29,7 +29,6 @@ import ( "k8s.io/klog" - csipb "github.com/container-storage-interface/spec/lib/go/csi" "k8s.io/api/core/v1" storage "k8s.io/api/storage/v1beta1" apierrs "k8s.io/apimachinery/pkg/api/errors" @@ -349,7 +348,7 @@ func (c *csiAttacher) MountDevice(spec *volume.Spec, devicePath string, deviceMo ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) defer cancel() // Check whether "STAGE_UNSTAGE_VOLUME" is set - stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + stageUnstageSet, err := csi.NodeSupportsStageUnstage(ctx) if err != nil { return err } @@ -529,7 +528,7 @@ func (c *csiAttacher) UnmountDevice(deviceMountPath string) error { ctx, cancel := context.WithTimeout(context.Background(), csiTimeout) defer cancel() // Check whether "STAGE_UNSTAGE_VOLUME" is set - stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + stageUnstageSet, err := csi.NodeSupportsStageUnstage(ctx) if err != nil { klog.Errorf(log("attacher.UnmountDevice failed to check whether STAGE_UNSTAGE_VOLUME set: %v", err)) return err @@ -563,24 +562,6 @@ func (c *csiAttacher) UnmountDevice(deviceMountPath string) error { return nil } -func hasStageUnstageCapability(ctx context.Context, csi csiClient) (bool, error) { - capabilities, err := csi.NodeGetCapabilities(ctx) - if err != nil { - return false, err - } - - stageUnstageSet := false - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipb.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME { - stageUnstageSet = true - } - } - return stageUnstageSet, nil -} - // getAttachmentName returns csi- func getAttachmentName(volName, csiDriverName, nodeName string) string { result := sha256.Sum256([]byte(fmt.Sprintf("%s%s%s", volName, csiDriverName, nodeName))) diff --git a/pkg/volume/csi/csi_block.go b/pkg/volume/csi/csi_block.go index 6a536cda1d..5355c3b268 100644 --- a/pkg/volume/csi/csi_block.go +++ b/pkg/volume/csi/csi_block.go @@ -96,7 +96,7 @@ func (m *csiBlockMapper) stageVolumeForBlock( klog.V(4).Infof(log("blockMapper.stageVolumeForBlock stagingPath set [%s]", stagingPath)) // Check whether "STAGE_UNSTAGE_VOLUME" is set - stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + stageUnstageSet, err := csi.NodeSupportsStageUnstage(ctx) if err != nil { klog.Error(log("blockMapper.stageVolumeForBlock failed to check STAGE_UNSTAGE_VOLUME capability: %v", err)) return "", err @@ -287,7 +287,7 @@ func (m *csiBlockMapper) unpublishVolumeForBlock(ctx context.Context, csi csiCli // unstageVolumeForBlock unstages a block volume from stagingPath func (m *csiBlockMapper) unstageVolumeForBlock(ctx context.Context, csi csiClient, stagingPath string) error { // Check whether "STAGE_UNSTAGE_VOLUME" is set - stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + stageUnstageSet, err := csi.NodeSupportsStageUnstage(ctx) if err != nil { klog.Error(log("blockMapper.unstageVolumeForBlock failed to check STAGE_UNSTAGE_VOLUME capability: %v", err)) return err diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index ea6ba190d8..56c28191ab 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -36,7 +36,7 @@ type csiClient interface { NodeGetInfo(ctx context.Context) ( nodeID string, maxVolumePerNode int64, - accessibleTopology *csipb.Topology, + accessibleTopology map[string]string, err error) NodePublishVolume( ctx context.Context, @@ -66,7 +66,7 @@ type csiClient interface { volumeContext map[string]string, ) error NodeUnstageVolume(ctx context.Context, volID, stagingTargetPath string) error - NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) + NodeSupportsStageUnstage(ctx context.Context) (bool, error) } // csiClient encapsulates all csi-plugin methods @@ -110,7 +110,7 @@ func newCsiDriverClient(driverName string) *csiDriverClient { func (c *csiDriverClient) NodeGetInfo(ctx context.Context) ( nodeID string, maxVolumePerNode int64, - accessibleTopology *csipb.Topology, + accessibleTopology map[string]string, err error) { klog.V(4).Info(log("calling NodeGetInfo rpc")) @@ -125,7 +125,11 @@ func (c *csiDriverClient) NodeGetInfo(ctx context.Context) ( return "", 0, nil, err } - return res.GetNodeId(), res.GetMaxVolumesPerNode(), res.GetAccessibleTopology(), nil + topology := res.GetAccessibleTopology() + if topology != nil { + accessibleTopology = topology.Segments + } + return res.GetNodeId(), res.GetMaxVolumesPerNode(), accessibleTopology, nil } func (c *csiDriverClient) NodePublishVolume( @@ -288,21 +292,33 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT return err } -func (c *csiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) { - klog.V(4).Info(log("calling NodeGetCapabilities rpc")) +func (c *csiDriverClient) NodeSupportsStageUnstage(ctx context.Context) (bool, error) { + klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsStageUnstage")) nodeClient, closer, err := c.nodeClientCreator(c.driverName) if err != nil { - return nil, err + return false, err } defer closer.Close() req := &csipb.NodeGetCapabilitiesRequest{} resp, err := nodeClient.NodeGetCapabilities(ctx, req) if err != nil { - return nil, err + return false, err } - return resp.GetCapabilities(), nil + + capabilities := resp.GetCapabilities() + + stageUnstageSet := false + if capabilities == nil { + return false, nil + } + for _, capability := range capabilities { + if capability.GetRpc().GetType() == csipb.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME { + stageUnstageSet = true + } + } + return stageUnstageSet, nil } func asCSIAccessMode(am api.PersistentVolumeAccessMode) csipb.VolumeCapability_AccessMode_Mode { diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 777be98b2d..699200d191 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -43,10 +43,14 @@ func newFakeCsiDriverClient(t *testing.T, stagingCapable bool) *fakeCsiDriverCli func (c *fakeCsiDriverClient) NodeGetInfo(ctx context.Context) ( nodeID string, maxVolumePerNode int64, - accessibleTopology *csipb.Topology, + accessibleTopology map[string]string, err error) { resp, err := c.nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{}) - return resp.GetNodeId(), resp.GetMaxVolumesPerNode(), resp.GetAccessibleTopology(), err + topology := resp.GetAccessibleTopology() + if topology != nil { + accessibleTopology = topology.Segments + } + return resp.GetNodeId(), resp.GetMaxVolumesPerNode(), accessibleTopology, err } func (c *fakeCsiDriverClient) NodePublishVolume( @@ -140,14 +144,26 @@ func (c *fakeCsiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stag return err } -func (c *fakeCsiDriverClient) NodeGetCapabilities(ctx context.Context) ([]*csipb.NodeServiceCapability, error) { - c.t.Log("calling fake.NodeGetCapabilities...") +func (c *fakeCsiDriverClient) NodeSupportsStageUnstage(ctx context.Context) (bool, error) { + c.t.Log("calling fake.NodeGetCapabilities for NodeSupportsStageUnstage...") req := &csipb.NodeGetCapabilitiesRequest{} resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) if err != nil { - return nil, err + return false, err } - return resp.GetCapabilities(), nil + + capabilities := resp.GetCapabilities() + + stageUnstageSet := false + if capabilities == nil { + return false, nil + } + for _, capability := range capabilities { + if capability.GetRpc().GetType() == csipb.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME { + stageUnstageSet = true + } + } + return stageUnstageSet, nil } func setupClient(t *testing.T, stageUnstageSet bool) csiClient { @@ -173,17 +189,15 @@ func TestClientNodeGetInfo(t *testing.T) { name string expectedNodeID string expectedMaxVolumePerNode int64 - expectedAccessibleTopology *csipb.Topology + expectedAccessibleTopology map[string]string mustFail bool err error }{ { - name: "test ok", - expectedNodeID: "node1", - expectedMaxVolumePerNode: 16, - expectedAccessibleTopology: &csipb.Topology{ - Segments: map[string]string{"com.example.csi-topology/zone": "zone1"}, - }, + name: "test ok", + expectedNodeID: "node1", + expectedMaxVolumePerNode: 16, + expectedAccessibleTopology: map[string]string{"com.example.csi-topology/zone": "zone1"}, }, { name: "grpc error", @@ -202,9 +216,11 @@ func TestClientNodeGetInfo(t *testing.T) { nodeClient := fake.NewNodeClient(false /* stagingCapable */) nodeClient.SetNextError(tc.err) nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{ - NodeId: tc.expectedNodeID, - MaxVolumesPerNode: tc.expectedMaxVolumePerNode, - AccessibleTopology: tc.expectedAccessibleTopology, + NodeId: tc.expectedNodeID, + MaxVolumesPerNode: tc.expectedMaxVolumePerNode, + AccessibleTopology: &csipb.Topology{ + Segments: tc.expectedAccessibleTopology, + }, }) return nodeClient, fakeCloser, nil }, @@ -222,7 +238,7 @@ func TestClientNodeGetInfo(t *testing.T) { } if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) { - t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology) + t.Errorf("expected accessibleTopology: %v; got: %v", tc.expectedAccessibleTopology, accessibleTopology) } if !tc.mustFail { diff --git a/pkg/volume/csi/csi_mounter.go b/pkg/volume/csi/csi_mounter.go index b088411277..d5dbe78de3 100644 --- a/pkg/volume/csi/csi_mounter.go +++ b/pkg/volume/csi/csi_mounter.go @@ -121,7 +121,7 @@ func (c *csiMountMgr) SetUpAt(dir string, fsGroup *int64) error { // Check for STAGE_UNSTAGE_VOLUME set and populate deviceMountPath if so deviceMountPath := "" - stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + stageUnstageSet, err := csi.NodeSupportsStageUnstage(ctx) if err != nil { klog.Error(log("mounter.SetUpAt failed to check for STAGE_UNSTAGE_VOLUME capabilty: %v", err)) return err diff --git a/pkg/volume/csi/nodeinfomanager/BUILD b/pkg/volume/csi/nodeinfomanager/BUILD index e949893feb..6af576c422 100644 --- a/pkg/volume/csi/nodeinfomanager/BUILD +++ b/pkg/volume/csi/nodeinfomanager/BUILD @@ -21,7 +21,6 @@ go_library( "//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library", "//staging/src/k8s.io/csi-api/pkg/apis/csi/v1alpha1:go_default_library", "//staging/src/k8s.io/csi-api/pkg/client/clientset/versioned:go_default_library", - "//vendor/github.com/container-storage-interface/spec/lib/go/csi:go_default_library", "//vendor/k8s.io/klog:go_default_library", ], ) @@ -63,7 +62,6 @@ go_test( "//staging/src/k8s.io/client-go/util/testing:go_default_library", "//staging/src/k8s.io/csi-api/pkg/apis/csi/v1alpha1:go_default_library", "//staging/src/k8s.io/csi-api/pkg/client/clientset/versioned/fake:go_default_library", - "//vendor/github.com/container-storage-interface/spec/lib/go/csi:go_default_library", "//vendor/github.com/stretchr/testify/assert:go_default_library", ], ) diff --git a/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go b/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go index a895e61755..a61e220465 100644 --- a/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go +++ b/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go @@ -23,7 +23,8 @@ import ( "fmt" "strings" - csipb "github.com/container-storage-interface/spec/lib/go/csi" + "time" + "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" @@ -40,7 +41,6 @@ import ( nodeutil "k8s.io/kubernetes/pkg/util/node" "k8s.io/kubernetes/pkg/volume" "k8s.io/kubernetes/pkg/volume/util" - "time" ) const ( @@ -75,7 +75,7 @@ type Interface interface { // Record in the cluster the given node information from the CSI driver with the given name. // Concurrent calls to InstallCSIDriver() is allowed, but they should not be intertwined with calls // to other methods in this interface. - InstallCSIDriver(driverName string, driverNodeID string, maxVolumeLimit int64, topology *csipb.Topology) error + InstallCSIDriver(driverName string, driverNodeID string, maxVolumeLimit int64, topology map[string]string) error // Remove in the cluster node information from the CSI driver with the given name. // Concurrent calls to UninstallCSIDriver() is allowed, but they should not be intertwined with calls @@ -97,7 +97,7 @@ func NewNodeInfoManager( // CSINodeInfo object. If the CSINodeInfo object doesn't yet exist, it will be created. // If multiple calls to InstallCSIDriver() are made in parallel, some calls might receive Node or // CSINodeInfo update conflicts, which causes the function to retry the corresponding update. -func (nim *nodeInfoManager) InstallCSIDriver(driverName string, driverNodeID string, maxAttachLimit int64, topology *csipb.Topology) error { +func (nim *nodeInfoManager) InstallCSIDriver(driverName string, driverNodeID string, maxAttachLimit int64, topology map[string]string) error { if driverNodeID == "" { return fmt.Errorf("error adding CSI driver node info: driverNodeID must not be empty") } @@ -321,13 +321,13 @@ func removeNodeIDFromNode(csiDriverName string) nodeUpdateFunc { // updateTopologyLabels returns a function that updates labels of a Node object with the given // topology information. -func updateTopologyLabels(topology *csipb.Topology) nodeUpdateFunc { +func updateTopologyLabels(topology map[string]string) nodeUpdateFunc { return func(node *v1.Node) (*v1.Node, bool, error) { - if topology == nil || len(topology.Segments) == 0 { + if topology == nil || len(topology) == 0 { return node, false, nil } - for k, v := range topology.Segments { + for k, v := range topology { if curVal, exists := node.Labels[k]; exists && curVal != v { return nil, false, fmt.Errorf("detected topology value collision: driver reported %q:%q but existing label is %q:%q", k, v, k, curVal) } @@ -336,7 +336,7 @@ func updateTopologyLabels(topology *csipb.Topology) nodeUpdateFunc { if node.Labels == nil { node.Labels = make(map[string]string) } - for k, v := range topology.Segments { + for k, v := range topology { node.Labels[k] = v } return node, true, nil @@ -346,7 +346,7 @@ func updateTopologyLabels(topology *csipb.Topology) nodeUpdateFunc { func (nim *nodeInfoManager) updateCSINodeInfo( driverName string, driverNodeID string, - topology *csipb.Topology) error { + topology map[string]string) error { csiKubeClient := nim.volumeHost.GetCSIClient() if csiKubeClient == nil { @@ -371,7 +371,7 @@ func (nim *nodeInfoManager) tryUpdateCSINodeInfo( csiKubeClient csiclientset.Interface, driverName string, driverNodeID string, - topology *csipb.Topology) error { + topology map[string]string) error { nodeInfo, err := csiKubeClient.CsiV1alpha1().CSINodeInfos().Get(string(nim.nodeName), metav1.GetOptions{}) if nodeInfo == nil || errors.IsNotFound(err) { @@ -428,7 +428,7 @@ func (nim *nodeInfoManager) installDriverToCSINodeInfo( nodeInfo *csiv1alpha1.CSINodeInfo, driverName string, driverNodeID string, - topology *csipb.Topology) error { + topology map[string]string) error { csiKubeClient := nim.volumeHost.GetCSIClient() if csiKubeClient == nil { @@ -436,10 +436,8 @@ func (nim *nodeInfoManager) installDriverToCSINodeInfo( } topologyKeys := make(sets.String) - if topology != nil { - for k := range topology.Segments { - topologyKeys.Insert(k) - } + for k := range topology { + topologyKeys.Insert(k) } specModified := true diff --git a/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go b/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go index 39ee2fd0cd..14f0fc67f6 100644 --- a/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go +++ b/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go @@ -21,7 +21,6 @@ import ( "fmt" "testing" - "github.com/container-storage-interface/spec/lib/go/csi" "github.com/stretchr/testify/assert" "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -49,7 +48,7 @@ type testcase struct { existingNode *v1.Node existingNodeInfo *csiv1alpha1.CSINodeInfo inputNodeID string - inputTopology *csi.Topology + inputTopology map[string]string inputVolumeLimit int64 expectedNodeIDMap map[string]string expectedTopologyMap map[string]sets.String @@ -71,10 +70,8 @@ func TestInstallCSIDriver(t *testing.T) { driverName: "com.example.csi/driver1", existingNode: generateNode(nil /* nodeIDs */, nil /* labels */, nil /*capacity*/), inputNodeID: "com.example.csi/csi-node1", - inputTopology: &csi.Topology{ - Segments: map[string]string{ - "com.example.csi/zone": "zoneA", - }, + inputTopology: map[string]string{ + "com.example.csi/zone": "zoneA", }, expectedNodeIDMap: map[string]string{ "com.example.csi/driver1": "com.example.csi/csi-node1", @@ -104,10 +101,8 @@ func TestInstallCSIDriver(t *testing.T) { }, ), inputNodeID: "com.example.csi/csi-node1", - inputTopology: &csi.Topology{ - Segments: map[string]string{ - "com.example.csi/zone": "zoneA", - }, + inputTopology: map[string]string{ + "com.example.csi/zone": "zoneA", }, expectedNodeIDMap: map[string]string{ "com.example.csi/driver1": "com.example.csi/csi-node1", @@ -134,10 +129,8 @@ func TestInstallCSIDriver(t *testing.T) { nil, /* topologyKeys */ ), inputNodeID: "com.example.csi/csi-node1", - inputTopology: &csi.Topology{ - Segments: map[string]string{ - "com.example.csi/zone": "zoneA", - }, + inputTopology: map[string]string{ + "com.example.csi/zone": "zoneA", }, expectedNodeIDMap: map[string]string{ "com.example.csi/driver1": "com.example.csi/csi-node1", @@ -168,10 +161,8 @@ func TestInstallCSIDriver(t *testing.T) { }, ), inputNodeID: "com.example.csi/csi-node1", - inputTopology: &csi.Topology{ - Segments: map[string]string{ - "com.example.csi/zone": "zoneA", - }, + inputTopology: map[string]string{ + "com.example.csi/zone": "zoneA", }, expectedNodeIDMap: map[string]string{ "com.example.csi/driver1": "com.example.csi/csi-node1", @@ -205,10 +196,8 @@ func TestInstallCSIDriver(t *testing.T) { }, ), inputNodeID: "com.example.csi/csi-node1", - inputTopology: &csi.Topology{ - Segments: map[string]string{ - "com.example.csi/zone": "other-zone", - }, + inputTopology: map[string]string{ + "com.example.csi/zone": "other-zone", }, expectFail: true, }, @@ -231,10 +220,8 @@ func TestInstallCSIDriver(t *testing.T) { }, ), inputNodeID: "com.example.csi/other-node", - inputTopology: &csi.Topology{ - Segments: map[string]string{ - "com.example.csi/rack": "rack1", - }, + inputTopology: map[string]string{ + "com.example.csi/rack": "rack1", }, expectedNodeIDMap: map[string]string{ "com.example.csi/driver1": "com.example.csi/other-node",