refactor: reuse logic for config init and set

pull/5560/head
Henrique Dias 2025-11-17 09:16:54 +01:00
parent 8c5dc7641e
commit 89be0b1873
No known key found for this signature in database
3 changed files with 132 additions and 229 deletions

View File

@ -37,6 +37,11 @@ func addConfigFlags(flags *pflag.FlagSet) {
flags.Uint("minimumPasswordLength", settings.DefaultMinimumPasswordLength, "minimum password length for new users") flags.Uint("minimumPasswordLength", settings.DefaultMinimumPasswordLength, "minimum password length for new users")
flags.String("shell", "", "shell command to which other commands should be appended") flags.String("shell", "", "shell command to which other commands should be appended")
// NB: these are string so they can be presented as octal in the help text
// as that's the conventional representation for modes in Unix.
flags.String("fileMode", fmt.Sprintf("%O", settings.DefaultFileMode), "mode bits that new files are created with")
flags.String("dirMode", fmt.Sprintf("%O", settings.DefaultDirMode), "mode bits that new directories are created with")
flags.String("auth.method", string(auth.MethodJSONAuth), "authentication type") flags.String("auth.method", string(auth.MethodJSONAuth), "authentication type")
flags.String("auth.header", "", "HTTP header for auth.method=proxy") flags.String("auth.header", "", "HTTP header for auth.method=proxy")
flags.String("auth.command", "", "command for auth.method=hook") flags.String("auth.command", "", "command for auth.method=hook")
@ -52,11 +57,6 @@ func addConfigFlags(flags *pflag.FlagSet) {
flags.Bool("branding.disableExternal", false, "disable external links such as GitHub links") flags.Bool("branding.disableExternal", false, "disable external links such as GitHub links")
flags.Bool("branding.disableUsedPercentage", false, "disable used disk percentage graph") flags.Bool("branding.disableUsedPercentage", false, "disable used disk percentage graph")
// NB: these are string so they can be presented as octal in the help text
// as that's the conventional representation for modes in Unix.
flags.String("fileMode", fmt.Sprintf("%O", settings.DefaultFileMode), "mode bits that new files are created with")
flags.String("dirMode", fmt.Sprintf("%O", settings.DefaultDirMode), "mode bits that new directories are created with")
flags.Uint64("tus.chunkSize", settings.DefaultTusChunkSize, "the tus chunk size") flags.Uint64("tus.chunkSize", settings.DefaultTusChunkSize, "the tus chunk size")
flags.Uint16("tus.retryCount", settings.DefaultTusRetryCount, "the tus retry count") flags.Uint16("tus.retryCount", settings.DefaultTusRetryCount, "the tus retry count")
} }
@ -266,3 +266,118 @@ func printSettings(ser *settings.Server, set *settings.Settings, auther auth.Aut
fmt.Printf("\nAuther configuration (raw):\n\n%s\n\n", string(b)) fmt.Printf("\nAuther configuration (raw):\n\n%s\n\n", string(b))
return nil return nil
} }
func getSettings(flags *pflag.FlagSet, set *settings.Settings, ser *settings.Server, auther auth.Auther, all bool) (auth.Auther, error) {
errs := []error{}
hasAuth := false
visit := func(flag *pflag.Flag) {
var err error
switch flag.Name {
// Server flags from [addServerFlags]
case "address":
ser.Address, err = flags.GetString(flag.Name)
case "log":
ser.Log, err = flags.GetString(flag.Name)
case "port":
ser.Port, err = flags.GetString(flag.Name)
case "cert":
ser.TLSCert, err = flags.GetString(flag.Name)
case "key":
ser.TLSKey, err = flags.GetString(flag.Name)
case "root":
ser.Root, err = flags.GetString(flag.Name)
case "socket":
ser.Socket, err = flags.GetString(flag.Name)
case "baseURL":
ser.BaseURL, err = flags.GetString(flag.Name)
case "tokenExpirationTime":
ser.TokenExpirationTime, err = flags.GetString(flag.Name)
case "disableThumbnails":
ser.EnableThumbnails, err = flags.GetBool(flag.Name)
ser.EnableThumbnails = !ser.EnableThumbnails
case "disablePreviewResize":
ser.ResizePreview, err = flags.GetBool(flag.Name)
ser.ResizePreview = !ser.ResizePreview
case "disableExec":
ser.EnableExec, err = flags.GetBool(flag.Name)
ser.EnableExec = !ser.EnableExec
case "disableTypeDetectionByHeader":
ser.TypeDetectionByHeader, err = flags.GetBool(flag.Name)
ser.TypeDetectionByHeader = !ser.TypeDetectionByHeader
// Settings flags from [addConfigFlags]
case "signup":
set.Signup, err = flags.GetBool(flag.Name)
case "hideLoginButton":
set.HideLoginButton, err = flags.GetBool(flag.Name)
case "createUserDir":
set.CreateUserDir, err = flags.GetBool(flag.Name)
case "minimumPasswordLength":
set.MinimumPasswordLength, err = flags.GetUint(flag.Name)
case "shell":
var shell string
shell, err = flags.GetString(flag.Name)
if err == nil {
set.Shell = convertCmdStrToCmdArray(shell)
}
case "auth.method":
hasAuth = true
case "branding.name":
set.Branding.Name, err = flags.GetString(flag.Name)
case "branding.theme":
set.Branding.Theme, err = flags.GetString(flag.Name)
case "branding.color":
set.Branding.Color, err = flags.GetString(flag.Name)
case "branding.files":
set.Branding.Files, err = flags.GetString(flag.Name)
case "branding.disableExternal":
set.Branding.DisableExternal, err = flags.GetBool(flag.Name)
case "branding.disableUsedPercentage":
set.Branding.DisableUsedPercentage, err = flags.GetBool(flag.Name)
case "fileMode":
set.FileMode, err = getAndParseFileMode(flags, flag.Name)
case "dirMode":
set.DirMode, err = getAndParseFileMode(flags, flag.Name)
case "tus.chunkSize":
set.Tus.ChunkSize, err = flags.GetUint64(flag.Name)
case "tus.retryCount":
set.Tus.RetryCount, err = flags.GetUint16(flag.Name)
}
if err != nil {
errs = append(errs, err)
}
}
if all {
flags.VisitAll(visit)
} else {
flags.Visit(visit)
}
err := nerrors.Join(errs...)
if err != nil {
return nil, err
}
err = getUserDefaults(flags, &set.Defaults, all)
if err != nil {
return nil, err
}
if all {
set.AuthMethod, auther, err = getAuthentication(flags)
if err != nil {
return nil, err
}
} else {
set.AuthMethod, auther, err = getAuthentication(flags, hasAuth, set, auther)
if err != nil {
return nil, err
}
}
return auther, nil
}

