From 84c46629c519cf1774cc3ddd2c983d0892ca2cf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hannes=20H=C3=B6rl?= Date: Fri, 25 Jan 2019 12:00:31 +0000 Subject: [PATCH] Refactor to use new csi.DriversList & csi.Driver --- pkg/volume/csi/csi_client.go | 11 ++----- pkg/volume/csi/csi_plugin.go | 50 +++++++------------------------ pkg/volume/csi/csi_plugin_test.go | 14 ++++++--- 3 files changed, 22 insertions(+), 53 deletions(-) diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index bbe3013462..d6d1fcbf9a 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -139,19 +139,12 @@ func newCsiDriverClient(driverName csiDriverName) (*csiDriverClient, error) { addr := fmt.Sprintf(csiAddrTemplate, driverName) requiresV0Client := true if utilfeature.DefaultFeatureGate.Enabled(features.KubeletPluginsWatcher) { - var existingDriver csiDriver - driverExists := false - func() { - csiDrivers.RLock() - defer csiDrivers.RUnlock() - existingDriver, driverExists = csiDrivers.driversMap[string(driverName)] - }() - + existingDriver, driverExists := csiDrivers.Get(string(driverName)) if !driverExists { return nil, fmt.Errorf("driver name %s not found in the list of registered CSI drivers", driverName) } - addr = existingDriver.driverEndpoint + addr = existingDriver.endpoint requiresV0Client = versionRequiresV0Client(existingDriver.highestSupportedVersion) } diff --git a/pkg/volume/csi/csi_plugin.go b/pkg/volume/csi/csi_plugin.go index 295aed85c8..c1a92c429f 100644 --- a/pkg/volume/csi/csi_plugin.go +++ b/pkg/volume/csi/csi_plugin.go @@ -23,7 +23,6 @@ import ( "path" "sort" "strings" - "sync" "time" "context" @@ -84,17 +83,6 @@ func ProbeVolumePlugins() []volume.VolumePlugin { // volume.VolumePlugin methods var _ volume.VolumePlugin = &csiPlugin{} -type csiDriver struct { - driverName string - driverEndpoint string - highestSupportedVersion *utilversion.Version -} - -type csiDriversStore struct { - driversMap map[string]csiDriver - sync.RWMutex -} - // RegistrationHandler is the handler which is fed to the pluginwatcher API. type RegistrationHandler struct { } @@ -102,7 +90,7 @@ type RegistrationHandler struct { // TODO (verult) consider using a struct instead of global variables // csiDrivers map keep track of all registered CSI drivers on the node and their // corresponding sockets -var csiDrivers csiDriversStore +var csiDrivers = &DriversStore{} var nim nodeinfomanager.Interface @@ -141,17 +129,12 @@ func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string, return err } - func() { - // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key - // all other CSI components will be able to get the actual socket of CSI drivers by its name. - - // It's not necessary to lock the entire RegistrationCallback() function because only the CSI - // client depends on this driver map, and the CSI client does not depend on node information - // updated in the rest of the function. - csiDrivers.Lock() - defer csiDrivers.Unlock() - csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint, highestSupportedVersion: highestSupportedVersion} - }() + // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key + // all other CSI components will be able to get the actual socket of CSI drivers by its name. + csiDrivers.Set(pluginName, Driver{ + endpoint: endpoint, + highestSupportedVersion: highestSupportedVersion, + }) // Get node info from the driver. csi, err := newCsiDriverClient(csiDriverName(pluginName)) @@ -201,15 +184,7 @@ func (h *RegistrationHandler) validateVersions(callerName, pluginName string, en return nil, err } - // Check for existing drivers with the same name - var existingDriver csiDriver - driverExists := false - func() { - csiDrivers.RLock() - defer csiDrivers.RUnlock() - existingDriver, driverExists = csiDrivers.driversMap[pluginName] - }() - + existingDriver, driverExists := csiDrivers.Get(pluginName) if driverExists { if !existingDriver.highestSupportedVersion.LessThan(newDriverHighestVersion) { err := fmt.Errorf("%s for CSI driver %q failed. Another driver with the same name is already registered with a higher supported version: %q", callerName, pluginName, existingDriver.highestSupportedVersion) @@ -246,8 +221,7 @@ func (p *csiPlugin) Init(host volume.VolumeHost) error { } } - // Initializing csiDrivers map and label management channels - csiDrivers = csiDriversStore{driversMap: map[string]csiDriver{}} + // Initializing the label management channels nim = nodeinfomanager.NewNodeInfoManager(host.GetNodeName(), host) // TODO(#70514) Init CSINodeInfo object if the CRD exists and create Driver @@ -658,11 +632,7 @@ func (p *csiPlugin) getPublishContext(client clientset.Interface, handle, driver } func unregisterDriver(driverName string) error { - func() { - csiDrivers.Lock() - defer csiDrivers.Unlock() - delete(csiDrivers.driversMap, driverName) - }() + csiDrivers.Delete(driverName) if err := nim.UninstallCSIDriver(driverName); err != nil { klog.Errorf("Error uninstalling CSI driver: %v", err) diff --git a/pkg/volume/csi/csi_plugin_test.go b/pkg/volume/csi/csi_plugin_test.go index ab82fde645..442dfc981a 100644 --- a/pkg/volume/csi/csi_plugin_test.go +++ b/pkg/volume/csi/csi_plugin_test.go @@ -105,13 +105,16 @@ func makeTestPV(name string, sizeGig int, driverName, volID string) *api.Persist } func registerFakePlugin(pluginName, endpoint string, versions []string, t *testing.T) { - csiDrivers = csiDriversStore{driversMap: map[string]csiDriver{}} highestSupportedVersions, err := highestSupportedVersion(versions) if err != nil { t.Fatalf("unexpected error parsing versions (%v) for pluginName % q endpoint %q: %#v", versions, pluginName, endpoint, err) } - csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint, highestSupportedVersion: highestSupportedVersions} + csiDrivers.Clear() + csiDrivers.Set(pluginName, Driver{ + endpoint: endpoint, + highestSupportedVersion: highestSupportedVersions, + }) } func TestPluginGetPluginName(t *testing.T) { @@ -839,13 +842,16 @@ func TestValidatePluginExistingDriver(t *testing.T) { for _, tc := range testCases { // Arrange & Act - csiDrivers = csiDriversStore{driversMap: map[string]csiDriver{}} highestSupportedVersions1, err := highestSupportedVersion(tc.versions1) if err != nil { t.Fatalf("unexpected error parsing version for testcase: %#v", tc) } - csiDrivers.driversMap[tc.pluginName1] = csiDriver{driverName: tc.pluginName1, driverEndpoint: tc.endpoint1, highestSupportedVersion: highestSupportedVersions1} + csiDrivers.Clear() + csiDrivers.Set(tc.pluginName1, Driver{ + endpoint: tc.endpoint1, + highestSupportedVersion: highestSupportedVersions1, + }) // Arrange & Act err = PluginHandler.ValidatePlugin(tc.pluginName2, tc.endpoint2, tc.versions2, tc.foundInDeprecatedDir2)