diff --git a/api/dataservices/edgestack/edgestack.go b/api/dataservices/edgestack/edgestack.go index 7f45d6146..69f4d5ea2 100644 --- a/api/dataservices/edgestack/edgestack.go +++ b/api/dataservices/edgestack/edgestack.go @@ -82,12 +82,22 @@ func (service *Service) Create(id portainer.EdgeStackID, edgeStack *portainer.Ed ) } -// UpdateEdgeStack updates an Edge stack. +// Deprecated: Use UpdateEdgeStackFunc instead. func (service *Service) UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error { identifier := service.connection.ConvertToKey(int(ID)) return service.connection.UpdateObject(BucketName, identifier, edgeStack) } +// UpdateEdgeStackFunc updates an Edge stack inside a transaction avoiding data races. +func (service *Service) UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error { + id := service.connection.ConvertToKey(int(ID)) + edgeStack := &portainer.EdgeStack{} + + return service.connection.UpdateObjectFunc(BucketName, id, edgeStack, func() { + updateFunc(edgeStack) + }) +} + // DeleteEdgeStack deletes an Edge stack. func (service *Service) DeleteEdgeStack(ID portainer.EdgeStackID) error { identifier := service.connection.ConvertToKey(int(ID)) diff --git a/api/dataservices/interface.go b/api/dataservices/interface.go index b6e2ed6ec..c6b0955a4 100644 --- a/api/dataservices/interface.go +++ b/api/dataservices/interface.go @@ -89,6 +89,7 @@ type ( EdgeStack(ID portainer.EdgeStackID) (*portainer.EdgeStack, error) Create(id portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error UpdateEdgeStack(ID portainer.EdgeStackID, edgeStack *portainer.EdgeStack) error + UpdateEdgeStackFunc(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error DeleteEdgeStack(ID portainer.EdgeStackID) error GetNextIdentifier() int BucketName() string diff --git a/api/http/handler/edgestacks/edgestack_status_update.go b/api/http/handler/edgestacks/edgestack_status_update.go index bbc7f46e6..001e0b551 100644 --- a/api/http/handler/edgestacks/edgestack_status_update.go +++ b/api/http/handler/edgestacks/edgestack_status_update.go @@ -4,11 +4,12 @@ import ( "errors" "net/http" - "github.com/asaskevich/govalidator" - httperror "github.com/portainer/libhttp/error" "github.com/portainer/libhttp/request" "github.com/portainer/libhttp/response" portainer "github.com/portainer/portainer/api" + + "github.com/asaskevich/govalidator" + httperror "github.com/portainer/libhttp/error" ) type updateStatusPayload struct { @@ -52,13 +53,6 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req return httperror.BadRequest("Invalid stack identifier route variable", err) } - stack, err := handler.DataStore.EdgeStack().EdgeStack(portainer.EdgeStackID(stackID)) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.NotFound("Unable to find a stack with the specified identifier inside the database", err) - } else if err != nil { - return httperror.InternalServerError("Unable to find a stack with the specified identifier inside the database", err) - } - var payload updateStatusPayload err = request.DecodeAndValidateJSONPayload(r, &payload) if err != nil { @@ -77,17 +71,22 @@ func (handler *Handler) edgeStackStatusUpdate(w http.ResponseWriter, r *http.Req return httperror.Forbidden("Permission denied to access environment", err) } - stack.Status[payload.EndpointID] = portainer.EdgeStackStatus{ - Type: *payload.Status, - Error: payload.Error, - EndpointID: payload.EndpointID, - } + var stack portainer.EdgeStack - err = handler.DataStore.EdgeStack().UpdateEdgeStack(stack.ID, stack) - if err != nil { + err = handler.DataStore.EdgeStack().UpdateEdgeStackFunc(portainer.EdgeStackID(stackID), func(edgeStack *portainer.EdgeStack) { + edgeStack.Status[payload.EndpointID] = portainer.EdgeStackStatus{ + Type: *payload.Status, + Error: payload.Error, + EndpointID: payload.EndpointID, + } + + stack = *edgeStack + }) + if handler.DataStore.IsErrObjectNotFound(err) { + return httperror.NotFound("Unable to find a stack with the specified identifier inside the database", err) + } else if err != nil { return httperror.InternalServerError("Unable to persist the stack changes inside the database", err) } return response.JSON(w, stack) - } diff --git a/api/http/handler/edgestacks/edgestack_test.go b/api/http/handler/edgestacks/edgestack_test.go index f9e76d193..b013b58c2 100644 --- a/api/http/handler/edgestacks/edgestack_test.go +++ b/api/http/handler/edgestacks/edgestack_test.go @@ -103,10 +103,9 @@ func setupHandler(t *testing.T) (*Handler, string, func()) { return handler, rawAPIKey, storeTeardown } -func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoint { +func createEndpointWithId(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.Endpoint { t.Helper() - endpointID := portainer.EndpointID(5) endpoint := portainer.Endpoint{ ID: endpointID, Name: "test-endpoint-" + strconv.Itoa(int(endpointID)), @@ -124,6 +123,10 @@ func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoi return endpoint } +func createEndpoint(t *testing.T, store dataservices.DataStore) portainer.Endpoint { + return createEndpointWithId(t, store, 5) +} + func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID portainer.EndpointID) portainer.EdgeStack { t.Helper() @@ -203,7 +206,7 @@ func TestInspectInvalidEdgeID(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != tc.ExpectedStatusCode { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) + t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code) } }) } @@ -263,7 +266,7 @@ func TestCreateAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } data := portainer.EdgeStack{} @@ -283,7 +286,7 @@ func TestCreateAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } data = portainer.EdgeStack{} @@ -293,7 +296,7 @@ func TestCreateAndInspect(t *testing.T) { } if payload.Name != data.Name { - t.Fatalf(fmt.Sprintf("expected EdgeStack Name %s, found %s", payload.Name, data.Name)) + t.Fatalf("expected EdgeStack Name %s, found %s", payload.Name, data.Name) } } @@ -421,7 +424,7 @@ func TestCreateWithInvalidPayload(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != tc.ExpectedStatusCode { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) + t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code) } }) } @@ -447,7 +450,7 @@ func TestDeleteAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } data := portainer.EdgeStack{} @@ -457,7 +460,7 @@ func TestDeleteAndInspect(t *testing.T) { } if data.ID != edgeStack.ID { - t.Fatalf(fmt.Sprintf("expected EdgeStackID %d, found %d", int(edgeStack.ID), data.ID)) + t.Fatalf("expected EdgeStackID %d, found %d", int(edgeStack.ID), data.ID) } // Delete @@ -471,7 +474,7 @@ func TestDeleteAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusNoContent, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusNoContent, rec.Code) } // Inspect @@ -485,7 +488,7 @@ func TestDeleteAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusNotFound, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusNotFound, rec.Code) } } @@ -514,7 +517,7 @@ func TestDeleteInvalidEdgeStack(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != tc.ExpectedStatusCode { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) + t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code) } }) } @@ -530,14 +533,7 @@ func TestUpdateAndInspect(t *testing.T) { // Update edge stack: create new Endpoint, EndpointRelation and EdgeGroup endpointID := portainer.EndpointID(6) - newEndpoint := portainer.Endpoint{ - ID: endpointID, - Name: "test-endpoint-" + strconv.Itoa(int(endpointID)), - Type: portainer.EdgeAgentOnDockerEnvironment, - URL: "https://portainer.io:9443", - EdgeID: "edge-id", - LastCheckInDate: time.Now().Unix(), - } + newEndpoint := createEndpointWithId(t, handler.DataStore, endpointID) err := handler.DataStore.Endpoint().Create(&newEndpoint) if err != nil { @@ -594,7 +590,7 @@ func TestUpdateAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } // Get updated edge stack @@ -608,7 +604,7 @@ func TestUpdateAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } data := portainer.EdgeStack{} @@ -618,11 +614,11 @@ func TestUpdateAndInspect(t *testing.T) { } if data.Version != *payload.Version { - t.Fatalf(fmt.Sprintf("expected EdgeStackID %d, found %d", edgeStack.Version, data.Version)) + t.Fatalf("expected EdgeStackID %d, found %d", edgeStack.Version, data.Version) } if data.DeploymentType != payload.DeploymentType { - t.Fatalf(fmt.Sprintf("expected DeploymentType %d, found %d", edgeStack.DeploymentType, data.DeploymentType)) + t.Fatalf("expected DeploymentType %d, found %d", edgeStack.DeploymentType, data.DeploymentType) } if !reflect.DeepEqual(data.EdgeGroups, payload.EdgeGroups) { @@ -705,7 +701,7 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != tc.ExpectedStatusCode { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) + t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code) } }) } @@ -764,7 +760,7 @@ func TestUpdateWithInvalidPayload(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != tc.ExpectedStatusCode { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) + t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code) } }) } @@ -802,7 +798,7 @@ func TestUpdateStatusAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } // Get updated edge stack @@ -816,7 +812,7 @@ func TestUpdateStatusAndInspect(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } data := portainer.EdgeStack{} @@ -826,15 +822,15 @@ func TestUpdateStatusAndInspect(t *testing.T) { } if data.Status[endpoint.ID].Type != *payload.Status { - t.Fatalf(fmt.Sprintf("expected EdgeStackStatusType %d, found %d", payload.Status, data.Status[endpoint.ID].Type)) + t.Fatalf("expected EdgeStackStatusType %d, found %d", payload.Status, data.Status[endpoint.ID].Type) } if data.Status[endpoint.ID].Error != payload.Error { - t.Fatalf(fmt.Sprintf("expected EdgeStackStatusError %s, found %s", payload.Error, data.Status[endpoint.ID].Error)) + t.Fatalf("expected EdgeStackStatusError %s, found %s", payload.Error, data.Status[endpoint.ID].Error) } if data.Status[endpoint.ID].EndpointID != payload.EndpointID { - t.Fatalf(fmt.Sprintf("expected EndpointID %d, found %d", payload.EndpointID, data.Status[endpoint.ID].EndpointID)) + t.Fatalf("expected EndpointID %d, found %d", payload.EndpointID, data.Status[endpoint.ID].EndpointID) } } func TestUpdateStatusWithInvalidPayload(t *testing.T) { @@ -903,7 +899,7 @@ func TestUpdateStatusWithInvalidPayload(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != tc.ExpectedStatusCode { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code)) + t.Fatalf("expected a %d response, found: %d", tc.ExpectedStatusCode, rec.Code) } }) } @@ -927,6 +923,6 @@ func TestDeleteStatus(t *testing.T) { handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { - t.Fatalf(fmt.Sprintf("expected a %d response, found: %d", http.StatusOK, rec.Code)) + t.Fatalf("expected a %d response, found: %d", http.StatusOK, rec.Code) } }