mirror of https://github.com/cloudreve/Cloudreve
191 lines
4.4 KiB
Go
191 lines
4.4 KiB
Go
package conf
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
|
|
"github.com/cloudreve/Cloudreve/v4/pkg/util"
|
|
"github.com/go-ini/ini"
|
|
"github.com/go-playground/validator/v10"
|
|
)
|
|
|
|
const (
|
|
envConfOverrideKey = "CR_CONF_"
|
|
)
|
|
|
|
type ConfigProvider interface {
|
|
Database() *Database
|
|
System() *System
|
|
SSL() *SSL
|
|
Unix() *Unix
|
|
Slave() *Slave
|
|
Redis() *Redis
|
|
Cors() *Cors
|
|
OptionOverwrite() map[string]any
|
|
}
|
|
|
|
// NewIniConfigProvider initializes a new Ini config file provider. A default config file
|
|
// will be created if the given path does not exist.
|
|
func NewIniConfigProvider(configPath string, l logging.Logger) (ConfigProvider, error) {
|
|
if configPath == "" || !util.Exists(configPath) {
|
|
l.Info("Config file %q not found, creating a new one.", configPath)
|
|
// 创建初始配置文件
|
|
confContent := util.Replace(map[string]string{
|
|
"{SessionSecret}": util.RandStringRunes(64),
|
|
}, defaultConf)
|
|
f, err := util.CreatNestedFile(configPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create config file: %w", err)
|
|
}
|
|
|
|
// 写入配置文件
|
|
_, err = f.WriteString(confContent)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to write config file: %w", err)
|
|
}
|
|
|
|
f.Close()
|
|
}
|
|
|
|
cfg, err := ini.Load(configPath, []byte(getOverrideConfFromEnv(l)))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse config file %q: %w", configPath, err)
|
|
}
|
|
|
|
provider := &iniConfigProvider{
|
|
database: *DatabaseConfig,
|
|
system: *SystemConfig,
|
|
ssl: *SSLConfig,
|
|
unix: *UnixConfig,
|
|
slave: *SlaveConfig,
|
|
redis: *RedisConfig,
|
|
cors: *CORSConfig,
|
|
optionOverwrite: make(map[string]interface{}),
|
|
}
|
|
|
|
sections := map[string]interface{}{
|
|
"Database": &provider.database,
|
|
"System": &provider.system,
|
|
"SSL": &provider.ssl,
|
|
"UnixSocket": &provider.unix,
|
|
"Redis": &provider.redis,
|
|
"CORS": &provider.cors,
|
|
"Slave": &provider.slave,
|
|
}
|
|
for sectionName, sectionStruct := range sections {
|
|
err = mapSection(cfg, sectionName, sectionStruct)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse config section %q: %w", sectionName, err)
|
|
}
|
|
}
|
|
|
|
// 映射数据库配置覆盖
|
|
for _, key := range cfg.Section("OptionOverwrite").Keys() {
|
|
provider.optionOverwrite[key.Name()] = key.Value()
|
|
}
|
|
|
|
return provider, nil
|
|
}
|
|
|
|
type iniConfigProvider struct {
|
|
database Database
|
|
system System
|
|
ssl SSL
|
|
unix Unix
|
|
slave Slave
|
|
redis Redis
|
|
cors Cors
|
|
optionOverwrite map[string]any
|
|
}
|
|
|
|
func (i *iniConfigProvider) Database() *Database {
|
|
return &i.database
|
|
}
|
|
|
|
func (i *iniConfigProvider) System() *System {
|
|
return &i.system
|
|
}
|
|
|
|
func (i *iniConfigProvider) SSL() *SSL {
|
|
return &i.ssl
|
|
}
|
|
|
|
func (i *iniConfigProvider) Unix() *Unix {
|
|
return &i.unix
|
|
}
|
|
|
|
func (i *iniConfigProvider) Slave() *Slave {
|
|
return &i.slave
|
|
}
|
|
|
|
func (i *iniConfigProvider) Redis() *Redis {
|
|
return &i.redis
|
|
}
|
|
|
|
func (i *iniConfigProvider) Cors() *Cors {
|
|
return &i.cors
|
|
}
|
|
|
|
func (i *iniConfigProvider) OptionOverwrite() map[string]any {
|
|
return i.optionOverwrite
|
|
}
|
|
|
|
const defaultConf = `[System]
|
|
Debug = false
|
|
Mode = master
|
|
Listen = :5212
|
|
SessionSecret = {SessionSecret}
|
|
HashIDSalt = {HashIDSalt}
|
|
`
|
|
|
|
// mapSection 将配置文件的 Section 映射到结构体上
|
|
func mapSection(cfg *ini.File, section string, confStruct interface{}) error {
|
|
err := cfg.Section(section).MapTo(confStruct)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 验证合法性
|
|
validate := validator.New()
|
|
err = validate.Struct(confStruct)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getOverrideConfFromEnv(l logging.Logger) string {
|
|
confMaps := make(map[string]map[string]string)
|
|
for _, env := range os.Environ() {
|
|
if !strings.HasPrefix(env, envConfOverrideKey) {
|
|
continue
|
|
}
|
|
|
|
// split by key=value and get key
|
|
kv := strings.SplitN(env, "=", 2)
|
|
configKey := strings.TrimPrefix(kv[0], envConfOverrideKey)
|
|
configValue := kv[1]
|
|
sectionKey := strings.SplitN(configKey, ".", 2)
|
|
if confMaps[sectionKey[0]] == nil {
|
|
confMaps[sectionKey[0]] = make(map[string]string)
|
|
}
|
|
|
|
confMaps[sectionKey[0]][sectionKey[1]] = configValue
|
|
l.Info("Override config %q = %q", configKey, configValue)
|
|
}
|
|
|
|
// generate ini content
|
|
var sb strings.Builder
|
|
for section, kvs := range confMaps {
|
|
sb.WriteString(fmt.Sprintf("[%s]\n", section))
|
|
for k, v := range kvs {
|
|
sb.WriteString(fmt.Sprintf("%s = %s\n", k, v))
|
|
}
|
|
}
|
|
|
|
return sb.String()
|
|
}
|