refactor(api): refactor

pull/572/head
Anthony Lapenna 2017-02-07 16:26:12 +13:00
parent dc78ec5135
commit f129bf3e97
5 changed files with 152 additions and 81 deletions

View File

@ -73,27 +73,14 @@ func (service *EndpointService) Synchronize(toCreate, toUpdate, toDelete []*port
bucket := tx.Bucket([]byte(endpointBucketName)) bucket := tx.Bucket([]byte(endpointBucketName))
for _, endpoint := range toCreate { for _, endpoint := range toCreate {
id, _ := bucket.NextSequence() err := storeNewEndpoint(endpoint, bucket)
endpoint.ID = portainer.EndpointID(id)
data, err := internal.MarshalEndpoint(endpoint)
if err != nil {
return err
}
err = bucket.Put(internal.Itob(int(endpoint.ID)), data)
if err != nil { if err != nil {
return err return err
} }
} }
for _, endpoint := range toUpdate { for _, endpoint := range toUpdate {
data, err := internal.MarshalEndpoint(endpoint) err := marshalAndStoreEndpoint(endpoint, bucket)
if err != nil {
return err
}
err = bucket.Put(internal.Itob(int(endpoint.ID)), data)
if err != nil { if err != nil {
return err return err
} }
@ -114,16 +101,7 @@ func (service *EndpointService) Synchronize(toCreate, toUpdate, toDelete []*port
func (service *EndpointService) CreateEndpoint(endpoint *portainer.Endpoint) error { func (service *EndpointService) CreateEndpoint(endpoint *portainer.Endpoint) error {
return service.store.db.Update(func(tx *bolt.Tx) error { return service.store.db.Update(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte(endpointBucketName)) bucket := tx.Bucket([]byte(endpointBucketName))
err := storeNewEndpoint(endpoint, bucket)
id, _ := bucket.NextSequence()
endpoint.ID = portainer.EndpointID(id)
data, err := internal.MarshalEndpoint(endpoint)
if err != nil {
return err
}
err = bucket.Put(internal.Itob(int(endpoint.ID)), data)
if err != nil { if err != nil {
return err return err
} }
@ -215,3 +193,23 @@ func (service *EndpointService) DeleteActive() error {
return nil return nil
}) })
} }
func marshalAndStoreEndpoint(endpoint *portainer.Endpoint, bucket *bolt.Bucket) error {
data, err := internal.MarshalEndpoint(endpoint)
if err != nil {
return err
}
err = bucket.Put(internal.Itob(int(endpoint.ID)), data)
if err != nil {
return err
}
return nil
}
func storeNewEndpoint(endpoint *portainer.Endpoint, bucket *bolt.Bucket) error {
id, _ := bucket.NextSequence()
endpoint.ID = portainer.EndpointID(id)
return marshalAndStoreEndpoint(endpoint, bucket)
}

View File

@ -15,10 +15,11 @@ import (
type Service struct{} type Service struct{}
const ( const (
errInvalidEnpointProtocol = portainer.Error("Invalid endpoint protocol: Portainer only supports unix:// or tcp://") errInvalidEnpointProtocol = portainer.Error("Invalid endpoint protocol: Portainer only supports unix:// or tcp://")
errSocketNotFound = portainer.Error("Unable to locate Unix socket") errSocketNotFound = portainer.Error("Unable to locate Unix socket")
errEndpointsFileNotFound = portainer.Error("Unable to locate external endpoints file") errEndpointsFileNotFound = portainer.Error("Unable to locate external endpoints file")
errInvalidSyncInterval = portainer.Error("Invalid synchronization interval") errInvalidSyncInterval = portainer.Error("Invalid synchronization interval")
errEndpointExcludeExternal = portainer.Error("Cannot use the -H mutually with --external-endpoints")
) )
// ParseFlags parse the CLI flags and return a portainer.Flags struct // ParseFlags parse the CLI flags and return a portainer.Flags struct
@ -48,13 +49,37 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
// ValidateFlags validates the values of the flags. // ValidateFlags validates the values of the flags.
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error { func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
if *flags.Endpoint != "" {
if !strings.HasPrefix(*flags.Endpoint, "unix://") && !strings.HasPrefix(*flags.Endpoint, "tcp://") { if *flags.Endpoint != "" && *flags.ExternalEndpoints != "" {
return errEndpointExcludeExternal
}
err := validateEndpoint(*flags.Endpoint)
if err != nil {
return err
}
err = validateExternalEndpoints(*flags.ExternalEndpoints)
if err != nil {
return err
}
err = validateSyncInterval(*flags.SyncInterval)
if err != nil {
return err
}
return nil
}
func validateEndpoint(endpoint string) error {
if endpoint != "" {
if !strings.HasPrefix(endpoint, "unix://") && !strings.HasPrefix(endpoint, "tcp://") {
return errInvalidEnpointProtocol return errInvalidEnpointProtocol
} }
if strings.HasPrefix(*flags.Endpoint, "unix://") { if strings.HasPrefix(endpoint, "unix://") {
socketPath := strings.TrimPrefix(*flags.Endpoint, "unix://") socketPath := strings.TrimPrefix(endpoint, "unix://")
if _, err := os.Stat(socketPath); err != nil { if _, err := os.Stat(socketPath); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return errSocketNotFound return errSocketNotFound
@ -63,22 +88,27 @@ func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
} }
} }
} }
return nil
}
if *flags.ExternalEndpoints != "" { func validateExternalEndpoints(externalEndpoints string) error {
if _, err := os.Stat(*flags.ExternalEndpoints); err != nil { if externalEndpoints != "" {
if _, err := os.Stat(externalEndpoints); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
return errEndpointsFileNotFound return errEndpointsFileNotFound
} }
return err return err
} }
} }
return nil
}
if *flags.SyncInterval != defaultSyncInterval { func validateSyncInterval(syncInterval string) error {
_, err := time.ParseDuration(*flags.SyncInterval) if syncInterval != defaultSyncInterval {
_, err := time.ParseDuration(syncInterval)
if err != nil { if err != nil {
return errInvalidSyncInterval return errInvalidSyncInterval
} }
} }
return nil return nil
} }

