fix(edgestacks): avoid a data race in edge stack status update endpoint EE-4737 (#8168)

pull/8136/head^2
andres-portainer 2022-12-14 10:41:45 -03:00 committed by GitHub
parent f38b8234d9
commit 37896661d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 51 deletions

View File

@ -82,12 +82,22 @@ func (service *Service) Create(id portainer.EdgeStackID, edgeStack *portainer.Ed
)
}
// UpdateEdgeStack updates an Edge stack.
// Deprecated: Use UpdateEdgeStackFunc instead.
func (service *Service) UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error {
identifier := service.connection.ConvertToKey(int(ID))
return service.connection.UpdateObject(BucketName, identifier, edgeStack)
}
// UpdateEdgeStackFunc updates an Edge stack inside a transaction avoiding data races.
func (service *Service) UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error {
id := service.connection.ConvertToKey(int(ID))
edgeStack := &portainer.EdgeStack{}
return service.connection.UpdateObjectFunc(BucketName, id, edgeStack, func() {
updateFunc(edgeStack)
})
}
// DeleteEdgeStack deletes an Edge stack.
func (service *Service) DeleteEdgeStack(ID portainer.EdgeStackID) error {
identifier := service.connection.ConvertToKey(int(ID))

View File

@ -89,6 +89,7 @@ type (
EdgeStack(ID portainer.EdgeStackID) (*portainer.EdgeStack, error)
Create(id portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error
UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error
UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
DeleteEdgeStack(ID portainer.EdgeStackID) error
GetNextIdentifier() int
BucketName() string

View File

@ -4,11 +4,12 @@ import (
"errors"
"net/http"
"github.com/asaskevich/govalidator"
httperror "github.com/portainer/libhttp/error"
"github.com/portainer/libhttp/request"
"github.com/portainer/libhttp/response"
portainer "github.com/portainer/portainer/api"
"github.com/asaskevich/govalidator"
httperror "github.com/portainer/libhttp/error"
)
type updateStatusPayload struct {
@ -52,13 +53,6 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req
return httperror.BadRequest("Invalid stack identifier route variable", err)
}
stack, err := handler.DataStore.EdgeStack().EdgeStack(portainer.EdgeStackID(stackID))
if handler.DataStore.IsErrObjectNotFound(err) {
return httperror.NotFound("Unable to find a stack with the specified identifier inside the database", err)
} else if err != nil {
return httperror.InternalServerError("Unable to find a stack with the specified identifier inside the database", err)
}
var payload updateStatusPayload
err = request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil {
@ -77,17 +71,22 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req
return httperror.Forbidden("Permission denied to access environment", err)
}
stack.Status[payload.EndpointID] = portainer.EdgeStackStatus{
Type: *payload.Status,
Error: payload.Error,
EndpointID: payload.EndpointID,
}
var stack portainer.EdgeStack
err = handler.DataStore.EdgeStack().UpdateEdgeStack(stack.ID, stack)
if err != nil {
err = handler.DataStore.EdgeStack().UpdateEdgeStackFunc(portainer.EdgeStackID(stackID), func(edgeStack *portainer.EdgeStack) {
edgeStack.Status[payload.EndpointID] = portainer.EdgeStackStatus{
Type: *payload.Status,
Error: payload.Error,
EndpointID: payload.EndpointID,
}
stack = *edgeStack
})
if handler.DataStore.IsErrObjectNotFound(err) {
return httperror.NotFound("Unable to find a stack with the specified identifier inside the database", err)
} else if err != nil {
return httperror.InternalServerError("Unable to persist the stack changes inside the database", err)
}
return response.JSON(w, stack)
}

View File

@ -103,10 +103,9 @@ func setupHandler(t *testing.T) (*Handler, string, func()) {
return handler, rawAPIKey, storeTeardown
}
func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoint {
func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.Endpoint {
t.Helper()
endpointID := portainer.EndpointID(5)
endpoint := portainer.Endpoint{
ID: endpointID,
Name: "test-endpoint-" + strconv.Itoa(int(endpointID)),
@ -124,6 +123,10 @@ func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoi
return endpoint
}
func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoint {
return createEndpointWithId(t, store, 5)
}
func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.EdgeStack {
t.Helper()
@ -203,7 +206,7 @@ func TestInspectInvalidEdgeID(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code))
t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
}
})
}
@ -263,7 +266,7 @@ func TestCreateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
data := portainer.EdgeStack{}
@ -283,7 +286,7 @@ func TestCreateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
data = portainer.EdgeStack{}
@ -293,7 +296,7 @@ func TestCreateAndInspect(t *testing.T) {
}
if payload.Name != data.Name {
t.Fatalf(fmt.Sprintf("expected EdgeStack Name %s, found %s", payload.Name, data.Name))
t.Fatalf("expected EdgeStack Name %s, found %s", payload.Name, data.Name)
}
}
@ -421,7 +424,7 @@ func TestCreateWithInvalidPayload(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code))
t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
}
})
}
@ -447,7 +450,7 @@ func TestDeleteAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
data := portainer.EdgeStack{}
@ -457,7 +460,7 @@ func TestDeleteAndInspect(t *testing.T) {
}
if data.ID != edgeStack.ID {
t.Fatalf(fmt.Sprintf("expected EdgeStackID %d, found %d", int(edgeStack.ID), data.ID))
t.Fatalf("expected EdgeStackID %d, found %d", int(edgeStack.ID), data.ID)
}
// Delete
@ -471,7 +474,7 @@ func TestDeleteAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusNoContent, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusNoContent, rec.Code)
}
// Inspect
@ -485,7 +488,7 @@ func TestDeleteAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusNotFound, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusNotFound, rec.Code)
}
}
@ -514,7 +517,7 @@ func TestDeleteInvalidEdgeStack(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code))
t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
}
})
}
@ -530,14 +533,7 @@ func TestUpdateAndInspect(t *testing.T) {
// Update edge stack: create new Endpoint, EndpointRelation and EdgeGroup
endpointID := portainer.EndpointID(6)
newEndpoint := portainer.Endpoint{
ID: endpointID,
Name: "test-endpoint-" + strconv.Itoa(int(endpointID)),
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "https://portainer.io:9443",
EdgeID: "edge-id",
LastCheckInDate: time.Now().Unix(),
}
newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID)
err := handler.DataStore.Endpoint().Create(&newEndpoint)
if err != nil {
@ -594,7 +590,7 @@ func TestUpdateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
// Get updated edge stack
@ -608,7 +604,7 @@ func TestUpdateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
data := portainer.EdgeStack{}
@ -618,11 +614,11 @@ func TestUpdateAndInspect(t *testing.T) {
}
if data.Version != *payload.Version {
t.Fatalf(fmt.Sprintf("expected EdgeStackID %d, found %d", edgeStack.Version, data.Version))
t.Fatalf("expected EdgeStackID %d, found %d", edgeStack.Version, data.Version)
}
if data.DeploymentType != payload.DeploymentType {
t.Fatalf(fmt.Sprintf("expected DeploymentType %d, found %d", edgeStack.DeploymentType, data.DeploymentType))
t.Fatalf("expected DeploymentType %d, found %d", edgeStack.DeploymentType, data.DeploymentType)
}
if !reflect.DeepEqual(data.EdgeGroups, payload.EdgeGroups) {
@ -705,7 +701,7 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code))
t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
}
})
}
@ -764,7 +760,7 @@ func TestUpdateWithInvalidPayload(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code))
t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
}
})
}
@ -802,7 +798,7 @@ func TestUpdateStatusAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
// Get updated edge stack
@ -816,7 +812,7 @@ func TestUpdateStatusAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
data := portainer.EdgeStack{}
@ -826,15 +822,15 @@ func TestUpdateStatusAndInspect(t *testing.T) {
}
if data.Status[endpoint.ID].Type != *payload.Status {
t.Fatalf(fmt.Sprintf("expected EdgeStackStatusType %d, found %d", payload.Status, data.Status[endpoint.ID].Type))
t.Fatalf("expected EdgeStackStatusType %d, found %d", payload.Status, data.Status[endpoint.ID].Type)
}
if data.Status[endpoint.ID].Error != payload.Error {
t.Fatalf(fmt.Sprintf("expected EdgeStackStatusError %s, found %s", payload.Error, data.Status[endpoint.ID].Error))
t.Fatalf("expected EdgeStackStatusError %s, found %s", payload.Error, data.Status[endpoint.ID].Error)
}
if data.Status[endpoint.ID].EndpointID != payload.EndpointID {
t.Fatalf(fmt.Sprintf("expected EndpointID %d, found %d", payload.EndpointID, data.Status[endpoint.ID].EndpointID))
t.Fatalf("expected EndpointID %d, found %d", payload.EndpointID, data.Status[endpoint.ID].EndpointID)
}
}
func TestUpdateStatusWithInvalidPayload(t *testing.T) {
@ -903,7 +899,7 @@ func TestUpdateStatusWithInvalidPayload(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code))
t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
}
})
}
@ -927,6 +923,6 @@ func TestDeleteStatus(t *testing.T) {
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code))
t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
}
}