From ea03024fbcfac0db1c2dc843bce018335754267d Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Wed, 15 Mar 2023 14:53:38 -0300 Subject: [PATCH] fix(edgegroup): fix data race in edge group update EE-4441 (#8523) --- .../handler/edgegroups/edgegroup_create.go | 14 +- .../handler/edgegroups/edgegroup_update.go | 180 +++++++++--------- api/http/handler/edgegroups/handler.go | 16 +- 3 files changed, 114 insertions(+), 96 deletions(-) diff --git a/api/http/handler/edgegroups/edgegroup_create.go b/api/http/handler/edgegroups/edgegroup_create.go index 68340d7c0..18797f14a 100644 --- a/api/http/handler/edgegroups/edgegroup_create.go +++ b/api/http/handler/edgegroups/edgegroup_create.go @@ -6,7 +6,6 @@ import ( httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" - "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/endpointutils" @@ -26,12 +25,15 @@ func (payload *edgeGroupCreatePayload) Validate(r *http.Request) error { if govalidator.IsNull(payload.Name) { return errors.New("invalid Edge group name") } + if payload.Dynamic && len(payload.TagIDs) == 0 { return errors.New("tagIDs is mandatory for a dynamic Edge group") } + if !payload.Dynamic && len(payload.Endpoints) == 0 { return errors.New("environment is mandatory for a static Edge group") } + return nil } @@ -56,7 +58,6 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) } var edgeGroup *portainer.EdgeGroup - err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { edgeGroups, err := tx.EdgeGroup().EdgeGroups() if err != nil { @@ -101,13 +102,6 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) return nil }) - if err != nil { - if httpErr, ok := err.(*httperror.HandlerError); ok { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, edgeGroup) + return txResponse(w, edgeGroup, err) } diff --git a/api/http/handler/edgegroups/edgegroup_update.go b/api/http/handler/edgegroups/edgegroup_update.go index ee1b8b3a5..0cffc1201 100644 --- a/api/http/handler/edgegroups/edgegroup_update.go +++ b/api/http/handler/edgegroups/edgegroup_update.go @@ -6,8 +6,8 @@ import ( httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" - "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/internal/slices" @@ -27,12 +27,15 @@ func (payload *edgeGroupUpdatePayload) Validate(r *http.Request) error { if govalidator.IsNull(payload.Name) { return errors.New("invalid Edge group name") } + if payload.Dynamic && len(payload.TagIDs) == 0 { return errors.New("tagIDs is mandatory for a dynamic Edge group") } + if !payload.Dynamic && len(payload.Endpoints) == 0 { return errors.New("environments is mandatory for a static Edge group") } + return nil } @@ -62,128 +65,135 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request) return httperror.BadRequest("Invalid request payload", err) } - edgeGroup, err := handler.DataStore.EdgeGroup().EdgeGroup(portainer.EdgeGroupID(edgeGroupID)) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.NotFound("Unable to find an Edge group with the specified identifier inside the database", err) - } else if err != nil { - return httperror.InternalServerError("Unable to find an Edge group with the specified identifier inside the database", err) - } - - if payload.Name != "" { - edgeGroups, err := handler.DataStore.EdgeGroup().EdgeGroups() - if err != nil { - return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err) - } - for _, edgeGroup := range edgeGroups { - if edgeGroup.Name == payload.Name && edgeGroup.ID != portainer.EdgeGroupID(edgeGroupID) { - return httperror.BadRequest("Edge group name must be unique", errors.New("edge group name must be unique")) - } + var edgeGroup *portainer.EdgeGroup + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + edgeGroup, err = tx.EdgeGroup().EdgeGroup(portainer.EdgeGroupID(edgeGroupID)) + if handler.DataStore.IsErrObjectNotFound(err) { + return httperror.NotFound("Unable to find an Edge group with the specified identifier inside the database", err) + } else if err != nil { + return httperror.InternalServerError("Unable to find an Edge group with the specified identifier inside the database", err) } - edgeGroup.Name = payload.Name - } - endpoints, err := handler.DataStore.Endpoint().Endpoints() - if err != nil { - return httperror.InternalServerError("Unable to retrieve environments from database", err) - } - - endpointGroups, err := handler.DataStore.EndpointGroup().EndpointGroups() - if err != nil { - return httperror.InternalServerError("Unable to retrieve environment groups from database", err) - } - - oldRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) - - edgeGroup.Dynamic = payload.Dynamic - if edgeGroup.Dynamic { - edgeGroup.TagIDs = payload.TagIDs - } else { - endpointIDs := []portainer.EndpointID{} - for _, endpointID := range payload.Endpoints { - endpoint, err := handler.DataStore.Endpoint().Endpoint(endpointID) + if payload.Name != "" { + edgeGroups, err := tx.EdgeGroup().EdgeGroups() if err != nil { - return httperror.InternalServerError("Unable to retrieve environment from the database", err) + return httperror.InternalServerError("Unable to retrieve Edge groups from the database", err) } - if endpointutils.IsEdgeEndpoint(endpoint) { - endpointIDs = append(endpointIDs, endpoint.ID) + for _, edgeGroup := range edgeGroups { + if edgeGroup.Name == payload.Name && edgeGroup.ID != portainer.EdgeGroupID(edgeGroupID) { + return httperror.BadRequest("Edge group name must be unique", errors.New("edge group name must be unique")) + } } + + edgeGroup.Name = payload.Name } - edgeGroup.Endpoints = endpointIDs - } - if payload.PartialMatch != nil { - edgeGroup.PartialMatch = *payload.PartialMatch - } - - err = handler.DataStore.EdgeGroup().UpdateEdgeGroup(edgeGroup.ID, edgeGroup) - if err != nil { - return httperror.InternalServerError("Unable to persist Edge group changes inside the database", err) - } - - newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) - endpointsToUpdate := append(newRelatedEndpoints, oldRelatedEndpoints...) - - edgeJobs, err := handler.DataStore.EdgeJob().EdgeJobs() - if err != nil { - return httperror.InternalServerError("Unable to fetch Edge jobs", err) - } - - for _, endpointID := range endpointsToUpdate { - err = handler.updateEndpointStacks(endpointID) + endpoints, err := tx.Endpoint().Endpoints() if err != nil { - return httperror.InternalServerError("Unable to persist Environment relation changes inside the database", err) + return httperror.InternalServerError("Unable to retrieve environments from database", err) } - endpoint, err := handler.DataStore.Endpoint().Endpoint(endpointID) + endpointGroups, err := tx.EndpointGroup().EndpointGroups() if err != nil { - return httperror.InternalServerError("Unable to get Environment from database", err) + return httperror.InternalServerError("Unable to retrieve environment groups from database", err) } - if !endpointutils.IsEdgeEndpoint(endpoint) { - continue - } + oldRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) - var operation string - if slices.Contains(newRelatedEndpoints, endpointID) { - operation = "add" - } else if slices.Contains(oldRelatedEndpoints, endpointID) { - operation = "remove" + edgeGroup.Dynamic = payload.Dynamic + if edgeGroup.Dynamic { + edgeGroup.TagIDs = payload.TagIDs } else { - continue + endpointIDs := []portainer.EndpointID{} + for _, endpointID := range payload.Endpoints { + endpoint, err := tx.Endpoint().Endpoint(endpointID) + if err != nil { + return httperror.InternalServerError("Unable to retrieve environment from the database", err) + } + + if endpointutils.IsEdgeEndpoint(endpoint) { + endpointIDs = append(endpointIDs, endpoint.ID) + } + } + edgeGroup.Endpoints = endpointIDs } - err = handler.updateEndpointEdgeJobs(edgeGroup.ID, endpoint, edgeJobs, operation) + if payload.PartialMatch != nil { + edgeGroup.PartialMatch = *payload.PartialMatch + } + + err = tx.EdgeGroup().UpdateEdgeGroup(edgeGroup.ID, edgeGroup) if err != nil { - return httperror.InternalServerError("Unable to persist Environment Edge Jobs changes inside the database", err) + return httperror.InternalServerError("Unable to persist Edge group changes inside the database", err) } - } - return response.JSON(w, edgeGroup) + newRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) + endpointsToUpdate := append(newRelatedEndpoints, oldRelatedEndpoints...) + + edgeJobs, err := tx.EdgeJob().EdgeJobs() + if err != nil { + return httperror.InternalServerError("Unable to fetch Edge jobs", err) + } + + for _, endpointID := range endpointsToUpdate { + err = handler.updateEndpointStacks(tx, endpointID) + if err != nil { + return httperror.InternalServerError("Unable to persist Environment relation changes inside the database", err) + } + + endpoint, err := tx.Endpoint().Endpoint(endpointID) + if err != nil { + return httperror.InternalServerError("Unable to get Environment from database", err) + } + + if !endpointutils.IsEdgeEndpoint(endpoint) { + continue + } + + var operation string + if slices.Contains(newRelatedEndpoints, endpointID) { + operation = "add" + } else if slices.Contains(oldRelatedEndpoints, endpointID) { + operation = "remove" + } else { + continue + } + + err = handler.updateEndpointEdgeJobs(edgeGroup.ID, endpoint, edgeJobs, operation) + if err != nil { + return httperror.InternalServerError("Unable to persist Environment Edge Jobs changes inside the database", err) + } + } + + return nil + }) + + return txResponse(w, edgeGroup, err) } -func (handler *Handler) updateEndpointStacks(endpointID portainer.EndpointID) error { - relation, err := handler.DataStore.EndpointRelation().EndpointRelation(endpointID) +func (handler *Handler) updateEndpointStacks(tx dataservices.DataStoreTx, endpointID portainer.EndpointID) error { + relation, err := tx.EndpointRelation().EndpointRelation(endpointID) if err != nil { return err } - endpoint, err := handler.DataStore.Endpoint().Endpoint(endpointID) + endpoint, err := tx.Endpoint().Endpoint(endpointID) if err != nil { return err } - endpointGroup, err := handler.DataStore.EndpointGroup().EndpointGroup(endpoint.GroupID) + endpointGroup, err := tx.EndpointGroup().EndpointGroup(endpoint.GroupID) if err != nil { return err } - edgeGroups, err := handler.DataStore.EdgeGroup().EdgeGroups() + edgeGroups, err := tx.EdgeGroup().EdgeGroups() if err != nil { return err } - edgeStacks, err := handler.DataStore.EdgeStack().EdgeStacks() + edgeStacks, err := tx.EdgeStack().EdgeStacks() if err != nil { return err } @@ -197,7 +207,7 @@ func (handler *Handler) updateEndpointStacks(endpointID portainer.EndpointID) er relation.EdgeStacks = edgeStackSet - return handler.DataStore.EndpointRelation().UpdateEndpointRelation(endpoint.ID, relation) + return tx.EndpointRelation().UpdateEndpointRelation(endpoint.ID, relation) } func (handler *Handler) updateEndpointEdgeJobs(edgeGroupID portainer.EdgeGroupID, endpoint *portainer.Endpoint, edgeJobs []portainer.EdgeJob, operation string) error { diff --git a/api/http/handler/edgegroups/handler.go b/api/http/handler/edgegroups/handler.go index 5b950a581..3e1a44e74 100644 --- a/api/http/handler/edgegroups/handler.go +++ b/api/http/handler/edgegroups/handler.go @@ -3,11 +3,13 @@ package edgegroups import ( "net/http" - "github.com/gorilla/mux" httperror "github.com/portainer/libhttp/error" + "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" + + "github.com/gorilla/mux" ) // Handler is the HTTP handler used to handle environment(endpoint) group operations. @@ -34,3 +36,15 @@ func NewHandler(bouncer *security.RequestBouncer) *Handler { bouncer.AdminAccess(bouncer.EdgeComputeOperation(httperror.LoggerHandler(h.edgeGroupDelete)))).Methods(http.MethodDelete) return h } + +func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError { + if err != nil { + if httpErr, ok := err.(*httperror.HandlerError); ok { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.JSON(w, r) +}