portainer/api/datastore/postinit/migrate_post_init_test.go

175 lines
5.0 KiB
Go

package postinit
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/pendingactions/actions"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMigrateGPUs(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/containers/json") {
containerSummary := []container.Summary{{ID: "container1"}}
if err := json.NewEncoder(w).Encode(containerSummary); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
}
container := container.InspectResponse{
ContainerJSONBase: &container.ContainerJSONBase{
ID: "container1",
HostConfig: &container.HostConfig{
Resources: container.Resources{
DeviceRequests: []container.DeviceRequest{
{Driver: "nvidia"},
},
},
},
},
}
if err := json.NewEncoder(w).Encode(container); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
}))
defer srv.Close()
_, store := datastore.MustNewTestStore(t, true, false)
migrator := &PostInitMigrator{dataStore: store}
dockerCli, err := client.NewClientWithOpts(client.WithHost(srv.URL), client.WithHTTPClient(http.DefaultClient))
require.NoError(t, err)
// Nonexistent endpoint
err = migrator.MigrateGPUs(portainer.Endpoint{}, dockerCli)
require.Error(t, err)
// Valid endpoint
endpoint := portainer.Endpoint{ID: 1, PostInitMigrations: portainer.EndpointPostInitMigrations{MigrateGPUs: true}}
err = store.Endpoint().Create(&endpoint)
require.NoError(t, err)
err = migrator.MigrateGPUs(endpoint, dockerCli)
require.NoError(t, err)
migratedEndpoint, err := store.Endpoint().Endpoint(endpoint.ID)
require.NoError(t, err)
require.Equal(t, endpoint.ID, migratedEndpoint.ID)
require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs)
require.True(t, migratedEndpoint.EnableGPUManagement)
}
func TestPostInitMigrate_PendingActionsCreated(t *testing.T) {
tests := []struct {
name string
existingPendingActions []*portainer.PendingAction
expectedPendingActions int
expectedAction string
}{
{
name: "when existing non-matching action exists, should add migration action",
existingPendingActions: []*portainer.PendingAction{
{
EndpointID: 7,
Action: "some-other-action",
},
},
expectedPendingActions: 2,
expectedAction: actions.PostInitMigrateEnvironment,
},
{
name: "when matching action exists, should not add duplicate",
existingPendingActions: []*portainer.PendingAction{
{
EndpointID: 7,
Action: actions.PostInitMigrateEnvironment,
},
},
expectedPendingActions: 1,
expectedAction: actions.PostInitMigrateEnvironment,
},
{
name: "when no actions exist, should add migration action",
existingPendingActions: []*portainer.PendingAction{},
expectedPendingActions: 1,
expectedAction: actions.PostInitMigrateEnvironment,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
is := assert.New(t)
_, store := datastore.MustNewTestStore(t, true, true)
// Create test endpoint
endpoint := &portainer.Endpoint{
ID: 7,
UserTrusted: true,
Type: portainer.EdgeAgentOnDockerEnvironment,
Edge: portainer.EnvironmentEdgeSettings{
AsyncMode: false,
},
EdgeID: "edgeID",
}
err := store.Endpoint().Create(endpoint)
require.NoError(t, err, "error creating endpoint")
// Create any existing pending actions
for _, action := range tt.existingPendingActions {
err = store.PendingActions().Create(action)
require.NoError(t, err, "error creating pending action")
}
migrator := NewPostInitMigrator(
nil, // kubeFactory not needed for this test
nil, // dockerFactory not needed for this test
store,
"", // assetsPath not needed for this test
nil, // kubernetesDeployer not needed for this test
)
err = migrator.PostInitMigrate()
require.NoError(t, err, "PostInitMigrate should not return error")
// Verify the results
pendingActions, err := store.PendingActions().ReadAll()
require.NoError(t, err, "error reading pending actions")
is.Len(pendingActions, tt.expectedPendingActions, "unexpected number of pending actions")
// If we expect any actions, verify at least one has the expected action type
if tt.expectedPendingActions > 0 {
hasExpectedAction := false
for _, action := range pendingActions {
if action.Action == tt.expectedAction {
hasExpectedAction = true
is.Equal(endpoint.ID, action.EndpointID, "action should reference correct endpoint")
break
}
}
is.True(hasExpectedAction, "should have found action of expected type")
}
})
}
}