Merge pull request #14758 from deads2k/fix-patch

make patch handle conflicts gracefully
pull/6/head
Alex Robinson 2015-10-05 14:41:51 -07:00
commit 52b8c40bfa
6 changed files with 494 additions and 154 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
package apiserver
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
@ -83,6 +84,9 @@ type RequestScope struct {
// may be used to deserialize an options object to pass to the getter.
type getterFunc func(ctx api.Context, name string, req *restful.Request) (runtime.Object, error)
// MaxPatchConflicts is the maximum number of conflicts retry for during a patch operation before returning failure
const MaxPatchConflicts = 5
// getResourceHandler is an HTTP handler function for get requests. It delegates to the
// passed-in getterFunc to perform the actual get.
func getResourceHandler(scope RequestScope, getter getterFunc) restful.RouteFunction {
@ -392,49 +396,26 @@ func PatchResource(r rest.Patcher, scope RequestScope, typer runtime.ObjectTyper
}
}
versionedObj, err := converter.ConvertToVersion(obj, scope.APIVersion)
versionedObj, err := converter.ConvertToVersion(r.New(), scope.APIVersion)
if err != nil {
errorJSON(err, scope.Codec, w)
return
}
original, err := r.Get(ctx, name)
if err != nil {
errorJSON(err, scope.Codec, w)
return
contentType := req.HeaderParameter("Content-Type")
// Remove "; charset=" if included in header.
if idx := strings.Index(contentType, ";"); idx > 0 {
contentType = contentType[:idx]
}
patchType := api.PatchType(contentType)
originalObjJS, err := scope.Codec.Encode(original)
if err != nil {
errorJSON(err, scope.Codec, w)
return
}
patchJS, err := readBody(req.Request)
if err != nil {
errorJSON(err, scope.Codec, w)
return
}
contentType := req.HeaderParameter("Content-Type")
patchedObjJS, err := getPatchedJS(contentType, originalObjJS, patchJS, versionedObj)
if err != nil {
errorJSON(err, scope.Codec, w)
return
}
if err := scope.Codec.DecodeInto(patchedObjJS, obj); err != nil {
errorJSON(err, scope.Codec, w)
return
}
if err := checkName(obj, name, namespace, scope.Namer); err != nil {
errorJSON(err, scope.Codec, w)
return
}
result, err := finishRequest(timeout, func() (runtime.Object, error) {
// update should never create as previous get would fail
obj, _, err := r.Update(ctx, obj)
return obj, err
})
result, err := patchResource(ctx, timeout, versionedObj, r, name, patchType, patchJS, scope.Namer, scope.Codec)
if err != nil {
errorJSON(err, scope.Codec, w)
return
@ -447,6 +428,95 @@ func PatchResource(r rest.Patcher, scope RequestScope, typer runtime.ObjectTyper
write(http.StatusOK, scope.APIVersion, scope.Codec, result, w, req.Request)
}
}
// patchResource divides PatchResource for easier unit testing
func patchResource(ctx api.Context, timeout time.Duration, versionedObj runtime.Object, patcher rest.Patcher, name string, patchType api.PatchType, patchJS []byte, namer ScopeNamer, codec runtime.Codec) (runtime.Object, error) {
namespace := api.NamespaceValue(ctx)
original, err := patcher.Get(ctx, name)
if err != nil {
return nil, err
}
originalObjJS, err := codec.Encode(original)
if err != nil {
return nil, err
}
originalPatchedObjJS, err := getPatchedJS(patchType, originalObjJS, patchJS, versionedObj)
if err != nil {
return nil, err
}
objToUpdate := patcher.New()
if err := codec.DecodeInto(originalPatchedObjJS, objToUpdate); err != nil {
return nil, err
}
if err := checkName(objToUpdate, name, namespace, namer); err != nil {
return nil, err
}
return finishRequest(timeout, func() (runtime.Object, error) {
// update should never create as previous get would fail
updateObject, _, updateErr := patcher.Update(ctx, objToUpdate)
for i := 0; i < MaxPatchConflicts && (errors.IsConflict(updateErr)); i++ {
// on a conflict,
// 1. build a strategic merge patch from originalJS and the patchedJS. Different patch types can
// be specified, but a strategic merge patch should be expressive enough handle them. Build the
// patch with this type to handle those cases.
// 2. build a strategic merge patch from originalJS and the currentJS
// 3. ensure no conflicts between the two patches
// 4. apply the #1 patch to the currentJS object
// 5. retry the update
currentObject, err := patcher.Get(ctx, name)
if err != nil {
return nil, err
}
currentObjectJS, err := codec.Encode(currentObject)
if err != nil {
return nil, err
}
currentPatch, err := strategicpatch.CreateStrategicMergePatch(originalObjJS, currentObjectJS, patcher.New())
if err != nil {
return nil, err
}
originalPatch, err := strategicpatch.CreateStrategicMergePatch(originalObjJS, originalPatchedObjJS, patcher.New())
if err != nil {
return nil, err
}
diff1 := make(map[string]interface{})
if err := json.Unmarshal(originalPatch, &diff1); err != nil {
return nil, err
}
diff2 := make(map[string]interface{})
if err := json.Unmarshal(currentPatch, &diff2); err != nil {
return nil, err
}
hasConflicts, err := strategicpatch.HasConflicts(diff1, diff2)
if err != nil {
return nil, err
}
if hasConflicts {
return updateObject, updateErr
}
newlyPatchedObjJS, err := getPatchedJS(api.StrategicMergePatchType, currentObjectJS, originalPatch, versionedObj)
if err != nil {
return nil, err
}
if err := codec.DecodeInto(newlyPatchedObjJS, objToUpdate); err != nil {
return nil, err
}
updateObject, _, updateErr = patcher.Update(ctx, objToUpdate)
}
return updateObject, updateErr
})
}
// UpdateResource returns a function that will handle a resource update
@ -736,12 +806,7 @@ func setListSelfLink(obj runtime.Object, req *restful.Request, namer ScopeNamer)
}
func getPatchedJS(contentType string, originalJS, patchJS []byte, obj runtime.Object) ([]byte, error) {
// Remove "; charset=" if included in header.
if idx := strings.Index(contentType, ";"); idx > 0 {
contentType = contentType[:idx]
}
patchType := api.PatchType(contentType)
func getPatchedJS(patchType api.PatchType, originalJS, patchJS []byte, obj runtime.Object) ([]byte, error) {
switch patchType {
case api.JSONPatchType:
patchObj, err := jsonpatch.DecodePatch(patchJS)
@ -755,6 +820,6 @@ func getPatchedJS(contentType string, originalJS, patchJS []byte, obj runtime.Ob
return strategicpatch.StrategicMergePatchData(originalJS, patchJS, obj)
default:
// only here as a safety net - go-restful filters content-type
return nil, fmt.Errorf("unknown Content-Type header for patch: %s", contentType)
return nil, fmt.Errorf("unknown Content-Type header for patch: %v", patchType)
}
}

View File

@ -17,10 +17,22 @@ limitations under the License.
package apiserver
import (
"errors"
"fmt"
"reflect"
"testing"
"time"
"github.com/emicklei/go-restful"
"github.com/evanphx/json-patch"
"k8s.io/kubernetes/pkg/api"
apierrors "k8s.io/kubernetes/pkg/api/errors"
"k8s.io/kubernetes/pkg/api/latest"
"k8s.io/kubernetes/pkg/api/unversioned"
"k8s.io/kubernetes/pkg/runtime"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/strategicpatch"
)
type testPatchType struct {
@ -40,12 +52,280 @@ func TestPatchAnonymousField(t *testing.T) {
patch := `{"theField": "changed!"}`
expectedJS := `{"kind":"testPatchType","theField":"changed!"}`
actualBytes, err := getPatchedJS(string(api.StrategicMergePatchType), []byte(originalJS), []byte(patch), &testPatchType{})
actualBytes, err := getPatchedJS(api.StrategicMergePatchType, []byte(originalJS), []byte(patch), &testPatchType{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if string(actualBytes) != expectedJS {
t.Errorf("expected %v, got %v", expectedJS, string(actualBytes))
}
}
type testPatcher struct {
// startingPod is used for the first Get
startingPod *api.Pod
// updatePod is the pod that is used for conflict comparison and returned for the SECOND Get
updatePod *api.Pod
numGets int
}
func (p *testPatcher) New() runtime.Object {
return &api.Pod{}
}
func (p *testPatcher) Update(ctx api.Context, obj runtime.Object) (runtime.Object, bool, error) {
inPod := obj.(*api.Pod)
if inPod.ResourceVersion != p.updatePod.ResourceVersion {
return nil, false, apierrors.NewConflict("Pod", inPod.Name, fmt.Errorf("existing %v, new %v", p.updatePod.ResourceVersion, inPod.ResourceVersion))
}
return inPod, false, nil
}
func (p *testPatcher) Get(ctx api.Context, name string) (runtime.Object, error) {
if p.numGets > 0 {
return p.updatePod, nil
}
p.numGets++
return p.startingPod, nil
}
type testNamer struct {
namespace string
name string
}
func (p *testNamer) Namespace(req *restful.Request) (namespace string, err error) {
return p.namespace, nil
}
// Name returns the name from the request, and an optional namespace value if this is a namespace
// scoped call. An error is returned if the name is not available.
func (p *testNamer) Name(req *restful.Request) (namespace, name string, err error) {
return p.namespace, p.name, nil
}
// ObjectName returns the namespace and name from an object if they exist, or an error if the object
// does not support names.
func (p *testNamer) ObjectName(obj runtime.Object) (namespace, name string, err error) {
return p.namespace, p.name, nil
}
// SetSelfLink sets the provided URL onto the object. The method should return nil if the object
// does not support selfLinks.
func (p *testNamer) SetSelfLink(obj runtime.Object, url string) error {
return errors.New("not implemented")
}
// GenerateLink creates a path and query for a given runtime object that represents the canonical path.
func (p *testNamer) GenerateLink(req *restful.Request, obj runtime.Object) (path, query string, err error) {
return "", "", errors.New("not implemented")
}
// GenerateLink creates a path and query for a list that represents the canonical path.
func (p *testNamer) GenerateListLink(req *restful.Request) (path, query string, err error) {
return "", "", errors.New("not implemented")
}
type patchTestCase struct {
name string
// startingPod is used for the first Get
startingPod *api.Pod
// changedPod is the "destination" pod for the patch. The test will create a patch from the startingPod to the changedPod
// to use when calling the patch operation
changedPod *api.Pod
// updatePod is the pod that is used for conflict comparison and returned for the SECOND Get
updatePod *api.Pod
// expectedPod is the pod that you expect to get back after the patch is complete
expectedPod *api.Pod
expectedError string
}
func (tc *patchTestCase) Run(t *testing.T) {
t.Logf("Starting test %s", tc.name)
namespace := tc.startingPod.Namespace
name := tc.startingPod.Name
codec := latest.GroupOrDie("").Codec
testPatcher := &testPatcher{}
testPatcher.startingPod = tc.startingPod
testPatcher.updatePod = tc.updatePod
ctx := api.NewDefaultContext()
ctx = api.WithNamespace(ctx, namespace)
namer := &testNamer{namespace, name}
versionedObj, err := api.Scheme.ConvertToVersion(&api.Pod{}, "v1")
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
for _, patchType := range []api.PatchType{api.JSONPatchType, api.MergePatchType, api.StrategicMergePatchType} {
// TODO SUPPORT THIS!
if patchType == api.JSONPatchType {
continue
}
t.Logf("Working with patchType %v", patchType)
originalObjJS, err := codec.Encode(tc.startingPod)
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
changedJS, err := codec.Encode(tc.changedPod)
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
patch := []byte{}
switch patchType {
case api.JSONPatchType:
continue
case api.StrategicMergePatchType:
patch, err = strategicpatch.CreateStrategicMergePatch(originalObjJS, changedJS, &api.Pod{})
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
case api.MergePatchType:
patch, err = jsonpatch.CreateMergePatch(originalObjJS, changedJS)
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
}
resultObj, err := patchResource(ctx, 1*time.Second, versionedObj, testPatcher, name, patchType, patch, namer, codec)
if len(tc.expectedError) != 0 {
if err == nil || err.Error() != tc.expectedError {
t.Errorf("%s: expected error %v, but got %v", tc.name, tc.expectedError, err)
return
}
} else {
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
}
if tc.expectedPod == nil {
if resultObj != nil {
t.Errorf("%s: unexpected result: %v", tc.name, resultObj)
}
return
}
resultPod := resultObj.(*api.Pod)
// roundtrip to get defaulting
expectedJS, err := codec.Encode(tc.expectedPod)
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
expectedObj, err := codec.Decode(expectedJS)
if err != nil {
t.Errorf("%s: unexpected error: %v", tc.name, err)
return
}
reallyExpectedPod := expectedObj.(*api.Pod)
if !reflect.DeepEqual(*reallyExpectedPod, *resultPod) {
t.Errorf("%s mismatch: %v\n", tc.name, util.ObjectGoPrintDiff(reallyExpectedPod, resultPod))
return
}
}
}
func TestPatchResourceWithVersionConflict(t *testing.T) {
namespace := "bar"
name := "foo"
fifteen := int64(15)
thirty := int64(30)
tc := &patchTestCase{
name: "TestPatchResourceWithVersionConflict",
startingPod: &api.Pod{},
changedPod: &api.Pod{},
updatePod: &api.Pod{},
expectedPod: &api.Pod{},
}
tc.startingPod.Name = name
tc.startingPod.Namespace = namespace
tc.startingPod.ResourceVersion = "1"
tc.startingPod.APIVersion = "v1"
tc.startingPod.Spec.ActiveDeadlineSeconds = &fifteen
tc.changedPod.Name = name
tc.changedPod.Namespace = namespace
tc.changedPod.ResourceVersion = "1"
tc.changedPod.APIVersion = "v1"
tc.changedPod.Spec.ActiveDeadlineSeconds = &thirty
tc.updatePod.Name = name
tc.updatePod.Namespace = namespace
tc.updatePod.ResourceVersion = "2"
tc.updatePod.APIVersion = "v1"
tc.updatePod.Spec.ActiveDeadlineSeconds = &fifteen
tc.updatePod.Spec.NodeName = "anywhere"
tc.expectedPod.Name = name
tc.expectedPod.Namespace = namespace
tc.expectedPod.ResourceVersion = "2"
tc.expectedPod.Spec.ActiveDeadlineSeconds = &thirty
tc.expectedPod.Spec.NodeName = "anywhere"
tc.Run(t)
}
func TestPatchResourceWithConflict(t *testing.T) {
namespace := "bar"
name := "foo"
tc := &patchTestCase{
name: "TestPatchResourceWithConflict",
startingPod: &api.Pod{},
changedPod: &api.Pod{},
updatePod: &api.Pod{},
expectedError: `Pod "foo" cannot be updated: existing 2, new 1`,
}
tc.startingPod.Name = name
tc.startingPod.Namespace = namespace
tc.startingPod.ResourceVersion = "1"
tc.startingPod.APIVersion = "v1"
tc.startingPod.Spec.NodeName = "here"
tc.changedPod.Name = name
tc.changedPod.Namespace = namespace
tc.changedPod.ResourceVersion = "1"
tc.changedPod.APIVersion = "v1"
tc.changedPod.Spec.NodeName = "there"
tc.updatePod.Name = name
tc.updatePod.Namespace = namespace
tc.updatePod.ResourceVersion = "2"
tc.updatePod.APIVersion = "v1"
tc.updatePod.Spec.NodeName = "anywhere"
tc.Run(t)
}

View File

@ -19,11 +19,11 @@ package jsonmerge
import (
"encoding/json"
"fmt"
"reflect"
"github.com/evanphx/json-patch"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/util/strategicpatch"
"k8s.io/kubernetes/pkg/util/yaml"
)
@ -161,9 +161,14 @@ func (d *Delta) Apply(latest []byte) ([]byte, error) {
}
glog.V(6).Infof("Testing for conflict between:\n%s\n%s", string(d.edit), string(changes))
if hasConflicts(diff1, diff2) {
hasConflicts, err := strategicpatch.HasConflicts(diff1, diff2)
if err != nil {
return nil, err
}
if hasConflicts {
return nil, ErrConflict
}
return jsonpatch.MergePatch(base, d.edit)
}
@ -183,45 +188,6 @@ func IsPreconditionFailed(err error) bool {
var ErrPreconditionFailed = fmt.Errorf("a precondition failed")
var ErrConflict = fmt.Errorf("changes are in conflict")
// hasConflicts returns true if the left and right JSON interface objects overlap with
// different values in any key. The code will panic if an unrecognized type is passed
// (anything not returned by a JSON decode). All keys are required to be strings.
func hasConflicts(left, right interface{}) bool {
switch typedLeft := left.(type) {
case map[string]interface{}:
switch typedRight := right.(type) {
case map[string]interface{}:
for key, leftValue := range typedLeft {
if rightValue, ok := typedRight[key]; ok && hasConflicts(leftValue, rightValue) {
return true
}
}
return false
default:
return true
}
case []interface{}:
switch typedRight := right.(type) {
case []interface{}:
if len(typedLeft) != len(typedRight) {
return true
}
for i := range typedLeft {
if hasConflicts(typedLeft[i], typedRight[i]) {
return true
}
}
return false
default:
return true
}
case string, float64, bool, int, int64, nil:
return !reflect.DeepEqual(left, right)
default:
panic(fmt.Sprintf("unknown type: %v", reflect.TypeOf(left)))
}
}
func (d *Delta) Edit() []byte {
return d.edit
}

View File

@ -1,75 +0,0 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package jsonmerge
import (
"testing"
)
func TestHasConflicts(t *testing.T) {
testCases := []struct {
A interface{}
B interface{}
Ret bool
}{
{A: "hello", B: "hello", Ret: false}, // 0
{A: "hello", B: "hell", Ret: true},
{A: "hello", B: nil, Ret: true},
{A: "hello", B: 1, Ret: true},
{A: "hello", B: float64(1.0), Ret: true},
{A: "hello", B: false, Ret: true},
{A: "hello", B: []interface{}{}, Ret: true}, // 6
{A: []interface{}{1}, B: []interface{}{}, Ret: true},
{A: []interface{}{}, B: []interface{}{}, Ret: false},
{A: []interface{}{1}, B: []interface{}{1}, Ret: false},
{A: map[string]interface{}{}, B: []interface{}{1}, Ret: true},
{A: map[string]interface{}{}, B: map[string]interface{}{"a": 1}, Ret: false}, // 11
{A: map[string]interface{}{"a": 1}, B: map[string]interface{}{"a": 1}, Ret: false},
{A: map[string]interface{}{"a": 1}, B: map[string]interface{}{"a": 2}, Ret: true},
{A: map[string]interface{}{"a": 1}, B: map[string]interface{}{"b": 2}, Ret: false},
{ // 15
A: map[string]interface{}{"a": []interface{}{1}},
B: map[string]interface{}{"a": []interface{}{1}},
Ret: false,
},
{
A: map[string]interface{}{"a": []interface{}{1}},
B: map[string]interface{}{"a": []interface{}{}},
Ret: true,
},
{
A: map[string]interface{}{"a": []interface{}{1}},
B: map[string]interface{}{"a": 1},
Ret: true,
},
}
for i, testCase := range testCases {
out := hasConflicts(testCase.A, testCase.B)
if out != testCase.Ret {
t.Errorf("%d: expected %t got %t", i, testCase.Ret, out)
continue
}
out = hasConflicts(testCase.B, testCase.A)
if out != testCase.Ret {
t.Errorf("%d: expected reversed %t got %t", i, testCase.Ret, out)
}
}
}

View File

@ -827,3 +827,44 @@ func sliceElementType(slices ...[]interface{}) (reflect.Type, error) {
return prevType, nil
}
// HasConflicts returns true if the left and right JSON interface objects overlap with
// different values in any key. The code will panic if an unrecognized type is passed
// (anything not returned by a JSON decode). All keys are required to be strings.
// Since patches of the same Type have congruent keys, this is valid for multiple patch
// types.
func HasConflicts(left, right interface{}) (bool, error) {
switch typedLeft := left.(type) {
case map[string]interface{}:
switch typedRight := right.(type) {
case map[string]interface{}:
for key, leftValue := range typedLeft {
rightValue, ok := typedRight[key]
if !ok {
return false, nil
}
return HasConflicts(leftValue, rightValue)
}
return false, nil
default:
return true, nil
}
case []interface{}:
switch typedRight := right.(type) {
case []interface{}:
if len(typedLeft) != len(typedRight) {
return true, nil
}
for i := range typedLeft {
return HasConflicts(typedLeft[i], typedRight[i])
}
return false, nil
default:
return true, nil
}
case string, float64, bool, int, int64, nil:
return !reflect.DeepEqual(left, right), nil
default:
return true, fmt.Errorf("unknown type: %v", reflect.TypeOf(left))
}
}

View File

@ -764,3 +764,66 @@ func jsonToYAML(j []byte) ([]byte, error) {
return y, nil
}
func TestHasConflicts(t *testing.T) {
testCases := []struct {
A interface{}
B interface{}
Ret bool
}{
{A: "hello", B: "hello", Ret: false}, // 0
{A: "hello", B: "hell", Ret: true},
{A: "hello", B: nil, Ret: true},
{A: "hello", B: 1, Ret: true},
{A: "hello", B: float64(1.0), Ret: true},
{A: "hello", B: false, Ret: true},
{A: 1, B: 1, Ret: false},
{A: false, B: false, Ret: false},
{A: float64(3), B: float64(3), Ret: false},
{A: "hello", B: []interface{}{}, Ret: true}, // 6
{A: []interface{}{1}, B: []interface{}{}, Ret: true},
{A: []interface{}{}, B: []interface{}{}, Ret: false},
{A: []interface{}{1}, B: []interface{}{1}, Ret: false},
{A: map[string]interface{}{}, B: []interface{}{1}, Ret: true},
{A: map[string]interface{}{}, B: map[string]interface{}{"a": 1}, Ret: false}, // 11
{A: map[string]interface{}{"a": 1}, B: map[string]interface{}{"a": 1}, Ret: false},
{A: map[string]interface{}{"a": 1}, B: map[string]interface{}{"a": 2}, Ret: true},
{A: map[string]interface{}{"a": 1}, B: map[string]interface{}{"b": 2}, Ret: false},
{ // 15
A: map[string]interface{}{"a": []interface{}{1}},
B: map[string]interface{}{"a": []interface{}{1}},
Ret: false,
},
{
A: map[string]interface{}{"a": []interface{}{1}},
B: map[string]interface{}{"a": []interface{}{}},
Ret: true,
},
{
A: map[string]interface{}{"a": []interface{}{1}},
B: map[string]interface{}{"a": 1},
Ret: true,
},
}
for i, testCase := range testCases {
out, err := HasConflicts(testCase.A, testCase.B)
if err != nil {
t.Errorf("%d: unexpected error: %v", i, err)
}
if out != testCase.Ret {
t.Errorf("%d: expected %t got %t", i, testCase.Ret, out)
continue
}
out, err = HasConflicts(testCase.B, testCase.A)
if err != nil {
t.Errorf("%d: unexpected error: %v", i, err)
}
if out != testCase.Ret {
t.Errorf("%d: expected reversed %t got %t", i, testCase.Ret, out)
}
}
}