diff --git a/api/dataservices/endpointrelation/endpointrelation.go b/api/dataservices/endpointrelation/endpointrelation.go index 4b7ff6b82..a81c258b9 100644 --- a/api/dataservices/endpointrelation/endpointrelation.go +++ b/api/dataservices/endpointrelation/endpointrelation.go @@ -22,6 +22,8 @@ type Service struct { mu sync.Mutex } +var _ dataservices.EndpointRelationService = &Service{} + func (service *Service) BucketName() string { return BucketName } @@ -109,6 +111,18 @@ func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID, return nil } +func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { + return service.connection.ViewTx(func(tx portainer.Transaction) error { + return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID) + }) +} + +func (service *Service) RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { + return service.connection.ViewTx(func(tx portainer.Transaction) error { + return service.Tx(tx).RemoveEndpointRelationsForEdgeStack(endpointIDs, edgeStackID) + }) +} + // DeleteEndpointRelation deletes an Environment(Endpoint) relation object func (service *Service) DeleteEndpointRelation(endpointID portainer.EndpointID) error { deletedRelation, _ := service.EndpointRelation(endpointID) diff --git a/api/dataservices/endpointrelation/tx.go b/api/dataservices/endpointrelation/tx.go index 097748767..2b2d90280 100644 --- a/api/dataservices/endpointrelation/tx.go +++ b/api/dataservices/endpointrelation/tx.go @@ -13,6 +13,8 @@ type ServiceTx struct { tx portainer.Transaction } +var _ dataservices.EndpointRelationService = &ServiceTx{} + func (service ServiceTx) BucketName() string { return BucketName } @@ -74,6 +76,58 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID, return nil } +func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { + for _, endpointID := range endpointIDs { + rel, err := service.EndpointRelation(endpointID) + if err != nil { + return err + } + + rel.EdgeStacks[edgeStackID] = true + + identifier := service.service.connection.ConvertToKey(int(endpointID)) + err = service.tx.UpdateObject(BucketName, identifier, rel) + cache.Del(endpointID) + if err != nil { + return err + } + } + + if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) { + edgeStack.NumDeployments += len(endpointIDs) + }); err != nil { + log.Error().Err(err).Msg("could not update the number of deployments") + } + + return nil +} + +func (service ServiceTx) RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { + for _, endpointID := range endpointIDs { + rel, err := service.EndpointRelation(endpointID) + if err != nil { + return err + } + + delete(rel.EdgeStacks, edgeStackID) + + identifier := service.service.connection.ConvertToKey(int(endpointID)) + err = service.tx.UpdateObject(BucketName, identifier, rel) + cache.Del(endpointID) + if err != nil { + return err + } + } + + if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) { + edgeStack.NumDeployments -= len(endpointIDs) + }); err != nil { + log.Error().Err(err).Msg("could not update the number of deployments") + } + + return nil +} + // DeleteEndpointRelation deletes an Environment(Endpoint) relation object func (service ServiceTx) DeleteEndpointRelation(endpointID portainer.EndpointID) error { deletedRelation, _ := service.EndpointRelation(endpointID) diff --git a/api/dataservices/interface.go b/api/dataservices/interface.go index 1efef4f70..2bc2df7f3 100644 --- a/api/dataservices/interface.go +++ b/api/dataservices/interface.go @@ -115,6 +115,8 @@ type ( EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error) Create(endpointRelation *portainer.EndpointRelation) error UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error + AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error + RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error DeleteEndpointRelation(EndpointID portainer.EndpointID) error BucketName() string } diff --git a/api/http/handler/edgestacks/edgestack_update.go b/api/http/handler/edgestacks/edgestack_update.go index 593e403ef..27a279fb3 100644 --- a/api/http/handler/edgestacks/edgestack_update.go +++ b/api/http/handler/edgestacks/edgestack_update.go @@ -138,57 +138,19 @@ func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edge return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database") } - oldRelatedSet := set.ToSet(oldRelatedEnvironmentIDs) - newRelatedSet := set.ToSet(newRelatedEnvironmentIDs) + oldRelatedEnvironmentsSet := set.ToSet(oldRelatedEnvironmentIDs) + newRelatedEnvironmentsSet := set.ToSet(newRelatedEnvironmentIDs) - endpointsToRemove := set.Set[portainer.EndpointID]{} - for endpointID := range oldRelatedSet { - if !newRelatedSet[endpointID] { - endpointsToRemove[endpointID] = true - } + relatedEnvironmentsToAdd := newRelatedEnvironmentsSet.Difference(oldRelatedEnvironmentsSet) + relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet) + + if len(relatedEnvironmentsToRemove) > 0 { + tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStackID) } - for endpointID := range endpointsToRemove { - relation, err := tx.EndpointRelation().EndpointRelation(endpointID) - if err != nil { - if tx.IsErrObjectNotFound(err) { - continue - } - return nil, nil, errors.WithMessage(err, "Unable to find environment relation in database") - } - - delete(relation.EdgeStacks, edgeStackID) - - if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil { - return nil, nil, errors.WithMessage(err, "Unable to persist environment relation in database") - } + if len(relatedEnvironmentsToAdd) > 0 { + tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStackID) } - endpointsToAdd := set.Set[portainer.EndpointID]{} - for endpointID := range newRelatedSet { - if !oldRelatedSet[endpointID] { - endpointsToAdd[endpointID] = true - } - } - - for endpointID := range endpointsToAdd { - relation, err := tx.EndpointRelation().EndpointRelation(endpointID) - if err != nil && !tx.IsErrObjectNotFound(err) { - return nil, nil, errors.WithMessage(err, "Unable to find environment relation in database") - } - - if relation == nil { - relation = &portainer.EndpointRelation{ - EndpointID: endpointID, - EdgeStacks: map[portainer.EdgeStackID]bool{}, - } - } - relation.EdgeStacks[edgeStackID] = true - - if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil { - return nil, nil, errors.WithMessage(err, "Unable to persist environment relation in database") - } - } - - return newRelatedEnvironmentIDs, endpointsToAdd, nil + return newRelatedEnvironmentIDs, relatedEnvironmentsToAdd, nil } diff --git a/api/internal/edge/edgestacks/service.go b/api/internal/edge/edgestacks/service.go index 3a19e8c9d..6986a6917 100644 --- a/api/internal/edge/edgestacks/service.go +++ b/api/internal/edge/edgestacks/service.go @@ -11,7 +11,6 @@ import ( httperrors "github.com/portainer/portainer/api/http/errors" "github.com/portainer/portainer/api/internal/edge" edgetypes "github.com/portainer/portainer/api/internal/edge/types" - "github.com/rs/zerolog/log" "github.com/pkg/errors" ) @@ -100,12 +99,15 @@ func (service *Service) PersistEdgeStack( stack.ManifestPath = manifestPath stack.ProjectPath = projectPath stack.EntryPoint = composePath - stack.NumDeployments = len(relatedEndpointIds) if err := tx.EdgeStack().Create(stack.ID, stack); err != nil { return nil, err } + if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEndpointIds, stack.ID); err != nil { + return nil, fmt.Errorf("unable to add endpoint relations: %w", err) + } + if err := service.updateEndpointRelations(tx, stack.ID, relatedEndpointIds); err != nil { return nil, fmt.Errorf("unable to update endpoint relations: %w", err) } @@ -148,25 +150,8 @@ func (service *Service) DeleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID return errors.WithMessage(err, "Unable to retrieve edge stack related environments from database") } - for _, endpointID := range relatedEndpointIds { - relation, err := tx.EndpointRelation().EndpointRelation(endpointID) - if err != nil { - if tx.IsErrObjectNotFound(err) { - log.Warn(). - Int("endpoint_id", int(endpointID)). - Msg("Unable to find endpoint relation in database, skipping") - - continue - } - - return errors.WithMessage(err, "Unable to find environment relation in database") - } - - delete(relation.EdgeStacks, edgeStackID) - - if err := tx.EndpointRelation().UpdateEndpointRelation(endpointID, relation); err != nil { - return errors.WithMessage(err, "Unable to persist environment relation in database") - } + if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEndpointIds, edgeStackID); err != nil { + return errors.WithMessage(err, "unable to remove environment relation in database") } if err := tx.EdgeStack().DeleteEdgeStack(edgeStackID); err != nil { diff --git a/api/internal/testhelpers/datastore.go b/api/internal/testhelpers/datastore.go index f0bba23fd..19d4df762 100644 --- a/api/internal/testhelpers/datastore.go +++ b/api/internal/testhelpers/datastore.go @@ -9,6 +9,8 @@ import ( "github.com/portainer/portainer/api/dataservices/errors" ) +var _ dataservices.DataStore = &testDatastore{} + type testDatastore struct { customTemplate dataservices.CustomTemplateService edgeGroup dataservices.EdgeGroupService @@ -227,6 +229,30 @@ func (s *stubEndpointRelationService) UpdateEndpointRelation(ID portainer.Endpoi return nil } +func (s *stubEndpointRelationService) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { + for _, endpointID := range endpointIDs { + for i, r := range s.relations { + if r.EndpointID == endpointID { + s.relations[i].EdgeStacks[edgeStackID] = true + } + } + } + + return nil +} + +func (s *stubEndpointRelationService) RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error { + for _, endpointID := range endpointIDs { + for i, r := range s.relations { + if r.EndpointID == endpointID { + delete(s.relations[i].EdgeStacks, edgeStackID) + } + } + } + + return nil +} + func (s *stubEndpointRelationService) DeleteEndpointRelation(ID portainer.EndpointID) error { return nil }