From f129bf3e97ecf48bd17535a3900fc4a82411d48c Mon Sep 17 00:00:00 2001 From: Anthony Lapenna Date: Tue, 7 Feb 2017 16:26:12 +1300 Subject: [PATCH] refactor(api): refactor --- api/bolt/endpoint_service.go | 48 +++++++-------- api/cli/cli.go | 56 +++++++++++++---- api/cmd/portainer/main.go | 114 ++++++++++++++++++++++------------- api/cron/endpoint_sync.go | 7 ++- api/cron/watcher.go | 8 ++- 5 files changed, 152 insertions(+), 81 deletions(-) diff --git a/api/bolt/endpoint_service.go b/api/bolt/endpoint_service.go index df086bea7..cb8349076 100644 --- a/api/bolt/endpoint_service.go +++ b/api/bolt/endpoint_service.go @@ -73,27 +73,14 @@ func (service *EndpointService) Synchronize(toCreate, toUpdate, toDelete []*port bucket := tx.Bucket([]byte(endpointBucketName)) for _, endpoint := range toCreate { - id, _ := bucket.NextSequence() - endpoint.ID = portainer.EndpointID(id) - - data, err := internal.MarshalEndpoint(endpoint) - if err != nil { - return err - } - - err = bucket.Put(internal.Itob(int(endpoint.ID)), data) + err := storeNewEndpoint(endpoint, bucket) if err != nil { return err } } for _, endpoint := range toUpdate { - data, err := internal.MarshalEndpoint(endpoint) - if err != nil { - return err - } - - err = bucket.Put(internal.Itob(int(endpoint.ID)), data) + err := marshalAndStoreEndpoint(endpoint, bucket) if err != nil { return err } @@ -114,16 +101,7 @@ func (service *EndpointService) Synchronize(toCreate, toUpdate, toDelete []*port func (service *EndpointService) CreateEndpoint(endpoint *portainer.Endpoint) error { return service.store.db.Update(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(endpointBucketName)) - - id, _ := bucket.NextSequence() - endpoint.ID = portainer.EndpointID(id) - - data, err := internal.MarshalEndpoint(endpoint) - if err != nil { - return err - } - - err = bucket.Put(internal.Itob(int(endpoint.ID)), data) + err := storeNewEndpoint(endpoint, bucket) if err != nil { return err } @@ -215,3 +193,23 @@ func (service *EndpointService) DeleteActive() error { return nil }) } + +func marshalAndStoreEndpoint(endpoint *portainer.Endpoint, bucket *bolt.Bucket) error { + data, err := internal.MarshalEndpoint(endpoint) + if err != nil { + return err + } + + err = bucket.Put(internal.Itob(int(endpoint.ID)), data) + if err != nil { + return err + } + return nil +} + +func storeNewEndpoint(endpoint *portainer.Endpoint, bucket *bolt.Bucket) error { + id, _ := bucket.NextSequence() + endpoint.ID = portainer.EndpointID(id) + + return marshalAndStoreEndpoint(endpoint, bucket) +} diff --git a/api/cli/cli.go b/api/cli/cli.go index 36c459225..20baa79f4 100644 --- a/api/cli/cli.go +++ b/api/cli/cli.go @@ -15,10 +15,11 @@ import ( type Service struct{} const ( - errInvalidEnpointProtocol = portainer.Error("Invalid endpoint protocol: Portainer only supports unix:// or tcp://") - errSocketNotFound = portainer.Error("Unable to locate Unix socket") - errEndpointsFileNotFound = portainer.Error("Unable to locate external endpoints file") - errInvalidSyncInterval = portainer.Error("Invalid synchronization interval") + errInvalidEnpointProtocol = portainer.Error("Invalid endpoint protocol: Portainer only supports unix:// or tcp://") + errSocketNotFound = portainer.Error("Unable to locate Unix socket") + errEndpointsFileNotFound = portainer.Error("Unable to locate external endpoints file") + errInvalidSyncInterval = portainer.Error("Invalid synchronization interval") + errEndpointExcludeExternal = portainer.Error("Cannot use the -H mutually with --external-endpoints") ) // ParseFlags parse the CLI flags and return a portainer.Flags struct @@ -48,13 +49,37 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) { // ValidateFlags validates the values of the flags. func (*Service) ValidateFlags(flags *portainer.CLIFlags) error { - if *flags.Endpoint != "" { - if !strings.HasPrefix(*flags.Endpoint, "unix://") && !strings.HasPrefix(*flags.Endpoint, "tcp://") { + + if *flags.Endpoint != "" && *flags.ExternalEndpoints != "" { + return errEndpointExcludeExternal + } + + err := validateEndpoint(*flags.Endpoint) + if err != nil { + return err + } + + err = validateExternalEndpoints(*flags.ExternalEndpoints) + if err != nil { + return err + } + + err = validateSyncInterval(*flags.SyncInterval) + if err != nil { + return err + } + + return nil +} + +func validateEndpoint(endpoint string) error { + if endpoint != "" { + if !strings.HasPrefix(endpoint, "unix://") && !strings.HasPrefix(endpoint, "tcp://") { return errInvalidEnpointProtocol } - if strings.HasPrefix(*flags.Endpoint, "unix://") { - socketPath := strings.TrimPrefix(*flags.Endpoint, "unix://") + if strings.HasPrefix(endpoint, "unix://") { + socketPath := strings.TrimPrefix(endpoint, "unix://") if _, err := os.Stat(socketPath); err != nil { if os.IsNotExist(err) { return errSocketNotFound @@ -63,22 +88,27 @@ func (*Service) ValidateFlags(flags *portainer.CLIFlags) error { } } } + return nil +} - if *flags.ExternalEndpoints != "" { - if _, err := os.Stat(*flags.ExternalEndpoints); err != nil { +func validateExternalEndpoints(externalEndpoints string) error { + if externalEndpoints != "" { + if _, err := os.Stat(externalEndpoints); err != nil { if os.IsNotExist(err) { return errEndpointsFileNotFound } return err } } + return nil +} - if *flags.SyncInterval != defaultSyncInterval { - _, err := time.ParseDuration(*flags.SyncInterval) +func validateSyncInterval(syncInterval string) error { + if syncInterval != defaultSyncInterval { + _, err := time.ParseDuration(syncInterval) if err != nil { return errInvalidSyncInterval } } - return nil } diff --git a/api/cmd/portainer/main.go b/api/cmd/portainer/main.go index 9b56f3238..ad5018c9d 100644 --- a/api/cmd/portainer/main.go +++ b/api/cmd/portainer/main.go @@ -13,7 +13,7 @@ import ( "log" ) -func main() { +func initCLI() *portainer.CLIFlags { var cli portainer.CLIService = &cli.Service{} flags, err := cli.ParseFlags(portainer.APIVersion) if err != nil { @@ -24,54 +24,80 @@ func main() { if err != nil { log.Fatal(err) } + return flags +} - fileService, err := file.NewService(*flags.Data, "") +func initFileService(dataStorePath string) portainer.FileService { + fileService, err := file.NewService(dataStorePath, "") if err != nil { log.Fatal(err) } + return fileService +} - var store = bolt.NewStore(*flags.Data) - err = store.Open() +func initStore(dataStorePath string) *bolt.Store { + var store = bolt.NewStore(dataStorePath) + err := store.Open() if err != nil { log.Fatal(err) } - defer store.Close() + return store +} - var jwtService portainer.JWTService - if !*flags.NoAuth { - jwtService, err = jwt.NewService() +func initJWTService(authenticationEnabled bool) portainer.JWTService { + if authenticationEnabled { + jwtService, err := jwt.NewService() if err != nil { log.Fatal(err) } + return jwtService } + return nil +} - var cryptoService portainer.CryptoService = &crypto.Service{} +func initCryptoService() portainer.CryptoService { + return &crypto.Service{} +} - var endpointWatcher portainer.EndpointWatcher +func initEndpointWatcher(endpointService portainer.EndpointService, externalEnpointFile string, syncInterval string) bool { authorizeEndpointMgmt := true - if *flags.ExternalEndpoints != "" { - log.Println("Using external endpoint definition. Disabling endpoint management via API.") + if externalEnpointFile != "" { authorizeEndpointMgmt = false - endpointWatcher = cron.NewWatcher(store.EndpointService, *flags.SyncInterval) - err = endpointWatcher.WatchEndpointFile(*flags.ExternalEndpoints) + log.Println("Using external endpoint definition. Endpoint management via the API will be disabled.") + endpointWatcher := cron.NewWatcher(endpointService, syncInterval) + err := endpointWatcher.WatchEndpointFile(externalEnpointFile) if err != nil { log.Fatal(err) } } + return authorizeEndpointMgmt +} - settings := &portainer.Settings{ +func initSettings(authorizeEndpointMgmt bool, flags *portainer.CLIFlags) *portainer.Settings { + return &portainer.Settings{ HiddenLabels: *flags.Labels, Logo: *flags.Logo, Authentication: !*flags.NoAuth, EndpointManagement: authorizeEndpointMgmt, } +} - // Initialize the active endpoint from the CLI only if there is no - // active endpoint defined yet. - var activeEndpoint *portainer.Endpoint - if *flags.Endpoint != "" { - activeEndpoint, err = store.EndpointService.GetActive() - if err == portainer.ErrEndpointNotFound { +func initActiveEndpointFromFirstEndpointInDatabase(endpointService portainer.EndpointService) { + +} + +func retrieveFirstEndpointFromDatabase(endpointService portainer.EndpointService) *portainer.Endpoint { + endpoints, err := endpointService.Endpoints() + if err != nil { + log.Fatal(err) + } + return &endpoints[0] +} + +func initActiveEndpoint(endpointService portainer.EndpointService, flags *portainer.CLIFlags) *portainer.Endpoint { + activeEndpoint, err := endpointService.GetActive() + if err == portainer.ErrEndpointNotFound { + if *flags.Endpoint != "" { activeEndpoint = &portainer.Endpoint{ Name: "primary", URL: *flags.Endpoint, @@ -80,30 +106,36 @@ func main() { TLSCertPath: *flags.TLSCert, TLSKeyPath: *flags.TLSKey, } - err = store.EndpointService.CreateEndpoint(activeEndpoint) + err = endpointService.CreateEndpoint(activeEndpoint) if err != nil { log.Fatal(err) } - } else if err != nil { - log.Fatal(err) - } - } - if *flags.ExternalEndpoints != "" { - activeEndpoint, err = store.EndpointService.GetActive() - if err == portainer.ErrEndpointNotFound { - var endpoints []portainer.Endpoint - endpoints, err = store.EndpointService.Endpoints() - if err != nil { - log.Fatal(err) - } - err = store.EndpointService.SetActive(&endpoints[0]) - if err != nil { - log.Fatal(err) - } - } else if err != nil { - log.Fatal(err) + } else if *flags.ExternalEndpoints != "" { + activeEndpoint = retrieveFirstEndpointFromDatabase(endpointService) } + } else if err != nil { + log.Fatal(err) } + return activeEndpoint +} + +func main() { + flags := initCLI() + + fileService := initFileService(*flags.Data) + + store := initStore(*flags.Data) + defer store.Close() + + jwtService := initJWTService(!*flags.NoAuth) + + cryptoService := initCryptoService() + + authorizeEndpointMgmt := initEndpointWatcher(store.EndpointService, *flags.ExternalEndpoints, *flags.SyncInterval) + + settings := initSettings(authorizeEndpointMgmt, flags) + + activeEndpoint := initActiveEndpoint(store.EndpointService, flags) var server portainer.Server = &http.Server{ BindAddress: *flags.Addr, @@ -121,7 +153,7 @@ func main() { } log.Printf("Starting Portainer on %s", *flags.Addr) - err = server.Start() + err := server.Start() if err != nil { log.Fatal(err) } diff --git a/api/cron/endpoint_sync.go b/api/cron/endpoint_sync.go index 221143087..9dcd4e290 100644 --- a/api/cron/endpoint_sync.go +++ b/api/cron/endpoint_sync.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "log" "os" + "strings" "github.com/portainer/portainer" ) @@ -46,6 +47,9 @@ func endpointSyncError(err error, logger *log.Logger) bool { func isValidEndpoint(endpoint *portainer.Endpoint) bool { if endpoint.Name != "" && endpoint.URL != "" { + if !strings.HasPrefix(endpoint.URL, "unix://") && !strings.HasPrefix(endpoint.URL, "tcp://") { + return false + } return true } return false @@ -155,12 +159,13 @@ func (job endpointSyncJob) Sync() error { if endpointSyncError(err, job.logger) { return err } + job.logger.Printf("Endpoint synchronization ended. [created: %v] [updated: %v] [deleted: %v]", len(sync.endpointsToCreate), len(sync.endpointsToUpdate), len(sync.endpointsToDelete)) } return nil } func (job endpointSyncJob) Run() { - job.logger.Printf("Endpoint synchronization job started") + job.logger.Println("Endpoint synchronization job started.") err := job.Sync() endpointSyncError(err, job.logger) } diff --git a/api/cron/watcher.go b/api/cron/watcher.go index 01852edb7..6b44ff5ce 100644 --- a/api/cron/watcher.go +++ b/api/cron/watcher.go @@ -24,11 +24,17 @@ func NewWatcher(endpointService portainer.EndpointService, syncInterval string) // WatchEndpointFile starts a cron job to synchronize the endpoints from a file func (watcher *Watcher) WatchEndpointFile(endpointFilePath string) error { job := newEndpointSyncJob(endpointFilePath, watcher.EndpointService) + err := job.Sync() if err != nil { return err } - watcher.Cron.AddJob("@every "+watcher.syncInterval, job) + + err = watcher.Cron.AddJob("@every "+watcher.syncInterval, job) + if err != nil { + return err + } + watcher.Cron.Start() return nil }