fix(api): fix endpoint snapshot process at endpoint creation time (#2072)

* fix(api): fix endpoint snapshot process at endpoint creation time

* refactor(api): remove comments
pull/2077/head
Anthony Lapenna 2018-07-24 14:47:19 +02:00 committed by GitHub
parent 1cf77bf9e9
commit 1f24320fa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 21 deletions

View File

@ -82,8 +82,11 @@ func (service *Service) CreateEndpoint(endpoint *portainer.Endpoint) error {
return service.db.Update(func(tx *bolt.Tx) error { return service.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(BucketName)) bucket := tx.Bucket([]byte(BucketName))
id, _ := bucket.NextSequence() // We manually manage sequences for endpoints
endpoint.ID = portainer.EndpointID(id) err := bucket.SetSequence(uint64(endpoint.ID))
if err != nil {
return err
}
data, err := internal.MarshalObject(endpoint) data, err := internal.MarshalObject(endpoint)
if err != nil { if err != nil {
@ -94,6 +97,11 @@ func (service *Service) CreateEndpoint(endpoint *portainer.Endpoint) error {
}) })
} }
// GetNextIdentifier returns the next identifier for an endpoint.
func (service *Service) GetNextIdentifier() int {
return internal.GetNextIdentifier(service.db, BucketName)
}
// Synchronize creates, updates and deletes endpoints inside a single transaction. // Synchronize creates, updates and deletes endpoints inside a single transaction.
func (service *Service) Synchronize(toCreate, toUpdate, toDelete []*portainer.Endpoint) error { func (service *Service) Synchronize(toCreate, toUpdate, toDelete []*portainer.Endpoint) error {
return service.db.Update(func(tx *bolt.Tx) error { return service.db.Update(func(tx *bolt.Tx) error {

View File

@ -281,7 +281,7 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
return generateAndStoreKeyPair(fileService, signatureService) return generateAndStoreKeyPair(fileService, signatureService)
} }
func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService) error { func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error {
tlsConfiguration := portainer.TLSConfiguration{ tlsConfiguration := portainer.TLSConfiguration{
TLS: *flags.TLS, TLS: *flags.TLS,
TLSSkipVerify: *flags.TLSSkipVerify, TLSSkipVerify: *flags.TLSSkipVerify,
@ -295,7 +295,9 @@ func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portain
tlsConfiguration.TLS = true tlsConfiguration.TLS = true
} }
endpointID := endpointService.GetNextIdentifier()
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(endpointID),
Name: "primary", Name: "primary",
URL: *flags.EndpointURL, URL: *flags.EndpointURL,
GroupID: portainer.EndpointGroupID(1), GroupID: portainer.EndpointGroupID(1),
@ -325,10 +327,10 @@ func createTLSSecuredEndpoint(flags *portainer.CLIFlags, endpointService portain
} }
} }
return endpointService.CreateEndpoint(endpoint) return snapshotAndPersistEndpoint(endpoint, endpointService, snapshotter)
} }
func createUnsecuredEndpoint(endpointURL string, endpointService portainer.EndpointService) error { func createUnsecuredEndpoint(endpointURL string, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error {
if strings.HasPrefix(endpointURL, "tcp://") { if strings.HasPrefix(endpointURL, "tcp://") {
_, err := client.ExecutePingOperation(endpointURL, nil) _, err := client.ExecutePingOperation(endpointURL, nil)
if err != nil { if err != nil {
@ -336,7 +338,9 @@ func createUnsecuredEndpoint(endpointURL string, endpointService portainer.Endpo
} }
} }
endpointID := endpointService.GetNextIdentifier()
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(endpointID),
Name: "primary", Name: "primary",
URL: endpointURL, URL: endpointURL,
GroupID: portainer.EndpointGroupID(1), GroupID: portainer.EndpointGroupID(1),
@ -350,10 +354,25 @@ func createUnsecuredEndpoint(endpointURL string, endpointService portainer.Endpo
Snapshots: []portainer.Snapshot{}, Snapshots: []portainer.Snapshot{},
} }
return snapshotAndPersistEndpoint(endpoint, endpointService, snapshotter)
}
func snapshotAndPersistEndpoint(endpoint *portainer.Endpoint, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error {
snapshot, err := snapshotter.CreateSnapshot(endpoint)
endpoint.Status = portainer.EndpointStatusUp
if err != nil {
log.Printf("http error: endpoint snapshot error (endpoint=%s, URL=%s) (err=%s)\n", endpoint.Name, endpoint.URL, err)
endpoint.Status = portainer.EndpointStatusDown
}
if snapshot != nil {
endpoint.Snapshots = []portainer.Snapshot{*snapshot}
}
return endpointService.CreateEndpoint(endpoint) return endpointService.CreateEndpoint(endpoint)
} }
func initEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService) error { func initEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointService, snapshotter portainer.Snapshotter) error {
if *flags.EndpointURL == "" { if *flags.EndpointURL == "" {
return nil return nil
} }
@ -369,9 +388,9 @@ func initEndpoint(flags *portainer.CLIFlags, endpointService portainer.EndpointS
} }
if *flags.TLS || *flags.TLSSkipVerify { if *flags.TLS || *flags.TLSSkipVerify {
return createTLSSecuredEndpoint(flags, endpointService) return createTLSSecuredEndpoint(flags, endpointService, snapshotter)
} }
return createUnsecuredEndpoint(*flags.EndpointURL, endpointService) return createUnsecuredEndpoint(*flags.EndpointURL, endpointService, snapshotter)
} }
func main() { func main() {
@ -437,7 +456,7 @@ func main() {
applicationStatus := initStatus(endpointManagement, *flags.Snapshot, flags) applicationStatus := initStatus(endpointManagement, *flags.Snapshot, flags)
err = initEndpoint(flags, store.EndpointService) err = initEndpoint(flags, store.EndpointService, snapshotter)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -167,7 +167,9 @@ func (handler *Handler) createAzureEndpoint(payload *endpointCreatePayload) (*po
return nil, &httperror.HandlerError{http.StatusInternalServerError, "Unable to authenticate against Azure", err} return nil, &httperror.HandlerError{http.StatusInternalServerError, "Unable to authenticate against Azure", err}
} }
endpointID := handler.EndpointService.GetNextIdentifier()
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(endpointID),
Name: payload.Name, Name: payload.Name,
URL: payload.URL, URL: payload.URL,
Type: portainer.AzureEnvironment, Type: portainer.AzureEnvironment,
@ -208,7 +210,9 @@ func (handler *Handler) createUnsecuredEndpoint(payload *endpointCreatePayload)
} }
} }
endpointID := handler.EndpointService.GetNextIdentifier()
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(endpointID),
Name: payload.Name, Name: payload.Name,
URL: payload.URL, URL: payload.URL,
Type: endpointType, Type: endpointType,
@ -249,7 +253,9 @@ func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload)
endpointType = portainer.AgentOnDockerEnvironment endpointType = portainer.AgentOnDockerEnvironment
} }
endpointID := handler.EndpointService.GetNextIdentifier()
endpoint := &portainer.Endpoint{ endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(endpointID),
Name: payload.Name, Name: payload.Name,
URL: payload.URL, URL: payload.URL,
Type: endpointType, Type: endpointType,
@ -267,20 +273,14 @@ func (handler *Handler) createTLSSecuredEndpoint(payload *endpointCreatePayload)
Snapshots: []portainer.Snapshot{}, Snapshots: []portainer.Snapshot{},
} }
endpointCreationError := handler.snapshotAndPersistEndpoint(endpoint)
if endpointCreationError != nil {
return nil, endpointCreationError
}
filesystemError := handler.storeTLSFiles(endpoint, payload) filesystemError := handler.storeTLSFiles(endpoint, payload)
if err != nil { if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return nil, filesystemError return nil, filesystemError
} }
err = handler.EndpointService.UpdateEndpoint(endpoint.ID, endpoint) endpointCreationError := handler.snapshotAndPersistEndpoint(endpoint)
if err != nil { if endpointCreationError != nil {
return nil, &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist endpoint changes inside the database", err} return nil, endpointCreationError
} }
return endpoint, nil return endpoint, nil
@ -312,7 +312,6 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
if !payload.TLSSkipVerify { if !payload.TLSSkipVerify {
caCertPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCA, payload.TLSCACertFile) caCertPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCA, payload.TLSCACertFile)
if err != nil { if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS CA certificate file on disk", err} return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS CA certificate file on disk", err}
} }
endpoint.TLSConfig.TLSCACertPath = caCertPath endpoint.TLSConfig.TLSCACertPath = caCertPath
@ -321,14 +320,12 @@ func (handler *Handler) storeTLSFiles(endpoint *portainer.Endpoint, payload *end
if !payload.TLSSkipClientVerify { if !payload.TLSSkipClientVerify {
certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile) certPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileCert, payload.TLSCertFile)
if err != nil { if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS certificate file on disk", err} return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS certificate file on disk", err}
} }
endpoint.TLSConfig.TLSCertPath = certPath endpoint.TLSConfig.TLSCertPath = certPath
keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile) keyPath, err := handler.FileService.StoreTLSFileFromBytes(folder, portainer.TLSFileKey, payload.TLSKeyFile)
if err != nil { if err != nil {
handler.EndpointService.DeleteEndpoint(endpoint.ID)
return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS key file on disk", err} return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist TLS key file on disk", err}
} }
endpoint.TLSConfig.TLSKeyPath = keyPath endpoint.TLSConfig.TLSKeyPath = keyPath

View File

@ -453,6 +453,7 @@ type (
UpdateEndpoint(ID EndpointID, endpoint *Endpoint) error UpdateEndpoint(ID EndpointID, endpoint *Endpoint) error
DeleteEndpoint(ID EndpointID) error DeleteEndpoint(ID EndpointID) error
Synchronize(toCreate, toUpdate, toDelete []*Endpoint) error Synchronize(toCreate, toUpdate, toDelete []*Endpoint) error
GetNextIdentifier() int
} }
// EndpointGroupService represents a service for managing endpoint group data. // EndpointGroupService represents a service for managing endpoint group data.