From 1f24320fa7101af7c7ff0dca6ab195f05683aa0c Mon Sep 17 00:00:00 2001 From: Anthony Lapenna Date: Tue, 24 Jul 2018 14:47:19 +0200 Subject: [PATCH] fix(api): fix endpoint snapshot process at endpoint creation time (#2072) * fix(api): fix endpoint snapshot process at endpoint creation time * refactor(api): remove comments --- api/bolt/endpoint/endpoint.go | 12 +++++-- api/cmd/portainer/main.go | 33 +++++++++++++++---- api/http/handler/endpoints/endpoint_create.go | 21 +++++------- api/portainer.go | 1 + 4 files changed, 46 insertions(+), 21 deletions(-) diff --git a/api/bolt/endpoint/endpoint.go b/api/bolt/endpoint/endpoint.go index 79672eb36..6dfff48df 100644 --- a/api/bolt/endpoint/endpoint.go +++ b/api/bolt/endpoint/endpoint.go @@ -82,8 +82,11 @@ func (service *Service) CreateEndpoint(endpoint *portainer.Endpoint) error { return service.db.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(BucketName)) - id, _ := bucket.NextSequence() - endpoint.ID = portainer.EndpointID(id) + // We manually manage sequences for endpoints + err := bucket.SetSequence(uint64(endpoint.ID)) + if err != nil { + return err + } data, err := internal.MarshalObject(endpoint) if err != nil { @@ -94,6 +97,11 @@ func (service *Service) CreateEndpoint(endpoint *portainer.Endpoint) error { }) } +// GetNextIdentifier returns the next identifier for an endpoint. +func (service *Service) GetNextIdentifier() int { + return internal.GetNextIdentifier(service.db, BucketName) +} + // Synchronize creates, updates and deletes endpoints inside a single transaction. func (service *Service) Synchronize(toCreate, toUpdate, toDelete []*portainer.Endpoint) error { return service.db.Update(func(tx *bolt.Tx) error { diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 567ff4c9e..53e9a8d75 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -281,7 +281,7 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D return generateAndStoreKeyPair(fileService, signatureService) } -func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService) error { +func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error { tlsConfiguration := portainer.TLSConfiguration{ TLS: *flags.TLS, TLSSkipVerify: *flags.TLSSkipVerify, @@ -295,7 +295,9 @@ func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portain tlsConfiguration.TLS = true } + endpointID := endpointService.GetNextIdentifier() endpoint := &portainer.Endpoint{ + ID: portainer.EndpointID(endpointID), Name: "primary", URL: *flags.EndpointURL, GroupID: portainer.EndpointGroupID(1), @@ -325,10 +327,10 @@ func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portain } } - return endpointService.CreateEndpoint(endpoint) + return snapshotAndPersistEndpoint(endpoint, endpointService, snapshotter) } -func createUnsecuredEndpoint(endpointURL string, endpointService portainer.EndpointService) error { +func createUnsecuredEndpoint(endpointURL string, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error { if strings.HasPrefix(endpointURL, "tcp://") { _, err := client.ExecutePingOperation(endpointURL, nil) if err != nil { @@ -336,7 +338,9 @@ func createUnsecuredEndpoint(endpointURL string, endpointService portainer.Endpo } } + endpointID := endpointService.GetNextIdentifier() endpoint := &portainer.Endpoint{ + ID: portainer.EndpointID(endpointID), Name: "primary", URL: endpointURL, GroupID: portainer.EndpointGroupID(1), @@ -350,10 +354,25 @@ func createUnsecuredEndpoint(endpointURL string, endpointService portainer.Endpo Snapshots: []portainer.Snapshot{}, } + return snapshotAndPersistEndpoint(endpoint, endpointService, snapshotter) +} + +func snapshotAndPersistEndpoint(endpoint *portainer.Endpoint, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error { + snapshot, err := snapshotter.CreateSnapshot(endpoint) + endpoint.Status = portainer.EndpointStatusUp + if err != nil { + log.Printf("http error: endpoint snapshot error (endpoint=%s, URL=%s) (err=%s)\n", endpoint.Name, endpoint.URL, err) + endpoint.Status = portainer.EndpointStatusDown + } + + if snapshot != nil { + endpoint.Snapshots = []portainer.Snapshot{*snapshot} + } + return endpointService.CreateEndpoint(endpoint) } -func initEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService) error { +func initEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error { if *flags.EndpointURL == "" { return nil } @@ -369,9 +388,9 @@ func initEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointS } if *flags.TLS || *flags.TLSSkipVerify { - return createTLSSecuredEndpoint(flags, endpointService) + return createTLSSecuredEndpoint(flags, endpointService, snapshotter) } - return createUnsecuredEndpoint(*flags.EndpointURL, endpointService) + return createUnsecuredEndpoint(*flags.EndpointURL, endpointService, snapshotter) } func main() { @@ -437,7 +456,7 @@ func main() { applicationStatus := initStatus(endpointManagement, *flags.Snapshot, flags) - err = initEndpoint(flags, store.EndpointService) + err = initEndpoint(flags, store.EndpointService, snapshotter) if err != nil { log.Fatal(err) } diff --git a/api/http/handler/endpoints/endpoint_create.go b/api/http/handler/endpoints/endpoint_create.go index 9ac29b649..914cbdcdc 100644 --- a/api/http/handler/endpoints/endpoint_create.go +++ b/api/http/handler/endpoints/endpoint_create.go @@ -167,7 +167,9 @@ func (handler *Handler) createAzureEndpoint(payload *endpointCreatePayload) (*po return nil, &httperror.HandlerError{http.StatusInternalServerError, "Unable to authenticate against Azure", err} } + endpointID := handler.EndpointService.GetNextIdentifier() endpoint := &portainer.Endpoint{ + ID: portainer.EndpointID(endpointID), Name: payload.Name, URL: payload.URL, Type: portainer.AzureEnvironment, @@ -208,7 +210,9 @@ func (handler *Handler) createUnsecuredEndpoint(payload *endpointCreatePayload) } } + endpointID := handler.EndpointService.GetNextIdentifier() endpoint := &portainer.Endpoint{ + ID: portainer.EndpointID(endpointID), Name: payload.Name, URL: payload.URL, Type: endpointType, @@ -249,7 +253,9 @@ func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload) endpointType = portainer.AgentOnDockerEnvironment } + endpointID := handler.EndpointService.GetNextIdentifier() endpoint := &portainer.Endpoint{ + ID: portainer.EndpointID(endpointID), Name: payload.Name, URL: payload.URL, Type: endpointType, @@ -267,20 +273,14 @@ func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload) Snapshots: []portainer.Snapshot{}, } - endpointCreationError := handler.snapshotAndPersistEndpoint(endpoint) - if endpointCreationError != nil { - return nil, endpointCreationError - } - filesystemError := handler.storeTLSFiles(endpoint, payload) if err != nil { - handler.EndpointService.DeleteEndpoint(endpoint.ID) return nil, filesystemError } - err = handler.EndpointService.UpdateEndpoint(endpoint.ID, endpoint) - if err != nil { - return nil, &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist endpoint changes inside the database", err} + endpointCreationError := handler.snapshotAndPersistEndpoint(endpoint) + if endpointCreationError != nil { + return nil, endpointCreationError } return endpoint, nil @@ -312,7 +312,6 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end if !payload.TLSSkipVerify { caCertPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCA, payload.TLSCACertFile) if err != nil { - handler.EndpointService.DeleteEndpoint(endpoint.ID) return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS CA certificate file on disk", err} } endpoint.TLSConfig.TLSCACertPath = caCertPath @@ -321,14 +320,12 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end if !payload.TLSSkipClientVerify { certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile) if err != nil { - handler.EndpointService.DeleteEndpoint(endpoint.ID) return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS certificate file on disk", err} } endpoint.TLSConfig.TLSCertPath = certPath keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile) if err != nil { - handler.EndpointService.DeleteEndpoint(endpoint.ID) return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS key file on disk", err} } endpoint.TLSConfig.TLSKeyPath = keyPath diff --git a/api/portainer.go b/api/portainer.go index d2a75e345..9c6e0f21e 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -453,6 +453,7 @@ type ( UpdateEndpoint(ID EndpointID, endpoint *Endpoint) error DeleteEndpoint(ID EndpointID) error Synchronize(toCreate, toUpdate, toDelete []*Endpoint) error + GetNextIdentifier() int } // EndpointGroupService represents a service for managing endpoint group data.