From 89be0b1873527987dd2dddac746e93b8bc684d46 Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Mon, 17 Nov 2025 09:16:54 +0100 Subject: [PATCH] refactor: reuse logic for config init and set --- cmd/config.go | 125 +++++++++++++++++++++++++++++++++++++++-- cmd/config_init.go | 136 ++------------------------------------------- cmd/config_set.go | 100 ++------------------------------- 3 files changed, 132 insertions(+), 229 deletions(-) diff --git a/cmd/config.go b/cmd/config.go index 8dc8baa7..7f9b1c19 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -37,6 +37,11 @@ func addConfigFlags(flags *pflag.FlagSet) { flags.Uint("minimumPasswordLength", settings.DefaultMinimumPasswordLength, "minimum password length for new users") flags.String("shell", "", "shell command to which other commands should be appended") + // NB: these are string so they can be presented as octal in the help text + // as that's the conventional representation for modes in Unix. + flags.String("fileMode", fmt.Sprintf("%O", settings.DefaultFileMode), "mode bits that new files are created with") + flags.String("dirMode", fmt.Sprintf("%O", settings.DefaultDirMode), "mode bits that new directories are created with") + flags.String("auth.method", string(auth.MethodJSONAuth), "authentication type") flags.String("auth.header", "", "HTTP header for auth.method=proxy") flags.String("auth.command", "", "command for auth.method=hook") @@ -52,11 +57,6 @@ func addConfigFlags(flags *pflag.FlagSet) { flags.Bool("branding.disableExternal", false, "disable external links such as GitHub links") flags.Bool("branding.disableUsedPercentage", false, "disable used disk percentage graph") - // NB: these are string so they can be presented as octal in the help text - // as that's the conventional representation for modes in Unix. - flags.String("fileMode", fmt.Sprintf("%O", settings.DefaultFileMode), "mode bits that new files are created with") - flags.String("dirMode", fmt.Sprintf("%O", settings.DefaultDirMode), "mode bits that new directories are created with") - flags.Uint64("tus.chunkSize", settings.DefaultTusChunkSize, "the tus chunk size") flags.Uint16("tus.retryCount", settings.DefaultTusRetryCount, "the tus retry count") } @@ -266,3 +266,118 @@ func printSettings(ser *settings.Server, set *settings.Settings, auther auth.Aut fmt.Printf("\nAuther configuration (raw):\n\n%s\n\n", string(b)) return nil } + +func getSettings(flags *pflag.FlagSet, set *settings.Settings, ser *settings.Server, auther auth.Auther, all bool) (auth.Auther, error) { + errs := []error{} + hasAuth := false + + visit := func(flag *pflag.Flag) { + var err error + + switch flag.Name { + // Server flags from [addServerFlags] + case "address": + ser.Address, err = flags.GetString(flag.Name) + case "log": + ser.Log, err = flags.GetString(flag.Name) + case "port": + ser.Port, err = flags.GetString(flag.Name) + case "cert": + ser.TLSCert, err = flags.GetString(flag.Name) + case "key": + ser.TLSKey, err = flags.GetString(flag.Name) + case "root": + ser.Root, err = flags.GetString(flag.Name) + case "socket": + ser.Socket, err = flags.GetString(flag.Name) + case "baseURL": + ser.BaseURL, err = flags.GetString(flag.Name) + case "tokenExpirationTime": + ser.TokenExpirationTime, err = flags.GetString(flag.Name) + case "disableThumbnails": + ser.EnableThumbnails, err = flags.GetBool(flag.Name) + ser.EnableThumbnails = !ser.EnableThumbnails + case "disablePreviewResize": + ser.ResizePreview, err = flags.GetBool(flag.Name) + ser.ResizePreview = !ser.ResizePreview + case "disableExec": + ser.EnableExec, err = flags.GetBool(flag.Name) + ser.EnableExec = !ser.EnableExec + case "disableTypeDetectionByHeader": + ser.TypeDetectionByHeader, err = flags.GetBool(flag.Name) + ser.TypeDetectionByHeader = !ser.TypeDetectionByHeader + + // Settings flags from [addConfigFlags] + case "signup": + set.Signup, err = flags.GetBool(flag.Name) + case "hideLoginButton": + set.HideLoginButton, err = flags.GetBool(flag.Name) + case "createUserDir": + set.CreateUserDir, err = flags.GetBool(flag.Name) + case "minimumPasswordLength": + set.MinimumPasswordLength, err = flags.GetUint(flag.Name) + case "shell": + var shell string + shell, err = flags.GetString(flag.Name) + if err == nil { + set.Shell = convertCmdStrToCmdArray(shell) + } + case "auth.method": + hasAuth = true + case "branding.name": + set.Branding.Name, err = flags.GetString(flag.Name) + case "branding.theme": + set.Branding.Theme, err = flags.GetString(flag.Name) + case "branding.color": + set.Branding.Color, err = flags.GetString(flag.Name) + case "branding.files": + set.Branding.Files, err = flags.GetString(flag.Name) + case "branding.disableExternal": + set.Branding.DisableExternal, err = flags.GetBool(flag.Name) + case "branding.disableUsedPercentage": + set.Branding.DisableUsedPercentage, err = flags.GetBool(flag.Name) + case "fileMode": + set.FileMode, err = getAndParseFileMode(flags, flag.Name) + case "dirMode": + set.DirMode, err = getAndParseFileMode(flags, flag.Name) + case "tus.chunkSize": + set.Tus.ChunkSize, err = flags.GetUint64(flag.Name) + case "tus.retryCount": + set.Tus.RetryCount, err = flags.GetUint16(flag.Name) + } + + if err != nil { + errs = append(errs, err) + } + } + + if all { + flags.VisitAll(visit) + } else { + flags.Visit(visit) + } + + err := nerrors.Join(errs...) + if err != nil { + return nil, err + } + + err = getUserDefaults(flags, &set.Defaults, all) + if err != nil { + return nil, err + } + + if all { + set.AuthMethod, auther, err = getAuthentication(flags) + if err != nil { + return nil, err + } + } else { + set.AuthMethod, auther, err = getAuthentication(flags, hasAuth, set, auther) + if err != nil { + return nil, err + } + } + + return auther, nil +} diff --git a/cmd/config_init.go b/cmd/config_init.go index 4c9aab63..2787f080 100644 --- a/cmd/config_init.go +++ b/cmd/config_init.go @@ -5,7 +5,6 @@ import ( "github.com/spf13/cobra" - "github.com/filebrowser/filebrowser/v2/auth" "github.com/filebrowser/filebrowser/v2/settings" ) @@ -26,140 +25,17 @@ override the options.`, RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error { flags := cmd.Flags() - // General Settings - s := &settings.Settings{ - Key: generateKey(), - } - - err := getUserDefaults(flags, &s.Defaults, true) - if err != nil { - return err - } - - s.Signup, err = flags.GetBool("signup") - if err != nil { - return err - } - - s.HideLoginButton, err = flags.GetBool("hideLoginButton") - if err != nil { - return err - } - - s.CreateUserDir, err = flags.GetBool("createUserDir") - if err != nil { - return err - } - - s.MinimumPasswordLength, err = flags.GetUint("minimumPasswordLength") - if err != nil { - return err - } - - shell, err := flags.GetString("shell") - if err != nil { - return err - } - s.Shell = convertCmdStrToCmdArray(shell) - - s.FileMode, err = getAndParseFileMode(flags, "fileMode") - if err != nil { - return err - } - - s.DirMode, err = getAndParseFileMode(flags, "dirMode") - if err != nil { - return err - } - - s.Branding.Name, err = flags.GetString("branding.name") - if err != nil { - return err - } - - s.Branding.DisableExternal, err = flags.GetBool("branding.disableExternal") - if err != nil { - return err - } - - s.Branding.DisableUsedPercentage, err = flags.GetBool("branding.disableUsedPercentage") - if err != nil { - return err - } - - s.Branding.Theme, err = flags.GetString("branding.themes") - if err != nil { - return err - } - - s.Branding.Files, err = flags.GetString("branding.files") - if err != nil { - return err - } - - s.Tus.ChunkSize, err = flags.GetUint64("tus.chunkSize") - if err != nil { - return err - } - - s.Tus.RetryCount, err = flags.GetUint16("tus.retryCount") - if err != nil { - return err - } - - var auther auth.Auther - s.AuthMethod, auther, err = getAuthentication(flags) - if err != nil { - return err - } - - // Server Settings + // Initialize config + s := &settings.Settings{Key: generateKey()} ser := &settings.Server{} - ser.Address, err = flags.GetString("address") - if err != nil { - return err - } - - ser.Socket, err = flags.GetString("socket") - if err != nil { - return err - } - - ser.Root, err = flags.GetString("root") - if err != nil { - return err - } - - ser.BaseURL, err = flags.GetString("baseURL") - if err != nil { - return err - } - - ser.TLSKey, err = flags.GetString("key") - if err != nil { - return err - } - - ser.TLSCert, err = flags.GetString("cert") - if err != nil { - return err - } - - ser.Port, err = flags.GetString("port") - if err != nil { - return err - } - - ser.Log, err = flags.GetString("log") - if err != nil { - return err - } - - ser.TokenExpirationTime, err = flags.GetString("tokenExpirationTime") + + // Fill config with options + auther, err := getSettings(flags, s, ser, nil, true) if err != nil { return err } + // Save updated config err = d.store.Settings.Save(s) if err != nil { return err diff --git a/cmd/config_set.go b/cmd/config_set.go index 74fae9ea..d25b6596 100644 --- a/cmd/config_set.go +++ b/cmd/config_set.go @@ -2,7 +2,6 @@ package cmd import ( "github.com/spf13/cobra" - "github.com/spf13/pflag" ) func init() { @@ -19,6 +18,7 @@ you want to change. Other options will remain unchanged.`, RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error { flags := cmd.Flags() + // Read existing config set, err := d.store.Settings.Get() if err != nil { return err @@ -29,116 +29,28 @@ you want to change. Other options will remain unchanged.`, return err } - hasAuth := false - - flags.Visit(func(flag *pflag.Flag) { - if err != nil { - return - } - - switch flag.Name { - // Server flags from [addServerFlags] - case "address": - ser.Address, err = flags.GetString(flag.Name) - case "log": - ser.Log, err = flags.GetString(flag.Name) - case "port": - ser.Port, err = flags.GetString(flag.Name) - case "cert": - ser.TLSCert, err = flags.GetString(flag.Name) - case "key": - ser.TLSKey, err = flags.GetString(flag.Name) - case "root": - ser.Root, err = flags.GetString(flag.Name) - case "socket": - ser.Socket, err = flags.GetString(flag.Name) - case "baseURL": - ser.BaseURL, err = flags.GetString(flag.Name) - case "tokenExpirationTime": - ser.TokenExpirationTime, err = flags.GetString(flag.Name) - case "disableThumbnails": - ser.EnableThumbnails, err = flags.GetBool(flag.Name) - ser.EnableThumbnails = !ser.EnableThumbnails - case "disablePreviewResize": - ser.ResizePreview, err = flags.GetBool(flag.Name) - ser.ResizePreview = !ser.ResizePreview - case "disableExec": - ser.EnableExec, err = flags.GetBool(flag.Name) - ser.EnableExec = !ser.EnableExec - case "disableTypeDetectionByHeader": - ser.TypeDetectionByHeader, err = flags.GetBool(flag.Name) - ser.TypeDetectionByHeader = !ser.TypeDetectionByHeader - - // Settings flags from [addConfigFlags] - case "signup": - set.Signup, err = flags.GetBool(flag.Name) - case "hideLoginButton": - set.HideLoginButton, err = flags.GetBool(flag.Name) - case "createUserDir": - set.CreateUserDir, err = flags.GetBool(flag.Name) - case "minimumPasswordLength": - set.MinimumPasswordLength, err = flags.GetUint(flag.Name) - case "shell": - var shell string - shell, err = flags.GetString(flag.Name) - if err != nil { - return - } - set.Shell = convertCmdStrToCmdArray(shell) - case "auth.method": - hasAuth = true - case "branding.name": - set.Branding.Name, err = flags.GetString(flag.Name) - case "branding.theme": - set.Branding.Theme, err = flags.GetString(flag.Name) - case "branding.color": - set.Branding.Color, err = flags.GetString(flag.Name) - case "branding.files": - set.Branding.Files, err = flags.GetString(flag.Name) - case "branding.disableExternal": - set.Branding.DisableExternal, err = flags.GetBool(flag.Name) - case "branding.disableUsedPercentage": - set.Branding.DisableUsedPercentage, err = flags.GetBool(flag.Name) - case "fileMode": - set.FileMode, err = getAndParseFileMode(flags, flag.Name) - case "dirMode": - set.DirMode, err = getAndParseFileMode(flags, flag.Name) - case "tus.chunkSize": - set.Tus.ChunkSize, err = flags.GetUint64(flag.Name) - case "tus.retryCount": - set.Tus.RetryCount, err = flags.GetUint16(flag.Name) - } - - }) - if err != nil { - return err - } - - err = getUserDefaults(flags, &set.Defaults, false) - if err != nil { - return err - } - - // read the defaults auther, err := d.store.Auth.Get(set.AuthMethod) if err != nil { return err } - // check if there are new flags for existing auth method - set.AuthMethod, auther, err = getAuthentication(flags, hasAuth, set, auther) + // Get updated config + auther, err = getSettings(flags, set, ser, auther, false) if err != nil { return err } + // Save updated config err = d.store.Auth.Save(auther) if err != nil { return err } + err = d.store.Settings.Save(set) if err != nil { return err } + err = d.store.Settings.SaveServer(ser) if err != nil { return err