View File

@ -5,7 +5,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/filebrowser/filebrowser/v2/auth"
"github.com/filebrowser/filebrowser/v2/settings" "github.com/filebrowser/filebrowser/v2/settings"
) )
@ -26,140 +25,17 @@ override the options.`,
RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error { RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error {
flags := cmd.Flags() flags := cmd.Flags()
// General Settings // Initialize config
s := &settings.Settings{ s := &settings.Settings{Key: generateKey()}
Key: generateKey(),
}
err := getUserDefaults(flags, &s.Defaults, true)
if err != nil {
return err
}
s.Signup, err = flags.GetBool("signup")
if err != nil {
return err
}
s.HideLoginButton, err = flags.GetBool("hideLoginButton")
if err != nil {
return err
}
s.CreateUserDir, err = flags.GetBool("createUserDir")
if err != nil {
return err
}
s.MinimumPasswordLength, err = flags.GetUint("minimumPasswordLength")
if err != nil {
return err
}
shell, err := flags.GetString("shell")
if err != nil {
return err
}
s.Shell = convertCmdStrToCmdArray(shell)
s.FileMode, err = getAndParseFileMode(flags, "fileMode")
if err != nil {
return err
}
s.DirMode, err = getAndParseFileMode(flags, "dirMode")
if err != nil {
return err
}
s.Branding.Name, err = flags.GetString("branding.name")
if err != nil {
return err
}
s.Branding.DisableExternal, err = flags.GetBool("branding.disableExternal")
if err != nil {
return err
}
s.Branding.DisableUsedPercentage, err = flags.GetBool("branding.disableUsedPercentage")
if err != nil {
return err
}
s.Branding.Theme, err = flags.GetString("branding.themes")
if err != nil {
return err
}
s.Branding.Files, err = flags.GetString("branding.files")
if err != nil {
return err
}
s.Tus.ChunkSize, err = flags.GetUint64("tus.chunkSize")
if err != nil {
return err
}
s.Tus.RetryCount, err = flags.GetUint16("tus.retryCount")
if err != nil {
return err
}
var auther auth.Auther
s.AuthMethod, auther, err = getAuthentication(flags)
if err != nil {
return err
}
// Server Settings
ser := &settings.Server{} ser := &settings.Server{}
ser.Address, err = flags.GetString("address")
if err != nil { // Fill config with options
return err auther, err := getSettings(flags, s, ser, nil, true)
}
ser.Socket, err = flags.GetString("socket")
if err != nil {
return err
}
ser.Root, err = flags.GetString("root")
if err != nil {
return err
}
ser.BaseURL, err = flags.GetString("baseURL")
if err != nil {
return err
}
ser.TLSKey, err = flags.GetString("key")
if err != nil {
return err
}
ser.TLSCert, err = flags.GetString("cert")
if err != nil {
return err
}
ser.Port, err = flags.GetString("port")
if err != nil {
return err
}
ser.Log, err = flags.GetString("log")
if err != nil {
return err
}
ser.TokenExpirationTime, err = flags.GetString("tokenExpirationTime")
if err != nil { if err != nil {
return err return err
} }
// Save updated config
err = d.store.Settings.Save(s) err = d.store.Settings.Save(s)
if err != nil { if err != nil {
return err return err

View File

@ -2,7 +2,6 @@ package cmd
import ( import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
) )
func init() { func init() {
@ -19,6 +18,7 @@ you want to change. Other options will remain unchanged.`,
RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error { RunE: python(func(cmd *cobra.Command, _ []string, d *pythonData) error {
flags := cmd.Flags() flags := cmd.Flags()
// Read existing config
set, err := d.store.Settings.Get() set, err := d.store.Settings.Get()
if err != nil { if err != nil {
return err return err
@ -29,116 +29,28 @@ you want to change. Other options will remain unchanged.`,
return err return err
} }
hasAuth := false
flags.Visit(func(flag *pflag.Flag) {
if err != nil {
return
}
switch flag.Name {
// Server flags from [addServerFlags]
case "address":
ser.Address, err = flags.GetString(flag.Name)
case "log":
ser.Log, err = flags.GetString(flag.Name)
case "port":
ser.Port, err = flags.GetString(flag.Name)
case "cert":
ser.TLSCert, err = flags.GetString(flag.Name)
case "key":
ser.TLSKey, err = flags.GetString(flag.Name)
case "root":
ser.Root, err = flags.GetString(flag.Name)
case "socket":
ser.Socket, err = flags.GetString(flag.Name)
case "baseURL":
ser.BaseURL, err = flags.GetString(flag.Name)
case "tokenExpirationTime":
ser.TokenExpirationTime, err = flags.GetString(flag.Name)
case "disableThumbnails":
ser.EnableThumbnails, err = flags.GetBool(flag.Name)
ser.EnableThumbnails = !ser.EnableThumbnails
case "disablePreviewResize":
ser.ResizePreview, err = flags.GetBool(flag.Name)
ser.ResizePreview = !ser.ResizePreview
case "disableExec":
ser.EnableExec, err = flags.GetBool(flag.Name)
ser.EnableExec = !ser.EnableExec
case "disableTypeDetectionByHeader":
ser.TypeDetectionByHeader, err = flags.GetBool(flag.Name)
ser.TypeDetectionByHeader = !ser.TypeDetectionByHeader
// Settings flags from [addConfigFlags]
case "signup":
set.Signup, err = flags.GetBool(flag.Name)
case "hideLoginButton":
set.HideLoginButton, err = flags.GetBool(flag.Name)
case "createUserDir":
set.CreateUserDir, err = flags.GetBool(flag.Name)
case "minimumPasswordLength":
set.MinimumPasswordLength, err = flags.GetUint(flag.Name)
case "shell":
var shell string
shell, err = flags.GetString(flag.Name)
if err != nil {
return
}
set.Shell = convertCmdStrToCmdArray(shell)
case "auth.method":
hasAuth = true
case "branding.name":
set.Branding.Name, err = flags.GetString(flag.Name)
case "branding.theme":
set.Branding.Theme, err = flags.GetString(flag.Name)
case "branding.color":
set.Branding.Color, err = flags.GetString(flag.Name)
case "branding.files":
set.Branding.Files, err = flags.GetString(flag.Name)
case "branding.disableExternal":
set.Branding.DisableExternal, err = flags.GetBool(flag.Name)
case "branding.disableUsedPercentage":
set.Branding.DisableUsedPercentage, err = flags.GetBool(flag.Name)
case "fileMode":
set.FileMode, err = getAndParseFileMode(flags, flag.Name)
case "dirMode":
set.DirMode, err = getAndParseFileMode(flags, flag.Name)
case "tus.chunkSize":
set.Tus.ChunkSize, err = flags.GetUint64(flag.Name)
case "tus.retryCount":
set.Tus.RetryCount, err = flags.GetUint16(flag.Name)
}
})
if err != nil {
return err
}
err = getUserDefaults(flags, &set.Defaults, false)
if err != nil {
return err
}
// read the defaults
auther, err := d.store.Auth.Get(set.AuthMethod) auther, err := d.store.Auth.Get(set.AuthMethod)
if err != nil { if err != nil {
return err return err
} }
// check if there are new flags for existing auth method // Get updated config
set.AuthMethod, auther, err = getAuthentication(flags, hasAuth, set, auther) auther, err = getSettings(flags, set, ser, auther, false)
if err != nil { if err != nil {
return err return err
} }
// Save updated config
err = d.store.Auth.Save(auther) err = d.store.Auth.Save(auther)
if err != nil { if err != nil {
return err return err
} }
err = d.store.Settings.Save(set) err = d.store.Settings.Save(set)
if err != nil { if err != nil {
return err return err
} }
err = d.store.Settings.SaveServer(ser) err = d.store.Settings.SaveServer(ser)
if err != nil { if err != nil {
return err return err