diff --git a/api/connection.go b/api/connection.go index 643767058..6c9e4f440 100644 --- a/api/connection.go +++ b/api/connection.go @@ -24,6 +24,7 @@ type Connection interface { SetServiceName(bucketName string) error GetObject(bucketName string, key []byte, object interface{}) error UpdateObject(bucketName string, key []byte, object interface{}) error + UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error DeleteObject(bucketName string, key []byte) error DeleteAllObjects(bucketName string, matching func(o interface{}) (id int, ok bool)) error GetNextIdentifier(bucketName string) int diff --git a/api/database/boltdb/db.go b/api/database/boltdb/db.go index 203f9255c..a829cedd8 100644 --- a/api/database/boltdb/db.go +++ b/api/database/boltdb/db.go @@ -179,7 +179,7 @@ func (connection *DbConnection) ConvertToKey(v int) []byte { return b } -// CreateBucket is a generic function used to create a bucket inside a database database. +// CreateBucket is a generic function used to create a bucket inside a database. func (connection *DbConnection) SetServiceName(bucketName string) error { return connection.Batch(func(tx *bolt.Tx) error { _, err := tx.CreateBucketIfNotExists([]byte(bucketName)) @@ -187,7 +187,7 @@ func (connection *DbConnection) SetServiceName(bucketName string) error { }) } -// GetObject is a generic function used to retrieve an unmarshalled object from a database database. +// GetObject is a generic function used to retrieve an unmarshalled object from a database. func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error { var data []byte @@ -219,7 +219,7 @@ func (connection *DbConnection) getEncryptionKey() []byte { return connection.EncryptionKey } -// UpdateObject is a generic function used to update an object inside a database database. +// UpdateObject is a generic function used to update an object inside a database. func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error { data, err := connection.MarshalObject(object) if err != nil { @@ -232,7 +232,33 @@ func (connection *DbConnection) UpdateObject(bucketName string, key []byte, obje }) } -// DeleteObject is a generic function used to delete an object inside a database database. +// UpdateObjectFunc is a generic function used to update an object safely without race conditions. +func (connection *DbConnection) UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error { + return connection.Batch(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte(bucketName)) + + data := bucket.Get(key) + if data == nil { + return dserrors.ErrObjectNotFound + } + + err := connection.UnmarshalObjectWithJsoniter(data, object) + if err != nil { + return err + } + + updateFn() + + data, err = connection.MarshalObject(object) + if err != nil { + return err + } + + return bucket.Put(key, data) + }) +} + +// DeleteObject is a generic function used to delete an object inside a database. func (connection *DbConnection) DeleteObject(bucketName string, key []byte) error { return connection.Batch(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(bucketName)) diff --git a/api/dataservices/interface.go b/api/dataservices/interface.go index 619cc6f4d..e936f5de1 100644 --- a/api/dataservices/interface.go +++ b/api/dataservices/interface.go @@ -251,6 +251,7 @@ type ( Tag(ID portainer.TagID) (*portainer.Tag, error) Create(tag *portainer.Tag) error UpdateTag(ID portainer.TagID, tag *portainer.Tag) error + UpdateTagFunc(ID portainer.TagID, updateFunc func(tag *portainer.Tag)) error DeleteTag(ID portainer.TagID) error BucketName() string } diff --git a/api/dataservices/tag/tag.go b/api/dataservices/tag/tag.go index 42c289598..7641afc37 100644 --- a/api/dataservices/tag/tag.go +++ b/api/dataservices/tag/tag.go @@ -80,12 +80,24 @@ func (service *Service) Create(tag *portainer.Tag) error { ) } -// UpdateTag updates a tag. +// Deprecated: Use UpdateTagFunc instead. func (service *Service) UpdateTag(ID portainer.TagID, tag *portainer.Tag) error { identifier := service.connection.ConvertToKey(int(ID)) return service.connection.UpdateObject(BucketName, identifier, tag) } +// UpdateTagFunc updates a tag inside a transaction avoiding data races. +func (service *Service) UpdateTagFunc(ID portainer.TagID, updateFunc func(tag *portainer.Tag)) error { + id := service.connection.ConvertToKey(int(ID)) + tag := &portainer.Tag{} + + service.connection.UpdateObjectFunc(BucketName, id, tag, func() { + updateFunc(tag) + }) + + return nil +} + // DeleteTag deletes a tag. func (service *Service) DeleteTag(ID portainer.TagID) error { identifier := service.connection.ConvertToKey(int(ID)) diff --git a/api/http/handler/endpointgroups/endpointgroup_create.go b/api/http/handler/endpointgroups/endpointgroup_create.go index cd41498f3..3a9096433 100644 --- a/api/http/handler/endpointgroups/endpointgroup_create.go +++ b/api/http/handler/endpointgroups/endpointgroup_create.go @@ -91,15 +91,13 @@ func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Reque } for _, tagID := range endpointGroup.TagIDs { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { - return httperror.InternalServerError("Unable to retrieve tag from the database", err) - } + handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + tag.EndpointGroups[endpointGroup.ID] = true + }) - tag.EndpointGroups[endpointGroup.ID] = true - - err = handler.DataStore.Tag().UpdateTag(tagID, tag) - if err != nil { + if handler.DataStore.IsErrObjectNotFound(err) { + return httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } diff --git a/api/http/handler/endpointgroups/endpointgroup_delete.go b/api/http/handler/endpointgroups/endpointgroup_delete.go index feea30c53..a96bd7278 100644 --- a/api/http/handler/endpointgroups/endpointgroup_delete.go +++ b/api/http/handler/endpointgroups/endpointgroup_delete.go @@ -66,15 +66,13 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque } for _, tagID := range endpointGroup.TagIDs { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { - return httperror.InternalServerError("Unable to retrieve tag from the database", err) - } + handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.EndpointGroups, endpointGroup.ID) + }) - delete(tag.EndpointGroups, endpointGroup.ID) - - err = handler.DataStore.Tag().UpdateTag(tagID, tag) - if err != nil { + if handler.DataStore.IsErrObjectNotFound(err) { + return httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } diff --git a/api/http/handler/endpointgroups/endpointgroup_update.go b/api/http/handler/endpointgroups/endpointgroup_update.go index 8f9578f3c..fbbd3384a 100644 --- a/api/http/handler/endpointgroups/endpointgroup_update.go +++ b/api/http/handler/endpointgroups/endpointgroup_update.go @@ -81,28 +81,26 @@ func (handler *Handler) endpointGroupUpdate(w http.ResponseWriter, r *http.Reque removeTags := tag.Difference(endpointGroupTagSet, payloadTagSet) for tagID := range removeTags { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { + handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.EndpointGroups, endpointGroup.ID) + }) + + if handler.DataStore.IsErrObjectNotFound(err) { return httperror.InternalServerError("Unable to find a tag inside the database", err) - } - delete(tag.EndpointGroups, endpointGroup.ID) - err = handler.DataStore.Tag().UpdateTag(tag.ID, tag) - if err != nil { + } else if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } endpointGroup.TagIDs = payload.TagIDs for _, tagID := range payload.TagIDs { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { + handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + tag.EndpointGroups[endpointGroup.ID] = true + }) + + if handler.DataStore.IsErrObjectNotFound(err) { return httperror.InternalServerError("Unable to find a tag inside the database", err) - } - - tag.EndpointGroups[endpointGroup.ID] = true - - err = handler.DataStore.Tag().UpdateTag(tag.ID, tag) - if err != nil { + } else if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } diff --git a/api/http/handler/endpoints/endpoint_create.go b/api/http/handler/endpoints/endpoint_create.go index a54dad51c..c2df56ab8 100644 --- a/api/http/handler/endpoints/endpoint_create.go +++ b/api/http/handler/endpoints/endpoint_create.go @@ -530,14 +530,9 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(endpoint *portainer. } for _, tagID := range endpoint.TagIDs { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { - return err - } - - tag.Endpoints[endpoint.ID] = true - - err = handler.DataStore.Tag().UpdateTag(tagID, tag) + err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + tag.Endpoints[endpoint.ID] = true + }) if err != nil { return err } diff --git a/api/http/handler/endpoints/endpoint_delete.go b/api/http/handler/endpoints/endpoint_delete.go index e271fbbe7..45a840970 100644 --- a/api/http/handler/endpoints/endpoint_delete.go +++ b/api/http/handler/endpoints/endpoint_delete.go @@ -62,15 +62,13 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * } for _, tagID := range endpoint.TagIDs { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { + err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.Endpoints, endpoint.ID) + }) + + if handler.DataStore.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find tag inside the database", err) - } - - delete(tag.Endpoints, endpoint.ID) - - err = handler.DataStore.Tag().UpdateTag(tagID, tag) - if err != nil { + } else if err != nil { return httperror.InternalServerError("Unable to persist tag relation inside the database", err) } } diff --git a/api/http/handler/endpoints/endpoint_update.go b/api/http/handler/endpoints/endpoint_update.go index d34d33428..a783a1b01 100644 --- a/api/http/handler/endpoints/endpoint_update.go +++ b/api/http/handler/endpoints/endpoint_update.go @@ -139,29 +139,26 @@ func (handler *Handler) endpointUpdate(w http.ResponseWriter, r *http.Request) * removeTags := tag.Difference(endpointTagSet, payloadTagSet) for tagID := range removeTags { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { - return httperror.InternalServerError("Unable to find a tag inside the database", err) - } + err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.Endpoints, endpoint.ID) + }) - delete(tag.Endpoints, endpoint.ID) - err = handler.DataStore.Tag().UpdateTag(tag.ID, tag) - if err != nil { + if handler.DataStore.IsErrObjectNotFound(err) { + return httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } endpoint.TagIDs = payload.TagIDs for _, tagID := range payload.TagIDs { - tag, err := handler.DataStore.Tag().Tag(tagID) - if err != nil { + err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + tag.Endpoints[endpoint.ID] = true + }) + + if handler.DataStore.IsErrObjectNotFound(err) { return httperror.InternalServerError("Unable to find a tag inside the database", err) - } - - tag.Endpoints[endpoint.ID] = true - - err = handler.DataStore.Tag().UpdateTag(tag.ID, tag) - if err != nil { + } else if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } }