From 8c533bee67f34f262bbdfb8477cbc74e0da6ad12 Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:36:00 -0300 Subject: [PATCH] feat(transactions): migrate some parts to use transactional code EE-5494 (#9213) --- api/dataservices/stack/stack.go | 6 ++ api/dataservices/stack/tx.go | 98 +++++++++++++++++ api/http/handler/endpoints/endpoint_create.go | 54 +++++----- api/http/handler/endpoints/endpoint_delete.go | 102 ++++++++++++++---- .../endpoints/endpoint_registries_list.go | 64 +++++++---- .../endpoints/endpoint_registry_access.go | 36 +++++-- api/http/handler/stacks/stack_create.go | 2 +- api/http/handler/webhooks/handler.go | 1 + api/internal/authorization/authorizations.go | 30 +++--- 9 files changed, 305 insertions(+), 88 deletions(-) create mode 100644 api/dataservices/stack/tx.go diff --git a/api/dataservices/stack/stack.go b/api/dataservices/stack/stack.go index 57e28e4d5..4114fdc8b 100644 --- a/api/dataservices/stack/stack.go +++ b/api/dataservices/stack/stack.go @@ -32,6 +32,12 @@ func NewService(connection portainer.Connection) (*Service, error) { }, nil } +func (service *Service) Tx(tx portainer.Transaction) ServiceTx { + return ServiceTx{ + BaseDataServiceTx: service.BaseDataService.Tx(tx), + } +} + // StackByName returns a stack object by name. func (service *Service) StackByName(name string) (*portainer.Stack, error) { var s portainer.Stack diff --git a/api/dataservices/stack/tx.go b/api/dataservices/stack/tx.go new file mode 100644 index 000000000..478c286ea --- /dev/null +++ b/api/dataservices/stack/tx.go @@ -0,0 +1,98 @@ +package stack + +import ( + "errors" + "strings" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + dserrors "github.com/portainer/portainer/api/dataservices/errors" +) + +type ServiceTx struct { + dataservices.BaseDataServiceTx[portainer.Stack, portainer.StackID] +} + +// StackByName returns a stack object by name. +func (service ServiceTx) StackByName(name string) (*portainer.Stack, error) { + var s portainer.Stack + + err := service.Tx.GetAll( + BucketName, + &portainer.Stack{}, + dataservices.FirstFn(&s, func(e portainer.Stack) bool { + return e.Name == name + }), + ) + + if errors.Is(err, dataservices.ErrStop) { + return &s, nil + } + + if err == nil { + return nil, dserrors.ErrObjectNotFound + } + + return nil, err +} + +// Stacks returns an array containing all the stacks with same name +func (service ServiceTx) StacksByName(name string) ([]portainer.Stack, error) { + var stacks = make([]portainer.Stack, 0) + + return stacks, service.Tx.GetAll( + BucketName, + &portainer.Stack{}, + dataservices.FilterFn(&stacks, func(e portainer.Stack) bool { + return e.Name == name + }), + ) +} + +// GetNextIdentifier returns the next identifier for a stack. +func (service ServiceTx) GetNextIdentifier() int { + return service.Tx.GetNextIdentifier(BucketName) +} + +// CreateStack creates a new stack. +func (service ServiceTx) Create(stack *portainer.Stack) error { + return service.Tx.CreateObjectWithId(BucketName, int(stack.ID), stack) +} + +// StackByWebhookID returns a pointer to a stack object by webhook ID. +// It returns nil, errors.ErrObjectNotFound if there's no stack associated with the webhook ID. +func (service ServiceTx) StackByWebhookID(id string) (*portainer.Stack, error) { + var s portainer.Stack + + err := service.Tx.GetAll( + BucketName, + &portainer.Stack{}, + dataservices.FirstFn(&s, func(e portainer.Stack) bool { + return e.AutoUpdate != nil && strings.EqualFold(e.AutoUpdate.Webhook, id) + }), + ) + + if errors.Is(err, dataservices.ErrStop) { + return &s, nil + } + + if err == nil { + return nil, dserrors.ErrObjectNotFound + } + + return nil, err + +} + +// RefreshableStacks returns stacks that are configured for a periodic update +func (service ServiceTx) RefreshableStacks() ([]portainer.Stack, error) { + stacks := make([]portainer.Stack, 0) + + return stacks, service.Tx.GetAll( + BucketName, + &portainer.Stack{}, + dataservices.FilterFn(&stacks, func(e portainer.Stack) bool { + return e.AutoUpdate != nil && e.AutoUpdate.Interval != "" + }), + ) +} diff --git a/api/http/handler/endpoints/endpoint_create.go b/api/http/handler/endpoints/endpoint_create.go index 15eeef5f7..25d6d4da0 100644 --- a/api/http/handler/endpoints/endpoint_create.go +++ b/api/http/handler/endpoints/endpoint_create.go @@ -15,6 +15,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/agent" "github.com/portainer/portainer/api/crypto" + "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/client" "github.com/portainer/portainer/api/internal/edge" "github.com/portainer/portainer/api/internal/endpointutils" @@ -217,7 +218,7 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) * return httperror.NewError(http.StatusConflict, "Name is not unique", nil) } - endpoint, endpointCreationError := handler.createEndpoint(payload) + endpoint, endpointCreationError := handler.createEndpoint(handler.DataStore, payload) if endpointCreationError != nil { return endpointCreationError } @@ -273,17 +274,17 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) * return response.JSON(w, endpoint) } -func (handler *Handler) createEndpoint(payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { +func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { var err error switch payload.EndpointCreationType { case azureEnvironment: - return handler.createAzureEndpoint(payload) + return handler.createAzureEndpoint(tx, payload) case edgeAgentEnvironment: - return handler.createEdgeAgentEndpoint(payload) + return handler.createEdgeAgentEndpoint(tx, payload) case localKubernetesEnvironment: - return handler.createKubernetesEndpoint(payload) + return handler.createKubernetesEndpoint(tx, payload) } endpointType := portainer.DockerEnvironment @@ -312,12 +313,13 @@ func (handler *Handler) createEndpoint(payload *endpointCreatePayload) (*portain } if payload.TLS { - return handler.createTLSSecuredEndpoint(payload, endpointType, agentVersion) + return handler.createTLSSecuredEndpoint(tx, payload, endpointType, agentVersion) } - return handler.createUnsecuredEndpoint(payload) + + return handler.createUnsecuredEndpoint(tx, payload) } -func (handler *Handler) createAzureEndpoint(payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { +func (handler *Handler) createAzureEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { credentials := portainer.AzureCredentials{ ApplicationID: payload.AzureApplicationID, TenantID: payload.AzureTenantID, @@ -330,7 +332,7 @@ func (handler *Handler) createAzureEndpoint(payload *endpointCreatePayload) (*po return nil, httperror.InternalServerError("Unable to authenticate against Azure", err) } - endpointID := handler.DataStore.Endpoint().GetNextIdentifier() + endpointID := tx.Endpoint().GetNextIdentifier() endpoint := &portainer.Endpoint{ ID: portainer.EndpointID(endpointID), Name: payload.Name, @@ -348,7 +350,7 @@ func (handler *Handler) createAzureEndpoint(payload *endpointCreatePayload) (*po Kubernetes: portainer.KubernetesDefault(), } - err = handler.saveEndpointAndUpdateAuthorizations(endpoint) + err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) if err != nil { return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err) } @@ -356,7 +358,7 @@ func (handler *Handler) createAzureEndpoint(payload *endpointCreatePayload) (*po return endpoint, nil } -func (handler *Handler) createEdgeAgentEndpoint(payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { +func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { endpointID := handler.DataStore.Endpoint().GetNextIdentifier() portainerHost, err := edge.ParseHostForEdge(payload.URL) @@ -401,7 +403,7 @@ func (handler *Handler) createEdgeAgentEndpoint(payload *endpointCreatePayload) endpoint.EdgeID = edgeID.String() } - err = handler.saveEndpointAndUpdateAuthorizations(endpoint) + err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) if err != nil { return nil, httperror.InternalServerError("An error occurred while trying to create the environment", err) } @@ -409,7 +411,7 @@ func (handler *Handler) createEdgeAgentEndpoint(payload *endpointCreatePayload) return endpoint, nil } -func (handler *Handler) createUnsecuredEndpoint(payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { +func (handler *Handler) createUnsecuredEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { endpointType := portainer.DockerEnvironment if payload.URL == "" { @@ -419,7 +421,7 @@ func (handler *Handler) createUnsecuredEndpoint(payload *endpointCreatePayload) } } - endpointID := handler.DataStore.Endpoint().GetNextIdentifier() + endpointID := tx.Endpoint().GetNextIdentifier() endpoint := &portainer.Endpoint{ ID: portainer.EndpointID(endpointID), Name: payload.Name, @@ -439,7 +441,7 @@ func (handler *Handler) createUnsecuredEndpoint(payload *endpointCreatePayload) Kubernetes: portainer.KubernetesDefault(), } - err := handler.snapshotAndPersistEndpoint(endpoint) + err := handler.snapshotAndPersistEndpoint(tx, endpoint) if err != nil { return nil, err } @@ -447,12 +449,12 @@ func (handler *Handler) createUnsecuredEndpoint(payload *endpointCreatePayload) return endpoint, nil } -func (handler *Handler) createKubernetesEndpoint(payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { +func (handler *Handler) createKubernetesEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) { if payload.URL == "" { payload.URL = "https://kubernetes.default.svc" } - endpointID := handler.DataStore.Endpoint().GetNextIdentifier() + endpointID := tx.Endpoint().GetNextIdentifier() endpoint := &portainer.Endpoint{ ID: portainer.EndpointID(endpointID), @@ -474,7 +476,7 @@ func (handler *Handler) createKubernetesEndpoint(payload *endpointCreatePayload) Kubernetes: portainer.KubernetesDefault(), } - err := handler.snapshotAndPersistEndpoint(endpoint) + err := handler.snapshotAndPersistEndpoint(tx, endpoint) if err != nil { return nil, err } @@ -482,8 +484,8 @@ func (handler *Handler) createKubernetesEndpoint(payload *endpointCreatePayload) return endpoint, nil } -func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload, endpointType portainer.EndpointType, agentVersion string) (*portainer.Endpoint, *httperror.HandlerError) { - endpointID := handler.DataStore.Endpoint().GetNextIdentifier() +func (handler *Handler) createTLSSecuredEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload, endpointType portainer.EndpointType, agentVersion string) (*portainer.Endpoint, *httperror.HandlerError) { + endpointID := tx.Endpoint().GetNextIdentifier() endpoint := &portainer.Endpoint{ ID: portainer.EndpointID(endpointID), Name: payload.Name, @@ -511,7 +513,7 @@ func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload, return nil, err } - err = handler.snapshotAndPersistEndpoint(endpoint) + err = handler.snapshotAndPersistEndpoint(tx, endpoint) if err != nil { return nil, err } @@ -519,7 +521,7 @@ func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload, return endpoint, nil } -func (handler *Handler) snapshotAndPersistEndpoint(endpoint *portainer.Endpoint) *httperror.HandlerError { +func (handler *Handler) snapshotAndPersistEndpoint(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) *httperror.HandlerError { err := handler.SnapshotService.SnapshotEndpoint(endpoint) if err != nil { if (endpoint.Type == portainer.AgentOnDockerEnvironment && strings.Contains(err.Error(), "Invalid request signature")) || @@ -529,7 +531,7 @@ func (handler *Handler) snapshotAndPersistEndpoint(endpoint *portainer.Endpoint) return httperror.InternalServerError("Unable to initiate communications with environment", err) } - err = handler.saveEndpointAndUpdateAuthorizations(endpoint) + err = handler.saveEndpointAndUpdateAuthorizations(tx, endpoint) if err != nil { return httperror.InternalServerError("An error occurred while trying to create the environment", err) } @@ -537,7 +539,7 @@ func (handler *Handler) snapshotAndPersistEndpoint(endpoint *portainer.Endpoint) return nil } -func (handler *Handler) saveEndpointAndUpdateAuthorizations(endpoint *portainer.Endpoint) error { +func (handler *Handler) saveEndpointAndUpdateAuthorizations(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) error { endpoint.SecuritySettings = portainer.EndpointSecuritySettings{ AllowVolumeBrowserForRegularUsers: false, EnableHostManagementFeatures: false, @@ -551,13 +553,13 @@ func (handler *Handler) saveEndpointAndUpdateAuthorizations(endpoint *portainer. AllowStackManagementForRegularUsers: true, } - err := handler.DataStore.Endpoint().Create(endpoint) + err := tx.Endpoint().Create(endpoint) if err != nil { return err } for _, tagID := range endpoint.TagIDs { - err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + err = tx.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { tag.Endpoints[endpoint.ID] = true }) if err != nil { diff --git a/api/http/handler/endpoints/endpoint_delete.go b/api/http/handler/endpoints/endpoint_delete.go index 664586bfe..2e8c3118d 100644 --- a/api/http/handler/endpoints/endpoint_delete.go +++ b/api/http/handler/endpoints/endpoint_delete.go @@ -1,6 +1,7 @@ package endpoints import ( + "errors" "net/http" "strconv" @@ -8,8 +9,11 @@ import ( "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" httperrors "github.com/portainer/portainer/api/http/errors" "github.com/portainer/portainer/api/internal/endpointutils" + "github.com/portainer/portainer/pkg/featureflags" + "github.com/rs/zerolog/log" ) @@ -33,41 +37,84 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * return httperror.BadRequest("Invalid environment identifier route variable", err) } + // This is a Portainer provisioned cloud environment + deleteCluster, err := request.RetrieveBooleanQueryParameter(r, "deleteCluster", true) + if err != nil { + return httperror.BadRequest("Invalid boolean query parameter", err) + } + if handler.demoService.IsDemoEnvironment(portainer.EndpointID(endpointID)) { return httperror.Forbidden(httperrors.ErrNotAvailableInDemo.Error(), httperrors.ErrNotAvailableInDemo) } - endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) - if handler.DataStore.IsErrObjectNotFound(err) { + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.deleteEndpoint(handler.DataStore, portainer.EndpointID(endpointID), deleteCluster) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + return handler.deleteEndpoint(tx, portainer.EndpointID(endpointID), deleteCluster) + }) + } + + if err != nil { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.Empty(w) +} + +func (handler *Handler) deleteEndpoint(tx dataservices.DataStoreTx, endpointID portainer.EndpointID, deleteCluster bool) error { + endpoint, err := tx.Endpoint().Endpoint(portainer.EndpointID(endpointID)) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to read the environment record from the database", err) } if endpoint.TLSConfig.TLS { - folder := strconv.Itoa(endpointID) + folder := strconv.Itoa(int(endpointID)) err = handler.FileService.DeleteTLSFiles(folder) if err != nil { log.Error().Err(err).Msgf("Unable to remove TLS files from disk when deleting endpoint %d", endpointID) } } - err = handler.DataStore.Snapshot().Delete(portainer.EndpointID(endpointID)) + err = tx.Snapshot().Delete(endpointID) if err != nil { log.Warn().Err(err).Msgf("Unable to remove the snapshot from the database") } handler.ProxyManager.DeleteEndpointProxy(endpoint.ID) - err = handler.DataStore.EndpointRelation().DeleteEndpointRelation(endpoint.ID) + if len(endpoint.UserAccessPolicies) > 0 || len(endpoint.TeamAccessPolicies) > 0 { + err = handler.AuthorizationService.UpdateUsersAuthorizationsTx(tx) + if err != nil { + log.Warn().Err(err).Msgf("Unable to update user authorizations") + } + } + + err = tx.EndpointRelation().DeleteEndpointRelation(endpoint.ID) if err != nil { log.Warn().Err(err).Msgf("Unable to remove environment relation from the database") } for _, tagID := range endpoint.TagIDs { - err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { - delete(tag.Endpoints, endpoint.ID) - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.DataStore.Tag().UpdateTagFunc(tagID, func(tag *portainer.Tag) { + delete(tag.Endpoints, endpoint.ID) + }) + } else { + var tag *portainer.Tag + tag, err = tx.Tag().Read(tagID) + if err == nil { + delete(tag.Endpoints, endpoint.ID) + err = tx.Tag().Update(tagID, tag) + } + } if handler.DataStore.IsErrObjectNotFound(err) { log.Warn().Err(err).Msgf("Unable to find tag inside the database") @@ -76,21 +123,27 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * } } - edgeGroups, err := handler.DataStore.EdgeGroup().ReadAll() + edgeGroups, err := tx.EdgeGroup().ReadAll() if err != nil { log.Warn().Err(err).Msgf("Unable to retrieve edge groups from the database") } for _, edgeGroup := range edgeGroups { - err = handler.DataStore.EdgeGroup().UpdateEdgeGroupFunc(edgeGroup.ID, func(g *portainer.EdgeGroup) { - g.Endpoints = removeElement(g.Endpoints, endpoint.ID) - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.DataStore.EdgeGroup().UpdateEdgeGroupFunc(edgeGroup.ID, func(g *portainer.EdgeGroup) { + g.Endpoints = removeElement(g.Endpoints, endpoint.ID) + }) + } else { + edgeGroup.Endpoints = removeElement(edgeGroup.Endpoints, endpoint.ID) + tx.EdgeGroup().Update(edgeGroup.ID, &edgeGroup) + } + if err != nil { log.Warn().Err(err).Msgf("Unable to update edge group") } } - edgeStacks, err := handler.DataStore.EdgeStack().EdgeStacks() + edgeStacks, err := tx.EdgeStack().EdgeStacks() if err != nil { log.Warn().Err(err).Msgf("Unable to retrieve edge stacks from the database") } @@ -99,14 +152,14 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * edgeStack := &edgeStacks[idx] if _, ok := edgeStack.Status[endpoint.ID]; ok { delete(edgeStack.Status, endpoint.ID) - err = handler.DataStore.EdgeStack().UpdateEdgeStack(edgeStack.ID, edgeStack) + err = tx.EdgeStack().UpdateEdgeStack(edgeStack.ID, edgeStack) if err != nil { log.Warn().Err(err).Msgf("Unable to update edge stack") } } } - registries, err := handler.DataStore.Registry().ReadAll() + registries, err := tx.Registry().ReadAll() if err != nil { log.Warn().Err(err).Msgf("Unable to retrieve registries from the database") } @@ -115,7 +168,7 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * registry := ®istries[idx] if _, ok := registry.RegistryAccesses[endpoint.ID]; ok { delete(registry.RegistryAccesses, endpoint.ID) - err = handler.DataStore.Registry().Update(registry.ID, registry) + err = tx.Registry().Update(registry.ID, registry) if err != nil { log.Warn().Err(err).Msgf("Unable to update registry accesses") } @@ -131,9 +184,14 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * for idx := range edgeJobs { edgeJob := &edgeJobs[idx] if _, ok := edgeJob.Endpoints[endpoint.ID]; ok { - err = handler.DataStore.EdgeJob().UpdateEdgeJobFunc(edgeJob.ID, func(j *portainer.EdgeJob) { - delete(j.Endpoints, endpoint.ID) - }) + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = tx.EdgeJob().UpdateEdgeJobFunc(edgeJob.ID, func(j *portainer.EdgeJob) { + delete(j.Endpoints, endpoint.ID) + }) + } else { + delete(edgeJob.Endpoints, endpoint.ID) + err = tx.EdgeJob().Update(edgeJob.ID, edgeJob) + } if err != nil { log.Warn().Err(err).Msgf("Unable to update edge job") @@ -142,12 +200,12 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * } } - err = handler.DataStore.Endpoint().DeleteEndpoint(portainer.EndpointID(endpointID)) + err = tx.Endpoint().DeleteEndpoint(portainer.EndpointID(endpointID)) if err != nil { - return httperror.InternalServerError("Unable to remove environment from the database", err) + return httperror.InternalServerError("Unable to delete the environment from the database", err) } - return response.Empty(w) + return nil } func removeElement(slice []portainer.EndpointID, elem portainer.EndpointID) []portainer.EndpointID { diff --git a/api/http/handler/endpoints/endpoint_registries_list.go b/api/http/handler/endpoints/endpoint_registries_list.go index d2a0b0421..e0ef82275 100644 --- a/api/http/handler/endpoints/endpoint_registries_list.go +++ b/api/http/handler/endpoints/endpoint_registries_list.go @@ -3,13 +3,16 @@ package endpoints import ( "net/http" - "github.com/pkg/errors" httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" "github.com/portainer/portainer/api/internal/endpointutils" + "github.com/portainer/portainer/pkg/featureflags" + + "github.com/pkg/errors" ) // @id endpointRegistriesList @@ -26,45 +29,68 @@ import ( // @failure 500 "Server error" // @router /endpoints/{id}/registries [get] func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { - securityContext, err := security.RetrieveRestrictedRequestContext(r) - if err != nil { - return httperror.InternalServerError("Unable to retrieve info from request context", err) - } - - user, err := handler.DataStore.User().Read(securityContext.UserID) - if err != nil { - return httperror.InternalServerError("Unable to retrieve user from the database", err) - } - endpointID, err := request.RetrieveNumericRouteVariableValue(r, "id") if err != nil { return httperror.BadRequest("Invalid environment identifier route variable", err) } - endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) + var registries []portainer.Registry + if featureflags.IsEnabled(portainer.FeatureNoTx) { + registries, err = handler.listRegistries(handler.DataStore, r, portainer.EndpointID(endpointID)) + } else { + err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { + registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID)) + return err + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.JSON(w, registries) +} + +func (handler *Handler) listRegistries(tx dataservices.DataStoreTx, r *http.Request, endpointID portainer.EndpointID) ([]portainer.Registry, error) { + securityContext, err := security.RetrieveRestrictedRequestContext(r) + if err != nil { + return nil, httperror.InternalServerError("Unable to retrieve info from request context", err) + } + + user, err := tx.User().Read(securityContext.UserID) + if err != nil { + return nil, httperror.InternalServerError("Unable to retrieve user from the database", err) + } + + endpoint, err := tx.Endpoint().Endpoint(portainer.EndpointID(endpointID)) + if tx.IsErrObjectNotFound(err) { + return nil, httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) } else if err != nil { - return httperror.InternalServerError("Unable to find an environment with the specified identifier inside the database", err) + return nil, httperror.InternalServerError("Unable to find an environment with the specified identifier inside the database", err) } isAdmin := securityContext.IsAdmin - registries, err := handler.DataStore.Registry().ReadAll() + registries, err := tx.Registry().ReadAll() if err != nil { - return httperror.InternalServerError("Unable to retrieve registries from the database", err) + return nil, httperror.InternalServerError("Unable to retrieve registries from the database", err) } registries, handleError := handler.filterRegistriesByAccess(r, registries, endpoint, user, securityContext.UserMemberships) if handleError != nil { - return handleError + return nil, handleError } for idx := range registries { hideRegistryFields(®istries[idx], !isAdmin) } - return response.JSON(w, registries) + return registries, err } func (handler *Handler) filterRegistriesByAccess(r *http.Request, registries []portainer.Registry, endpoint *portainer.Endpoint, user *portainer.User, memberships []portainer.TeamMembership) ([]portainer.Registry, *httperror.HandlerError) { diff --git a/api/http/handler/endpoints/endpoint_registry_access.go b/api/http/handler/endpoints/endpoint_registry_access.go index 7226fcf49..e6924d745 100644 --- a/api/http/handler/endpoints/endpoint_registry_access.go +++ b/api/http/handler/endpoints/endpoint_registry_access.go @@ -1,13 +1,16 @@ package endpoints import ( + "errors" "net/http" httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/pkg/featureflags" ) type registryAccessPayload struct { @@ -48,8 +51,29 @@ func (handler *Handler) endpointRegistryAccess(w http.ResponseWriter, r *http.Re return httperror.BadRequest("Invalid registry identifier route variable", err) } - endpoint, err := handler.DataStore.Endpoint().Endpoint(portainer.EndpointID(endpointID)) - if handler.DataStore.IsErrObjectNotFound(err) { + if featureflags.IsEnabled(portainer.FeatureNoTx) { + err = handler.updateRegistryAccess(handler.DataStore, r, portainer.EndpointID(endpointID), portainer.RegistryID(registryID)) + } else { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + return handler.updateRegistryAccess(tx, r, portainer.EndpointID(endpointID), portainer.RegistryID(registryID)) + }) + } + + if err != nil { + var httpErr *httperror.HandlerError + if errors.As(err, &httpErr) { + return httpErr + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return response.Empty(w) +} + +func (handler *Handler) updateRegistryAccess(tx dataservices.DataStoreTx, r *http.Request, endpointID portainer.EndpointID, registryID portainer.RegistryID) error { + endpoint, err := tx.Endpoint().Endpoint(endpointID) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment with the specified identifier inside the database", err) @@ -69,8 +93,8 @@ func (handler *Handler) endpointRegistryAccess(w http.ResponseWriter, r *http.Re return httperror.Forbidden("User is not authorized", err) } - registry, err := handler.DataStore.Registry().Read(portainer.RegistryID(registryID)) - if handler.DataStore.IsErrObjectNotFound(err) { + registry, err := tx.Registry().Read(registryID) + if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an environment with the specified identifier inside the database", err) } else if err != nil { return httperror.InternalServerError("Unable to find an environment with the specified identifier inside the database", err) @@ -106,9 +130,7 @@ func (handler *Handler) endpointRegistryAccess(w http.ResponseWriter, r *http.Re registry.RegistryAccesses[portainer.EndpointID(endpointID)] = registryAccess - handler.DataStore.Registry().Update(registry.ID, registry) - - return response.Empty(w) + return tx.Registry().Update(registry.ID, registry) } func (handler *Handler) updateKubeAccess(endpoint *portainer.Endpoint, registry *portainer.Registry, oldNamespaces, newNamespaces []string) error { diff --git a/api/http/handler/stacks/stack_create.go b/api/http/handler/stacks/stack_create.go index a391bc75f..f18d89981 100644 --- a/api/http/handler/stacks/stack_create.go +++ b/api/http/handler/stacks/stack_create.go @@ -73,7 +73,6 @@ func (handler *Handler) stackCreate(w http.ResponseWriter, r *http.Request) *htt } func (handler *Handler) createComposeStack(w http.ResponseWriter, r *http.Request, method string, endpoint *portainer.Endpoint, userID portainer.UserID) *httperror.HandlerError { - switch method { case "string": return handler.createComposeStackFromFileContent(w, r, endpoint, userID) @@ -108,6 +107,7 @@ func (handler *Handler) createKubernetesStack(w http.ResponseWriter, r *http.Req case "url": return handler.createKubernetesStackFromManifestURL(w, r, endpoint, userID) } + return httperror.BadRequest("Invalid value for query parameter: method. Value must be one of: string or repository", errors.New(request.ErrInvalidQueryParameter)) } diff --git a/api/http/handler/webhooks/handler.go b/api/http/handler/webhooks/handler.go index 0453c08f1..0566658c7 100644 --- a/api/http/handler/webhooks/handler.go +++ b/api/http/handler/webhooks/handler.go @@ -35,5 +35,6 @@ func NewHandler(bouncer security.BouncerService) *Handler { bouncer.AuthenticatedAccess(httperror.LoggerHandler(h.webhookDelete))).Methods(http.MethodDelete) h.Handle("/webhooks/{token}", bouncer.PublicAccess(httperror.LoggerHandler(h.webhookExecute))).Methods(http.MethodPost) + return h } diff --git a/api/internal/authorization/authorizations.go b/api/internal/authorization/authorizations.go index 08d827461..5934b16d2 100644 --- a/api/internal/authorization/authorizations.go +++ b/api/internal/authorization/authorizations.go @@ -432,13 +432,17 @@ func DefaultPortainerAuthorizations() portainer.Authorizations { // UpdateUsersAuthorizations will trigger an update of the authorizations for all the users. func (service *Service) UpdateUsersAuthorizations() error { - users, err := service.dataStore.User().ReadAll() + return service.UpdateUsersAuthorizationsTx(service.dataStore) +} + +func (service *Service) UpdateUsersAuthorizationsTx(tx dataservices.DataStoreTx) error { + users, err := tx.User().ReadAll() if err != nil { return err } for _, user := range users { - err := service.updateUserAuthorizations(user.ID) + err := service.updateUserAuthorizations(tx, user.ID) if err != nil { return err } @@ -447,44 +451,44 @@ func (service *Service) UpdateUsersAuthorizations() error { return nil } -func (service *Service) updateUserAuthorizations(userID portainer.UserID) error { - user, err := service.dataStore.User().Read(userID) +func (service *Service) updateUserAuthorizations(tx dataservices.DataStoreTx, userID portainer.UserID) error { + user, err := tx.User().Read(userID) if err != nil { return err } - endpointAuthorizations, err := service.getAuthorizations(user) + endpointAuthorizations, err := service.getAuthorizations(tx, user) if err != nil { return err } user.EndpointAuthorizations = endpointAuthorizations - return service.dataStore.User().Update(userID, user) + return tx.User().Update(userID, user) } -func (service *Service) getAuthorizations(user *portainer.User) (portainer.EndpointAuthorizations, error) { +func (service *Service) getAuthorizations(tx dataservices.DataStoreTx, user *portainer.User) (portainer.EndpointAuthorizations, error) { endpointAuthorizations := portainer.EndpointAuthorizations{} if user.Role == portainer.AdministratorRole { return endpointAuthorizations, nil } - userMemberships, err := service.dataStore.TeamMembership().TeamMembershipsByUserID(user.ID) + userMemberships, err := tx.TeamMembership().TeamMembershipsByUserID(user.ID) if err != nil { return endpointAuthorizations, err } - endpoints, err := service.dataStore.Endpoint().Endpoints() + endpoints, err := tx.Endpoint().Endpoints() if err != nil { return endpointAuthorizations, err } - endpointGroups, err := service.dataStore.EndpointGroup().ReadAll() + endpointGroups, err := tx.EndpointGroup().ReadAll() if err != nil { return endpointAuthorizations, err } - roles, err := service.dataStore.Role().ReadAll() + roles, err := tx.Role().ReadAll() if err != nil { return endpointAuthorizations, err } @@ -608,8 +612,8 @@ func getAuthorizationsFromRoles(roleIdentifiers []portainer.RoleID, roles []port return authorizations } -func (service *Service) UserIsAdminOrAuthorized(userID portainer.UserID, endpointID portainer.EndpointID, authorizations []portainer.Authorization) (bool, error) { - user, err := service.dataStore.User().Read(userID) +func (service *Service) UserIsAdminOrAuthorized(tx dataservices.DataStoreTx, userID portainer.UserID, endpointID portainer.EndpointID, authorizations []portainer.Authorization) (bool, error) { + user, err := tx.User().Read(userID) if err != nil { return false, err }