diff --git a/client/admin_api.go b/client/admin_api.go index e775f52..84db31c 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -57,7 +57,7 @@ func (svr *Service) apiReload(w http.ResponseWriter, _ *http.Request) { } }() - cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile) + cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(svr.cfgFile, svr.strictConfig) if err != nil { res.Code = 400 res.Msg = err.Error() diff --git a/client/service.go b/client/service.go index 66a642c..0a25ae0 100644 --- a/client/service.go +++ b/client/service.go @@ -70,6 +70,9 @@ type Service struct { // string if no configuration file was used. cfgFile string + // Whether strict configuration parsing had been requested. + strictConfig bool + // service context ctx context.Context // call cancel to stop service @@ -82,14 +85,16 @@ func NewService( pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer, cfgFile string, + strictConfig bool, ) *Service { return &Service{ - authSetter: auth.NewAuthSetter(cfg.Auth), - cfg: cfg, - cfgFile: cfgFile, - pxyCfgs: pxyCfgs, - visitorCfgs: visitorCfgs, - ctx: context.Background(), + authSetter: auth.NewAuthSetter(cfg.Auth), + cfg: cfg, + cfgFile: cfgFile, + strictConfig: strictConfig, + pxyCfgs: pxyCfgs, + visitorCfgs: visitorCfgs, + ctx: context.Background(), } } diff --git a/cmd/frpc/sub/admin.go b/cmd/frpc/sub/admin.go index 2a5f283..d98b4d3 100644 --- a/cmd/frpc/sub/admin.go +++ b/cmd/frpc/sub/admin.go @@ -52,7 +52,7 @@ func NewAdminCommand(name, short string, handler func(*v1.ClientCommonConfig) er Use: name, Short: short, Run: func(cmd *cobra.Command, args []string) { - cfg, _, _, _, err := config.LoadClientConfig(cfgFile) + cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfig) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/nathole.go b/cmd/frpc/sub/nathole.go index 72b635f..eafea27 100644 --- a/cmd/frpc/sub/nathole.go +++ b/cmd/frpc/sub/nathole.go @@ -48,7 +48,7 @@ var natholeDiscoveryCmd = &cobra.Command{ Short: "Discover nathole information from stun server", RunE: func(cmd *cobra.Command, args []string) error { // ignore error here, because we can use command line pameters - cfg, _, _, _, err := config.LoadClientConfig(cfgFile) + cfg, _, _, _, err := config.LoadClientConfig(cfgFile, strictConfig) if err != nil { cfg = &v1.ClientCommonConfig{} } diff --git a/cmd/frpc/sub/proxy.go b/cmd/frpc/sub/proxy.go index 7ae8d35..41c20bc 100644 --- a/cmd/frpc/sub/proxy.go +++ b/cmd/frpc/sub/proxy.go @@ -84,7 +84,7 @@ func NewProxyCommand(name string, c v1.ProxyConfigurer, clientCfg *v1.ClientComm fmt.Println(err) os.Exit(1) } - err := startService(clientCfg, []v1.ProxyConfigurer{c}, nil, "") + err := startService(clientCfg, []v1.ProxyConfigurer{c}, nil, "", strictConfig) if err != nil { fmt.Println(err) os.Exit(1) @@ -110,7 +110,7 @@ func NewVisitorCommand(name string, c v1.VisitorConfigurer, clientCfg *v1.Client fmt.Println(err) os.Exit(1) } - err := startService(clientCfg, nil, []v1.VisitorConfigurer{c}, "") + err := startService(clientCfg, nil, []v1.VisitorConfigurer{c}, "", strictConfig) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index 125c88c..855c7ab 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -36,15 +36,17 @@ import ( ) var ( - cfgFile string - cfgDir string - showVersion bool + cfgFile string + cfgDir string + showVersion bool + strictConfig bool ) func init() { rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "./frpc.ini", "config file of frpc") rootCmd.PersistentFlags().StringVarP(&cfgDir, "config_dir", "", "", "config directory, run one frpc service for each file in config directory") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frpc") + rootCmd.PersistentFlags().BoolVarP(&strictConfig, "strict_config", "", false, "strict config parsing mode") } var rootCmd = &cobra.Command{ @@ -108,7 +110,7 @@ func handleTermSignal(svr *client.Service) { } func runClient(cfgFilePath string) error { - cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath) + cfg, pxyCfgs, visitorCfgs, isLegacyFormat, err := config.LoadClientConfig(cfgFilePath, strictConfig) if err != nil { return err } @@ -120,11 +122,14 @@ func runClient(cfgFilePath string) error { warning, err := validation.ValidateAllClientConfig(cfg, pxyCfgs, visitorCfgs) if warning != nil { fmt.Printf("WARNING: %v\n", warning) + if strictConfig { + return fmt.Errorf("warning: %v", warning) + } } if err != nil { return err } - return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath) + return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath, strictConfig) } func startService( @@ -132,6 +137,7 @@ func startService( pxyCfgs []v1.ProxyConfigurer, visitorCfgs []v1.VisitorConfigurer, cfgFile string, + strictConfig bool, ) error { log.InitLog(cfg.Log.To, cfg.Log.Level, cfg.Log.MaxDays, cfg.Log.DisablePrintColor) @@ -139,7 +145,7 @@ func startService( log.Info("start frpc service for config file [%s]", cfgFile) defer log.Info("frpc service for config file [%s] stopped", cfgFile) } - svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile) + svr := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile, strictConfig) shouldGracefulClose := cfg.Transport.Protocol == "kcp" || cfg.Transport.Protocol == "quic" // Capture the exit signal if we use kcp or quic. diff --git a/cmd/frpc/sub/verify.go b/cmd/frpc/sub/verify.go index a84f54f..0e6adca 100644 --- a/cmd/frpc/sub/verify.go +++ b/cmd/frpc/sub/verify.go @@ -37,7 +37,7 @@ var verifyCmd = &cobra.Command{ return nil } - cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile) + cliCfg, pxyCfgs, visitorCfgs, _, err := config.LoadClientConfig(cfgFile, strictConfig) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frps/root.go b/cmd/frps/root.go index 4a6f011..adb8852 100644 --- a/cmd/frps/root.go +++ b/cmd/frps/root.go @@ -30,8 +30,9 @@ import ( ) var ( - cfgFile string - showVersion bool + cfgFile string + showVersion bool + strictConfig bool serverCfg v1.ServerConfig ) @@ -39,6 +40,7 @@ var ( func init() { rootCmd.PersistentFlags().StringVarP(&cfgFile, "config", "c", "", "config file of frps") rootCmd.PersistentFlags().BoolVarP(&showVersion, "version", "v", false, "version of frps") + rootCmd.PersistentFlags().BoolVarP(&strictConfig, "strict_config", "", false, "strict config parsing mode") RegisterServerConfigFlags(rootCmd, &serverCfg) } @@ -58,7 +60,7 @@ var rootCmd = &cobra.Command{ err error ) if cfgFile != "" { - svrCfg, isLegacyFormat, err = config.LoadServerConfig(cfgFile) + svrCfg, isLegacyFormat, err = config.LoadServerConfig(cfgFile, strictConfig) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frps/verify.go b/cmd/frps/verify.go index 4f0cefb..838ac7b 100644 --- a/cmd/frps/verify.go +++ b/cmd/frps/verify.go @@ -36,7 +36,7 @@ var verifyCmd = &cobra.Command{ fmt.Println("frps: the configuration file is not specified") return nil } - svrCfg, _, err := config.LoadServerConfig(cfgFile) + svrCfg, _, err := config.LoadServerConfig(cfgFile, strictConfig) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/pkg/config/load.go b/pkg/config/load.go index af2c3e8..a4013c3 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -27,7 +27,7 @@ import ( "github.com/samber/lo" "gopkg.in/ini.v1" "k8s.io/apimachinery/pkg/util/sets" - "k8s.io/apimachinery/pkg/util/yaml" + yaml "k8s.io/apimachinery/pkg/util/yaml" "github.com/fatedier/frp/pkg/config/legacy" v1 "github.com/fatedier/frp/pkg/config/v1" @@ -100,26 +100,39 @@ func LoadFileContentWithTemplate(path string, values *Values) ([]byte, error) { return RenderWithTemplate(b, values) } -func LoadConfigureFromFile(path string, c any) error { +func LoadConfigureFromFile(path string, c any, strict bool) error { content, err := LoadFileContentWithTemplate(path, GetValues()) if err != nil { return err } - return LoadConfigure(content, c) + return LoadConfigure(content, c, strict) } // LoadConfigure loads configuration from bytes and unmarshal into c. // Now it supports json, yaml and toml format. -func LoadConfigure(b []byte, c any) error { +func LoadConfigure(b []byte, c any, strict bool) error { var tomlObj interface{} + // Try to unmarshal as TOML first; swallow errors from that (assume it's not valid TOML). + // TODO: caller should probably be able to specify the format, so we don't need to swallow errors. if err := toml.Unmarshal(b, &tomlObj); err == nil { b, err = json.Marshal(&tomlObj) if err != nil { return err } } - decoder := yaml.NewYAMLOrJSONDecoder(bytes.NewBuffer(b), 4096) - return decoder.Decode(c) + // If the buffer smells like JSON (first non-whitespace character is '{'), unmarshal as JSON directly. + if yaml.IsJSONBuffer(b) { + decoder := json.NewDecoder(bytes.NewBuffer(b)) + if strict { + decoder.DisallowUnknownFields() + } + return decoder.Decode(c) + } + // It wasn't JSON. Unmarshal as YAML. + if strict { + return yaml.UnmarshalStrict(b, c) + } + return yaml.Unmarshal(b, c) } func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1.ProxyConfigurer, error) { @@ -139,7 +152,7 @@ func NewProxyConfigurerFromMsg(m *msg.NewProxy, serverCfg *v1.ServerConfig) (v1. return configurer, nil } -func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) { +func LoadServerConfig(path string, strict bool) (*v1.ServerConfig, bool, error) { var ( svrCfg *v1.ServerConfig isLegacyFormat bool @@ -158,7 +171,7 @@ func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) { isLegacyFormat = true } else { svrCfg = &v1.ServerConfig{} - if err := LoadConfigureFromFile(path, svrCfg); err != nil { + if err := LoadConfigureFromFile(path, svrCfg, strict); err != nil { return nil, false, err } } @@ -168,7 +181,7 @@ func LoadServerConfig(path string) (*v1.ServerConfig, bool, error) { return svrCfg, isLegacyFormat, nil } -func LoadClientConfig(path string) ( +func LoadClientConfig(path string, strict bool) ( *v1.ClientCommonConfig, []v1.ProxyConfigurer, []v1.VisitorConfigurer, @@ -196,7 +209,7 @@ func LoadClientConfig(path string) ( isLegacyFormat = true } else { allCfg := v1.ClientConfig{} - if err := LoadConfigureFromFile(path, &allCfg); err != nil { + if err := LoadConfigureFromFile(path, &allCfg, strict); err != nil { return nil, nil, nil, false, err } cliCfg = &allCfg.ClientCommonConfig @@ -211,7 +224,7 @@ func LoadClientConfig(path string) ( // Load additional config from includes. // legacy ini format alredy handle this in ParseClientConfig. if len(cliCfg.IncludeConfigFiles) > 0 && !isLegacyFormat { - extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat) + extPxyCfgs, extVisitorCfgs, err := LoadAdditionalClientConfigs(cliCfg.IncludeConfigFiles, isLegacyFormat, strict) if err != nil { return nil, nil, nil, isLegacyFormat, err } @@ -242,7 +255,7 @@ func LoadClientConfig(path string) ( return cliCfg, pxyCfgs, visitorCfgs, isLegacyFormat, nil } -func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) { +func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool, strict bool) ([]v1.ProxyConfigurer, []v1.VisitorConfigurer, error) { pxyCfgs := make([]v1.ProxyConfigurer, 0) visitorCfgs := make([]v1.VisitorConfigurer, 0) for _, path := range paths { @@ -265,7 +278,7 @@ func LoadAdditionalClientConfigs(paths []string, isLegacyFormat bool) ([]v1.Prox if matched, _ := filepath.Match(filepath.Join(absDir, filepath.Base(path)), absFile); matched { // support yaml/json/toml cfg := v1.ClientConfig{} - if err := LoadConfigureFromFile(absFile, &cfg); err != nil { + if err := LoadConfigureFromFile(absFile, &cfg, strict); err != nil { return nil, nil, fmt.Errorf("load additional config from %s error: %v", absFile, err) } for _, c := range cfg.Proxies { diff --git a/pkg/config/load_test.go b/pkg/config/load_test.go index eab4ba9..876d53e 100644 --- a/pkg/config/load_test.go +++ b/pkg/config/load_test.go @@ -15,6 +15,7 @@ package config import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -22,9 +23,7 @@ import ( v1 "github.com/fatedier/frp/pkg/config/v1" ) -func TestLoadConfigure(t *testing.T) { - require := require.New(t) - content := ` +const tomlServerContent = ` bindAddr = "127.0.0.1" kcpBindPort = 7000 quicBindPort = 7001 @@ -33,13 +32,60 @@ custom404Page = "/abc.html" transport.tcpKeepalive = 10 ` - svrCfg := v1.ServerConfig{} - err := LoadConfigure([]byte(content), &svrCfg) - require.NoError(err) - require.EqualValues("127.0.0.1", svrCfg.BindAddr) - require.EqualValues(7000, svrCfg.KCPBindPort) - require.EqualValues(7001, svrCfg.QUICBindPort) - require.EqualValues(7005, svrCfg.TCPMuxHTTPConnectPort) - require.EqualValues("/abc.html", svrCfg.Custom404Page) - require.EqualValues(10, svrCfg.Transport.TCPKeepAlive) +const yamlServerContent = ` +bindAddr: 127.0.0.1 +kcpBindPort: 7000 +quicBindPort: 7001 +tcpmuxHTTPConnectPort: 7005 +custom404Page: /abc.html +transport: + tcpKeepalive: 10 +` + +const jsonServerContent = ` +{ + "bindAddr": "127.0.0.1", + "kcpBindPort": 7000, + "quicBindPort": 7001, + "tcpmuxHTTPConnectPort": 7005, + "custom404Page": "/abc.html", + "transport": { + "tcpKeepalive": 10 + } +} +` + +func TestLoadServerConfig(t *testing.T) { + for _, content := range []string{tomlServerContent, yamlServerContent, jsonServerContent} { + svrCfg := v1.ServerConfig{} + err := LoadConfigure([]byte(content), &svrCfg, true) + require := require.New(t) + require.NoError(err) + require.EqualValues("127.0.0.1", svrCfg.BindAddr) + require.EqualValues(7000, svrCfg.KCPBindPort) + require.EqualValues(7001, svrCfg.QUICBindPort) + require.EqualValues(7005, svrCfg.TCPMuxHTTPConnectPort) + require.EqualValues("/abc.html", svrCfg.Custom404Page) + require.EqualValues(10, svrCfg.Transport.TCPKeepAlive) + } +} + +// Test that loading in strict mode fails when the config is invalid. +func TestLoadServerConfigErrorMode(t *testing.T) { + for strict := range []bool{false, true} { + for _, content := range []string{tomlServerContent, yamlServerContent, jsonServerContent} { + // Break the content with an innocent typo + brokenContent := strings.Replace(content, "bindAddr", "bindAdur", 1) + svrCfg := v1.ServerConfig{} + err := LoadConfigure([]byte(brokenContent), &svrCfg, strict == 1) + require := require.New(t) + if strict == 1 { + require.ErrorContains(err, "bindAdur") + } else { + require.NoError(err) + // BindAddr didn't get parsed because of the typo. + require.EqualValues("", svrCfg.BindAddr) + } + } + } }