View File

@ -13,7 +13,7 @@ import (
"log" "log"
) )
func main() { func initCLI() *portainer.CLIFlags {
var cli portainer.CLIService = &cli.Service{} var cli portainer.CLIService = &cli.Service{}
flags, err := cli.ParseFlags(portainer.APIVersion) flags, err := cli.ParseFlags(portainer.APIVersion)
if err != nil { if err != nil {
@ -24,54 +24,80 @@ func main() {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return flags
}
fileService, err := file.NewService(*flags.Data, "") func initFileService(dataStorePath string) portainer.FileService {
fileService, err := file.NewService(dataStorePath, "")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return fileService
}
var store = bolt.NewStore(*flags.Data) func initStore(dataStorePath string) *bolt.Store {
err = store.Open() var store = bolt.NewStore(dataStorePath)
err := store.Open()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer store.Close() return store
}
var jwtService portainer.JWTService func initJWTService(authenticationEnabled bool) portainer.JWTService {
if !*flags.NoAuth { if authenticationEnabled {
jwtService, err = jwt.NewService() jwtService, err := jwt.NewService()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return jwtService
} }
return nil
}
var cryptoService portainer.CryptoService = &crypto.Service{} func initCryptoService() portainer.CryptoService {
return &crypto.Service{}
}
var endpointWatcher portainer.EndpointWatcher func initEndpointWatcher(endpointService portainer.EndpointService, externalEnpointFile string, syncInterval string) bool {
authorizeEndpointMgmt := true authorizeEndpointMgmt := true
if *flags.ExternalEndpoints != "" { if externalEnpointFile != "" {
log.Println("Using external endpoint definition. Disabling endpoint management via API.")
authorizeEndpointMgmt = false authorizeEndpointMgmt = false
endpointWatcher = cron.NewWatcher(store.EndpointService, *flags.SyncInterval) log.Println("Using external endpoint definition. Endpoint management via the API will be disabled.")
err = endpointWatcher.WatchEndpointFile(*flags.ExternalEndpoints) endpointWatcher := cron.NewWatcher(endpointService, syncInterval)
err := endpointWatcher.WatchEndpointFile(externalEnpointFile)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
return authorizeEndpointMgmt
}
settings := &portainer.Settings{ func initSettings(authorizeEndpointMgmt bool, flags *portainer.CLIFlags) *portainer.Settings {
return &portainer.Settings{
HiddenLabels: *flags.Labels, HiddenLabels: *flags.Labels,
Logo: *flags.Logo, Logo: *flags.Logo,
Authentication: !*flags.NoAuth, Authentication: !*flags.NoAuth,
EndpointManagement: authorizeEndpointMgmt, EndpointManagement: authorizeEndpointMgmt,
} }
}
// Initialize the active endpoint from the CLI only if there is no func initActiveEndpointFromFirstEndpointInDatabase(endpointService portainer.EndpointService) {
// active endpoint defined yet.
var activeEndpoint *portainer.Endpoint }
if *flags.Endpoint != "" {
activeEndpoint, err = store.EndpointService.GetActive() func retrieveFirstEndpointFromDatabase(endpointService portainer.EndpointService) *portainer.Endpoint {
if err == portainer.ErrEndpointNotFound { endpoints, err := endpointService.Endpoints()
if err != nil {
log.Fatal(err)
}
return &endpoints[0]
}
func initActiveEndpoint(endpointService portainer.EndpointService, flags *portainer.CLIFlags) *portainer.Endpoint {
activeEndpoint, err := endpointService.GetActive()
if err == portainer.ErrEndpointNotFound {
if *flags.Endpoint != "" {
activeEndpoint = &portainer.Endpoint{ activeEndpoint = &portainer.Endpoint{
Name: "primary", Name: "primary",
URL: *flags.Endpoint, URL: *flags.Endpoint,
@ -80,30 +106,36 @@ func main() {
TLSCertPath: *flags.TLSCert, TLSCertPath: *flags.TLSCert,
TLSKeyPath: *flags.TLSKey, TLSKeyPath: *flags.TLSKey,
} }
err = store.EndpointService.CreateEndpoint(activeEndpoint) err = endpointService.CreateEndpoint(activeEndpoint)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
} else if err != nil { } else if *flags.ExternalEndpoints != "" {
log.Fatal(err) activeEndpoint = retrieveFirstEndpointFromDatabase(endpointService)
}
}
if *flags.ExternalEndpoints != "" {
activeEndpoint, err = store.EndpointService.GetActive()
if err == portainer.ErrEndpointNotFound {
var endpoints []portainer.Endpoint
endpoints, err = store.EndpointService.Endpoints()
if err != nil {
log.Fatal(err)
}
err = store.EndpointService.SetActive(&endpoints[0])
if err != nil {
log.Fatal(err)
}
} else if err != nil {
log.Fatal(err)
} }
} else if err != nil {
log.Fatal(err)
} }
return activeEndpoint
}
func main() {
flags := initCLI()
fileService := initFileService(*flags.Data)
store := initStore(*flags.Data)
defer store.Close()
jwtService := initJWTService(!*flags.NoAuth)
cryptoService := initCryptoService()
authorizeEndpointMgmt := initEndpointWatcher(store.EndpointService, *flags.ExternalEndpoints, *flags.SyncInterval)
settings := initSettings(authorizeEndpointMgmt, flags)
activeEndpoint := initActiveEndpoint(store.EndpointService, flags)
var server portainer.Server = &http.Server{ var server portainer.Server = &http.Server{
BindAddress: *flags.Addr, BindAddress: *flags.Addr,
@ -121,7 +153,7 @@ func main() {
} }
log.Printf("Starting Portainer on %s", *flags.Addr) log.Printf("Starting Portainer on %s", *flags.Addr)
err = server.Start() err := server.Start()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"os" "os"
"strings"
"github.com/portainer/portainer" "github.com/portainer/portainer"
) )
@ -46,6 +47,9 @@ func endpointSyncError(err error, logger *log.Logger) bool {
func isValidEndpoint(endpoint *portainer.Endpoint) bool { func isValidEndpoint(endpoint *portainer.Endpoint) bool {
if endpoint.Name != "" && endpoint.URL != "" { if endpoint.Name != "" && endpoint.URL != "" {
if !strings.HasPrefix(endpoint.URL, "unix://") && !strings.HasPrefix(endpoint.URL, "tcp://") {
return false
}
return true return true
} }
return false return false
@ -155,12 +159,13 @@ func (job endpointSyncJob) Sync() error {
if endpointSyncError(err, job.logger) { if endpointSyncError(err, job.logger) {
return err return err
} }
job.logger.Printf("Endpoint synchronization ended. [created: %v] [updated: %v] [deleted: %v]", len(sync.endpointsToCreate), len(sync.endpointsToUpdate), len(sync.endpointsToDelete))
} }
return nil return nil
} }
func (job endpointSyncJob) Run() { func (job endpointSyncJob) Run() {
job.logger.Printf("Endpoint synchronization job started") job.logger.Println("Endpoint synchronization job started.")
err := job.Sync() err := job.Sync()
endpointSyncError(err, job.logger) endpointSyncError(err, job.logger)
} }

View File

@ -24,11 +24,17 @@ func NewWatcher(endpointService portainer.EndpointService, syncInterval string)
// WatchEndpointFile starts a cron job to synchronize the endpoints from a file // WatchEndpointFile starts a cron job to synchronize the endpoints from a file
func (watcher *Watcher) WatchEndpointFile(endpointFilePath string) error { func (watcher *Watcher) WatchEndpointFile(endpointFilePath string) error {
job := newEndpointSyncJob(endpointFilePath, watcher.EndpointService) job := newEndpointSyncJob(endpointFilePath, watcher.EndpointService)
err := job.Sync() err := job.Sync()
if err != nil { if err != nil {
return err return err
} }
watcher.Cron.AddJob("@every "+watcher.syncInterval, job)
err = watcher.Cron.AddJob("@every "+watcher.syncInterval, job)
if err != nil {
return err
}
watcher.Cron.Start() watcher.Cron.Start()
return nil return nil
} }