diff --git a/api/dataservices/edgestack/edgestack.go b/api/dataservices/edgestack/edgestack.go index b04a53d13..aa6561015 100644 --- a/api/dataservices/edgestack/edgestack.go +++ b/api/dataservices/edgestack/edgestack.go @@ -15,7 +15,7 @@ type Service struct { connection portainer.Connection idxVersion map[portainer.EdgeStackID]int mu sync.RWMutex - cacheInvalidationFn func(portainer.EdgeStackID) + cacheInvalidationFn func(portainer.Transaction, portainer.EdgeStackID) } func (service *Service) BucketName() string { @@ -23,7 +23,7 @@ func (service *Service) BucketName() string { } // NewService creates a new instance of a service. -func NewService(connection portainer.Connection, cacheInvalidationFn func(portainer.EdgeStackID)) (*Service, error) { +func NewService(connection portainer.Connection, cacheInvalidationFn func(portainer.Transaction, portainer.EdgeStackID)) (*Service, error) { err := connection.SetServiceName(BucketName) if err != nil { return nil, err @@ -36,7 +36,7 @@ func NewService(connection portainer.Connection, cacheInvalidationFn func(portai } if s.cacheInvalidationFn == nil { - s.cacheInvalidationFn = func(portainer.EdgeStackID) {} + s.cacheInvalidationFn = func(portainer.Transaction, portainer.EdgeStackID) {} } es, err := s.EdgeStacks() @@ -106,7 +106,7 @@ func (service *Service) Create(id portainer.EdgeStackID, edgeStack *portainer.Ed service.mu.Lock() service.idxVersion[id] = edgeStack.Version - service.cacheInvalidationFn(id) + service.cacheInvalidationFn(service.connection, id) service.mu.Unlock() return nil @@ -125,7 +125,7 @@ func (service *Service) UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *por } service.idxVersion[ID] = edgeStack.Version - service.cacheInvalidationFn(ID) + service.cacheInvalidationFn(service.connection, ID) return nil } @@ -142,7 +142,7 @@ func (service *Service) UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc updateFunc(edgeStack) service.idxVersion[ID] = edgeStack.Version - service.cacheInvalidationFn(ID) + service.cacheInvalidationFn(service.connection, ID) }) } @@ -165,7 +165,7 @@ func (service *Service) DeleteEdgeStack(ID portainer.EdgeStackID) error { delete(service.idxVersion, ID) - service.cacheInvalidationFn(ID) + service.cacheInvalidationFn(service.connection, ID) return nil } diff --git a/api/dataservices/edgestack/tx.go b/api/dataservices/edgestack/tx.go index 25c07c2ef..d401e26cd 100644 --- a/api/dataservices/edgestack/tx.go +++ b/api/dataservices/edgestack/tx.go @@ -44,8 +44,7 @@ func (service ServiceTx) EdgeStack(ID portainer.EdgeStackID) (*portainer.EdgeSta var stack portainer.EdgeStack identifier := service.service.connection.ConvertToKey(int(ID)) - err := service.tx.GetObject(BucketName, identifier, &stack) - if err != nil { + if err := service.tx.GetObject(BucketName, identifier, &stack); err != nil { return nil, err } @@ -65,18 +64,17 @@ func (service ServiceTx) EdgeStackVersion(ID portainer.EdgeStackID) (int, bool) func (service ServiceTx) Create(id portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error { edgeStack.ID = id - err := service.tx.CreateObjectWithId( + if err := service.tx.CreateObjectWithId( BucketName, int(edgeStack.ID), edgeStack, - ) - if err != nil { + ); err != nil { return err } service.service.mu.Lock() service.service.idxVersion[id] = edgeStack.Version - service.service.cacheInvalidationFn(id) + service.service.cacheInvalidationFn(service.tx, id) service.service.mu.Unlock() return nil @@ -89,13 +87,12 @@ func (service ServiceTx) UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *po identifier := service.service.connection.ConvertToKey(int(ID)) - err := service.tx.UpdateObject(BucketName, identifier, edgeStack) - if err != nil { + if err := service.tx.UpdateObject(BucketName, identifier, edgeStack); err != nil { return err } service.service.idxVersion[ID] = edgeStack.Version - service.service.cacheInvalidationFn(ID) + service.service.cacheInvalidationFn(service.tx, ID) return nil } @@ -119,14 +116,13 @@ func (service ServiceTx) DeleteEdgeStack(ID portainer.EdgeStackID) error { identifier := service.service.connection.ConvertToKey(int(ID)) - err := service.tx.DeleteObject(BucketName, identifier) - if err != nil { + if err := service.tx.DeleteObject(BucketName, identifier); err != nil { return err } delete(service.service.idxVersion, ID) - service.service.cacheInvalidationFn(ID) + service.service.cacheInvalidationFn(service.tx, ID) return nil } diff --git a/api/dataservices/endpointrelation/endpointrelation.go b/api/dataservices/endpointrelation/endpointrelation.go index 8b591ec04..4b7ff6b82 100644 --- a/api/dataservices/endpointrelation/endpointrelation.go +++ b/api/dataservices/endpointrelation/endpointrelation.go @@ -1,6 +1,8 @@ package endpointrelation import ( + "sync" + portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge/cache" @@ -13,9 +15,11 @@ const BucketName = "endpoint_relations" // Service represents a service for managing environment(endpoint) relation data. type Service struct { - connection portainer.Connection - updateStackFn func(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error - updateStackFnTx func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error + connection portainer.Connection + updateStackFn func(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error + updateStackFnTx func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error + endpointRelationsCache []portainer.EndpointRelation + mu sync.Mutex } func (service *Service) BucketName() string { @@ -76,6 +80,10 @@ func (service *Service) Create(endpointRelation *portainer.EndpointRelation) err err := service.connection.CreateObjectWithId(BucketName, int(endpointRelation.EndpointID), endpointRelation) cache.Del(endpointRelation.EndpointID) + service.mu.Lock() + service.endpointRelationsCache = nil + service.mu.Unlock() + return err } @@ -92,6 +100,10 @@ func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID, updatedRelationState, _ := service.EndpointRelation(endpointID) + service.mu.Lock() + service.endpointRelationsCache = nil + service.mu.Unlock() + service.updateEdgeStacksAfterRelationChange(previousRelationState, updatedRelationState) return nil @@ -108,27 +120,15 @@ func (service *Service) DeleteEndpointRelation(endpointID portainer.EndpointID) return err } + service.mu.Lock() + service.endpointRelationsCache = nil + service.mu.Unlock() + service.updateEdgeStacksAfterRelationChange(deletedRelation, nil) return nil } -func (service *Service) InvalidateEdgeCacheForEdgeStack(edgeStackID portainer.EdgeStackID) { - rels, err := service.EndpointRelations() - if err != nil { - log.Error().Err(err).Msg("cannot retrieve endpoint relations") - return - } - - for _, rel := range rels { - for id := range rel.EdgeStacks { - if edgeStackID == id { - cache.Del(rel.EndpointID) - } - } - } -} - func (service *Service) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) { relations, _ := service.EndpointRelations() diff --git a/api/dataservices/endpointrelation/tx.go b/api/dataservices/endpointrelation/tx.go index 63ab85ec3..097748767 100644 --- a/api/dataservices/endpointrelation/tx.go +++ b/api/dataservices/endpointrelation/tx.go @@ -45,6 +45,10 @@ func (service ServiceTx) Create(endpointRelation *portainer.EndpointRelation) er err := service.tx.CreateObjectWithId(BucketName, int(endpointRelation.EndpointID), endpointRelation) cache.Del(endpointRelation.EndpointID) + service.service.mu.Lock() + service.service.endpointRelationsCache = nil + service.service.mu.Unlock() + return err } @@ -61,6 +65,10 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID, updatedRelationState, _ := service.EndpointRelation(endpointID) + service.service.mu.Lock() + service.service.endpointRelationsCache = nil + service.service.mu.Unlock() + service.updateEdgeStacksAfterRelationChange(previousRelationState, updatedRelationState) return nil @@ -77,27 +85,44 @@ func (service ServiceTx) DeleteEndpointRelation(endpointID portainer.EndpointID) return err } + service.service.mu.Lock() + service.service.endpointRelationsCache = nil + service.service.mu.Unlock() + service.updateEdgeStacksAfterRelationChange(deletedRelation, nil) return nil } func (service ServiceTx) InvalidateEdgeCacheForEdgeStack(edgeStackID portainer.EdgeStackID) { - rels, err := service.EndpointRelations() + rels, err := service.cachedEndpointRelations() if err != nil { log.Error().Err(err).Msg("cannot retrieve endpoint relations") return } for _, rel := range rels { - for id := range rel.EdgeStacks { - if edgeStackID == id { - cache.Del(rel.EndpointID) - } + if _, ok := rel.EdgeStacks[edgeStackID]; ok { + cache.Del(rel.EndpointID) } } } +func (service ServiceTx) cachedEndpointRelations() ([]portainer.EndpointRelation, error) { + service.service.mu.Lock() + defer service.service.mu.Unlock() + + if service.service.endpointRelationsCache == nil { + var err error + service.service.endpointRelationsCache, err = service.EndpointRelations() + if err != nil { + return nil, err + } + } + + return service.service.endpointRelationsCache, nil +} + func (service ServiceTx) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) { relations, _ := service.EndpointRelations() @@ -133,6 +158,7 @@ func (service ServiceTx) updateEdgeStacksAfterRelationChange(previousRelationSta } numDeployments := 0 + for _, r := range relations { for sId, enabled := range r.EdgeStacks { if enabled && sId == refStackId { diff --git a/api/datastore/services.go b/api/datastore/services.go index 0e9208f5c..b5363afe9 100644 --- a/api/datastore/services.go +++ b/api/datastore/services.go @@ -100,7 +100,9 @@ func (store *Store) initServices() error { } store.EndpointRelationService = endpointRelationService - edgeStackService, err := edgestack.NewService(store.connection, endpointRelationService.InvalidateEdgeCacheForEdgeStack) + edgeStackService, err := edgestack.NewService(store.connection, func(tx portainer.Transaction, ID portainer.EdgeStackID) { + endpointRelationService.Tx(tx).InvalidateEdgeCacheForEdgeStack(ID) + }) if err != nil { return err }