From d04f5968293293668c675ef6c76439b37afe790f Mon Sep 17 00:00:00 2001 From: hui luo Date: Wed, 30 May 2018 17:07:58 -0700 Subject: [PATCH] Add hierarchy support for plugin directory it traverses and watch plugin directory and its sub directory recursively, plugin socket file only need be unique within one directory, - plugin socket directory - | - ---->sub directory 1 - | | - | -----> socket1, socket2 ... - ----->sub directory 2 - | - ------> socket1, socket2 ... the design itself allow sub directory be anything, but in practical, each plugin type could just use one sub directory. four bonus changes added as below 1. extract example handler out from test, it is easier to read the code with the seperation. 2. there are two variables here: "Watcher" and "watcher". "Watcher" is the plugin watcher, and "watcher" is the fsnotify watcher. so rename the "watcher" to "fsWatcher" to make code easier to understand. 3. change RegisterCallbackFn() return value order, it is conventional to return error last, after this change, the pkg/volume/csi is compliance with golint, so remove it from hack/.golint_failures 4. refactor errors handling at invokeRegistrationCallbackAtHandler() to make error message more clear. --- hack/.golint_failures | 1 - pkg/kubelet/util/pluginwatcher/BUILD | 5 +- pkg/kubelet/util/pluginwatcher/README | 13 +- .../util/pluginwatcher/example_handler.go | 105 ++++++++ .../util/pluginwatcher/example_plugin.go | 19 +- .../util/pluginwatcher/plugin_watcher.go | 163 ++++++------ .../util/pluginwatcher/plugin_watcher_test.go | 248 +++++++++--------- pkg/volume/csi/csi_plugin.go | 4 +- 8 files changed, 334 insertions(+), 224 deletions(-) create mode 100644 pkg/kubelet/util/pluginwatcher/example_handler.go diff --git a/hack/.golint_failures b/hack/.golint_failures index 66fb728748..76e47dbb22 100644 --- a/hack/.golint_failures +++ b/hack/.golint_failures @@ -380,7 +380,6 @@ pkg/volume/azure_dd pkg/volume/azure_file pkg/volume/cephfs pkg/volume/configmap -pkg/volume/csi pkg/volume/csi/fake pkg/volume/csi/labelmanager pkg/volume/empty_dir diff --git a/pkg/kubelet/util/pluginwatcher/BUILD b/pkg/kubelet/util/pluginwatcher/BUILD index 0b62c3a465..7b887b444f 100644 --- a/pkg/kubelet/util/pluginwatcher/BUILD +++ b/pkg/kubelet/util/pluginwatcher/BUILD @@ -9,6 +9,7 @@ load( go_library( name = "go_default_library", srcs = [ + "example_handler.go", "example_plugin.go", "plugin_watcher.go", ], @@ -20,6 +21,7 @@ go_library( "//pkg/util/filesystem:go_default_library", "//vendor/github.com/fsnotify/fsnotify:go_default_library", "//vendor/github.com/golang/glog:go_default_library", + "//vendor/github.com/pkg/errors:go_default_library", "//vendor/golang.org/x/net/context:go_default_library", "//vendor/google.golang.org/grpc:go_default_library", ], @@ -49,10 +51,7 @@ go_test( embed = [":go_default_library"], deps = [ "//pkg/kubelet/apis/pluginregistration/v1alpha1:go_default_library", - "//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1:go_default_library", - "//pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library", "//vendor/github.com/stretchr/testify/require:go_default_library", - "//vendor/golang.org/x/net/context:go_default_library", ], ) diff --git a/pkg/kubelet/util/pluginwatcher/README b/pkg/kubelet/util/pluginwatcher/README index 9654b2cf62..c8b6cc2844 100644 --- a/pkg/kubelet/util/pluginwatcher/README +++ b/pkg/kubelet/util/pluginwatcher/README @@ -13,17 +13,22 @@ communication with any API version supported by the plugin. Here are the general rules that Kubelet plugin developers should follow: - Run as 'root' user. Currently creating socket under PluginsSockDir, a root owned directory, requires plugin process to be running as 'root'. + - Implements the Registration service specified in pkg/kubelet/apis/pluginregistration/v*/api.proto. + - The plugin name sent during Registration.GetInfo grpc should be unique for the given plugin type (CSIPlugin or DevicePlugin). -- The socket path needs to be unique and doesn't conflict with the path chosen - by any other potential plugins. Currently we only support flat fs namespace - under PluginsSockDir but will soon support recursive inotify watch for - hierarchical socket paths. + +- The socket path needs to be unique within one directory, in normal case, + each plugin type has its own sub directory, but the design does support socket file + under any sub directory of PluginSockDir. + - A plugin should clean up its own socket upon exiting or when a new instance comes up. A plugin should NOT remove any sockets belonging to other plugins. + - A plugin should make sure it has service ready for any supported service API version listed in the PluginInfo. + - For an example plugin implementation, take a look at example_plugin.go included in this directory. diff --git a/pkg/kubelet/util/pluginwatcher/example_handler.go b/pkg/kubelet/util/pluginwatcher/example_handler.go new file mode 100644 index 0000000000..4eae4188d6 --- /dev/null +++ b/pkg/kubelet/util/pluginwatcher/example_handler.go @@ -0,0 +1,105 @@ +/* +Copyright 2018 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pluginwatcher + +import ( + "errors" + "fmt" + "reflect" + "sync" + "time" + + "golang.org/x/net/context" + + v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1" + v1beta2 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2" +) + +type exampleHandler struct { + registeredPlugins map[string]struct{} + mutex sync.Mutex + chanForHandlerAckErrors chan error // for testing +} + +// NewExampleHandler provide a example handler +func NewExampleHandler() *exampleHandler { + return &exampleHandler{ + chanForHandlerAckErrors: make(chan error), + registeredPlugins: make(map[string]struct{}), + } +} + +func (h *exampleHandler) Cleanup() error { + h.mutex.Lock() + defer h.mutex.Unlock() + h.registeredPlugins = make(map[string]struct{}) + return nil +} + +func (h *exampleHandler) Handler(pluginName string, endpoint string, versions []string, sockPath string) (chan bool, error) { + + // check for supported versions + if !reflect.DeepEqual([]string{"v1beta1", "v1beta2"}, versions) { + return nil, fmt.Errorf("not the supported versions: %s", versions) + } + + // this handler expects non-empty endpoint as an example + if len(endpoint) == 0 { + return nil, errors.New("expecting non empty endpoint") + } + + _, conn, err := dial(sockPath) + if err != nil { + return nil, err + } + defer conn.Close() + + // The plugin handler should be able to use any listed service API version. + v1beta1Client := v1beta1.NewExampleClient(conn) + v1beta2Client := v1beta2.NewExampleClient(conn) + + // Tests v1beta1 GetExampleInfo + if _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}); err != nil { + return nil, err + } + + // Tests v1beta2 GetExampleInfo + if _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}); err != nil { + return nil, err + } + + // handle registered plugin + h.mutex.Lock() + if _, exist := h.registeredPlugins[pluginName]; exist { + h.mutex.Unlock() + return nil, fmt.Errorf("plugin %s already registered", pluginName) + } + h.registeredPlugins[pluginName] = struct{}{} + h.mutex.Unlock() + + chanForAckOfNotification := make(chan bool) + go func() { + select { + case <-chanForAckOfNotification: + // TODO: handle the negative scenario + close(chanForAckOfNotification) + case <-time.After(time.Second): + h.chanForHandlerAckErrors <- errors.New("Timed out while waiting for notification ack") + } + }() + return chanForAckOfNotification, nil +} diff --git a/pkg/kubelet/util/pluginwatcher/example_plugin.go b/pkg/kubelet/util/pluginwatcher/example_plugin.go index fbca43acad..5c2dd966ba 100644 --- a/pkg/kubelet/util/pluginwatcher/example_plugin.go +++ b/pkg/kubelet/util/pluginwatcher/example_plugin.go @@ -17,7 +17,7 @@ limitations under the License. package pluginwatcher import ( - "fmt" + "errors" "net" "sync" "time" @@ -31,17 +31,14 @@ import ( v1beta2 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2" ) -const ( - PluginName = "example-plugin" - PluginType = "example-plugin-type" -) - // examplePlugin is a sample plugin to work with plugin watcher type examplePlugin struct { grpcServer *grpc.Server wg sync.WaitGroup registrationStatus chan registerapi.RegistrationStatus // for testing endpoint string // for testing + pluginName string + pluginType string } type pluginServiceV1Beta1 struct { @@ -76,8 +73,10 @@ func NewExamplePlugin() *examplePlugin { } // NewTestExamplePlugin returns an initialized examplePlugin instance for testing -func NewTestExamplePlugin(endpoint string) *examplePlugin { +func NewTestExamplePlugin(pluginName string, pluginType string, endpoint string) *examplePlugin { return &examplePlugin{ + pluginName: pluginName, + pluginType: pluginType, registrationStatus: make(chan registerapi.RegistrationStatus), endpoint: endpoint, } @@ -86,8 +85,8 @@ func NewTestExamplePlugin(endpoint string) *examplePlugin { // GetInfo is the RPC invoked by plugin watcher func (e *examplePlugin) GetInfo(ctx context.Context, req *registerapi.InfoRequest) (*registerapi.PluginInfo, error) { return ®isterapi.PluginInfo{ - Type: PluginType, - Name: PluginName, + Type: e.pluginType, + Name: e.pluginName, Endpoint: e.endpoint, SupportedVersions: []string{"v1beta1", "v1beta2"}, }, nil @@ -145,6 +144,6 @@ func (e *examplePlugin) Stop() error { return nil case <-time.After(time.Second): glog.Errorf("Timed out on waiting for stop completion") - return fmt.Errorf("Timed out on waiting for stop completion") + return errors.New("Timed out on waiting for stop completion") } } diff --git a/pkg/kubelet/util/pluginwatcher/plugin_watcher.go b/pkg/kubelet/util/pluginwatcher/plugin_watcher.go index 9a5241cb2e..6db743dd4f 100644 --- a/pkg/kubelet/util/pluginwatcher/plugin_watcher.go +++ b/pkg/kubelet/util/pluginwatcher/plugin_watcher.go @@ -20,13 +20,12 @@ import ( "fmt" "net" "os" - "path" - "path/filepath" "sync" "time" "github.com/fsnotify/fsnotify" "github.com/golang/glog" + "github.com/pkg/errors" "golang.org/x/net/context" "google.golang.org/grpc" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" @@ -34,17 +33,17 @@ import ( ) // RegisterCallbackFn is the type of the callback function that handlers will provide -type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (error, chan bool) +type RegisterCallbackFn func(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) // Watcher is the plugin watcher type Watcher struct { - path string - handlers map[string]RegisterCallbackFn - stopCh chan interface{} - fs utilfs.Filesystem - watcher *fsnotify.Watcher - wg sync.WaitGroup - mutex sync.Mutex + path string + handlers map[string]RegisterCallbackFn + stopCh chan interface{} + fs utilfs.Filesystem + fsWatcher *fsnotify.Watcher + wg sync.WaitGroup + mutex sync.Mutex } // NewWatcher provides a new watcher @@ -57,40 +56,45 @@ func NewWatcher(sockDir string) Watcher { } // AddHandler registers a callback to be invoked for a particular type of plugin -func (w *Watcher) AddHandler(handlerType string, handlerCbkFn RegisterCallbackFn) { +func (w *Watcher) AddHandler(pluginType string, handlerCbkFn RegisterCallbackFn) { w.mutex.Lock() defer w.mutex.Unlock() - w.handlers[handlerType] = handlerCbkFn + w.handlers[pluginType] = handlerCbkFn } // Creates the plugin directory, if it doesn't already exist. func (w *Watcher) createPluginDir() error { glog.V(4).Infof("Ensuring Plugin directory at %s ", w.path) if err := w.fs.MkdirAll(w.path, 0755); err != nil { - return fmt.Errorf("error (re-)creating driver directory: %s", err) + return fmt.Errorf("error (re-)creating root %s: %v", w.path, err) } + return nil } -// Walks through the plugin directory to discover any existing plugin sockets. -func (w *Watcher) traversePluginDir() error { - files, err := w.fs.ReadDir(w.path) - if err != nil { - return fmt.Errorf("error reading the plugin directory: %v", err) - } - for _, f := range files { - // Currently only supports flat fs namespace under the plugin directory. - // TODO: adds support for hierarchical fs namespace. - if !f.IsDir() && filepath.Base(f.Name())[0] != '.' { - go func(sockName string) { - w.watcher.Events <- fsnotify.Event{ - Name: sockName, - Op: fsnotify.Op(uint32(1)), - } - }(path.Join(w.path, f.Name())) +// Walks through the plugin directory discover any existing plugin sockets. +func (w *Watcher) traversePluginDir(dir string) error { + return w.fs.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("error accessing path: %s error: %v", path, err) } - } - return nil + + switch mode := info.Mode(); { + case mode.IsDir(): + if err := w.fsWatcher.Add(path); err != nil { + return fmt.Errorf("failed to watch %s, err: %v", path, err) + } + case mode&os.ModeSocket != 0: + go func() { + w.fsWatcher.Events <- fsnotify.Event{ + Name: path, + Op: fsnotify.Create, + } + }() + } + + return nil + }) } func (w *Watcher) init() error { @@ -102,7 +106,6 @@ func (w *Watcher) init() error { func (w *Watcher) registerPlugin(socketPath string) error { //TODO: Implement rate limiting to mitigate any DOS kind of attacks. - glog.V(4).Infof("registerPlugin called for socketPath: %s", socketPath) client, conn, err := dial(socketPath) if err != nil { return fmt.Errorf("dial failed at socket %s, err: %v", socketPath, err) @@ -115,11 +118,8 @@ func (w *Watcher) registerPlugin(socketPath string) error { if err != nil { return fmt.Errorf("failed to get plugin info using RPC GetInfo at socket %s, err: %v", socketPath, err) } - if err := w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath); err != nil { - return fmt.Errorf("failed to register plugin. Callback handler returned err: %v", err) - } - glog.V(4).Infof("Successfully registered plugin for plugin type: %s, name: %s, socket: %s", infoResp.Type, infoResp.Name, socketPath) - return nil + + return w.invokeRegistrationCallbackAtHandler(ctx, client, infoResp, socketPath) } func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, client registerapi.RegistrationClient, infoResp *registerapi.PluginInfo, socketPath string) error { @@ -127,13 +127,14 @@ func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, clien var ok bool handlerCbkFn, ok = w.handlers[infoResp.Type] if !ok { + errStr := fmt.Sprintf("no handler registered for plugin type: %s at socket %s", infoResp.Type, socketPath) if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ PluginRegistered: false, - Error: fmt.Sprintf("No handler found registered for plugin type: %s, socket: %s", infoResp.Type, socketPath), + Error: errStr, }); err != nil { - glog.Errorf("Failed to send registration status at socket %s, err: %v", socketPath, err) + return errors.Wrap(err, errStr) } - return fmt.Errorf("no handler found registered for plugin type: %s, socket: %s", infoResp.Type, socketPath) + return errors.New(errStr) } var versions []string @@ -141,27 +142,51 @@ func (w *Watcher) invokeRegistrationCallbackAtHandler(ctx context.Context, clien versions = append(versions, version) } // calls handler callback to verify registration request - err, chanForAckOfNotification := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath) + chanForAckOfNotification, err := handlerCbkFn(infoResp.Name, infoResp.Endpoint, versions, socketPath) if err != nil { + errStr := fmt.Sprintf("plugin registration failed with err: %v", err) if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ PluginRegistered: false, - Error: fmt.Sprintf("Plugin registration failed with err: %v", err), + Error: errStr, }); err != nil { - glog.Errorf("Failed to send registration status at socket %s, err: %v", socketPath, err) + return errors.Wrap(err, errStr) } - chanForAckOfNotification <- false - return fmt.Errorf("plugin registration failed with err: %v", err) + return errors.New(errStr) } if _, err := client.NotifyRegistrationStatus(ctx, ®isterapi.RegistrationStatus{ PluginRegistered: true, }); err != nil { + chanForAckOfNotification <- false return fmt.Errorf("failed to send registration status at socket %s, err: %v", socketPath, err) } + chanForAckOfNotification <- true return nil } +// Handle filesystem notify event. +func (w *Watcher) handleFsNotifyEvent(event fsnotify.Event) error { + if event.Op&fsnotify.Create != fsnotify.Create { + return nil + } + + fi, err := os.Stat(event.Name) + if err != nil { + return fmt.Errorf("stat file %s failed: %v", event.Name, err) + } + + if !fi.IsDir() { + return w.registerPlugin(event.Name) + } + + if err := w.traversePluginDir(event.Name); err != nil { + return fmt.Errorf("failed to traverse plugin path %s, err: %v", event.Name, err) + } + + return nil +} + // Start watches for the creation of plugin sockets at the path func (w *Watcher) Start() error { glog.V(2).Infof("Plugin Watcher Start at %s", w.path) @@ -173,52 +198,42 @@ func (w *Watcher) Start() error { return err } - watcher, err := fsnotify.NewWatcher() + fsWatcher, err := fsnotify.NewWatcher() if err != nil { - return fmt.Errorf("failed to start plugin watcher, err: %v", err) + return fmt.Errorf("failed to start plugin fsWatcher, err: %v", err) } + w.fsWatcher = fsWatcher - if err := watcher.Add(w.path); err != nil { - watcher.Close() - return fmt.Errorf("failed to start plugin watcher, err: %v", err) - } - - w.watcher = watcher - - if err := w.traversePluginDir(); err != nil { - watcher.Close() + if err := w.traversePluginDir(w.path); err != nil { + fsWatcher.Close() return fmt.Errorf("failed to traverse plugin socket path, err: %v", err) } w.wg.Add(1) - go func(watcher *fsnotify.Watcher) { + go func(fsWatcher *fsnotify.Watcher) { defer w.wg.Done() for { select { - case event := <-watcher.Events: - if event.Op&fsnotify.Create == fsnotify.Create { - go func(eventName string) { - err := w.registerPlugin(eventName) - if err != nil { - glog.Errorf("Plugin %s registration failed with error: %v", eventName, err) - } - }(event.Name) - } - continue - case err := <-watcher.Errors: + case event := <-fsWatcher.Events: //TODO: Handle errors by taking corrective measures + go func() { + err := w.handleFsNotifyEvent(event) + if err != nil { + glog.Errorf("error %v when handle event: %s", err, event) + } + }() + continue + case err := <-fsWatcher.Errors: if err != nil { - glog.Errorf("Watcher received error: %v", err) + glog.Errorf("fsWatcher received error: %v", err) } continue - case <-w.stopCh: - watcher.Close() - break + fsWatcher.Close() + return } - break } - }(watcher) + }(fsWatcher) return nil } diff --git a/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go b/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go index 44bccf9a6f..5bfb49568e 100644 --- a/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go +++ b/pkg/kubelet/util/pluginwatcher/plugin_watcher_test.go @@ -17,135 +17,56 @@ limitations under the License. package pluginwatcher import ( - "fmt" + "errors" "io/ioutil" + "path/filepath" "strconv" "sync" "testing" "time" "github.com/stretchr/testify/require" - "golang.org/x/net/context" "k8s.io/apimachinery/pkg/util/sets" registerapi "k8s.io/kubernetes/pkg/kubelet/apis/pluginregistration/v1alpha1" - v1beta1 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta1" - v1beta2 "k8s.io/kubernetes/pkg/kubelet/util/pluginwatcher/example_plugin_apis/v1beta2" ) -func TestExamplePlugin(t *testing.T) { - socketDir, err := ioutil.TempDir("", "plugin_test") - require.NoError(t, err) - socketPath := socketDir + "/plugin.sock" - w := NewWatcher(socketDir) - - testCases := []struct { - description string - expectedEndpoint string - returnErr error - }{ - { - description: "Successfully register plugin through inotify", - expectedEndpoint: "", - returnErr: nil, - }, - { - description: "Successfully register plugin through inotify and got expected optional endpoint", - expectedEndpoint: "dummyEndpoint", - returnErr: nil, - }, - { - description: "Fails registration because endpoint is expected to be non-empty", - expectedEndpoint: "dummyEndpoint", - returnErr: fmt.Errorf("empty endpoint received"), - }, - { - description: "Successfully register plugin through inotify after plugin restarts", - expectedEndpoint: "", - returnErr: nil, - }, - { - description: "Fails registration with conflicting plugin name", - expectedEndpoint: "", - returnErr: fmt.Errorf("conflicting plugin name"), - }, - { - description: "Successfully register plugin during initial traverse after plugin watcher restarts", - expectedEndpoint: "", - returnErr: nil, - }, - { - description: "Fails registration with conflicting plugin name during initial traverse after plugin watcher restarts", - expectedEndpoint: "", - returnErr: fmt.Errorf("conflicting plugin name"), - }, +// helper function +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + return false // completed normally + case <-time.After(timeout): + return true // timed out } +} - callbackCount := struct { - mutex sync.Mutex - count int32 - }{} - w.AddHandler(PluginType, func(name string, endpoint string, versions []string, sockPath string) (error, chan bool) { - callbackCount.mutex.Lock() - localCount := callbackCount.count - callbackCount.count = callbackCount.count + 1 - callbackCount.mutex.Unlock() +func TestExamplePlugin(t *testing.T) { + rootDir, err := ioutil.TempDir("", "plugin_test") + require.NoError(t, err) + w := NewWatcher(rootDir) + h := NewExampleHandler() + w.AddHandler(registerapi.DevicePlugin, h.Handler) - require.True(t, localCount <= int32((len(testCases)-1))) - require.Equal(t, PluginName, name, "Plugin name mismatched!!") - retError := testCases[localCount].returnErr - if retError == nil || retError.Error() != "empty endpoint received" { - require.Equal(t, testCases[localCount].expectedEndpoint, endpoint, "Unexpected endpoint") - } else { - require.NotEqual(t, testCases[localCount].expectedEndpoint, endpoint, "Unexpected endpoint") - } - - require.Equal(t, []string{"v1beta1", "v1beta2"}, versions, "Plugin version mismatched!!") - // Verifies the grpcServer is ready to serve services. - _, conn, err := dial(sockPath) - require.Nil(t, err) - defer conn.Close() - - // The plugin handler should be able to use any listed service API version. - v1beta1Client := v1beta1.NewExampleClient(conn) - v1beta2Client := v1beta2.NewExampleClient(conn) - - // Tests v1beta1 GetExampleInfo - _, err = v1beta1Client.GetExampleInfo(context.Background(), &v1beta1.ExampleRequest{}) - require.Nil(t, err) - - // Tests v1beta1 GetExampleInfo - _, err = v1beta2Client.GetExampleInfo(context.Background(), &v1beta2.ExampleRequest{}) - //atomic.AddInt32(&callbackCount, 1) - chanForAckOfNotification := make(chan bool) - - go func() { - select { - case <-chanForAckOfNotification: - close(chanForAckOfNotification) - case <-time.After(time.Second): - t.Fatalf("Timed out while waiting for notification ack") - } - }() - return retError, chanForAckOfNotification - }) require.NoError(t, w.Start()) - p := NewTestExamplePlugin("") - require.NoError(t, p.Serve(socketPath)) - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + socketPath := filepath.Join(rootDir, "plugin.sock") + PluginName := "example-plugin" - require.NoError(t, p.Stop()) - - p = NewTestExamplePlugin("dummyEndpoint") - require.NoError(t, p.Serve(socketPath)) - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) - - require.NoError(t, p.Stop()) - - p = NewTestExamplePlugin("") + // handler expecting plugin has a non-empty endpoint + p := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "") require.NoError(t, p.Serve(socketPath)) require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + require.NoError(t, p.Stop()) + + p = NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint") + require.NoError(t, p.Serve(socketPath)) + require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) // Trying to start a plugin service at the same socket path should fail // with "bind: address already in use" @@ -154,27 +75,20 @@ func TestExamplePlugin(t *testing.T) { // grpcServer.Stop() will remove the socket and starting plugin service // at the same path again should succeeds and trigger another callback. require.NoError(t, p.Stop()) - p = NewTestExamplePlugin("") - go func() { - require.Nil(t, p.Serve(socketPath)) - }() - require.True(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) + require.Nil(t, p.Serve(socketPath)) + require.False(t, waitForPluginRegistrationStatus(t, p.registrationStatus)) // Starting another plugin with the same name got verification error. - p2 := NewTestExamplePlugin("") - socketPath2 := socketDir + "/plugin2.sock" - go func() { - require.NoError(t, p2.Serve(socketPath2)) - }() + p2 := NewTestExamplePlugin(PluginName, registerapi.DevicePlugin, "dummyEndpoint") + socketPath2 := filepath.Join(rootDir, "plugin2.sock") + require.NoError(t, p2.Serve(socketPath2)) require.False(t, waitForPluginRegistrationStatus(t, p2.registrationStatus)) // Restarts plugin watcher should traverse the socket directory and issues a // callback for every existing socket. require.NoError(t, w.Stop()) - errCh := make(chan error) - go func() { - errCh <- w.Start() - }() + require.NoError(t, h.Cleanup()) + require.NoError(t, w.Start()) var wg sync.WaitGroup wg.Add(2) @@ -188,7 +102,11 @@ func TestExamplePlugin(t *testing.T) { p2Status = strconv.FormatBool(waitForPluginRegistrationStatus(t, p2.registrationStatus)) wg.Done() }() - wg.Wait() + + if waitTimeout(&wg, 2*time.Second) { + t.Fatalf("Timed out waiting for wait group") + } + expectedSet := sets.NewString() expectedSet.Insert("true", "false") actualSet := sets.NewString() @@ -197,16 +115,86 @@ func TestExamplePlugin(t *testing.T) { require.Equal(t, expectedSet, actualSet) select { - case err = <-errCh: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatalf("Timed out while waiting for watcher start") - + case err := <-h.chanForHandlerAckErrors: + t.Fatalf("%v", err) + case <-time.After(2 * time.Second): } require.NoError(t, w.Stop()) - err = w.Cleanup() + require.NoError(t, w.Cleanup()) +} + +func TestPluginWithSubDir(t *testing.T) { + rootDir, err := ioutil.TempDir("", "plugin_test") require.NoError(t, err) + + w := NewWatcher(rootDir) + hcsi := NewExampleHandler() + hdp := NewExampleHandler() + + w.AddHandler(registerapi.CSIPlugin, hcsi.Handler) + w.AddHandler(registerapi.DevicePlugin, hdp.Handler) + + err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.DevicePlugin), 0755) + require.NoError(t, err) + err = w.fs.MkdirAll(filepath.Join(rootDir, registerapi.CSIPlugin), 0755) + require.NoError(t, err) + + dpSocketPath := filepath.Join(rootDir, registerapi.DevicePlugin, "plugin.sock") + csiSocketPath := filepath.Join(rootDir, registerapi.CSIPlugin, "plugin.sock") + + require.NoError(t, w.Start()) + + // two plugins using the same name but with different type + dp := NewTestExamplePlugin("exampleplugin", registerapi.DevicePlugin, "example-endpoint") + require.NoError(t, dp.Serve(dpSocketPath)) + require.True(t, waitForPluginRegistrationStatus(t, dp.registrationStatus)) + + csi := NewTestExamplePlugin("exampleplugin", registerapi.CSIPlugin, "example-endpoint") + require.NoError(t, csi.Serve(csiSocketPath)) + require.True(t, waitForPluginRegistrationStatus(t, csi.registrationStatus)) + + // Restarts plugin watcher should traverse the socket directory and issues a + // callback for every existing socket. + require.NoError(t, w.Stop()) + require.NoError(t, hcsi.Cleanup()) + require.NoError(t, hdp.Cleanup()) + require.NoError(t, w.Start()) + + var wg sync.WaitGroup + wg.Add(2) + var dpStatus string + var csiStatus string + go func() { + dpStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, dp.registrationStatus)) + wg.Done() + }() + go func() { + csiStatus = strconv.FormatBool(waitForPluginRegistrationStatus(t, csi.registrationStatus)) + wg.Done() + }() + + if waitTimeout(&wg, 4*time.Second) { + require.NoError(t, errors.New("Timed out waiting for wait group")) + } + + expectedSet := sets.NewString() + expectedSet.Insert("true", "true") + actualSet := sets.NewString() + actualSet.Insert(dpStatus, csiStatus) + + require.Equal(t, expectedSet, actualSet) + + select { + case err := <-hcsi.chanForHandlerAckErrors: + t.Fatalf("%v", err) + case err := <-hdp.chanForHandlerAckErrors: + t.Fatalf("%v", err) + case <-time.After(4 * time.Second): + } + + require.NoError(t, w.Stop()) + require.NoError(t, w.Cleanup()) } func waitForPluginRegistrationStatus(t *testing.T, statusCh chan registerapi.RegistrationStatus) bool { diff --git a/pkg/volume/csi/csi_plugin.go b/pkg/volume/csi/csi_plugin.go index 85d513d7c4..60f04a2419 100644 --- a/pkg/volume/csi/csi_plugin.go +++ b/pkg/volume/csi/csi_plugin.go @@ -84,7 +84,7 @@ var lm labelmanager.Interface // RegistrationCallback is called by kubelet's plugin watcher upon detection // of a new registration socket opened by CSI Driver registrar side car. -func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (error, chan bool) { +func RegistrationCallback(pluginName string, endpoint string, versions []string, socketPath string) (chan bool, error) { glog.Infof(log("Callback from kubelet with plugin name: %s endpoint: %s versions: %s socket path: %s", pluginName, endpoint, strings.Join(versions, ","), socketPath)) @@ -95,7 +95,7 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string, // Calling nodeLabelManager to update label for newly registered CSI driver err := lm.AddLabels(pluginName) if err != nil { - return err, nil + return nil, err } // Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key // all other CSI components will be able to get the actual socket of CSI drivers by its name.