Cloudreve/pkg/conf/conf.go

191 lines
4.4 KiB
Go

package conf
import (
"fmt"
"os"
"strings"
"github.com/cloudreve/Cloudreve/v4/pkg/logging"
"github.com/cloudreve/Cloudreve/v4/pkg/util"
"github.com/go-ini/ini"
"github.com/go-playground/validator/v10"
)
const (
envConfOverrideKey = "CR_CONF_"
)
type ConfigProvider interface {
Database() *Database
System() *System
SSL() *SSL
Unix() *Unix
Slave() *Slave
Redis() *Redis
Cors() *Cors
OptionOverwrite() map[string]any
}
// NewIniConfigProvider initializes a new Ini config file provider. A default config file
// will be created if the given path does not exist.
func NewIniConfigProvider(configPath string, l logging.Logger) (ConfigProvider, error) {
if configPath == "" || !util.Exists(configPath) {
l.Info("Config file %q not found, creating a new one.", configPath)
// 创建初始配置文件
confContent := util.Replace(map[string]string{
"{SessionSecret}": util.RandStringRunes(64),
}, defaultConf)
f, err := util.CreatNestedFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to create config file: %w", err)
}
// 写入配置文件
_, err = f.WriteString(confContent)
if err != nil {
return nil, fmt.Errorf("failed to write config file: %w", err)
}
f.Close()
}
cfg, err := ini.Load(configPath, []byte(getOverrideConfFromEnv(l)))
if err != nil {
return nil, fmt.Errorf("failed to parse config file %q: %w", configPath, err)
}
provider := &iniConfigProvider{
database: *DatabaseConfig,
system: *SystemConfig,
ssl: *SSLConfig,
unix: *UnixConfig,
slave: *SlaveConfig,
redis: *RedisConfig,
cors: *CORSConfig,
optionOverwrite: make(map[string]interface{}),
}
sections := map[string]interface{}{
"Database": &provider.database,
"System": &provider.system,
"SSL": &provider.ssl,
"UnixSocket": &provider.unix,
"Redis": &provider.redis,
"CORS": &provider.cors,
"Slave": &provider.slave,
}
for sectionName, sectionStruct := range sections {
err = mapSection(cfg, sectionName, sectionStruct)
if err != nil {
return nil, fmt.Errorf("failed to parse config section %q: %w", sectionName, err)
}
}
// 映射数据库配置覆盖
for _, key := range cfg.Section("OptionOverwrite").Keys() {
provider.optionOverwrite[key.Name()] = key.Value()
}
return provider, nil
}
type iniConfigProvider struct {
database Database
system System
ssl SSL
unix Unix
slave Slave
redis Redis
cors Cors
optionOverwrite map[string]any
}
func (i *iniConfigProvider) Database() *Database {
return &i.database
}
func (i *iniConfigProvider) System() *System {
return &i.system
}
func (i *iniConfigProvider) SSL() *SSL {
return &i.ssl
}
func (i *iniConfigProvider) Unix() *Unix {
return &i.unix
}
func (i *iniConfigProvider) Slave() *Slave {
return &i.slave
}
func (i *iniConfigProvider) Redis() *Redis {
return &i.redis
}
func (i *iniConfigProvider) Cors() *Cors {
return &i.cors
}
func (i *iniConfigProvider) OptionOverwrite() map[string]any {
return i.optionOverwrite
}
const defaultConf = `[System]
Debug = false
Mode = master
Listen = :5212
SessionSecret = {SessionSecret}
HashIDSalt = {HashIDSalt}
`
// mapSection 将配置文件的 Section 映射到结构体上
func mapSection(cfg *ini.File, section string, confStruct interface{}) error {
err := cfg.Section(section).MapTo(confStruct)
if err != nil {
return err
}
// 验证合法性
validate := validator.New()
err = validate.Struct(confStruct)
if err != nil {
return err
}
return nil
}
func getOverrideConfFromEnv(l logging.Logger) string {
confMaps := make(map[string]map[string]string)
for _, env := range os.Environ() {
if !strings.HasPrefix(env, envConfOverrideKey) {
continue
}
// split by key=value and get key
kv := strings.SplitN(env, "=", 2)
configKey := strings.TrimPrefix(kv[0], envConfOverrideKey)
configValue := kv[1]
sectionKey := strings.SplitN(configKey, ".", 2)
if confMaps[sectionKey[0]] == nil {
confMaps[sectionKey[0]] = make(map[string]string)
}
confMaps[sectionKey[0]][sectionKey[1]] = configValue
l.Info("Override config %q = %q", configKey, configValue)
}
// generate ini content
var sb strings.Builder
for section, kvs := range confMaps {
sb.WriteString(fmt.Sprintf("[%s]\n", section))
for k, v := range kvs {
sb.WriteString(fmt.Sprintf("%s = %s\n", k, v))
}
}
return sb.String()
}