diff --git a/api/datastore/pendingactions_test.go b/api/datastore/pendingactions_test.go new file mode 100644 index 000000000..c67d59e21 --- /dev/null +++ b/api/datastore/pendingactions_test.go @@ -0,0 +1,95 @@ +package datastore + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/pendingactions/actions" +) + +func Test_ConvertCleanNAPWithOverridePoliciesPayload(t *testing.T) { + t.Run("test ConvertCleanNAPWithOverridePoliciesPayload", func(t *testing.T) { + + _, store := MustNewTestStore(t, true, false) + defer store.Close() + + testData := []struct { + Name string + PendingAction portainer.PendingActions + Expected *actions.CleanNAPWithOverridePoliciesPayload + Err bool + }{ + { + Name: "test actiondata with EndpointGroupID 1", + PendingAction: portainer.PendingActions{ + EndpointID: 1, + Action: "CleanNAPWithOverridePolicies", + ActionData: &actions.CleanNAPWithOverridePoliciesPayload{ + EndpointGroupID: 1, + }, + }, + Expected: &actions.CleanNAPWithOverridePoliciesPayload{ + EndpointGroupID: 1, + }, + }, + { + Name: "test actionData nil", + PendingAction: portainer.PendingActions{ + EndpointID: 2, + Action: "CleanNAPWithOverridePolicies", + ActionData: nil, + }, + Expected: nil, + }, + { + Name: "test actionData empty and expected error", + PendingAction: portainer.PendingActions{ + EndpointID: 2, + Action: "CleanNAPWithOverridePolicies", + ActionData: "", + }, + Expected: nil, + Err: true, + }, + } + + for _, d := range testData { + err := store.PendingActions().Create(&d.PendingAction) + if err != nil { + t.Error(err) + return + } + + pendingActions, err := store.PendingActions().ReadAll() + if err != nil { + t.Error(err) + return + } + + for _, endpointPendingAction := range pendingActions { + t.Run(d.Name, func(t *testing.T) { + if endpointPendingAction.Action == "CleanNAPWithOverridePolicies" { + actionData, err := actions.ConvertCleanNAPWithOverridePoliciesPayload(endpointPendingAction.ActionData) + if d.Err && err == nil { + t.Error(err) + } + + if d.Expected == nil && actionData != nil { + t.Errorf("expected nil , got %d", actionData) + } + + if d.Expected != nil && actionData == nil { + t.Errorf("expected not nil , got %d", actionData) + } + + if d.Expected != nil && actionData.EndpointGroupID != d.Expected.EndpointGroupID { + t.Errorf("expected EndpointGroupID %d , got %d", d.Expected.EndpointGroupID, actionData.EndpointGroupID) + } + } + }) + } + + store.PendingActions().Delete(d.PendingAction.ID) + } + }) +} diff --git a/api/http/handler/endpointgroups/endpointgroup_update.go b/api/http/handler/endpointgroups/endpointgroup_update.go index bf7410cfd..3cda7028e 100644 --- a/api/http/handler/endpointgroups/endpointgroup_update.go +++ b/api/http/handler/endpointgroups/endpointgroup_update.go @@ -8,6 +8,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/internal/tag" + pendingActionActions "github.com/portainer/portainer/api/pendingactions/actions" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" @@ -159,7 +160,9 @@ func (handler *Handler) updateEndpointGroup(tx dataservices.DataStoreTx, endpoin err := handler.PendingActionsService.Create(portainer.PendingActions{ EndpointID: endpointID, Action: "CleanNAPWithOverridePolicies", - ActionData: endpointGroupID, + ActionData: &pendingActionActions.CleanNAPWithOverridePoliciesPayload{ + EndpointGroupID: endpointGroupID, + }, }) if err != nil { log.Error().Err(err).Msgf("Unable to create pending action to clean NAP with override policies for endpoint (%d) and endpoint group (%d).", endpointID, endpointGroupID) diff --git a/api/pendingactions/actions/converters.go b/api/pendingactions/actions/converters.go new file mode 100644 index 000000000..f48a99932 --- /dev/null +++ b/api/pendingactions/actions/converters.go @@ -0,0 +1,44 @@ +package actions + +import ( + "fmt" + + portainer "github.com/portainer/portainer/api" +) + +type ( + CleanNAPWithOverridePoliciesPayload struct { + EndpointGroupID portainer.EndpointGroupID + } +) + +func ConvertCleanNAPWithOverridePoliciesPayload(actionData interface{}) (*CleanNAPWithOverridePoliciesPayload, error) { + var payload CleanNAPWithOverridePoliciesPayload + + if actionData == nil { + return nil, nil + } + + // backward compatible with old data format + if endpointGroupId, ok := actionData.(float64); ok { + payload.EndpointGroupID = portainer.EndpointGroupID(endpointGroupId) + return &payload, nil + } + + data, ok := actionData.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("failed to convert actionData to map[string]interface{}") + + } + + for key, value := range data { + switch key { + case "EndpointGroupID": + if endpointGroupID, ok := value.(float64); ok { + payload.EndpointGroupID = portainer.EndpointGroupID(endpointGroupID) + } + } + } + + return &payload, nil +} diff --git a/api/pendingactions/pendingactions.go b/api/pendingactions/pendingactions.go index 5e9a26706..1ad1f5bef 100644 --- a/api/pendingactions/pendingactions.go +++ b/api/pendingactions/pendingactions.go @@ -117,12 +117,18 @@ func (service *PendingActionsService) executePendingAction(pendingAction portain switch pendingAction.Action { case actions.CleanNAPWithOverridePolicies: - if (pendingAction.ActionData == nil) || (pendingAction.ActionData.(portainer.EndpointGroupID) == 0) { + pendingActionData, err := actions.ConvertCleanNAPWithOverridePoliciesPayload(pendingAction.ActionData) + if err != nil { + return fmt.Errorf("failed to parse pendingActionData for CleanNAPWithOverridePoliciesPayload") + } + + if pendingActionData == nil || pendingActionData.EndpointGroupID == 0 { service.authorizationService.CleanNAPWithOverridePolicies(service.dataStore, endpoint, nil) return nil } - endpointGroupID := pendingAction.ActionData.(portainer.EndpointGroupID) + endpointGroupID := pendingActionData.EndpointGroupID + endpointGroup, err := service.dataStore.EndpointGroup().Read(portainer.EndpointGroupID(endpointGroupID)) if err != nil { log.Error().Err(err).Msgf("Error reading environment group to clean NAP with override policies for environment %d and environment group %d", endpoint.ID, endpointGroup.ID)