fix(pending-action): pending action data format [EE-7064] (#11766)

pull/11777/head
Prabhat Khera 2024-05-06 15:46:51 +12:00 committed by GitHub
parent e75e6cb7f7
commit f22aed34b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 151 additions and 3 deletions

View File

@ -0,0 +1,95 @@
package datastore
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/pendingactions/actions"
)
func Test_ConvertCleanNAPWithOverridePoliciesPayload(t *testing.T) {
t.Run("test ConvertCleanNAPWithOverridePoliciesPayload", func(t *testing.T) {
_, store := MustNewTestStore(t, true, false)
defer store.Close()
testData := []struct {
Name string
PendingAction portainer.PendingActions
Expected *actions.CleanNAPWithOverridePoliciesPayload
Err bool
}{
{
Name: "test actiondata with EndpointGroupID 1",
PendingAction: portainer.PendingActions{
EndpointID: 1,
Action: "CleanNAPWithOverridePolicies",
ActionData: &actions.CleanNAPWithOverridePoliciesPayload{
EndpointGroupID: 1,
},
},
Expected: &actions.CleanNAPWithOverridePoliciesPayload{
EndpointGroupID: 1,
},
},
{
Name: "test actionData nil",
PendingAction: portainer.PendingActions{
EndpointID: 2,
Action: "CleanNAPWithOverridePolicies",
ActionData: nil,
},
Expected: nil,
},
{
Name: "test actionData empty and expected error",
PendingAction: portainer.PendingActions{
EndpointID: 2,
Action: "CleanNAPWithOverridePolicies",
ActionData: "",
},
Expected: nil,
Err: true,
},
}
for _, d := range testData {
err := store.PendingActions().Create(&d.PendingAction)
if err != nil {
t.Error(err)
return
}
pendingActions, err := store.PendingActions().ReadAll()
if err != nil {
t.Error(err)
return
}
for _, endpointPendingAction := range pendingActions {
t.Run(d.Name, func(t *testing.T) {
if endpointPendingAction.Action == "CleanNAPWithOverridePolicies" {
actionData, err := actions.ConvertCleanNAPWithOverridePoliciesPayload(endpointPendingAction.ActionData)
if d.Err && err == nil {
t.Error(err)
}
if d.Expected == nil && actionData != nil {
t.Errorf("expected nil , got %d", actionData)
}
if d.Expected != nil && actionData == nil {
t.Errorf("expected not nil , got %d", actionData)
}
if d.Expected != nil && actionData.EndpointGroupID != d.Expected.EndpointGroupID {
t.Errorf("expected EndpointGroupID %d , got %d", d.Expected.EndpointGroupID, actionData.EndpointGroupID)
}
}
})
}
store.PendingActions().Delete(d.PendingAction.ID)
}
})
}

View File

@ -8,6 +8,7 @@ import (
portainer "github.com/portainer/portainer/api" portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices" "github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/tag" "github.com/portainer/portainer/api/internal/tag"
pendingActionActions "github.com/portainer/portainer/api/pendingactions/actions"
httperror "github.com/portainer/portainer/pkg/libhttp/error" httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/request"
"github.com/portainer/portainer/pkg/libhttp/response" "github.com/portainer/portainer/pkg/libhttp/response"
@ -159,7 +160,9 @@ func (handler *Handler) updateEndpointGroup(tx dataservices.DataStoreTx, endpoin
err := handler.PendingActionsService.Create(portainer.PendingActions{ err := handler.PendingActionsService.Create(portainer.PendingActions{
EndpointID: endpointID, EndpointID: endpointID,
Action: "CleanNAPWithOverridePolicies", Action: "CleanNAPWithOverridePolicies",
ActionData: endpointGroupID, ActionData: &pendingActionActions.CleanNAPWithOverridePoliciesPayload{
EndpointGroupID: endpointGroupID,
},
}) })
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Unable to create pending action to clean NAP with override policies for endpoint (%d) and endpoint group (%d).", endpointID, endpointGroupID) log.Error().Err(err).Msgf("Unable to create pending action to clean NAP with override policies for endpoint (%d) and endpoint group (%d).", endpointID, endpointGroupID)

View File

@ -0,0 +1,44 @@
package actions
import (
"fmt"
portainer "github.com/portainer/portainer/api"
)
type (
CleanNAPWithOverridePoliciesPayload struct {
EndpointGroupID portainer.EndpointGroupID
}
)
func ConvertCleanNAPWithOverridePoliciesPayload(actionData interface{}) (*CleanNAPWithOverridePoliciesPayload, error) {
var payload CleanNAPWithOverridePoliciesPayload
if actionData == nil {
return nil, nil
}
// backward compatible with old data format
if endpointGroupId, ok := actionData.(float64); ok {
payload.EndpointGroupID = portainer.EndpointGroupID(endpointGroupId)
return &payload, nil
}
data, ok := actionData.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("failed to convert actionData to map[string]interface{}")
}
for key, value := range data {
switch key {
case "EndpointGroupID":
if endpointGroupID, ok := value.(float64); ok {
payload.EndpointGroupID = portainer.EndpointGroupID(endpointGroupID)
}
}
}
return &payload, nil
}

View File

@ -117,12 +117,18 @@ func (service *PendingActionsService) executePendingAction(pendingAction portain
switch pendingAction.Action { switch pendingAction.Action {
case actions.CleanNAPWithOverridePolicies: case actions.CleanNAPWithOverridePolicies:
if (pendingAction.ActionData == nil) || (pendingAction.ActionData.(portainer.EndpointGroupID) == 0) { pendingActionData, err := actions.ConvertCleanNAPWithOverridePoliciesPayload(pendingAction.ActionData)
if err != nil {
return fmt.Errorf("failed to parse pendingActionData for CleanNAPWithOverridePoliciesPayload")
}
if pendingActionData == nil || pendingActionData.EndpointGroupID == 0 {
service.authorizationService.CleanNAPWithOverridePolicies(service.dataStore, endpoint, nil) service.authorizationService.CleanNAPWithOverridePolicies(service.dataStore, endpoint, nil)
return nil return nil
} }
endpointGroupID := pendingAction.ActionData.(portainer.EndpointGroupID) endpointGroupID := pendingActionData.EndpointGroupID
endpointGroup, err := service.dataStore.EndpointGroup().Read(portainer.EndpointGroupID(endpointGroupID)) endpointGroup, err := service.dataStore.EndpointGroup().Read(portainer.EndpointGroupID(endpointGroupID))
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Error reading environment group to clean NAP with override policies for environment %d and environment group %d", endpoint.ID, endpointGroup.ID) log.Error().Err(err).Msgf("Error reading environment group to clean NAP with override policies for environment %d and environment group %d", endpoint.ID, endpointGroup.ID)