From c9d18b614b9e9bcc2b6e7090014db60afc96c5d1 Mon Sep 17 00:00:00 2001 From: LP B Date: Sat, 16 Aug 2025 03:52:13 +0200 Subject: [PATCH] fix(api/edge-stacks): avoid overriding updates with old values (#1047) --- .../endpointrelation/endpointrelation.go | 4 +-- .../endpointrelation/endpointrelation_test.go | 36 +++++++++++++++++++ api/dataservices/endpointrelation/tx.go | 12 ++++--- api/dataservices/interface.go | 2 +- .../handler/edgestacks/edgestack_update.go | 8 ++--- api/internal/edge/edgestacks/service.go | 2 +- api/internal/testhelpers/datastore.go | 4 +-- 7 files changed, 54 insertions(+), 14 deletions(-) diff --git a/api/dataservices/endpointrelation/endpointrelation.go b/api/dataservices/endpointrelation/endpointrelation.go index 91c00f05a..3877d4361 100644 --- a/api/dataservices/endpointrelation/endpointrelation.go +++ b/api/dataservices/endpointrelation/endpointrelation.go @@ -91,9 +91,9 @@ func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID, }) } -func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { +func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error { return service.connection.UpdateTx(func(tx portainer.Transaction) error { - return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID) + return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack) }) } diff --git a/api/dataservices/endpointrelation/endpointrelation_test.go b/api/dataservices/endpointrelation/endpointrelation_test.go index f1ead0919..8e5807b82 100644 --- a/api/dataservices/endpointrelation/endpointrelation_test.go +++ b/api/dataservices/endpointrelation/endpointrelation_test.go @@ -5,6 +5,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/database/boltdb" + "github.com/portainer/portainer/api/dataservices/edgestack" "github.com/portainer/portainer/api/internal/edge/cache" "github.com/stretchr/testify/require" @@ -102,3 +103,38 @@ func TestUpdateRelation(t *testing.T) { require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments) require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments) } + +func TestAddEndpointRelationsForEdgeStack(t *testing.T) { + var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()} + err := conn.Open() + require.NoError(t, err) + + defer conn.Close() + + service, err := NewService(conn) + require.NoError(t, err) + + edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {}) + require.NoError(t, err) + + service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx) + require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{})) + require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}})) + require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1})) +} + +func TestEndpointRelations(t *testing.T) { + var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()} + err := conn.Open() + require.NoError(t, err) + + defer conn.Close() + + service, err := NewService(conn) + require.NoError(t, err) + + require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1})) + rels, err := service.EndpointRelations() + require.NoError(t, err) + require.Equal(t, 1, len(rels)) +} diff --git a/api/dataservices/endpointrelation/tx.go b/api/dataservices/endpointrelation/tx.go index 54e66a31b..36639a4df 100644 --- a/api/dataservices/endpointrelation/tx.go +++ b/api/dataservices/endpointrelation/tx.go @@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID, return nil } -func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { +func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error { for _, endpointID := range endpointIDs { rel, err := service.EndpointRelation(endpointID) if err != nil { return err } - rel.EdgeStacks[edgeStackID] = true + rel.EdgeStacks[edgeStack.ID] = true identifier := service.service.connection.ConvertToKey(int(endpointID)) err = service.tx.UpdateObject(BucketName, identifier, rel) @@ -97,8 +97,12 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine service.service.endpointRelationsCache = nil service.service.mu.Unlock() - if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) { - edgeStack.NumDeployments += len(endpointIDs) + if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) { + es.NumDeployments += len(endpointIDs) + + // sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call + // to avoid overriding with the previous values + edgeStack.NumDeployments = es.NumDeployments }); err != nil { log.Error().Err(err).Msg("could not update the number of deployments") } diff --git a/api/dataservices/interface.go b/api/dataservices/interface.go index d330d4959..9255c6361 100644 --- a/api/dataservices/interface.go +++ b/api/dataservices/interface.go @@ -126,7 +126,7 @@ type ( EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error) Create(endpointRelation *portainer.EndpointRelation) error UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error - AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error + AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error DeleteEndpointRelation(EndpointID portainer.EndpointID) error BucketName() string diff --git a/api/http/handler/edgestacks/edgestack_update.go b/api/http/handler/edgestacks/edgestack_update.go index db896d0eb..7dc8915f6 100644 --- a/api/http/handler/edgestacks/edgestack_update.go +++ b/api/http/handler/edgestacks/edgestack_update.go @@ -99,7 +99,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por groupsIds := stack.EdgeGroups if payload.EdgeGroups != nil { - newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack.ID, payload.EdgeGroups, relatedEndpointIds, relationConfig) + newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack, payload.EdgeGroups, relatedEndpointIds, relationConfig) if err != nil { return nil, httperror.InternalServerError("Unable to handle edge groups change", err) } @@ -136,7 +136,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por return stack, nil } -func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) { +func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStack *portainer.EdgeStack, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) { newRelatedEnvironmentIDs, err := edge.EdgeStackRelatedEndpoints(newEdgeGroupsIDs, relationConfig.Endpoints, relationConfig.EndpointGroups, relationConfig.EdgeGroups) if err != nil { return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database") @@ -149,13 +149,13 @@ func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edge relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet) if len(relatedEnvironmentsToRemove) > 0 { - if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStackID); err != nil { + if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStack.ID); err != nil { return nil, nil, errors.WithMessage(err, "Unable to remove edge stack relations from the database") } } if len(relatedEnvironmentsToAdd) > 0 { - if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStackID); err != nil { + if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStack); err != nil { return nil, nil, errors.WithMessage(err, "Unable to add edge stack relations to the database") } } diff --git a/api/internal/edge/edgestacks/service.go b/api/internal/edge/edgestacks/service.go index c0ecb5caf..35215aa19 100644 --- a/api/internal/edge/edgestacks/service.go +++ b/api/internal/edge/edgestacks/service.go @@ -111,7 +111,7 @@ func (service *Service) PersistEdgeStack( } } - if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEndpointIds, stack.ID); err != nil { + if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEndpointIds, stack); err != nil { return nil, fmt.Errorf("unable to add endpoint relations: %w", err) } diff --git a/api/internal/testhelpers/datastore.go b/api/internal/testhelpers/datastore.go index 024c4b34d..d89b07531 100644 --- a/api/internal/testhelpers/datastore.go +++ b/api/internal/testhelpers/datastore.go @@ -230,11 +230,11 @@ func (s *stubEndpointRelationService) UpdateEndpointRelation(ID portainer.Endpoi return nil } -func (s *stubEndpointRelationService) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { +func (s *stubEndpointRelationService) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error { for _, endpointID := range endpointIDs { for i, r := range s.relations { if r.EndpointID == endpointID { - s.relations[i].EdgeStacks[edgeStackID] = true + s.relations[i].EdgeStacks[edgeStack.ID] = true } } }