diff --git a/api/datastore/backup_test.go b/api/datastore/backup_test.go index 6b7d35bb9..6a44b9df4 100644 --- a/api/datastore/backup_test.go +++ b/api/datastore/backup_test.go @@ -5,15 +5,14 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/database/models" + "github.com/stretchr/testify/require" "github.com/rs/zerolog/log" ) func TestStoreCreation(t *testing.T) { _, store := MustNewTestStore(t, true, true) - if store == nil { - t.Fatal("Expect to create a store") - } + require.NotNil(t, store) v, err := store.VersionService.Version() if err != nil { diff --git a/api/http/errors/tx.go b/api/http/errors/tx.go deleted file mode 100644 index 1db484247..000000000 --- a/api/http/errors/tx.go +++ /dev/null @@ -1,20 +0,0 @@ -package errors - -import ( - "errors" - - httperror "github.com/portainer/portainer/pkg/libhttp/error" -) - -func TxResponse(err error, validResponse func() *httperror.HandlerError) *httperror.HandlerError { - if err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } - - return httperror.InternalServerError("Unexpected error", err) - } - - return validResponse() -} diff --git a/api/http/handler/docker/dashboard.go b/api/http/handler/docker/dashboard.go index f903711ec..7963c52d1 100644 --- a/api/http/handler/docker/dashboard.go +++ b/api/http/handler/docker/dashboard.go @@ -12,7 +12,6 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/docker/stats" - "github.com/portainer/portainer/api/http/errors" "github.com/portainer/portainer/api/http/handler/docker/utils" "github.com/portainer/portainer/api/http/middlewares" "github.com/portainer/portainer/api/http/security" @@ -164,7 +163,5 @@ func (h *Handler) dashboard(w http.ResponseWriter, r *http.Request) *httperror.H return nil }) - return errors.TxResponse(err, func() *httperror.HandlerError { - return response.JSON(w, resp) - }) + return response.TxResponse(w, resp, err) } diff --git a/api/http/handler/edgegroups/edgegroup_create.go b/api/http/handler/edgegroups/edgegroup_create.go index c074bffde..d52ab6757 100644 --- a/api/http/handler/edgegroups/edgegroup_create.go +++ b/api/http/handler/edgegroups/edgegroup_create.go @@ -10,6 +10,7 @@ import ( "github.com/portainer/portainer/api/roar" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" ) type edgeGroupCreatePayload struct { @@ -111,5 +112,5 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request) return nil }) - return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) + return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) } diff --git a/api/http/handler/edgegroups/edgegroup_delete.go b/api/http/handler/edgegroups/edgegroup_delete.go index 46890dc81..37b61a767 100644 --- a/api/http/handler/edgegroups/edgegroup_delete.go +++ b/api/http/handler/edgegroups/edgegroup_delete.go @@ -32,16 +32,8 @@ func (handler *Handler) edgeGroupDelete(w http.ResponseWriter, r *http.Request) err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return deleteEdgeGroup(tx, portainer.EdgeGroupID(edgeGroupID)) }) - if err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func deleteEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) error { diff --git a/api/http/handler/edgegroups/edgegroup_inspect.go b/api/http/handler/edgegroups/edgegroup_inspect.go index 76780ec1d..537abe06d 100644 --- a/api/http/handler/edgegroups/edgegroup_inspect.go +++ b/api/http/handler/edgegroups/edgegroup_inspect.go @@ -8,6 +8,7 @@ import ( "github.com/portainer/portainer/api/roar" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" ) // @id EdgeGroupInspect @@ -36,7 +37,7 @@ func (handler *Handler) edgeGroupInspect(w http.ResponseWriter, r *http.Request) edgeGroup.Endpoints = edgeGroup.EndpointIDs.ToSlice() - return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) + return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) } func getEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) { diff --git a/api/http/handler/edgegroups/edgegroup_list.go b/api/http/handler/edgegroups/edgegroup_list.go index 87de867eb..610da85da 100644 --- a/api/http/handler/edgegroups/edgegroup_list.go +++ b/api/http/handler/edgegroups/edgegroup_list.go @@ -9,6 +9,7 @@ import ( "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/roar" httperror "github.com/portainer/portainer/pkg/libhttp/error" + "github.com/portainer/portainer/pkg/libhttp/response" ) type shadowedEdgeGroup struct { @@ -44,7 +45,7 @@ func (handler *Handler) edgeGroupList(w http.ResponseWriter, r *http.Request) *h return err }) - return txResponse(w, decoratedEdgeGroups, err) + return response.TxResponse(w, decoratedEdgeGroups, err) } func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error) { diff --git a/api/http/handler/edgegroups/edgegroup_update.go b/api/http/handler/edgegroups/edgegroup_update.go index 270bd10df..989ac452c 100644 --- a/api/http/handler/edgegroups/edgegroup_update.go +++ b/api/http/handler/edgegroups/edgegroup_update.go @@ -13,6 +13,7 @@ import ( "github.com/portainer/portainer/api/slicesx" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" ) type edgeGroupUpdatePayload struct { @@ -158,7 +159,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request) return nil }) - return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) + return response.TxResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err) } func (handler *Handler) updateEndpointStacks(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, edgeGroups []portainer.EdgeGroup, edgeStacks []portainer.EdgeStack) error { diff --git a/api/http/handler/edgegroups/handler.go b/api/http/handler/edgegroups/handler.go index 77075c772..b81387252 100644 --- a/api/http/handler/edgegroups/handler.go +++ b/api/http/handler/edgegroups/handler.go @@ -1,14 +1,12 @@ package edgegroups import ( - "errors" "net/http" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" httperror "github.com/portainer/portainer/pkg/libhttp/error" - "github.com/portainer/portainer/pkg/libhttp/response" "github.com/gorilla/mux" ) @@ -38,16 +36,3 @@ func NewHandler(bouncer security.BouncerService) *Handler { return h } - -func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError { - if err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } - - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, r) -} diff --git a/api/http/handler/edgejobs/edgejob_create.go b/api/http/handler/edgejobs/edgejob_create.go index dd6b4d5df..dab6b11f9 100644 --- a/api/http/handler/edgejobs/edgejob_create.go +++ b/api/http/handler/edgejobs/edgejob_create.go @@ -15,6 +15,7 @@ import ( "github.com/portainer/portainer/api/internal/endpointutils" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/validate" ) @@ -85,19 +86,18 @@ func (payload *edgeJobCreateFromFileContentPayload) Validate(r *http.Request) er // @router /edge_jobs/create/string [post] func (handler *Handler) createEdgeJobFromFileContent(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload edgeJobCreateFromFileContentPayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } var edgeJob *portainer.EdgeJob + var err error err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, []byte(payload.FileContent)) - return err }) - return txResponse(w, edgeJob, err) + return response.TxResponse(w, edgeJob, err) } func (handler *Handler) createEdgeJob(tx dataservices.DataStoreTx, payload *edgeJobBasePayload, fileContent []byte) (*portainer.EdgeJob, error) { @@ -191,19 +191,18 @@ func (payload *edgeJobCreateFromFilePayload) Validate(r *http.Request) error { // @router /edge_jobs/create/file [post] func (handler *Handler) createEdgeJobFromFile(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { payload := &edgeJobCreateFromFilePayload{} - err := payload.Validate(r) - if err != nil { + if err := payload.Validate(r); err != nil { return httperror.BadRequest("Invalid request payload", err) } var edgeJob *portainer.EdgeJob + var err error err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { edgeJob, err = handler.createEdgeJob(tx, &payload.edgeJobBasePayload, payload.File) - return err }) - return txResponse(w, edgeJob, err) + return response.TxResponse(w, edgeJob, err) } func (handler *Handler) createEdgeJobObjectFromPayload(tx dataservices.DataStoreTx, payload *edgeJobBasePayload) *portainer.EdgeJob { diff --git a/api/http/handler/edgejobs/edgejob_create_test.go b/api/http/handler/edgejobs/edgejob_create_test.go new file mode 100644 index 000000000..db9d69ff1 --- /dev/null +++ b/api/http/handler/edgejobs/edgejob_create_test.go @@ -0,0 +1,158 @@ +package edgejobs + +import ( + "bytes" + "encoding/json" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gorilla/mux" + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/api/datastore" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockFileService struct { + mock.Mock + portainer.FileService +} + +func (m *mockFileService) StoreEdgeJobFileFromBytes(id string, file []byte) (string, error) { + args := m.Called(id, file) + return args.String(0), args.Error(1) +} + +func (m *mockFileService) GetEdgeJobFolder(id string) string { + args := m.Called(id) + + return args.String(0) +} + +func (m *mockFileService) RemoveDirectory(path string) error { + args := m.Called(path) + + return args.Error(0) +} + +func initStore(t *testing.T) *datastore.Store { + _, store := datastore.MustNewTestStore(t, true, true) + require.NotNil(t, store) + + require.NoError(t, store.UpdateTx(func(tx dataservices.DataStoreTx) error { + require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ + ID: 1, + Name: "endpoint-1", + EdgeID: "edge-id-1", + GroupID: 1, + Type: portainer.EdgeAgentOnDockerEnvironment, + UserTrusted: true, + })) + + require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ + ID: 2, + Name: "endpoint-2", + EdgeID: "edge-id-2", + GroupID: 1, + Type: portainer.EdgeAgentOnDockerEnvironment, + UserTrusted: false, + })) + return nil + })) + + return store +} + +func Test_edgeJobCreate_StringMethod_Success(t *testing.T) { + store := initStore(t) + + fileService := &mockFileService{} + fileService.On("StoreEdgeJobFileFromBytes", mock.Anything, mock.Anything).Return("testfile.txt", nil) + + handler := &Handler{ + DataStore: store, + FileService: fileService, + } + + payload := edgeJobCreateFromFileContentPayload{ + edgeJobBasePayload: edgeJobBasePayload{ + Name: "testjob", + CronExpression: "* * * * *", + Endpoints: []portainer.EndpointID{1, 2}, + }, + FileContent: "echo hello", + } + + body, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPost, "/edge_jobs/create/string", bytes.NewReader(body)) + req = mux.SetURLVars(req, map[string]string{"method": "string"}) + w := httptest.NewRecorder() + + // Call handler + errh := handler.edgeJobCreate(w, req) + require.Nil(t, errh) + require.Equal(t, http.StatusOK, w.Result().StatusCode) + + // Get edge job ID from response + var resp struct { + ID int `json:"Id"` + } + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + edgeJob, err := store.EdgeJob().Read(portainer.EdgeJobID(resp.ID)) + require.NoError(t, err) + + require.Len(t, edgeJob.Endpoints, 2) + require.Contains(t, edgeJob.Endpoints, portainer.EndpointID(1)) +} + +func Test_edgeJobCreate_FileMethod_Success(t *testing.T) { + store := initStore(t) + + fileService := &mockFileService{} + fileService.On("StoreEdgeJobFileFromBytes", mock.Anything, mock.Anything).Return("testfile.txt", nil) + + handler := &Handler{ + DataStore: store, + FileService: fileService, + } + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("Name", "testjob")) + require.NoError(t, writer.WriteField("CronExpression", "* * * * *")) + require.NoError(t, writer.WriteField("Endpoints", "[1,2]")) + + fileWriter, err := writer.CreateFormFile("file", "test.txt") + require.NoError(t, err) + + _, err = io.Copy(fileWriter, strings.NewReader("echo hello")) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/edge_jobs/create/file", &body) + req = mux.SetURLVars(req, map[string]string{"method": "file"}) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + w := httptest.NewRecorder() + handlerErr := handler.edgeJobCreate(w, req) + require.Nil(t, handlerErr) + require.Equal(t, http.StatusOK, w.Result().StatusCode) + + var resp struct { + ID int `json:"Id"` + } + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + edgeJob, err := store.EdgeJob().Read(portainer.EdgeJobID(resp.ID)) + require.NoError(t, err) + + require.Len(t, edgeJob.Endpoints, 2) + require.Contains(t, edgeJob.Endpoints, portainer.EndpointID(1)) +} diff --git a/api/http/handler/edgejobs/edgejob_delete.go b/api/http/handler/edgejobs/edgejob_delete.go index fa78e6a50..6b2650704 100644 --- a/api/http/handler/edgejobs/edgejob_delete.go +++ b/api/http/handler/edgejobs/edgejob_delete.go @@ -1,7 +1,6 @@ package edgejobs import ( - "errors" "maps" "net/http" "strconv" @@ -35,18 +34,11 @@ func (handler *Handler) edgeJobDelete(w http.ResponseWriter, r *http.Request) *h return httperror.BadRequest("Invalid Edge job identifier route variable", err) } - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.deleteEdgeJob(tx, portainer.EdgeJobID(edgeJobID)) - }); err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) deleteEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) error { diff --git a/api/http/handler/edgejobs/edgejob_tasklogs_clear.go b/api/http/handler/edgejobs/edgejob_tasklogs_clear.go index dfa99042f..61b422aa3 100644 --- a/api/http/handler/edgejobs/edgejob_tasklogs_clear.go +++ b/api/http/handler/edgejobs/edgejob_tasklogs_clear.go @@ -1,7 +1,6 @@ package edgejobs import ( - "errors" "net/http" "slices" "strconv" @@ -54,7 +53,7 @@ func (handler *Handler) edgeJobTasksClear(w http.ResponseWriter, r *http.Request } } - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { updateEdgeJobFn := func(edgeJob *portainer.EdgeJob, endpointID portainer.EndpointID, endpointsFromGroups []portainer.EndpointID) error { mutationFn(edgeJob, endpointID, endpointsFromGroups) @@ -62,16 +61,9 @@ func (handler *Handler) edgeJobTasksClear(w http.ResponseWriter, r *http.Request } return handler.clearEdgeJobTaskLogs(tx, portainer.EdgeJobID(edgeJobID), portainer.EndpointID(taskID), updateEdgeJobFn) - }); err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) clearEdgeJobTaskLogs(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, endpointID portainer.EndpointID, updateEdgeJob func(*portainer.EdgeJob, portainer.EndpointID, []portainer.EndpointID) error) error { diff --git a/api/http/handler/edgejobs/edgejob_tasklogs_collect.go b/api/http/handler/edgejobs/edgejob_tasklogs_collect.go index e6bc53e46..6f1e8023c 100644 --- a/api/http/handler/edgejobs/edgejob_tasklogs_collect.go +++ b/api/http/handler/edgejobs/edgejob_tasklogs_collect.go @@ -1,7 +1,6 @@ package edgejobs import ( - "errors" "net/http" "slices" @@ -39,7 +38,7 @@ func (handler *Handler) edgeJobTasksCollect(w http.ResponseWriter, r *http.Reque return httperror.BadRequest("Invalid Task identifier route variable", err) } - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { edgeJob, err := tx.EdgeJob().Read(portainer.EdgeJobID(edgeJobID)) if tx.IsErrObjectNotFound(err) { return httperror.NotFound("Unable to find an Edge job with the specified identifier inside the database", err) @@ -81,14 +80,7 @@ func (handler *Handler) edgeJobTasksCollect(w http.ResponseWriter, r *http.Reque } return nil - }); err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } diff --git a/api/http/handler/edgejobs/edgejob_tasks_list.go b/api/http/handler/edgejobs/edgejob_tasks_list.go index 64d50137b..afde439f5 100644 --- a/api/http/handler/edgejobs/edgejob_tasks_list.go +++ b/api/http/handler/edgejobs/edgejob_tasks_list.go @@ -13,6 +13,7 @@ import ( "github.com/portainer/portainer/api/internal/edge" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" ) type taskContainer struct { @@ -49,31 +50,33 @@ func (handler *Handler) edgeJobTasksList(w http.ResponseWriter, r *http.Request) return err }) - results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{ - SearchAccessors: []filters.SearchAccessor[*taskContainer]{ - func(tc *taskContainer) (string, error) { - switch tc.LogsStatus { - case portainer.EdgeJobLogsStatusPending: - return "pending", nil - case 0, portainer.EdgeJobLogsStatusIdle: - return "idle", nil - case portainer.EdgeJobLogsStatusCollected: - return "collected", nil - } - return "", errors.New("unknown state") + return response.TxFuncResponse(err, func() *httperror.HandlerError { + results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{ + SearchAccessors: []filters.SearchAccessor[*taskContainer]{ + func(tc *taskContainer) (string, error) { + switch tc.LogsStatus { + case portainer.EdgeJobLogsStatusPending: + return "pending", nil + case 0, portainer.EdgeJobLogsStatusIdle: + return "idle", nil + case portainer.EdgeJobLogsStatusCollected: + return "collected", nil + } + return "", errors.New("unknown state") + }, + func(tc *taskContainer) (string, error) { + return tc.EndpointName, nil + }, }, - func(tc *taskContainer) (string, error) { - return tc.EndpointName, nil + SortBindings: []filters.SortBinding[*taskContainer]{ + {Key: "EndpointName", Fn: func(a, b *taskContainer) int { return strings.Compare(a.EndpointName, b.EndpointName) }}, }, - }, - SortBindings: []filters.SortBinding[*taskContainer]{ - {Key: "EndpointName", Fn: func(a, b *taskContainer) int { return strings.Compare(a.EndpointName, b.EndpointName) }}, - }, + }) + + filters.ApplyFilterResultsHeaders(&w, results) + + return response.JSON(w, results.Items) }) - - filters.ApplyFilterResultsHeaders(&w, results) - - return txResponse(w, results.Items, err) } func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) ([]*taskContainer, error) { diff --git a/api/http/handler/edgejobs/edgejob_update.go b/api/http/handler/edgejobs/edgejob_update.go index 6f2b8e382..bd9284d39 100644 --- a/api/http/handler/edgejobs/edgejob_update.go +++ b/api/http/handler/edgejobs/edgejob_update.go @@ -14,6 +14,7 @@ import ( "github.com/portainer/portainer/api/internal/endpointutils" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/validate" ) @@ -66,7 +67,7 @@ func (handler *Handler) edgeJobUpdate(w http.ResponseWriter, r *http.Request) *h return err }) - return txResponse(w, edgeJob, err) + return response.TxResponse(w, edgeJob, err) } func (handler *Handler) updateEdgeJob(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID, payload edgeJobUpdatePayload) (*portainer.EdgeJob, error) { diff --git a/api/http/handler/edgejobs/handler.go b/api/http/handler/edgejobs/handler.go index ab3d66b3b..c91e3dd2a 100644 --- a/api/http/handler/edgejobs/handler.go +++ b/api/http/handler/edgejobs/handler.go @@ -1,14 +1,12 @@ package edgejobs import ( - "errors" "net/http" portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" httperror "github.com/portainer/portainer/pkg/libhttp/error" - "github.com/portainer/portainer/pkg/libhttp/response" "github.com/gorilla/mux" ) @@ -60,16 +58,3 @@ func convertEndpointsToMetaObject(endpoints []portainer.EndpointID) map[portaine return endpointsMap } - -func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError { - if err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } - - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, r) -} diff --git a/api/http/handler/edgestacks/edgestack_delete.go b/api/http/handler/edgestacks/edgestack_delete.go index 0e6307684..9371a68a9 100644 --- a/api/http/handler/edgestacks/edgestack_delete.go +++ b/api/http/handler/edgestacks/edgestack_delete.go @@ -1,7 +1,6 @@ package edgestacks import ( - "errors" "net/http" "strconv" @@ -30,18 +29,11 @@ func (handler *Handler) edgeStackDelete(w http.ResponseWriter, r *http.Request) return httperror.BadRequest("Invalid edge stack identifier route variable", err) } - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.deleteEdgeStack(tx, portainer.EdgeStackID(edgeStackID)) - }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) deleteEdgeStack(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID) error { diff --git a/api/http/handler/edgestacks/edgestack_status_update.go b/api/http/handler/edgestacks/edgestack_status_update.go index 0ff6a9eff..c773b1b6e 100644 --- a/api/http/handler/edgestacks/edgestack_status_update.go +++ b/api/http/handler/edgestacks/edgestack_status_update.go @@ -96,12 +96,7 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req return nil }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - - return httperror.InternalServerError("Unexpected error", err) + return response.TxErrorResponse(err) } if ok, _ := strconv.ParseBool(r.Header.Get("X-Portainer-No-Body")); ok { diff --git a/api/http/handler/edgestacks/edgestack_update.go b/api/http/handler/edgestacks/edgestack_update.go index 7dc8915f6..d859a0c05 100644 --- a/api/http/handler/edgestacks/edgestack_update.go +++ b/api/http/handler/edgestacks/edgestack_update.go @@ -66,12 +66,7 @@ func (handler *Handler) edgeStackUpdate(w http.ResponseWriter, r *http.Request) stack, err = handler.updateEdgeStack(tx, portainer.EdgeStackID(stackID), payload) return err }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - - return httperror.InternalServerError("Unexpected error", err) + return response.TxErrorResponse(err) } if err := fillEdgeStackStatus(handler.DataStore, stack); err != nil { diff --git a/api/http/handler/endpointgroups/endpointgroup_create.go b/api/http/handler/endpointgroups/endpointgroup_create.go index dde16910f..59f00abe3 100644 --- a/api/http/handler/endpointgroups/endpointgroup_create.go +++ b/api/http/handler/endpointgroups/endpointgroup_create.go @@ -49,26 +49,18 @@ func (payload *endpointGroupCreatePayload) Validate(r *http.Request) error { // @router /endpoint_groups [post] func (handler *Handler) endpointGroupCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload endpointGroupCreatePayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } var endpointGroup *portainer.EndpointGroup + var err error err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { endpointGroup, err = handler.createEndpointGroup(tx, payload) return err }) - if err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, endpointGroup) + return response.TxResponse(w, endpointGroup, err) } func (handler *Handler) createEndpointGroup(tx dataservices.DataStoreTx, payload endpointGroupCreatePayload) (*portainer.EndpointGroup, error) { diff --git a/api/http/handler/endpointgroups/endpointgroup_delete.go b/api/http/handler/endpointgroups/endpointgroup_delete.go index b71778446..974771bad 100644 --- a/api/http/handler/endpointgroups/endpointgroup_delete.go +++ b/api/http/handler/endpointgroups/endpointgroup_delete.go @@ -37,16 +37,8 @@ func (handler *Handler) endpointGroupDelete(w http.ResponseWriter, r *http.Reque err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.deleteEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID)) }) - if err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) deleteEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID) error { diff --git a/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go b/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go index f64787566..ba66564af 100644 --- a/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go +++ b/api/http/handler/endpointgroups/endpointgroup_endpoint_add.go @@ -1,7 +1,6 @@ package endpointgroups import ( - "errors" "net/http" portainer "github.com/portainer/portainer/api" @@ -39,16 +38,8 @@ func (handler *Handler) endpointGroupAddEndpoint(w http.ResponseWriter, r *http. err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.addEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) }) - if err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) addEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error { diff --git a/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go b/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go index e5d4e69d4..b37e092d9 100644 --- a/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go +++ b/api/http/handler/endpointgroups/endpointgroup_endpoint_delete.go @@ -1,7 +1,6 @@ package endpointgroups import ( - "errors" "net/http" portainer "github.com/portainer/portainer/api" @@ -38,16 +37,8 @@ func (handler *Handler) endpointGroupDeleteEndpoint(w http.ResponseWriter, r *ht err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.removeEndpoint(tx, portainer.EndpointGroupID(endpointGroupID), portainer.EndpointID(endpointID)) }) - if err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) removeEndpoint(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, endpointID portainer.EndpointID) error { diff --git a/api/http/handler/endpointgroups/endpointgroup_update.go b/api/http/handler/endpointgroups/endpointgroup_update.go index b50af9044..df6f2fb90 100644 --- a/api/http/handler/endpointgroups/endpointgroup_update.go +++ b/api/http/handler/endpointgroups/endpointgroup_update.go @@ -1,7 +1,6 @@ package endpointgroups import ( - "errors" "net/http" "reflect" @@ -61,20 +60,12 @@ func (handler *Handler) endpointGroupUpdate(w http.ResponseWriter, r *http.Reque var endpointGroup *portainer.EndpointGroup - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { endpointGroup, err = handler.updateEndpointGroup(tx, portainer.EndpointGroupID(endpointGroupID), payload) - return err - }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, endpointGroup) + return response.TxResponse(w, endpointGroup, err) } func (handler *Handler) updateEndpointGroup(tx dataservices.DataStoreTx, endpointGroupID portainer.EndpointGroupID, payload endpointGroupUpdatePayload) (*portainer.EndpointGroup, error) { diff --git a/api/http/handler/endpoints/endpoint_delete.go b/api/http/handler/endpoints/endpoint_delete.go index a9b4ae5dc..83728bf1c 100644 --- a/api/http/handler/endpoints/endpoint_delete.go +++ b/api/http/handler/endpoints/endpoint_delete.go @@ -62,18 +62,11 @@ func (handler *Handler) endpointDelete(w http.ResponseWriter, r *http.Request) * return httperror.BadRequest("Invalid boolean query parameter", err) } - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.deleteEndpoint(tx, portainer.EndpointID(endpointID), deleteCluster) - }); err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } // @id EndpointDeleteBatch diff --git a/api/http/handler/endpoints/endpoint_registries_list.go b/api/http/handler/endpoints/endpoint_registries_list.go index 5bc4a930d..49fd2f37d 100644 --- a/api/http/handler/endpoints/endpoint_registries_list.go +++ b/api/http/handler/endpoints/endpoint_registries_list.go @@ -35,19 +35,12 @@ func (handler *Handler) endpointRegistriesList(w http.ResponseWriter, r *http.Re } var registries []portainer.Registry - if err := handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { + err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { registries, err = handler.listRegistries(tx, r, portainer.EndpointID(endpointID)) return err - }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, registries) + return response.TxResponse(w, registries, err) } func (handler *Handler) listRegistries(tx dataservices.DataStoreTx, r *http.Request, endpointID portainer.EndpointID) ([]portainer.Registry, error) { diff --git a/api/http/handler/endpoints/endpoint_registry_access.go b/api/http/handler/endpoints/endpoint_registry_access.go index b931ffa62..a29e06094 100644 --- a/api/http/handler/endpoints/endpoint_registry_access.go +++ b/api/http/handler/endpoints/endpoint_registry_access.go @@ -1,7 +1,6 @@ package endpoints import ( - "errors" "net/http" portainer "github.com/portainer/portainer/api" @@ -53,16 +52,8 @@ func (handler *Handler) endpointRegistryAccess(w http.ResponseWriter, r *http.Re err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return handler.updateRegistryAccess(tx, r, portainer.EndpointID(endpointID), portainer.RegistryID(registryID)) }) - if err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func (handler *Handler) updateRegistryAccess(tx dataservices.DataStoreTx, r *http.Request, endpointID portainer.EndpointID, registryID portainer.RegistryID) error { diff --git a/api/http/handler/settings/settings_update.go b/api/http/handler/settings/settings_update.go index 98da8da7d..97e930236 100644 --- a/api/http/handler/settings/settings_update.go +++ b/api/http/handler/settings/settings_update.go @@ -128,12 +128,7 @@ func (handler *Handler) settingsUpdate(w http.ResponseWriter, r *http.Request) * return err }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } - - return httperror.InternalServerError("Unexpected error", err) + return response.TxErrorResponse(err) } hideFields(settings) diff --git a/api/http/handler/tags/handler.go b/api/http/handler/tags/handler.go index 5a1bd7623..3dbac0cc9 100644 --- a/api/http/handler/tags/handler.go +++ b/api/http/handler/tags/handler.go @@ -1,13 +1,11 @@ package tags import ( - "errors" "net/http" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/http/security" httperror "github.com/portainer/portainer/pkg/libhttp/error" - "github.com/portainer/portainer/pkg/libhttp/response" "github.com/gorilla/mux" ) @@ -32,16 +30,3 @@ func NewHandler(bouncer security.BouncerService) *Handler { return h } - -func txResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError { - if err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } - - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, r) -} diff --git a/api/http/handler/tags/tag_create.go b/api/http/handler/tags/tag_create.go index d2daa5f85..60d72fcf4 100644 --- a/api/http/handler/tags/tag_create.go +++ b/api/http/handler/tags/tag_create.go @@ -8,6 +8,7 @@ import ( "github.com/portainer/portainer/api/dataservices" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" + "github.com/portainer/portainer/pkg/libhttp/response" ) type tagCreatePayload struct { @@ -38,18 +39,18 @@ func (payload *tagCreatePayload) Validate(r *http.Request) error { // @router /tags [post] func (handler *Handler) tagCreate(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { var payload tagCreatePayload - err := request.DecodeAndValidateJSONPayload(r, &payload) - if err != nil { + if err := request.DecodeAndValidateJSONPayload(r, &payload); err != nil { return httperror.BadRequest("Invalid request payload", err) } var tag *portainer.Tag + var err error err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { tag, err = createTag(tx, payload) return err }) - return txResponse(w, tag, err) + return response.TxResponse(w, tag, err) } func createTag(tx dataservices.DataStoreTx, payload tagCreatePayload) (*portainer.Tag, error) { diff --git a/api/http/handler/tags/tag_delete.go b/api/http/handler/tags/tag_delete.go index f8f1b7786..afc612726 100644 --- a/api/http/handler/tags/tag_delete.go +++ b/api/http/handler/tags/tag_delete.go @@ -1,7 +1,6 @@ package tags import ( - "errors" "net/http" "slices" @@ -37,16 +36,8 @@ func (handler *Handler) tagDelete(w http.ResponseWriter, r *http.Request) *httpe err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { return deleteTag(tx, portainer.TagID(id)) }) - if err != nil { - var handlerError *httperror.HandlerError - if errors.As(err, &handlerError) { - return handlerError - } - return httperror.InternalServerError("Unexpected error", err) - } - - return response.Empty(w) + return response.TxEmptyResponse(w, err) } func deleteTag(tx dataservices.DataStoreTx, tagID portainer.TagID) error { diff --git a/api/http/handler/teams/team_create.go b/api/http/handler/teams/team_create.go index 130cc1d37..f0731964a 100644 --- a/api/http/handler/teams/team_create.go +++ b/api/http/handler/teams/team_create.go @@ -48,22 +48,13 @@ func (handler *Handler) teamCreate(w http.ResponseWriter, r *http.Request) *http } var team *portainer.Team - - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { - var err error + var err error + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { team, err = createTeam(tx, payload) - return err - }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, team) + return response.TxResponse(w, team, err) } func createTeam(tx dataservices.DataStoreTx, payload teamCreatePayload) (*portainer.Team, error) { diff --git a/api/http/handler/users/user_create.go b/api/http/handler/users/user_create.go index a5afb1b8b..daa2f6ca0 100644 --- a/api/http/handler/users/user_create.go +++ b/api/http/handler/users/user_create.go @@ -55,22 +55,13 @@ func (handler *Handler) userCreate(w http.ResponseWriter, r *http.Request) *http } var user *portainer.User - - if err := handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { - var err error + var err error + err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error { user, err = handler.createUser(tx, payload) - return err - }); err != nil { - var httpErr *httperror.HandlerError - if errors.As(err, &httpErr) { - return httpErr - } + }) - return httperror.InternalServerError("Unexpected error", err) - } - - return response.JSON(w, user) + return response.TxResponse(w, user, err) } func (handler *Handler) createUser(tx dataservices.DataStoreTx, payload userCreatePayload) (*portainer.User, error) { diff --git a/pkg/libhttp/response/txresponse.go b/pkg/libhttp/response/txresponse.go new file mode 100644 index 000000000..631dfde78 --- /dev/null +++ b/pkg/libhttp/response/txresponse.go @@ -0,0 +1,47 @@ +package response + +import ( + "errors" + "net/http" + + httperror "github.com/portainer/portainer/pkg/libhttp/error" +) + +func TxResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError { + return TxFuncResponse(err, func() *httperror.HandlerError { return JSON(w, r) }) +} + +func TxEmptyResponse(w http.ResponseWriter, err error) *httperror.HandlerError { + if err != nil { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return Empty(w) +} + +func TxFuncResponse(err error, validResponse func() *httperror.HandlerError) *httperror.HandlerError { + if err != nil { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return validResponse() +} + +func TxErrorResponse(err error) *httperror.HandlerError { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) +} diff --git a/pkg/libhttp/response/txresponse_test.go b/pkg/libhttp/response/txresponse_test.go new file mode 100644 index 000000000..5094b4c4f --- /dev/null +++ b/pkg/libhttp/response/txresponse_test.go @@ -0,0 +1,86 @@ +package response + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + httperrors "github.com/portainer/portainer/api/http/errors" + httperror "github.com/portainer/portainer/pkg/libhttp/error" + "github.com/stretchr/testify/require" +) + +func TestTxResponse(t *testing.T) { + type sample struct { + Name string `json:"name"` + } + + w := httptest.NewRecorder() + got := TxResponse(w, sample{Name: "Alice"}, nil) + require.Nil(t, got) + require.Equal(t, http.StatusOK, w.Result().StatusCode) + + w = httptest.NewRecorder() + got = TxResponse(w, sample{}, httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + w = httptest.NewRecorder() + got = TxResponse(w, sample{}, errors.New("Some error")) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +} + +func TestTxEmptyResponse(t *testing.T) { + w := httptest.NewRecorder() + got := TxEmptyResponse(w, nil) + require.Nil(t, got) + require.Equal(t, http.StatusNoContent, w.Result().StatusCode) + + w = httptest.NewRecorder() + got = TxEmptyResponse(w, httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + w = httptest.NewRecorder() + got = TxEmptyResponse(w, errors.New("Some error")) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +} + +func TestTxFuncResponse(t *testing.T) { + got := TxFuncResponse(nil, func() *httperror.HandlerError { return nil }) + require.Nil(t, got) + + got = TxFuncResponse(httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied), func() *httperror.HandlerError { return nil }) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + got = TxFuncResponse(errors.New("Some error"), func() *httperror.HandlerError { return nil }) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +} + +func TestTxErrorResponse(t *testing.T) { + got := TxErrorResponse(nil) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) + + got = TxErrorResponse(httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + got = TxErrorResponse(errors.New("Some error")) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +}