fix(edgestacks): avoid a data race in edge stack status update endpoint EE-4737 (#8168)

pull/8136/head^2
andres-portainer 2022-12-14 10:41:45 -03:00 committed by GitHub
parent f38b8234d9
commit 37896661d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 51 deletions

View File

@ -82,12 +82,22 @@ func (service *Service) Create(id portainer.EdgeStackID, edgeStack *portainer.Ed
) )
} }
// UpdateEdgeStack updates an Edge stack. // Deprecated: Use UpdateEdgeStackFunc instead.
func (service *Service) UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error { func (service *Service) UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error {
identifier := service.connection.ConvertToKey(int(ID)) identifier := service.connection.ConvertToKey(int(ID))
return service.connection.UpdateObject(BucketName, identifier, edgeStack) return service.connection.UpdateObject(BucketName, identifier, edgeStack)
} }
// UpdateEdgeStackFunc updates an Edge stack inside a transaction avoiding data races.
func (service *Service) UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error {
id := service.connection.ConvertToKey(int(ID))
edgeStack := &portainer.EdgeStack{}
return service.connection.UpdateObjectFunc(BucketName, id, edgeStack, func() {
updateFunc(edgeStack)
})
}
// DeleteEdgeStack deletes an Edge stack. // DeleteEdgeStack deletes an Edge stack.
func (service *Service) DeleteEdgeStack(ID portainer.EdgeStackID) error { func (service *Service) DeleteEdgeStack(ID portainer.EdgeStackID) error {
identifier := service.connection.ConvertToKey(int(ID)) identifier := service.connection.ConvertToKey(int(ID))

View File

@ -89,6 +89,7 @@ type (
EdgeStack(ID portainer.EdgeStackID) (*portainer.EdgeStack, error) EdgeStack(ID portainer.EdgeStackID) (*portainer.EdgeStack, error)
Create(id portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error Create(id portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error
UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error
UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
DeleteEdgeStack(ID portainer.EdgeStackID) error DeleteEdgeStack(ID portainer.EdgeStackID) error
GetNextIdentifier() int GetNextIdentifier() int
BucketName() string BucketName() string

View File

@ -4,11 +4,12 @@ import (
"errors" "errors"
"net/http" "net/http"
"github.com/asaskevich/govalidator"
httperror "github.com/portainer/libhttp/error"
"github.com/portainer/libhttp/request" "github.com/portainer/libhttp/request"
"github.com/portainer/libhttp/response" "github.com/portainer/libhttp/response"
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/asaskevich/govalidator"
httperror "github.com/portainer/libhttp/error"
) )
type updateStatusPayload struct { type updateStatusPayload struct {
@ -52,13 +53,6 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req
return httperror.BadRequest("Invalid stack identifier route variable", err) return httperror.BadRequest("Invalid stack identifier route variable", err)
} }
stack, err := handler.DataStore.EdgeStack().EdgeStack(portainer.EdgeStackID(stackID))
if handler.DataStore.IsErrObjectNotFound(err) {
return httperror.NotFound("Unable to find a stack with the specified identifier inside the database", err)
} else if err != nil {
return httperror.InternalServerError("Unable to find a stack with the specified identifier inside the database", err)
}
var payload updateStatusPayload var payload updateStatusPayload
err = request.DecodeAndValidateJSONPayload(r, &payload) err = request.DecodeAndValidateJSONPayload(r, &payload)
if err != nil { if err != nil {
@ -77,17 +71,22 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req
return httperror.Forbidden("Permission denied to access environment", err) return httperror.Forbidden("Permission denied to access environment", err)
} }
stack.Status[payload.EndpointID] = portainer.EdgeStackStatus{ var stack portainer.EdgeStack
err = handler.DataStore.EdgeStack().UpdateEdgeStackFunc(portainer.EdgeStackID(stackID), func(edgeStack *portainer.EdgeStack) {
edgeStack.Status[payload.EndpointID] = portainer.EdgeStackStatus{
Type: *payload.Status, Type: *payload.Status,
Error: payload.Error, Error: payload.Error,
EndpointID: payload.EndpointID, EndpointID: payload.EndpointID,
} }
err = handler.DataStore.EdgeStack().UpdateEdgeStack(stack.ID, stack) stack = *edgeStack
if err != nil { })
if handler.DataStore.IsErrObjectNotFound(err) {
return httperror.NotFound("Unable to find a stack with the specified identifier inside the database", err)
} else if err != nil {
return httperror.InternalServerError("Unable to persist the stack changes inside the database", err) return httperror.InternalServerError("Unable to persist the stack changes inside the database", err)
} }
return response.JSON(w, stack) return response.JSON(w, stack)
} }

View File

@ -103,10 +103,9 @@ func setupHandler(t *testing.T) (*Handler, string, func()) {
return handler, rawAPIKey, storeTeardown return handler, rawAPIKey, storeTeardown
} }
func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoint { func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.Endpoint {
t.Helper() t.Helper()
endpointID := portainer.EndpointID(5)
endpoint := portainer.Endpoint{ endpoint := portainer.Endpoint{
ID: endpointID, ID: endpointID,
Name: "test-endpoint-" + strconv.Itoa(int(endpointID)), Name: "test-endpoint-" + strconv.Itoa(int(endpointID)),
@ -124,6 +123,10 @@ func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoi
return endpoint return endpoint
} }
func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoint {
return createEndpointWithId(t, store, 5)
}
func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.EdgeStack { func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.EdgeStack {
t.Helper() t.Helper()
@ -203,7 +206,7 @@ func TestInspectInvalidEdgeID(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode { if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
} }
}) })
} }
@ -263,7 +266,7 @@ func TestCreateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
data := portainer.EdgeStack{} data := portainer.EdgeStack{}
@ -283,7 +286,7 @@ func TestCreateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
data = portainer.EdgeStack{} data = portainer.EdgeStack{}
@ -293,7 +296,7 @@ func TestCreateAndInspect(t *testing.T) {
} }
if payload.Name != data.Name { if payload.Name != data.Name {
t.Fatalf(fmt.Sprintf("expected EdgeStack Name %s, found %s", payload.Name, data.Name)) t.Fatalf("expected EdgeStack Name %s, found %s", payload.Name, data.Name)
} }
} }
@ -421,7 +424,7 @@ func TestCreateWithInvalidPayload(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode { if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
} }
}) })
} }
@ -447,7 +450,7 @@ func TestDeleteAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
data := portainer.EdgeStack{} data := portainer.EdgeStack{}
@ -457,7 +460,7 @@ func TestDeleteAndInspect(t *testing.T) {
} }
if data.ID != edgeStack.ID { if data.ID != edgeStack.ID {
t.Fatalf(fmt.Sprintf("expected EdgeStackID %d, found %d", int(edgeStack.ID), data.ID)) t.Fatalf("expected EdgeStackID %d, found %d", int(edgeStack.ID), data.ID)
} }
// Delete // Delete
@ -471,7 +474,7 @@ func TestDeleteAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent { if rec.Code != http.StatusNoContent {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusNoContent, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusNoContent, rec.Code)
} }
// Inspect // Inspect
@ -485,7 +488,7 @@ func TestDeleteAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound { if rec.Code != http.StatusNotFound {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusNotFound, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusNotFound, rec.Code)
} }
} }
@ -514,7 +517,7 @@ func TestDeleteInvalidEdgeStack(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode { if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
} }
}) })
} }
@ -530,14 +533,7 @@ func TestUpdateAndInspect(t *testing.T) {
// Update edge stack: create new Endpoint, EndpointRelation and EdgeGroup // Update edge stack: create new Endpoint, EndpointRelation and EdgeGroup
endpointID := portainer.EndpointID(6) endpointID := portainer.EndpointID(6)
newEndpoint := portainer.Endpoint{ newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID)
ID: endpointID,
Name: "test-endpoint-" + strconv.Itoa(int(endpointID)),
Type: portainer.EdgeAgentOnDockerEnvironment,
URL: "https://portainer.io:9443",
EdgeID: "edge-id",
LastCheckInDate: time.Now().Unix(),
}
err := handler.DataStore.Endpoint().Create(&newEndpoint) err := handler.DataStore.Endpoint().Create(&newEndpoint)
if err != nil { if err != nil {
@ -594,7 +590,7 @@ func TestUpdateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
// Get updated edge stack // Get updated edge stack
@ -608,7 +604,7 @@ func TestUpdateAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
data := portainer.EdgeStack{} data := portainer.EdgeStack{}
@ -618,11 +614,11 @@ func TestUpdateAndInspect(t *testing.T) {
} }
if data.Version != *payload.Version { if data.Version != *payload.Version {
t.Fatalf(fmt.Sprintf("expected EdgeStackID %d, found %d", edgeStack.Version, data.Version)) t.Fatalf("expected EdgeStackID %d, found %d", edgeStack.Version, data.Version)
} }
if data.DeploymentType != payload.DeploymentType { if data.DeploymentType != payload.DeploymentType {
t.Fatalf(fmt.Sprintf("expected DeploymentType %d, found %d", edgeStack.DeploymentType, data.DeploymentType)) t.Fatalf("expected DeploymentType %d, found %d", edgeStack.DeploymentType, data.DeploymentType)
} }
if !reflect.DeepEqual(data.EdgeGroups, payload.EdgeGroups) { if !reflect.DeepEqual(data.EdgeGroups, payload.EdgeGroups) {
@ -705,7 +701,7 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode { if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
} }
}) })
} }
@ -764,7 +760,7 @@ func TestUpdateWithInvalidPayload(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode { if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
} }
}) })
} }
@ -802,7 +798,7 @@ func TestUpdateStatusAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
// Get updated edge stack // Get updated edge stack
@ -816,7 +812,7 @@ func TestUpdateStatusAndInspect(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
data := portainer.EdgeStack{} data := portainer.EdgeStack{}
@ -826,15 +822,15 @@ func TestUpdateStatusAndInspect(t *testing.T) {
} }
if data.Status[endpoint.ID].Type != *payload.Status { if data.Status[endpoint.ID].Type != *payload.Status {
t.Fatalf(fmt.Sprintf("expected EdgeStackStatusType %d, found %d", payload.Status, data.Status[endpoint.ID].Type)) t.Fatalf("expected EdgeStackStatusType %d, found %d", payload.Status, data.Status[endpoint.ID].Type)
} }
if data.Status[endpoint.ID].Error != payload.Error { if data.Status[endpoint.ID].Error != payload.Error {
t.Fatalf(fmt.Sprintf("expected EdgeStackStatusError %s, found %s", payload.Error, data.Status[endpoint.ID].Error)) t.Fatalf("expected EdgeStackStatusError %s, found %s", payload.Error, data.Status[endpoint.ID].Error)
} }
if data.Status[endpoint.ID].EndpointID != payload.EndpointID { if data.Status[endpoint.ID].EndpointID != payload.EndpointID {
t.Fatalf(fmt.Sprintf("expected EndpointID %d, found %d", payload.EndpointID, data.Status[endpoint.ID].EndpointID)) t.Fatalf("expected EndpointID %d, found %d", payload.EndpointID, data.Status[endpoint.ID].EndpointID)
} }
} }
func TestUpdateStatusWithInvalidPayload(t *testing.T) { func TestUpdateStatusWithInvalidPayload(t *testing.T) {
@ -903,7 +899,7 @@ func TestUpdateStatusWithInvalidPayload(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatusCode { if rec.Code != tc.ExpectedStatusCode {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)
} }
}) })
} }
@ -927,6 +923,6 @@ func TestDeleteStatus(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK { if rec.Code != http.StatusOK {
t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code)
} }
} }