diff --git a/api/http/handler/endpointgroups/endpointgroup_create.go b/api/http/handler/endpointgroups/endpointgroup_create.go index 939868fea..464a894a0 100644 --- a/api/http/handler/endpointgroups/endpointgroup_create.go +++ b/api/http/handler/endpointgroups/endpointgroup_create.go @@ -4,11 +4,14 @@ import ( "errors" "net/http" - "github.com/asaskevich/govalidator" httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/pkg/featureflags" + + "github.com/asaskevich/govalidator" ) type endpointGroupCreatePayload struct { @@ -24,11 +27,13 @@ type endpointGroupCreatePayload struct { func (payload *endpointGroupCreatePayload) Validate(r *http.Request) error { if govalidator.IsNull(payload.Name) { - return errors.New("Invalid environment group name") + return errors.New("invalid environment group name") } + if payload.TagIDs == nil { payload.TagIDs = []portainer.TagID{} } + return nil } @@ -52,6 +57,29 @@ func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Reque return httperror.BadRequest("Invalid request payload", err) } + var endpointGroup *portainer.EndpointGroup + if featureflags.IsEnabled(portainer.FeatureNoTx) { + endpointGroup, err = handler.createEndpointGroup(handler.DataStore, payload) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + endpointGroup, err = handler.createEndpointGroup(tx, payload) + return err + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.JSON(w, endpointGroup) +} + +func (handler *Handler) createEndpointGroup(tx dataservices.DataStoreTx, payload endpointGroupCreatePayload) (*portainer.EndpointGroup, error) { endpointGroup := &portainer.EndpointGroup{ Name: payload.Name, Description: payload.Description, @@ -60,14 +88,14 @@ func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Reque TagIDs: payload.TagIDs, } - err = handler.DataStore.EndpointGroup().Create(endpointGroup) + err := tx.EndpointGroup().Create(endpointGroup) if err != nil { - return httperror.InternalServerError("Unable to persist the environment group inside the database", err) + return nil, httperror.InternalServerError("Unable to persist the environment group inside the database", err) } - endpoints, err := handler.DataStore.Endpoint().Endpoints() + endpoints, err := tx.Endpoint().Endpoints() if err != nil { - return httperror.InternalServerError("Unable to retrieve environments from the database", err) + return nil, httperror.InternalServerError("Unable to retrieve environments from the database", err) } for _, id := range payload.AssociatedEndpoints { @@ -75,14 +103,14 @@ func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Reque if endpoint.ID == id { endpoint.GroupID = endpointGroup.ID - err := handler.DataStore.Endpoint().UpdateEndpoint(endpoint.ID, &endpoint) + err := tx.Endpoint().UpdateEndpoint(endpoint.ID, &endpoint) if err != nil { - return httperror.InternalServerError("Unable to update environment", err) + return nil, httperror.InternalServerError("Unable to update environment", err) } - err = handler.updateEndpointRelations(&endpoint, endpointGroup) + err = handler.updateEndpointRelations(tx, &endpoint, endpointGroup) if err != nil { - return httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) + return nil, httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) } break @@ -91,16 +119,32 @@ func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Reque } for _, tagID := range endpointGroup.TagIDs { - err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { - tag.EndpointGroups[endpointGroup.ID] = true - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + tag.EndpointGroups[endpointGroup.ID] = true + }) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.InternalServerError("Unable to find a tag inside the database", err) - } else if err != nil { - return httperror.InternalServerError("Unable to persist tag changes inside the database", err) + if tx.IsErrObjectNotFound(err) { + return nil, httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { + return nil, httperror.InternalServerError("Unable to persist tag changes inside the database", err) + } + + continue + } + + tag, err := tx.Tag().Tag(tagID) + if err != nil { + return nil, httperror.InternalServerError("Unable to find a tag inside the database", err) + } + + tag.EndpointGroups[endpointGroup.ID] = true + + err = tx.Tag().UpdateTag(tagID, tag) + if err != nil { + return nil, httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } - return response.JSON(w, endpointGroup) + return endpointGroup, nil } diff --git a/api/http/handler/endpointgroups/endpointgroup_delete.go b/api/http/handler/endpointgroups/endpointgroup_delete.go index c82eee7fb..5869c5720 100644 --- a/api/http/handler/endpointgroups/endpointgroup_delete.go +++ b/api/http/handler/endpointgroups/endpointgroup_delete.go @@ -8,6 +8,8 @@ import ( "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/pkg/featureflags" ) // @id EndpointGroupDelete @@ -33,19 +35,40 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque return httperror.Forbidden("Unable to remove the default 'Unassigned' group", errors.New("Cannot remove the default environment group")) } - endpointGroup, err := handler.DataStore.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) - if handler.DataStore.IsErrObjectNotFound(err) { + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.deleteEndpointGroup(handler.DataStore, portainer.EndpointGroupID(endpointGroupID)) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + return handler.deleteEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID)) + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.Empty(w) +} + +func (handler *Handler) deleteEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID) error { + endpointGroup, err := tx.EndpointGroup().EndpointGroup(endpointGroupID) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment group with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment group with the specified identifier inside the database", err) } - err = handler.DataStore.EndpointGroup().DeleteEndpointGroup(portainer.EndpointGroupID(endpointGroupID)) + err = tx.EndpointGroup().DeleteEndpointGroup(endpointGroupID) if err != nil { return httperror.InternalServerError("Unable to remove the environment group from the database", err) } - endpoints, err := handler.DataStore.Endpoint().Endpoints() + endpoints, err := tx.Endpoint().Endpoints() if err != nil { return httperror.InternalServerError("Unable to retrieve environment from the database", err) } @@ -53,12 +76,12 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque for _, endpoint := range endpoints { if endpoint.GroupID == portainer.EndpointGroupID(endpointGroupID) { endpoint.GroupID = portainer.EndpointGroupID(1) - err = handler.DataStore.Endpoint().UpdateEndpoint(endpoint.ID, &endpoint) + err = tx.Endpoint().UpdateEndpoint(endpoint.ID, &endpoint) if err != nil { return httperror.InternalServerError("Unable to update environment", err) } - err = handler.updateEndpointRelations(&endpoint, nil) + err = handler.updateEndpointRelations(tx, &endpoint, nil) if err != nil { return httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) } @@ -66,16 +89,32 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque } for _, tagID := range endpointGroup.TagIDs { - err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { - delete(tag.EndpointGroups, endpointGroup.ID) - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.EndpointGroups, endpointGroup.ID) + }) - if handler.DataStore.IsErrObjectNotFound(err) { + if tx.IsErrObjectNotFound(err) { + return httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { + return httperror.InternalServerError("Unable to persist tag changes inside the database", err) + } + + continue + } + + tag, err := tx.Tag().Tag(tagID) + if tx.IsErrObjectNotFound(err) { return httperror.InternalServerError("Unable to find a tag inside the database", err) - } else if err != nil { + } + + delete(tag.EndpointGroups, endpointGroup.ID) + + err = tx.Tag().UpdateTag(tagID, tag) + if err != nil { return httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } - return response.Empty(w) + return nil } diff --git a/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go b/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go index 3910ede8d..1edfbbce0 100644 --- a/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go +++ b/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go @@ -1,12 +1,15 @@ package endpointgroups import ( + "errors" "net/http" httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/pkg/featureflags" ) // @id EndpointGroupAddEndpoint @@ -34,15 +37,36 @@ func (handler *Handler) endpointGroupAddEndpoint(w http.ResponseWriter, r *http. return httperror.BadRequest("Invalid environment identifier route variable", err) } - endpointGroup, err := handler.DataStore.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) - if handler.DataStore.IsErrObjectNotFound(err) { + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.addEndpoint(handler.DataStore, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + return handler.addEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.Empty(w) +} + +func (handler *Handler) addEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error { + endpointGroup, err := tx.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment group with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment group with the specified identifier inside the database", err) } - endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) - if handler.DataStore.IsErrObjectNotFound(err) { + endpoint, err := tx.Endpoint().Endpoint(portainer.EndpointID(endpointID)) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment with the specified identifier inside the database", err) @@ -50,15 +74,15 @@ func (handler *Handler) endpointGroupAddEndpoint(w http.ResponseWriter, r *http. endpoint.GroupID = endpointGroup.ID - err = handler.DataStore.Endpoint().UpdateEndpoint(endpoint.ID, endpoint) + err = tx.Endpoint().UpdateEndpoint(endpoint.ID, endpoint) if err != nil { return httperror.InternalServerError("Unable to persist environment changes inside the database", err) } - err = handler.updateEndpointRelations(endpoint, endpointGroup) + err = handler.updateEndpointRelations(tx, endpoint, endpointGroup) if err != nil { return httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) } - return response.Empty(w) + return nil } diff --git a/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go b/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go index ff81dd839..6a2a047d3 100644 --- a/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go +++ b/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go @@ -1,12 +1,15 @@ package endpointgroups import ( + "errors" "net/http" httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/pkg/featureflags" ) // @id EndpointGroupDeleteEndpoint @@ -33,15 +36,36 @@ func (handler *Handler) endpointGroupDeleteEndpoint(w http.ResponseWriter, r *ht return httperror.BadRequest("Invalid environment identifier route variable", err) } - _, err = handler.DataStore.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) - if handler.DataStore.IsErrObjectNotFound(err) { + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.removeEndpoint(handler.DataStore, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + return handler.removeEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.Empty(w) +} + +func (handler *Handler) removeEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error { + _, err := tx.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment group with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment group with the specified identifier inside the database", err) } - endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) - if handler.DataStore.IsErrObjectNotFound(err) { + endpoint, err := tx.Endpoint().Endpoint(portainer.EndpointID(endpointID)) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment with the specified identifier inside the database", err) @@ -49,15 +73,15 @@ func (handler *Handler) endpointGroupDeleteEndpoint(w http.ResponseWriter, r *ht endpoint.GroupID = portainer.EndpointGroupID(1) - err = handler.DataStore.Endpoint().UpdateEndpoint(endpoint.ID, endpoint) + err = tx.Endpoint().UpdateEndpoint(endpoint.ID, endpoint) if err != nil { return httperror.InternalServerError("Unable to persist environment changes inside the database", err) } - err = handler.updateEndpointRelations(endpoint, nil) + err = handler.updateEndpointRelations(tx, endpoint, nil) if err != nil { return httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) } - return response.Empty(w) + return nil } diff --git a/api/http/handler/endpointgroups/endpointgroup_update.go b/api/http/handler/endpointgroups/endpointgroup_update.go index ee0c40fce..0bdfec41d 100644 --- a/api/http/handler/endpointgroups/endpointgroup_update.go +++ b/api/http/handler/endpointgroups/endpointgroup_update.go @@ -1,6 +1,7 @@ package endpointgroups import ( + "errors" "net/http" "reflect" @@ -8,7 +9,9 @@ import ( "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/tag" + "github.com/portainer/portainer/pkg/featureflags" ) type endpointGroupUpdatePayload struct { @@ -54,11 +57,34 @@ func (handler *Handler) endpointGroupUpdate(w http.ResponseWriter, r *http.Reque return httperror.BadRequest("Invalid request payload", err) } - endpointGroup, err := handler.DataStore.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.NotFound("Unable to find an environment group with the specified identifier inside the database", err) + var endpointGroup *portainer.EndpointGroup + if featureflags.IsEnabled(portainer.FeatureNoTx) { + endpointGroup, err = handler.updateEndpointGroup(handler.DataStore, portainer.EndpointGroupID(endpointGroupID), payload) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + endpointGroup, err = handler.updateEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID), payload) + return err + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.JSON(w, endpointGroup) +} + +func (handler *Handler) updateEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, payload endpointGroupUpdatePayload) (*portainer.EndpointGroup, error) { + endpointGroup, err := tx.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(endpointGroupID)) + if tx.IsErrObjectNotFound(err) { + return nil, httperror.NotFound("Unable to find an environment group with the specified identifier inside the database", err) } else if err != nil { - return httperror.InternalServerError("Unable to find an environment group with the specified identifier inside the database", err) + return nil, httperror.InternalServerError("Unable to find an environment group with the specified identifier inside the database", err) } if payload.Name != "" { @@ -81,27 +107,59 @@ func (handler *Handler) endpointGroupUpdate(w http.ResponseWriter, r *http.Reque removeTags := tag.Difference(endpointGroupTagSet, payloadTagSet) for tagID := range removeTags { - err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { - delete(tag.EndpointGroups, endpointGroup.ID) - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.EndpointGroups, endpointGroup.ID) + }) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.InternalServerError("Unable to find a tag inside the database", err) - } else if err != nil { - return httperror.InternalServerError("Unable to persist tag changes inside the database", err) + if tx.IsErrObjectNotFound(err) { + return nil, httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { + return nil, httperror.InternalServerError("Unable to persist tag changes inside the database", err) + } + + continue + } + + tag, err := tx.Tag().Tag(tagID) + if err != nil { + return nil, httperror.InternalServerError("Unable to find a tag inside the database", err) + } + + delete(tag.EndpointGroups, endpointGroup.ID) + + err = tx.Tag().UpdateTag(tagID, tag) + if err != nil { + return nil, httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } endpointGroup.TagIDs = payload.TagIDs for _, tagID := range payload.TagIDs { - err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { - tag.EndpointGroups[endpointGroup.ID] = true - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + tag.EndpointGroups[endpointGroup.ID] = true + }) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.InternalServerError("Unable to find a tag inside the database", err) - } else if err != nil { - return httperror.InternalServerError("Unable to persist tag changes inside the database", err) + if tx.IsErrObjectNotFound(err) { + return nil, httperror.InternalServerError("Unable to find a tag inside the database", err) + } else if err != nil { + return nil, httperror.InternalServerError("Unable to persist tag changes inside the database", err) + } + + continue + } + + tag, err := tx.Tag().Tag(tagID) + if err != nil { + return nil, httperror.InternalServerError("Unable to find a tag inside the database", err) + } + + tag.EndpointGroups[endpointGroup.ID] = true + + err = tx.Tag().UpdateTag(tagID, tag) + if err != nil { + return nil, httperror.InternalServerError("Unable to persist tag changes inside the database", err) } } } @@ -119,44 +177,44 @@ func (handler *Handler) endpointGroupUpdate(w http.ResponseWriter, r *http.Reque } if updateAuthorizations { - endpoints, err := handler.DataStore.Endpoint().Endpoints() + endpoints, err := tx.Endpoint().Endpoints() if err != nil { - return httperror.InternalServerError("Unable to retrieve environments from the database", err) + return nil, httperror.InternalServerError("Unable to retrieve environments from the database", err) } for _, endpoint := range endpoints { if endpoint.GroupID == endpointGroup.ID { if endpoint.Type == portainer.KubernetesLocalEnvironment || endpoint.Type == portainer.AgentOnKubernetesEnvironment || endpoint.Type == portainer.EdgeAgentOnKubernetesEnvironment { - err = handler.AuthorizationService.CleanNAPWithOverridePolicies(&endpoint, endpointGroup) + err = handler.AuthorizationService.CleanNAPWithOverridePolicies(tx, &endpoint, endpointGroup) if err != nil { - return httperror.InternalServerError("Unable to update user authorizations", err) + return nil, httperror.InternalServerError("Unable to update user authorizations", err) } } } } } - err = handler.DataStore.EndpointGroup().UpdateEndpointGroup(endpointGroup.ID, endpointGroup) + err = tx.EndpointGroup().UpdateEndpointGroup(endpointGroup.ID, endpointGroup) if err != nil { - return httperror.InternalServerError("Unable to persist environment group changes inside the database", err) + return nil, httperror.InternalServerError("Unable to persist environment group changes inside the database", err) } if tagsChanged { - endpoints, err := handler.DataStore.Endpoint().Endpoints() + endpoints, err := tx.Endpoint().Endpoints() if err != nil { - return httperror.InternalServerError("Unable to retrieve environments from the database", err) + return nil, httperror.InternalServerError("Unable to retrieve environments from the database", err) } for _, endpoint := range endpoints { if endpoint.GroupID == endpointGroup.ID { - err = handler.updateEndpointRelations(&endpoint, endpointGroup) + err = handler.updateEndpointRelations(tx, &endpoint, endpointGroup) if err != nil { - return httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) + return nil, httperror.InternalServerError("Unable to persist environment relations changes inside the database", err) } } } } - return response.JSON(w, endpointGroup) + return endpointGroup, nil } diff --git a/api/http/handler/endpointgroups/endpoints.go b/api/http/handler/endpointgroups/endpoints.go index c1757d3c7..7e6462415 100644 --- a/api/http/handler/endpointgroups/endpoints.go +++ b/api/http/handler/endpointgroups/endpoints.go @@ -2,16 +2,17 @@ package endpointgroups import ( portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/edge" ) -func (handler *Handler) updateEndpointRelations(endpoint *portainer.Endpoint, endpointGroup *portainer.EndpointGroup) error { +func (handler *Handler) updateEndpointRelations(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, endpointGroup *portainer.EndpointGroup) error { if endpoint.Type != portainer.EdgeAgentOnKubernetesEnvironment && endpoint.Type != portainer.EdgeAgentOnDockerEnvironment { return nil } if endpointGroup == nil { - unassignedGroup, err := handler.DataStore.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(1)) + unassignedGroup, err := tx.EndpointGroup().EndpointGroup(portainer.EndpointGroupID(1)) if err != nil { return err } @@ -19,17 +20,17 @@ func (handler *Handler) updateEndpointRelations(endpoint *portainer.Endpoint, en endpointGroup = unassignedGroup } - endpointRelation, err := handler.DataStore.EndpointRelation().EndpointRelation(endpoint.ID) + endpointRelation, err := tx.EndpointRelation().EndpointRelation(endpoint.ID) if err != nil { return err } - edgeGroups, err := handler.DataStore.EdgeGroup().EdgeGroups() + edgeGroups, err := tx.EdgeGroup().EdgeGroups() if err != nil { return err } - edgeStacks, err := handler.DataStore.EdgeStack().EdgeStacks() + edgeStacks, err := tx.EdgeStack().EdgeStacks() if err != nil { return err } @@ -41,5 +42,5 @@ func (handler *Handler) updateEndpointRelations(endpoint *portainer.Endpoint, en } endpointRelation.EdgeStacks = stacksSet - return handler.DataStore.EndpointRelation().UpdateEndpointRelation(endpoint.ID, endpointRelation) + return tx.EndpointRelation().UpdateEndpointRelation(endpoint.ID, endpointRelation) } diff --git a/api/http/handler/endpoints/endpoint_update.go b/api/http/handler/endpoints/endpoint_update.go index a21898a2e..8e2400512 100644 --- a/api/http/handler/endpoints/endpoint_update.go +++ b/api/http/handler/endpoints/endpoint_update.go @@ -256,7 +256,7 @@ func (handler *Handler) endpointUpdate(w http.ResponseWriter, r *http.Request) * if updateAuthorizations { if endpoint.Type == portainer.KubernetesLocalEnvironment || endpoint.Type == portainer.AgentOnKubernetesEnvironment || endpoint.Type == portainer.EdgeAgentOnKubernetesEnvironment { - err = handler.AuthorizationService.CleanNAPWithOverridePolicies(endpoint, nil) + err = handler.AuthorizationService.CleanNAPWithOverridePolicies(handler.DataStore, endpoint, nil) if err != nil { return httperror.InternalServerError("Unable to update user authorizations", err) } diff --git a/api/internal/authorization/endpint_role_with_override.go b/api/internal/authorization/endpoint_role_with_override.go similarity index 73% rename from api/internal/authorization/endpint_role_with_override.go rename to api/internal/authorization/endpoint_role_with_override.go index 14d004487..3ba1a49d3 100644 --- a/api/internal/authorization/endpint_role_with_override.go +++ b/api/internal/authorization/endpoint_role_with_override.go @@ -1,9 +1,13 @@ package authorization -import portainer "github.com/portainer/portainer/api" +import ( + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" +) // CleanNAPWithOverridePolicies Clean Namespace Access Policies with override policies func (service *Service) CleanNAPWithOverridePolicies( + tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, endpointGroup *portainer.EndpointGroup, ) error { @@ -21,10 +25,11 @@ func (service *Service) CleanNAPWithOverridePolicies( for namespace, policy := range accessPolicies { for teamID := range policy.TeamAccessPolicies { - access, err := service.getTeamEndpointAccessWithPolicies(teamID, endpoint, endpointGroup) + access, err := service.getTeamEndpointAccessWithPolicies(tx, teamID, endpoint, endpointGroup) if err != nil { return err } + if !access { delete(accessPolicies[namespace].TeamAccessPolicies, teamID) hasChange = true @@ -32,10 +37,11 @@ func (service *Service) CleanNAPWithOverridePolicies( } for userID := range policy.UserAccessPolicies { - access, err := service.getUserEndpointAccessWithPolicies(userID, endpoint, endpointGroup) + access, err := service.getUserEndpointAccessWithPolicies(tx, userID, endpoint, endpointGroup) if err != nil { return err } + if !access { delete(accessPolicies[namespace].UserAccessPolicies, userID) hasChange = true @@ -51,27 +57,28 @@ func (service *Service) CleanNAPWithOverridePolicies( } func (service *Service) getUserEndpointAccessWithPolicies( + tx dataservices.DataStoreTx, userID portainer.UserID, endpoint *portainer.Endpoint, endpointGroup *portainer.EndpointGroup, ) (bool, error) { - memberships, err := service.dataStore.TeamMembership().TeamMembershipsByUserID(userID) + memberships, err := tx.TeamMembership().TeamMembershipsByUserID(userID) if err != nil { return false, err } if endpointGroup == nil { - endpointGroup, err = service.dataStore.EndpointGroup().EndpointGroup(endpoint.GroupID) + endpointGroup, err = tx.EndpointGroup().EndpointGroup(endpoint.GroupID) if err != nil { return false, err } } - if userAccess(userID, endpoint.UserAccessPolicies, endpoint.TeamAccessPolicies, memberships) { + if userAccess(tx, userID, endpoint.UserAccessPolicies, endpoint.TeamAccessPolicies, memberships) { return true, nil } - if userAccess(userID, endpointGroup.UserAccessPolicies, endpointGroup.TeamAccessPolicies, memberships) { + if userAccess(tx, userID, endpointGroup.UserAccessPolicies, endpointGroup.TeamAccessPolicies, memberships) { return true, nil } @@ -80,6 +87,7 @@ func (service *Service) getUserEndpointAccessWithPolicies( } func userAccess( + tx dataservices.DataStoreTx, userID portainer.UserID, userAccessPolicies portainer.UserAccessPolicies, teamAccessPolicies portainer.TeamAccessPolicies, @@ -99,13 +107,14 @@ func userAccess( } func (service *Service) getTeamEndpointAccessWithPolicies( + tx dataservices.DataStoreTx, teamID portainer.TeamID, endpoint *portainer.Endpoint, endpointGroup *portainer.EndpointGroup, ) (bool, error) { if endpointGroup == nil { var err error - endpointGroup, err = service.dataStore.EndpointGroup().EndpointGroup(endpoint.GroupID) + endpointGroup, err = tx.EndpointGroup().EndpointGroup(endpoint.GroupID) if err != nil { return false, err }