From 93c2656d5a5251642edec1d94d4ce652ddfaa766 Mon Sep 17 00:00:00 2001 From: Michail Kargakis Date: Mon, 2 May 2016 11:31:52 +0200 Subject: [PATCH] api: validate generation updates --- pkg/api/rest/resttest/resttest.go | 35 ++++++++++++++++++++++ pkg/api/rest/update.go | 6 ++++ pkg/api/validation/validation.go | 5 ++++ pkg/api/validation/validation_test.go | 43 +++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) diff --git a/pkg/api/rest/resttest/resttest.go b/pkg/api/rest/resttest/resttest.go index 9c71f1a6f7..a6f50a804b 100644 --- a/pkg/api/rest/resttest/resttest.go +++ b/pkg/api/rest/resttest/resttest.go @@ -173,6 +173,7 @@ func (t *Tester) TestUpdate(valid runtime.Object, createFn CreateFunc, getFn Get t.testUpdateWithWrongUID(copyOrDie(valid), createFn, getFn) t.testUpdateRetrievesOldObject(copyOrDie(valid), createFn, getFn) t.testUpdatePropagatesUpdatedObjectError(copyOrDie(valid), createFn, getFn) + t.testUpdateIgnoreGenerationUpdates(copyOrDie(valid), createFn, getFn) } // Test deleting an object. @@ -613,6 +614,40 @@ func (t *Tester) testUpdatePropagatesUpdatedObjectError(obj runtime.Object, crea } } +func (t *Tester) testUpdateIgnoreGenerationUpdates(obj runtime.Object, createFn CreateFunc, getFn GetFunc) { + ctx := t.TestContext() + + foo := copyOrDie(obj) + name := t.namer(8) + t.setObjectMeta(foo, name) + + if err := createFn(ctx, foo); err != nil { + t.Errorf("unexpected error: %v", err) + } + + storedFoo, err := getFn(ctx, foo) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + older := copyOrDie(storedFoo) + olderMeta := t.getObjectMetaOrFail(older) + olderMeta.Generation = 2 + + _, _, err = t.storage.(rest.Updater).Update(t.TestContext(), olderMeta.Name, rest.DefaultUpdatedObjectInfo(older, api.Scheme)) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + updatedFoo, err := getFn(ctx, older) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if exp, got := int64(1), t.getObjectMetaOrFail(updatedFoo).Generation; exp != got { + t.Errorf("Unexpected generation update: expected %d, got %d", exp, got) + } +} + func (t *Tester) testUpdateOnNotFound(obj runtime.Object) { t.setObjectMeta(obj, t.namer(0)) _, created, err := t.storage.(rest.Updater).Update(t.TestContext(), t.namer(0), rest.DefaultUpdatedObjectInfo(obj, api.Scheme)) diff --git a/pkg/api/rest/update.go b/pkg/api/rest/update.go index 73f1339197..68cccb213d 100644 --- a/pkg/api/rest/update.go +++ b/pkg/api/rest/update.go @@ -86,6 +86,12 @@ func BeforeUpdate(strategy RESTUpdateStrategy, ctx api.Context, obj, old runtime } else { objectMeta.Namespace = api.NamespaceNone } + // Ensure requests cannot update generation + oldMeta, err := api.ObjectMetaFor(old) + if err != nil { + return err + } + objectMeta.Generation = oldMeta.Generation strategy.PrepareForUpdate(obj, old) diff --git a/pkg/api/validation/validation.go b/pkg/api/validation/validation.go index dcb922b040..ec11f35c7c 100644 --- a/pkg/api/validation/validation.go +++ b/pkg/api/validation/validation.go @@ -373,6 +373,11 @@ func ValidateObjectMetaUpdate(newMeta, oldMeta *api.ObjectMeta, fldPath *field.P allErrs = append(allErrs, field.Invalid(fldPath.Child("resourceVersion"), newMeta.ResourceVersion, "must be specified for an update")) } + // Generation shouldn't be decremented + if newMeta.Generation < oldMeta.Generation { + allErrs = append(allErrs, field.Invalid(fldPath.Child("generation"), newMeta.Generation, "must not be decremented")) + } + allErrs = append(allErrs, ValidateImmutableField(newMeta.Name, oldMeta.Name, fldPath.Child("name"))...) allErrs = append(allErrs, ValidateImmutableField(newMeta.Namespace, oldMeta.Namespace, fldPath.Child("namespace"))...) allErrs = append(allErrs, ValidateImmutableField(newMeta.UID, oldMeta.UID, fldPath.Child("uid"))...) diff --git a/pkg/api/validation/validation_test.go b/pkg/api/validation/validation_test.go index 55c124fee7..936d1db706 100644 --- a/pkg/api/validation/validation_test.go +++ b/pkg/api/validation/validation_test.go @@ -330,6 +330,49 @@ func TestValidateObjectMetaUpdatePreventsDeletionFieldMutation(t *testing.T) { } } +func TestObjectMetaGenerationUpdate(t *testing.T) { + testcases := map[string]struct { + Old api.ObjectMeta + New api.ObjectMeta + ExpectedErrs []string + }{ + "invalid generation change - decremented": { + Old: api.ObjectMeta{Name: "test", ResourceVersion: "1", Generation: 5}, + New: api.ObjectMeta{Name: "test", ResourceVersion: "1", Generation: 4}, + ExpectedErrs: []string{"field.generation: Invalid value: 4: must not be decremented"}, + }, + "valid generation change - incremented by one": { + Old: api.ObjectMeta{Name: "test", ResourceVersion: "1", Generation: 1}, + New: api.ObjectMeta{Name: "test", ResourceVersion: "1", Generation: 2}, + ExpectedErrs: []string{}, + }, + "valid generation field - not updated": { + Old: api.ObjectMeta{Name: "test", ResourceVersion: "1", Generation: 5}, + New: api.ObjectMeta{Name: "test", ResourceVersion: "1", Generation: 5}, + ExpectedErrs: []string{}, + }, + } + + for k, tc := range testcases { + errList := []string{} + errs := ValidateObjectMetaUpdate(&tc.New, &tc.Old, field.NewPath("field")) + if len(errs) != len(tc.ExpectedErrs) { + t.Logf("%s: Expected: %#v", k, tc.ExpectedErrs) + for _, err := range errs { + errList = append(errList, err.Error()) + } + t.Logf("%s: Got: %#v", k, errList) + t.Errorf("%s: expected %d errors, got %d", k, len(tc.ExpectedErrs), len(errs)) + continue + } + for i := range errList { + if errList[i] != tc.ExpectedErrs[i] { + t.Errorf("%s: error #%d: expected %q, got %q", k, i, tc.ExpectedErrs[i], errList[i]) + } + } + } +} + // Ensure trailing slash is allowed in generate name func TestValidateObjectMetaTrimsTrailingSlash(t *testing.T) { errs := ValidateObjectMeta(