diff --git a/api/http/handler/edgegroups/edgegroup_create.go b/api/http/handler/edgegroups/edgegroup_create.go index c6e103e37..989599eba 100644 --- a/api/http/handler/edgegroups/edgegroup_create.go +++ b/api/http/handler/edgegroups/edgegroup_create.go @@ -33,6 +33,29 @@ func (payload *edgeGroupCreatePayload) Validate(r *http.Request) error { return nil } +func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.EdgeGroup, endpoints []portainer.EndpointID, tagIDs []portainer.TagID) error { + if edgeGroup.Dynamic { + edgeGroup.TagIDs = tagIDs + } else { + endpointIDs := []portainer.EndpointID{} + + for _, endpointID := range endpoints { + endpoint, err := tx.Endpoint().Endpoint(endpointID) + if err != nil { + return httperror.InternalServerError("Unable to retrieve environment from the database", err) + } + + if endpointutils.IsEdgeEndpoint(endpoint) { + endpointIDs = append(endpointIDs, endpoint.ID) + } + } + + edgeGroup.Endpoints = endpointIDs + } + + return nil +} + // @id EdgeGroupCreate // @summary Create an EdgeGroup // @description **Access policy**: administrator @@ -74,21 +97,8 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) PartialMatch: payload.PartialMatch, } - if edgeGroup.Dynamic { - edgeGroup.TagIDs = payload.TagIDs - } else { - endpointIDs := []portainer.EndpointID{} - for _, endpointID := range payload.Endpoints { - endpoint, err := tx.Endpoint().Endpoint(endpointID) - if err != nil { - return httperror.InternalServerError("Unable to retrieve environment from the database", err) - } - - if endpointutils.IsEdgeEndpoint(endpoint) { - endpointIDs = append(endpointIDs, endpoint.ID) - } - } - edgeGroup.Endpoints = endpointIDs + if err := calculateEndpointsOrTags(tx, edgeGroup, payload.Endpoints, payload.TagIDs); err != nil { + return err } err = tx.EdgeGroup().Create(edgeGroup) diff --git a/api/http/handler/edgegroups/edgegroup_update.go b/api/http/handler/edgegroups/edgegroup_update.go index 4441a6073..b3cd4aabd 100644 --- a/api/http/handler/edgegroups/edgegroup_update.go +++ b/api/http/handler/edgegroups/edgegroup_update.go @@ -99,21 +99,8 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request) oldRelatedEndpoints := edge.EdgeGroupRelatedEndpoints(edgeGroup, endpoints, endpointGroups) edgeGroup.Dynamic = payload.Dynamic - if edgeGroup.Dynamic { - edgeGroup.TagIDs = payload.TagIDs - } else { - endpointIDs := []portainer.EndpointID{} - for _, endpointID := range payload.Endpoints { - endpoint, err := tx.Endpoint().Endpoint(endpointID) - if err != nil { - return httperror.InternalServerError("Unable to retrieve environment from the database", err) - } - - if endpointutils.IsEdgeEndpoint(endpoint) { - endpointIDs = append(endpointIDs, endpoint.ID) - } - } - edgeGroup.Endpoints = endpointIDs + if err := calculateEndpointsOrTags(tx, edgeGroup, payload.Endpoints, payload.TagIDs); err != nil { + return err } if payload.PartialMatch != nil { @@ -138,6 +125,13 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request) return err } + // Update the edgeGroups with the modified edgeGroup for updateEndpointStacks() + for i := range edgeGroups { + if edgeGroups[i].ID == edgeGroup.ID { + edgeGroups[i] = *edgeGroup + } + } + for _, endpointID := range endpointsToUpdate { endpoint, err := tx.Endpoint().Endpoint(endpointID) if err != nil {