From 2c526077c05b298a838b4468881e2c460369d5bc Mon Sep 17 00:00:00 2001 From: 1138-4EB <1138-4EB@users.noreply.github.com> Date: Tue, 8 Jan 2019 21:10:27 +0100 Subject: [PATCH] style: rename functions in root.go (#623) * replace isFlagSet with Changed * style: rename mustGetStringViperFlag and getStringViperFlag, use getParamB to read noauth * style * style: move * fix build error * rename getServerWithViper to getRunParams --- cmd/root.go | 118 ++++++++++++++++++++++--------------------------- cmd/upgrade.go | 2 +- cmd/utils.go | 2 +- 3 files changed, 54 insertions(+), 68 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index de260a5d..0c71df09 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -54,43 +54,6 @@ func addServerFlags(flags *pflag.FlagSet) { flags.StringP("baseurl", "b", "", "base url") } -func isFlagSet(flags *pflag.FlagSet, key string) bool { - set:= false - flags.Visit(func(flag *pflag.Flag) { - if flag.Name == key { - set = true - } - }) - return set -} - -// NOTE: we could simply bind the flags to viper and use IsSet. -// Although there is a bug on Viper that always returns true on IsSet -// if a flag is binded. Our alternative way is to manually check -// the flag and then the value from env/config/gotten by viper. -// https://github.com/spf13/viper/pull/331 -func getStringViperFlag(flags *pflag.FlagSet, key string) (string, bool) { - value, _ := flags.GetString(key) - - // If set on Flags, use it. - if isFlagSet(flags, key) { - return value, true - } - - // If set through viper (env, config), return it. - if v.IsSet(key) { - return v.GetString(key), true - } - - // Otherwise use default value on flags. - return value, false -} - -func mustGetStringViperFlag(flags *pflag.FlagSet, key string) string { - val, _ := getStringViperFlag(flags, key) - return val -} - var rootCmd = &cobra.Command{ Use: "filebrowser", Version: version.Version, @@ -137,7 +100,7 @@ user created with the credentials from options "username" and "password".`, quickSetup(cmd.Flags(), d) } - server := getServerWithViper(cmd.Flags(), d.store) + server := getRunParams(cmd.Flags(), d.store) setupLog(server.Log) root, err := filepath.Abs(server.Root) @@ -168,41 +131,70 @@ user created with the credentials from options "username" and "password".`, }, pythonConfig{allowNoDB: true}), } -func getServerWithViper(flags *pflag.FlagSet, st *storage.Storage) *settings.Server { +func getRunParams(flags *pflag.FlagSet, st *storage.Storage) *settings.Server { server, err := st.Settings.GetServer() checkErr(err) - if val, set := getStringViperFlag(flags, "root"); set { + if val, set := getParamB(flags, "root"); set { server.Root = val } - if val, set := getStringViperFlag(flags, "baseurl"); set { + if val, set := getParamB(flags, "baseurl"); set { server.BaseURL = val } - if val, set := getStringViperFlag(flags, "address"); set { + if val, set := getParamB(flags, "address"); set { server.Address = val } - if val, set := getStringViperFlag(flags, "port"); set { + if val, set := getParamB(flags, "port"); set { server.Port = val } - if val, set := getStringViperFlag(flags, "log"); set { + if val, set := getParamB(flags, "log"); set { server.Log = val } - if val, set := getStringViperFlag(flags, "key"); set { + if val, set := getParamB(flags, "key"); set { server.TLSKey = val } - if val, set := getStringViperFlag(flags, "cert"); set { + if val, set := getParamB(flags, "cert"); set { server.TLSCert = val } return server } +// getParamB returns a parameter as a string and a boolean to tell if it is different from the default +// +// NOTE: we could simply bind the flags to viper and use IsSet. +// Although there is a bug on Viper that always returns true on IsSet +// if a flag is binded. Our alternative way is to manually check +// the flag and then the value from env/config/gotten by viper. +// https://github.com/spf13/viper/pull/331 +func getParamB(flags *pflag.FlagSet, key string) (string, bool) { + value, _ := flags.GetString(key) + + // If set on Flags, use it. + if flags.Changed(key) { + return value, true + } + + // If set through viper (env, config), return it. + if v.IsSet(key) { + return v.GetString(key), true + } + + // Otherwise use default value on flags. + return value, false +} + +func getParam(flags *pflag.FlagSet, key string) string { + val, _ := getParamB(flags, key) + return val +} + func setupLog(logMethod string) { switch logMethod { case "stdout": @@ -223,8 +215,8 @@ func setupLog(logMethod string) { func quickSetup(flags *pflag.FlagSet, d pythonData) { set := &settings.Settings{ - Key: generateRandomBytes(64), // 256 bit - Signup: false, + Key: generateRandomBytes(64), // 256 bit + Signup: false, Defaults: settings.UserDefaults{ Scope: ".", Locale: "en", @@ -241,40 +233,34 @@ func quickSetup(flags *pflag.FlagSet, d pythonData) { }, } - noauth, err := flags.GetBool("noauth") - checkErr(err) - - if !isFlagSet(flags, "noauth") && v.IsSet("noauth") { - noauth = v.GetBool("noauth") - } - - if noauth { + var err error + if _, noauth := getParamB(flags, "noauth"); noauth { set.AuthMethod = auth.MethodNoAuth err = d.store.Auth.Save(&auth.NoAuth{}) } else { set.AuthMethod = auth.MethodJSONAuth err = d.store.Auth.Save(&auth.JSONAuth{}) } - + checkErr(err) err = d.store.Settings.Save(set) checkErr(err) ser := &settings.Server{ - BaseURL: mustGetStringViperFlag(flags, "baseurl"), - Port: mustGetStringViperFlag(flags, "port"), - Log: mustGetStringViperFlag(flags, "log"), - TLSKey: mustGetStringViperFlag(flags, "key"), - TLSCert: mustGetStringViperFlag(flags, "cert"), - Address: mustGetStringViperFlag(flags, "address"), - Root: mustGetStringViperFlag(flags, "root"), + BaseURL: getParam(flags, "baseurl"), + Port: getParam(flags, "port"), + Log: getParam(flags, "log"), + TLSKey: getParam(flags, "key"), + TLSCert: getParam(flags, "cert"), + Address: getParam(flags, "address"), + Root: getParam(flags, "root"), } err = d.store.Settings.SaveServer(ser) checkErr(err) - username := mustGetStringViperFlag(flags, "username") - password := mustGetStringViperFlag(flags, "password") + username := getParam(flags, "username") + password := getParam(flags, "password") if password == "" { password, err = users.HashPwd("admin") diff --git a/cmd/upgrade.go b/cmd/upgrade.go index 0e3dd290..d46d4fe9 100644 --- a/cmd/upgrade.go +++ b/cmd/upgrade.go @@ -25,7 +25,7 @@ this version.`, flags := cmd.Flags() oldDB := mustGetString(flags, "old.database") oldConf := mustGetString(flags, "old.config") - err := importer.Import(oldDB, oldConf, mustGetStringViperFlag(flags, "database")) + err := importer.Import(oldDB, oldConf, getParam(flags, "database")) checkErr(err) }, } diff --git a/cmd/utils.go b/cmd/utils.go index bd741998..b52fcf5a 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -66,7 +66,7 @@ func python(fn pythonFunc, cfg pythonConfig) cobraFunc { return func(cmd *cobra.Command, args []string) { data := pythonData{hadDB: true} - path := mustGetStringViperFlag(cmd.Flags(), "database") + path := getParam(cmd.Flags(), "database") _, err := os.Stat(path) if os.IsNotExist(err) {