diff --git a/pkg/volume/csi/nodeinfomanager/BUILD b/pkg/volume/csi/nodeinfomanager/BUILD index 588871e29e..3eadf06129 100644 --- a/pkg/volume/csi/nodeinfomanager/BUILD +++ b/pkg/volume/csi/nodeinfomanager/BUILD @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/features:go_default_library", + "//pkg/util/node:go_default_library", "//pkg/volume:go_default_library", "//pkg/volume/util:go_default_library", "//staging/src/k8s.io/api/core/v1:go_default_library", @@ -53,12 +54,15 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/types:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/strategicpatch:go_default_library", "//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library", "//staging/src/k8s.io/apiserver/pkg/util/feature/testing:go_default_library", "//staging/src/k8s.io/client-go/kubernetes/fake:go_default_library", + "//staging/src/k8s.io/client-go/testing:go_default_library", "//staging/src/k8s.io/client-go/util/testing:go_default_library", "//staging/src/k8s.io/csi-api/pkg/apis/csi/v1alpha1:go_default_library", "//staging/src/k8s.io/csi-api/pkg/client/clientset/versioned/fake:go_default_library", "//vendor/github.com/container-storage-interface/spec/lib/go/csi/v0:go_default_library", + "//vendor/github.com/stretchr/testify/assert:go_default_library", ], ) diff --git a/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go b/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go index bea909d6a5..0a28d9c4ba 100644 --- a/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go +++ b/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go @@ -34,6 +34,7 @@ import ( "k8s.io/client-go/util/retry" csiv1alpha1 "k8s.io/csi-api/pkg/apis/csi/v1alpha1" "k8s.io/kubernetes/pkg/features" + nodeutil "k8s.io/kubernetes/pkg/util/node" "k8s.io/kubernetes/pkg/volume" "k8s.io/kubernetes/pkg/volume/util" ) @@ -150,7 +151,8 @@ func (nim *nodeInfoManager) updateNode(updateFuncs ...nodeUpdateFunc) error { } nodeClient := kubeClient.CoreV1().Nodes() - node, err := nodeClient.Get(string(nim.nodeName), metav1.GetOptions{}) + originalNode, err := nodeClient.Get(string(nim.nodeName), metav1.GetOptions{}) + node := originalNode.DeepCopy() if err != nil { return err // do not wrap error } @@ -166,7 +168,9 @@ func (nim *nodeInfoManager) updateNode(updateFuncs ...nodeUpdateFunc) error { } if needUpdate { - _, updateErr := nodeClient.Update(node) + // PatchNodeStatus can update both node's status and labels or annotations + // Updating status by directly updating node does not work + _, _, updateErr := nodeutil.PatchNodeStatus(kubeClient.CoreV1(), types.NodeName(node.Name), originalNode, node) return updateErr // do not wrap error } diff --git a/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go b/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go index 6d5c299b39..8e42f77869 100644 --- a/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go +++ b/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go @@ -18,18 +18,22 @@ package nodeinfomanager import ( "encoding/json" + "fmt" "testing" "github.com/container-storage-interface/spec/lib/go/csi/v0" + "github.com/stretchr/testify/assert" "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/strategicpatch" utilfeature "k8s.io/apiserver/pkg/util/feature" utilfeaturetesting "k8s.io/apiserver/pkg/util/feature/testing" "k8s.io/client-go/kubernetes/fake" + clienttesting "k8s.io/client-go/testing" utiltesting "k8s.io/client-go/util/testing" csiv1alpha1 "k8s.io/csi-api/pkg/apis/csi/v1alpha1" csifake "k8s.io/csi-api/pkg/client/clientset/versioned/fake" @@ -682,9 +686,18 @@ func test(t *testing.T, addNodeInfo bool, csiNodeInfoEnabled bool, testcases []t continue } - /* Node Validation */ - node, err := client.CoreV1().Nodes().Get(nodeName, metav1.GetOptions{}) - if err != nil { + actions := client.Actions() + + var node *v1.Node + if hasPatchAction(actions) { + node, err = applyNodeStatusPatch(tc.existingNode, actions[1].(clienttesting.PatchActionImpl).GetPatch()) + assert.NoError(t, err) + } else { + node, err = client.CoreV1().Nodes().Get(nodeName, metav1.GetOptions{}) + assert.NoError(t, err) + } + + if node == nil { t.Errorf("error getting node: %v", err) continue } @@ -807,3 +820,29 @@ func generateNodeInfo(nodeIDs map[string]string, topologyKeys map[string][]strin CSIDrivers: drivers, } } + +func applyNodeStatusPatch(originalNode *v1.Node, patch []byte) (*v1.Node, error) { + original, err := json.Marshal(originalNode) + if err != nil { + return nil, fmt.Errorf("failed to marshal original node %#v: %v", originalNode, err) + } + updated, err := strategicpatch.StrategicMergePatch(original, patch, v1.Node{}) + if err != nil { + return nil, fmt.Errorf("failed to apply strategic merge patch %q on node %#v: %v", + patch, originalNode, err) + } + updatedNode := &v1.Node{} + if err := json.Unmarshal(updated, updatedNode); err != nil { + return nil, fmt.Errorf("failed to unmarshal updated node %q: %v", updated, err) + } + return updatedNode, nil +} + +func hasPatchAction(actions []clienttesting.Action) bool { + for _, action := range actions { + if action.GetVerb() == "patch" { + return true + } + } + return false +}