From 20e3d3a15b01f401b7fcec8382462d60863bc9e1 Mon Sep 17 00:00:00 2001 From: Anthony Lapenna Date: Mon, 25 Nov 2024 11:03:12 +1300 Subject: [PATCH] fix: review snapshot and post init migration logic (#158) --- api/datastore/postinit/migrate_post_init.go | 31 ++-- api/internal/endpointutils/endpointutils.go | 2 + api/internal/snapshot/snapshot.go | 4 +- pkg/endpoints/utils.go | 20 +++ pkg/endpoints/utils_test.go | 160 ++++++++++++++++++++ 5 files changed, 206 insertions(+), 11 deletions(-) create mode 100644 pkg/endpoints/utils.go create mode 100644 pkg/endpoints/utils_test.go diff --git a/api/datastore/postinit/migrate_post_init.go b/api/datastore/postinit/migrate_post_init.go index c7dc11cae..8967f4257 100644 --- a/api/datastore/postinit/migrate_post_init.go +++ b/api/datastore/postinit/migrate_post_init.go @@ -11,6 +11,7 @@ import ( "github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/kubernetes/cli" "github.com/portainer/portainer/api/pendingactions/actions" + "github.com/portainer/portainer/pkg/endpoints" "github.com/rs/zerolog/log" ) @@ -49,17 +50,29 @@ func (postInitMigrator *PostInitMigrator) PostInitMigrate() error { for _, environment := range environments { // edge environments will run after the server starts, in pending actions - if endpointutils.IsEdgeEndpoint(&environment) { - log.Info().Msgf("Adding pending action 'PostInitMigrateEnvironment' for environment %d", environment.ID) - err = postInitMigrator.createPostInitMigrationPendingAction(environment.ID) - if err != nil { - log.Error().Err(err).Msgf("Error creating pending action for environment %d", environment.ID) + if endpoints.IsEdgeEndpoint(&environment) { + // Skip edge environments that do not have direct connectivity + if !endpoints.HasDirectConnectivity(&environment) { + continue + } + + log.Info(). + Int("endpoint_id", int(environment.ID)). + Msg("adding pending action 'PostInitMigrateEnvironment' for environment") + + if err := postInitMigrator.createPostInitMigrationPendingAction(environment.ID); err != nil { + log.Error(). + Err(err). + Int("endpoint_id", int(environment.ID)). + Msg("error creating pending action for environment") } } else { - // non-edge environments will run before the server starts. - err = postInitMigrator.MigrateEnvironment(&environment) - if err != nil { - log.Error().Err(err).Msgf("Error running post-init migrations for non-edge environment %d", environment.ID) + // Non-edge environments will run before the server starts. + if err := postInitMigrator.MigrateEnvironment(&environment); err != nil { + log.Error(). + Err(err). + Int("endpoint_id", int(environment.ID)). + Msg("error running post-init migrations for non-edge environment") } } diff --git a/api/internal/endpointutils/endpointutils.go b/api/internal/endpointutils/endpointutils.go index 1ab31e01d..6b7eb1c2d 100644 --- a/api/internal/endpointutils/endpointutils.go +++ b/api/internal/endpointutils/endpointutils.go @@ -11,6 +11,8 @@ import ( log "github.com/rs/zerolog/log" ) +// TODO: this file should be migrated to package/server-ce/pkg/endpoints + // IsLocalEndpoint returns true if this is a local environment(endpoint) func IsLocalEndpoint(endpoint *portainer.Endpoint) bool { return strings.HasPrefix(endpoint.URL, "unix://") || diff --git a/api/internal/snapshot/snapshot.go b/api/internal/snapshot/snapshot.go index d51c5b440..6c1aff49b 100644 --- a/api/internal/snapshot/snapshot.go +++ b/api/internal/snapshot/snapshot.go @@ -10,8 +10,8 @@ import ( "github.com/portainer/portainer/api/agent" "github.com/portainer/portainer/api/crypto" "github.com/portainer/portainer/api/dataservices" - "github.com/portainer/portainer/api/internal/endpointutils" "github.com/portainer/portainer/api/pendingactions" + endpointsutils "github.com/portainer/portainer/pkg/endpoints" "github.com/rs/zerolog/log" ) @@ -64,7 +64,7 @@ func NewBackgroundSnapshotter(dataStore dataservices.DataStore, tunnelService po } for _, e := range endpoints { - if !endpointutils.IsEdgeEndpoint(&e) || e.Edge.AsyncMode || !e.UserTrusted { + if !endpointsutils.HasDirectConnectivity(&e) { continue } diff --git a/pkg/endpoints/utils.go b/pkg/endpoints/utils.go new file mode 100644 index 000000000..8c4ad7ce8 --- /dev/null +++ b/pkg/endpoints/utils.go @@ -0,0 +1,20 @@ +package endpoints + +import portainer "github.com/portainer/portainer/api" + +// IsEdgeEndpoint returns true if this is an Edge endpoint +func IsEdgeEndpoint(endpoint *portainer.Endpoint) bool { + return endpoint.Type == portainer.EdgeAgentOnDockerEnvironment || endpoint.Type == portainer.EdgeAgentOnKubernetesEnvironment +} + +// IsAssociatedEdgeEndpoint returns true if the environment is an Edge environment +// and has a set EdgeID and UserTrusted is true. +func IsAssociatedEdgeEndpoint(endpoint *portainer.Endpoint) bool { + return IsEdgeEndpoint(endpoint) && endpoint.EdgeID != "" && endpoint.UserTrusted +} + +// HasDirectConnectivity returns true if the environment is a non-Edge environment +// or is an associated Edge environment that is not in async mode. +func HasDirectConnectivity(endpoint *portainer.Endpoint) bool { + return !IsEdgeEndpoint(endpoint) || (IsAssociatedEdgeEndpoint(endpoint) && !endpoint.Edge.AsyncMode) +} diff --git a/pkg/endpoints/utils_test.go b/pkg/endpoints/utils_test.go new file mode 100644 index 000000000..b69686ffe --- /dev/null +++ b/pkg/endpoints/utils_test.go @@ -0,0 +1,160 @@ +package endpoints + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/stretchr/testify/assert" +) + +func TestIsEdgeEndpoint(t *testing.T) { + tests := []struct { + name string + endpoint *portainer.Endpoint + expected bool + }{ + { + name: "EdgeAgentOnDockerEnvironment", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + }, + expected: true, + }, + { + name: "EdgeAgentOnKubernetesEnvironment", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnKubernetesEnvironment, + }, + expected: true, + }, + { + name: "NonEdgeEnvironment", + endpoint: &portainer.Endpoint{ + Type: portainer.DockerEnvironment, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsEdgeEndpoint(tt.endpoint) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsAssociatedEdgeEndpoint(t *testing.T) { + tests := []struct { + name string + endpoint *portainer.Endpoint + expected bool + }{ + { + name: "AssociatedEdgeEndpoint", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: true, + }, + expected: true, + }, + { + name: "NonAssociatedEdgeEndpoint", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "", + UserTrusted: true, + }, + expected: false, + }, + { + name: "EdgeEndpointInWaitingRoom", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: false, + }, + expected: false, + }, + { + name: "NonEdgeEnvironment", + endpoint: &portainer.Endpoint{ + Type: portainer.DockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: true, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsAssociatedEdgeEndpoint(tt.endpoint) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHasDirectConnectivity(t *testing.T) { + tests := []struct { + name string + endpoint *portainer.Endpoint + expected bool + }{ + { + name: "NonEdgeEnvironment", + endpoint: &portainer.Endpoint{ + Type: portainer.DockerEnvironment, + }, + expected: true, + }, + { + name: "AssociatedEdgeEndpoint", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: true, + Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, + }, + expected: true, + }, + { + name: "AssociatedAsyncEdgeEndpoint", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: true, + Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, + }, + expected: false, + }, + { + name: "EdgeEndpointInWaitingRoom", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: false, + Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, + }, + expected: false, + }, + { + name: "AsyncEdgeEndpointInWaitingRoom", + endpoint: &portainer.Endpoint{ + Type: portainer.EdgeAgentOnDockerEnvironment, + EdgeID: "some-edge-id", + UserTrusted: false, + Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasDirectConnectivity(tt.endpoint) + assert.Equal(t, tt.expected, result) + }) + } +}