mirror of https://github.com/k3s-io/k3s
Merge pull request #67110 from verult/kubelet-nodeid
Automatic merge from submit-queue (batch tested with PRs 67017, 67190, 67110, 67140, 66873). If you want to cherry-pick this change to another branch, please follow the instructions <a href="https://github.com/kubernetes/community/blob/master/contributors/devel/cherry-picks.md">here</a>. CSI plugin now calls NodeGetInfo() to get driver's node ID **Which issue(s) this PR fixes** *(optional, in `fixes #<issue number>(, fixes #<issue_number>, ...)` format, will close the issue(s) when PR gets merged)*: Fixes #67040 **Special notes for your reviewer**: **Release note**: ```release-note NONE ``` /sig storage @sbezverk @vladimirvivien @saad-alipull/8/head
commit
032a096d86
|
@ -32,6 +32,11 @@ import (
|
|||
)
|
||||
|
||||
type csiClient interface {
|
||||
NodeGetInfo(ctx context.Context) (
|
||||
nodeID string,
|
||||
maxVolumePerNode int64,
|
||||
accessibleTopology *csipb.Topology,
|
||||
err error)
|
||||
NodePublishVolume(
|
||||
ctx context.Context,
|
||||
volumeid string,
|
||||
|
@ -75,6 +80,24 @@ func newCsiDriverClient(driverName string) *csiDriverClient {
|
|||
return c
|
||||
}
|
||||
|
||||
func (c *csiDriverClient) NodeGetInfo(ctx context.Context) (
|
||||
nodeID string,
|
||||
maxVolumePerNode int64,
|
||||
accessibleTopology *csipb.Topology,
|
||||
err error) {
|
||||
glog.V(4).Info(log("calling NodeGetInfo rpc"))
|
||||
|
||||
conn, err := newGrpcConn(c.driverName)
|
||||
if err != nil {
|
||||
return "", 0, nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
nodeClient := csipb.NewNodeClient(conn)
|
||||
|
||||
res, err := nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{})
|
||||
return res.GetNodeId(), res.GetMaxVolumesPerNode(), res.GetAccessibleTopology(), nil
|
||||
}
|
||||
|
||||
func (c *csiDriverClient) NodePublishVolume(
|
||||
ctx context.Context,
|
||||
volID string,
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
csipb "github.com/container-storage-interface/spec/lib/go/csi/v0"
|
||||
api "k8s.io/api/core/v1"
|
||||
"k8s.io/kubernetes/pkg/volume/csi/fake"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type fakeCsiDriverClient struct {
|
||||
|
@ -38,6 +39,15 @@ func newFakeCsiDriverClient(t *testing.T, stagingCapable bool) *fakeCsiDriverCli
|
|||
}
|
||||
}
|
||||
|
||||
func (c *fakeCsiDriverClient) NodeGetInfo(ctx context.Context) (
|
||||
nodeID string,
|
||||
maxVolumePerNode int64,
|
||||
accessibleTopology *csipb.Topology,
|
||||
err error) {
|
||||
resp, err := c.nodeClient.NodeGetInfo(ctx, &csipb.NodeGetInfoRequest{})
|
||||
return resp.GetNodeId(), resp.GetMaxVolumesPerNode(), resp.GetAccessibleTopology(), err
|
||||
}
|
||||
|
||||
func (c *fakeCsiDriverClient) NodePublishVolume(
|
||||
ctx context.Context,
|
||||
volID string,
|
||||
|
@ -141,6 +151,60 @@ func setupClient(t *testing.T, stageUnstageSet bool) csiClient {
|
|||
return newFakeCsiDriverClient(t, stageUnstageSet)
|
||||
}
|
||||
|
||||
func TestClientNodeGetInfo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expectedNodeID string
|
||||
expectedMaxVolumePerNode int64
|
||||
expectedAccessibleTopology *csipb.Topology
|
||||
mustFail bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "test ok",
|
||||
expectedNodeID: "node1",
|
||||
expectedMaxVolumePerNode: 16,
|
||||
expectedAccessibleTopology: &csipb.Topology{
|
||||
Segments: map[string]string{"com.example.csi-topology/zone": "zone1"},
|
||||
},
|
||||
},
|
||||
{name: "grpc error", mustFail: true, err: errors.New("grpc error")},
|
||||
}
|
||||
|
||||
client := setupClient(t, false /* stageUnstageSet */)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Logf("test case: %s", tc.name)
|
||||
client.(*fakeCsiDriverClient).nodeClient.SetNextError(tc.err)
|
||||
client.(*fakeCsiDriverClient).nodeClient.SetNodeGetInfoResp(&csipb.NodeGetInfoResponse{
|
||||
NodeId: tc.expectedNodeID,
|
||||
MaxVolumesPerNode: tc.expectedMaxVolumePerNode,
|
||||
AccessibleTopology: tc.expectedAccessibleTopology,
|
||||
})
|
||||
nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background())
|
||||
|
||||
if tc.mustFail && err == nil {
|
||||
t.Error("expected an error but got none")
|
||||
}
|
||||
|
||||
if !tc.mustFail && err != nil {
|
||||
t.Errorf("expected no errors but got: %v", err)
|
||||
}
|
||||
|
||||
if nodeID != tc.expectedNodeID {
|
||||
t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID)
|
||||
}
|
||||
|
||||
if maxVolumePerNode != tc.expectedMaxVolumePerNode {
|
||||
t.Errorf("expected maxVolumePerNode: %v; got: %v", tc.expectedMaxVolumePerNode, maxVolumePerNode)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(accessibleTopology, tc.expectedAccessibleTopology) {
|
||||
t.Errorf("expected accessibleTopology: %v; got: %v", *tc.expectedAccessibleTopology, *accessibleTopology)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientNodePublishVolume(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"context"
|
||||
"github.com/golang/glog"
|
||||
api "k8s.io/api/core/v1"
|
||||
meta "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
@ -76,6 +77,7 @@ type csiDriversStore struct {
|
|||
sync.RWMutex
|
||||
}
|
||||
|
||||
// TODO (verult) consider using a struct instead of global variables
|
||||
// csiDrivers map keep track of all registered CSI drivers on the node and their
|
||||
// corresponding sockets
|
||||
var csiDrivers csiDriversStore
|
||||
|
@ -92,17 +94,33 @@ func RegistrationCallback(pluginName string, endpoint string, versions []string,
|
|||
if endpoint == "" {
|
||||
endpoint = socketPath
|
||||
}
|
||||
// Calling nodeLabelManager to update label for newly registered CSI driver
|
||||
err := lm.AddLabels(pluginName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Storing endpoint of newly registered CSI driver into the map, where CSI driver name will be the key
|
||||
// all other CSI components will be able to get the actual socket of CSI drivers by its name.
|
||||
csiDrivers.Lock()
|
||||
defer csiDrivers.Unlock()
|
||||
csiDrivers.driversMap[pluginName] = csiDriver{driverName: pluginName, driverEndpoint: endpoint}
|
||||
|
||||
// Get node info from the driver.
|
||||
csi := newCsiDriverClient(pluginName)
|
||||
// TODO (verult) retry with exponential backoff, possibly added in csi client library.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), csiTimeout)
|
||||
defer cancel()
|
||||
driverNodeID, _, _, err := csi.NodeGetInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during CSI NodeGetInfo() call: %v", err)
|
||||
}
|
||||
|
||||
// Calling nodeLabelManager to update annotations and labels for newly registered CSI driver
|
||||
err = lm.AddLabels(pluginName, driverNodeID)
|
||||
if err != nil {
|
||||
// Unregister the driver and return error
|
||||
csiDrivers.Lock()
|
||||
defer csiDrivers.Unlock()
|
||||
delete(csiDrivers.driversMap, pluginName)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -61,6 +61,7 @@ type NodeClient struct {
|
|||
nodePublishedVolumes map[string]string
|
||||
nodeStagedVolumes map[string]string
|
||||
stageUnstageSet bool
|
||||
nodeGetInfoResp *csipb.NodeGetInfoResponse
|
||||
nextErr error
|
||||
}
|
||||
|
||||
|
@ -78,6 +79,10 @@ func (f *NodeClient) SetNextError(err error) {
|
|||
f.nextErr = err
|
||||
}
|
||||
|
||||
func (f *NodeClient) SetNodeGetInfoResp(resp *csipb.NodeGetInfoResponse) {
|
||||
f.nodeGetInfoResp = resp
|
||||
}
|
||||
|
||||
// GetNodePublishedVolumes returns node published volumes
|
||||
func (f *NodeClient) GetNodePublishedVolumes() map[string]string {
|
||||
return f.nodePublishedVolumes
|
||||
|
@ -179,6 +184,14 @@ func (f *NodeClient) NodeGetId(ctx context.Context, in *csipb.NodeGetIdRequest,
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
// NodeGetId implements csi method
|
||||
func (f *NodeClient) NodeGetInfo(ctx context.Context, in *csipb.NodeGetInfoRequest, opts ...grpc.CallOption) (*csipb.NodeGetInfoResponse, error) {
|
||||
if f.nextErr != nil {
|
||||
return nil, f.nextErr
|
||||
}
|
||||
return f.nodeGetInfoResp, nil
|
||||
}
|
||||
|
||||
// NodeGetCapabilities implements csi method
|
||||
func (f *NodeClient) NodeGetCapabilities(ctx context.Context, in *csipb.NodeGetCapabilitiesRequest, opts ...grpc.CallOption) (*csipb.NodeGetCapabilitiesResponse, error) {
|
||||
resp := &csipb.NodeGetCapabilitiesResponse{
|
||||
|
|
|
@ -34,7 +34,6 @@ const (
|
|||
// Name of node annotation that contains JSON map of driver names to node
|
||||
// names
|
||||
annotationKey = "csi.volume.kubernetes.io/nodeid"
|
||||
csiPluginName = "kubernetes.io/csi"
|
||||
)
|
||||
|
||||
// labelManagementStruct is struct of channels used for communication between the driver registration
|
||||
|
@ -46,7 +45,7 @@ type labelManagerStruct struct {
|
|||
|
||||
// Interface implements an interface for managing labels of a node
|
||||
type Interface interface {
|
||||
AddLabels(driverName string) error
|
||||
AddLabels(driverName string, driverNodeId string) error
|
||||
}
|
||||
|
||||
// NewLabelManager initializes labelManagerStruct and returns available interfaces
|
||||
|
@ -59,8 +58,8 @@ func NewLabelManager(nodeName types.NodeName, kubeClient kubernetes.Interface) I
|
|||
|
||||
// nodeLabelManager waits for labeling requests initiated by the driver's registration
|
||||
// process.
|
||||
func (lm labelManagerStruct) AddLabels(driverName string) error {
|
||||
err := verifyAndAddNodeId(string(lm.nodeName), lm.k8s.CoreV1().Nodes(), driverName, string(lm.nodeName))
|
||||
func (lm labelManagerStruct) AddLabels(driverName string, driverNodeId string) error {
|
||||
err := verifyAndAddNodeId(string(lm.nodeName), lm.k8s.CoreV1().Nodes(), driverName, driverNodeId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update node %s's annotation with error: %+v", lm.nodeName, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue