From 250cbdde7c18d8aa6f30b51b9b4dc405ffc038f5 Mon Sep 17 00:00:00 2001 From: ouqiang Date: Thu, 7 Sep 2017 18:01:48 +0800 Subject: [PATCH] =?UTF-8?q?refactor($auth):=20=E8=B0=83=E5=BA=A6=E5=99=A8?= =?UTF-8?q?=E4=B8=8E=E4=BB=BB=E5=8A=A1=E8=8A=82=E7=82=B9=E6=94=AF=E6=8C=81?= =?UTF-8?q?HTTPS=E5=8F=8C=E5=90=91=E8=AE=A4=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gocron-node.go | 41 ++++++++++++++++-- gocron.go | 2 +- models/model.go | 67 +++++++++-------------------- modules/app/app.go | 4 +- modules/rpc/auth/Certification.go | 71 +++++++++++++++++++++++++++++++ modules/rpc/grpcpool/grpc_pool.go | 23 +++++++++- modules/rpc/server/server.go | 29 ++++++++++--- modules/setting/setting.go | 69 +++++++++++++++++++++++++++++- routers/install/install.go | 16 +++---- routers/routers.go | 14 +++--- 10 files changed, 255 insertions(+), 81 deletions(-) create mode 100644 modules/rpc/auth/Certification.go diff --git a/gocron-node.go b/gocron-node.go index 163997e..812317e 100644 --- a/gocron-node.go +++ b/gocron-node.go @@ -9,29 +9,62 @@ import ( "runtime" "os" "fmt" + "strings" + "github.com/ouqiang/gocron/modules/rpc/auth" + "github.com/ouqiang/gocron/modules/utils" ) -const AppVersion = "1.2" +const AppVersion = "1.2.2" func main() { var serverAddr string var allowRoot bool var version bool + var CAFile string + var certFile string + var keyFile string + var enableTLS bool flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root") flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port") flag.BoolVar(&version, "v", false, "./gocron-node -v") + flag.BoolVar(&enableTLS, "enable-tls", false, "./gocron-node -enable-tls") + flag.StringVar(&CAFile, "ca-file", "", "./gocron-node -ca-file path") + flag.StringVar(&certFile, "cert-file", "", "./gocron-node -cert-file path") + flag.StringVar(&keyFile, "key-file", "", "./gocron-node -key-file path") flag.Parse() if version { fmt.Println(AppVersion) - os.Exit(0) + return } + if (enableTLS) { + if !utils.FileExist(CAFile) { + fmt.Printf("failed to read ca cert file: %s", CAFile) + return + } + if !utils.FileExist(certFile) { + fmt.Printf("failed to read server cert file: %s", certFile) + return + } + if !utils.FileExist(keyFile) { + fmt.Printf("failed to read server key file: %s", keyFile) + return + } + } + + certificate := auth.Certificate{ + CAFile: strings.TrimSpace(CAFile), + CertFile: strings.TrimSpace(certFile), + KeyFile: strings.TrimSpace(keyFile), + } + + if runtime.GOOS != "windows" && os.Getuid() == 0 && !allowRoot { fmt.Println("Do not run gocron-node as root user") - os.Exit(1) + return } - server.Start(serverAddr) + server.Start(serverAddr, enableTLS, certificate) } \ No newline at end of file diff --git a/gocron.go b/gocron.go index cc58076..2f29406 100644 --- a/gocron.go +++ b/gocron.go @@ -10,7 +10,7 @@ import ( "github.com/ouqiang/gocron/cmd" ) -const AppVersion = "1.1" +const AppVersion = "1.2.2" func main() { app := cli.NewApp() diff --git a/models/model.go b/models/model.go index f5c5c20..2ebc993 100644 --- a/models/model.go +++ b/models/model.go @@ -9,8 +9,8 @@ import ( "strings" "github.com/ouqiang/gocron/modules/logger" "github.com/ouqiang/gocron/modules/app" - "strconv" "time" + "github.com/ouqiang/gocron/modules/setting" ) type Status int8 @@ -65,27 +65,18 @@ func (model *BaseModel) pageLimitOffset() int { // 创建Db func CreateDb() *xorm.Engine { - config := getDbConfig() - dsn := getDbEngineDSN(config["engine"], config) - engine, err := xorm.NewEngine(config["engine"], dsn) + dsn := getDbEngineDSN(app.Setting) + engine, err := xorm.NewEngine(app.Setting.Db.Engine, dsn) if err != nil { logger.Fatal("创建xorm引擎失败", err) } - maxIdleConns, err := strconv.Atoi(config["max_idle_conns"]) - maxOpenConns, err := strconv.Atoi(config["max_open_conns"]) - if maxIdleConns <= 0 { - maxIdleConns = 30 - } - if maxOpenConns <= 0 { - maxOpenConns = 100 - } - engine.SetMaxIdleConns(maxIdleConns) - engine.SetMaxOpenConns(maxOpenConns) + engine.SetMaxIdleConns(app.Setting.Db.MaxIdleConns) + engine.SetMaxOpenConns(app.Setting.Db.MaxOpenConns) - if config["prefix"] != "" { + if app.Setting.Db.Prefix != "" { // 设置表前缀 - TablePrefix = config["prefix"] - mapper := core.NewPrefixMapper(core.SnakeMapper{}, config["prefix"]) + TablePrefix = app.Setting.Db.Prefix + mapper := core.NewPrefixMapper(core.SnakeMapper{}, app.Setting.Db.Prefix) engine.SetTableMapper(mapper) } // 本地环境开启日志 @@ -100,48 +91,30 @@ func CreateDb() *xorm.Engine { } // 创建临时数据库连接 -func CreateTmpDb(config map[string]string) (*xorm.Engine, error) { - dsn := getDbEngineDSN(config["engine"], config) +func CreateTmpDb(setting *setting.Setting) (*xorm.Engine, error) { + dsn := getDbEngineDSN(setting) - return xorm.NewEngine(config["engine"], dsn) + return xorm.NewEngine(setting.Db.Engine, dsn) } // 获取数据库引擎DSN mysql,sqlite -func getDbEngineDSN(engine string, config map[string]string) string { - engine = strings.ToLower(engine) +func getDbEngineDSN(setting *setting.Setting) string { + engine := strings.ToLower(setting.Db.Engine) var dsn string = "" switch engine { case "mysql": - dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s", - config["user"], - config["password"], - config["host"], - config["port"], - config["database"], - config["charset"]) + dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s", + setting.Db.User, + setting.Db.Password, + setting.Db.Host, + setting.Db.Port , + setting.Db.Database, + setting.Db.Charset) } return dsn } - -// 获取数据库配置 -func getDbConfig() map[string]string { - var db map[string]string = make(map[string]string) - db["user"] = app.Setting.Key("db.user").String() - db["password"] = app.Setting.Key("db.password").String() - db["host"] = app.Setting.Key("db.host").String() - db["port"] = app.Setting.Key("db.port").String() - db["database"] = app.Setting.Key("db.database").String() - db["charset"] = app.Setting.Key("db.charset").String() - db["prefix"] = app.Setting.Key("db.prefix").String() - db["engine"] = app.Setting.Key("db.engine").String() - db["max_idle_conns"] = app.Setting.Key("db.max.idle.conns").String() - db["max_open_conns"] = app.Setting.Key("db.max.open.conns").String() - - return db -} - func keepDbAlived(engine *xorm.Engine) { t := time.Tick(180 * time.Second) for { diff --git a/modules/app/app.go b/modules/app/app.go index 262be05..f1df384 100644 --- a/modules/app/app.go +++ b/modules/app/app.go @@ -5,10 +5,10 @@ import ( "github.com/ouqiang/gocron/modules/logger" "github.com/ouqiang/gocron/modules/utils" - "gopkg.in/ini.v1" "io/ioutil" "strconv" "strings" + "github.com/ouqiang/gocron/modules/setting" ) var ( @@ -18,7 +18,7 @@ var ( DataDir string // 存放session等 AppConfig string // 应用配置文件 Installed bool // 应用是否安装过 - Setting *ini.Section // 应用配置 + Setting *setting.Setting // 应用配置 VersionId int // 版本号 VersionFile string // 版本号文件 ) diff --git a/modules/rpc/auth/Certification.go b/modules/rpc/auth/Certification.go new file mode 100644 index 0000000..767e109 --- /dev/null +++ b/modules/rpc/auth/Certification.go @@ -0,0 +1,71 @@ +package auth + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "errors" + "fmt" + "google.golang.org/grpc/credentials" +) + +type Certificate struct { + CAFile string + CertFile string + KeyFile string + ServerName string +} + +func (c Certificate) GetTLSConfigForServer() (*tls.Config, error) { + certificate, err := tls.LoadX509KeyPair( + c.CertFile, + c.KeyFile, + ) + + certPool := x509.NewCertPool() + bs, err := ioutil.ReadFile(c.CAFile) + if err != nil { + return nil, errors.New(fmt.Sprintf("failed to read client ca cert: %s", err)) + } + + ok := certPool.AppendCertsFromPEM(bs) + if !ok { + return nil, errors.New("failed to append client certs") + } + + + tlsConfig := &tls.Config{ + ClientAuth: tls.RequireAndVerifyClientCert, + Certificates: []tls.Certificate{certificate}, + ClientCAs: certPool, + } + + + return tlsConfig, nil +} + +func (c Certificate) GetTransportCredsForClient() (credentials.TransportCredentials, error) { + certificate, err := tls.LoadX509KeyPair( + c.CertFile, + c.KeyFile, + ) + + certPool := x509.NewCertPool() + bs, err := ioutil.ReadFile(c.CAFile) + if err != nil { + return nil, errors.New(fmt.Sprintf("failed to read ca cert: %s", err)) + } + + ok := certPool.AppendCertsFromPEM(bs) + if !ok { + return nil, errors.New("failed to append certs") + } + + transportCreds := credentials.NewTLS(&tls.Config{ + ServerName: c.ServerName, + Certificates: []tls.Certificate{certificate}, + RootCAs: certPool, + }) + + return transportCreds, nil +} \ No newline at end of file diff --git a/modules/rpc/grpcpool/grpc_pool.go b/modules/rpc/grpcpool/grpc_pool.go index ec08f4c..d6ce85e 100644 --- a/modules/rpc/grpcpool/grpc_pool.go +++ b/modules/rpc/grpcpool/grpc_pool.go @@ -6,6 +6,9 @@ import ( "time" "google.golang.org/grpc" "errors" + "github.com/ouqiang/gocron/modules/rpc/auth" + "github.com/ouqiang/gocron/modules/app" + "strings" ) @@ -97,7 +100,25 @@ func (p *GRPCPool) newCommonPool(addr string) (error) { InitialCap: 1, MaxCap: 30, Factory: func() (interface{}, error) { - return grpc.Dial(addr, grpc.WithInsecure()) + if !app.Setting.EnableTLS { + return grpc.Dial(addr, grpc.WithInsecure()) + } + + server := strings.Split(addr, ":") + + certificate := auth.Certificate{ + CAFile: app.Setting.CAFile, + CertFile: app.Setting.CertFile, + KeyFile: app.Setting.KeyFile, + ServerName: server[0], + } + + transportCreds, err := certificate.GetTransportCredsForClient() + if err != nil { + return nil, err + } + + return grpc.Dial(addr, grpc.WithTransportCredentials(transportCreds)) }, Close: func(v interface{}) error { conn, ok := v.(*grpc.ClientConn) diff --git a/modules/rpc/server/server.go b/modules/rpc/server/server.go index 6f2a4d7..a529a40 100644 --- a/modules/rpc/server/server.go +++ b/modules/rpc/server/server.go @@ -7,6 +7,8 @@ import ( "google.golang.org/grpc" pb "github.com/ouqiang/gocron/modules/rpc/proto" "github.com/ouqiang/gocron/modules/utils" + "github.com/ouqiang/gocron/modules/rpc/auth" + "google.golang.org/grpc/credentials" ) type Server struct {} @@ -29,22 +31,35 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse, return resp, nil } -func Start(addr string) { +func Start(addr string, enableTLS bool, certificate auth.Certificate) { defer func() { if err := recover(); err != nil { grpclog.Println("panic", err) } } () + l, err := net.Listen("tcp", addr) if err != nil { grpclog.Fatal(err) } - s := grpc.NewServer() - pb.RegisterTaskServer(s, Server{}) - grpclog.Println("listen ", addr) - err = s.Serve(l) - if err != nil { - grpclog.Fatal(err) + + var s *grpc.Server + if enableTLS { + tlsConfig, err := certificate.GetTLSConfigForServer() + if err != nil { + grpclog.Fatal(err) + } + opt := grpc.Creds(credentials.NewTLS(tlsConfig)) + s = grpc.NewServer(opt) + pb.RegisterTaskServer(s, Server{}) + grpclog.Printf("listen %s with TLS", addr) + } else { + s = grpc.NewServer() + pb.RegisterTaskServer(s, Server{}) + grpclog.Printf("listen %s", addr) } + + err = s.Serve(l) + grpclog.Fatal(err) } diff --git a/modules/setting/setting.go b/modules/setting/setting.go index 334da08..d5da574 100644 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -3,19 +3,84 @@ package setting import ( "errors" "gopkg.in/ini.v1" + "github.com/ouqiang/gocron/modules/utils" + "github.com/ouqiang/gocron/modules/logger" ) const DefaultSection = "default" +type Setting struct { + Db struct{ + Engine string + Host string + Port int + User string + Password string + Database string + Prefix string + Charset string + MaxIdleConns int + MaxOpenConns int + } + AllowIps string + AppName string + ApiKey string + ApiSecret string + ApiSignEnable bool + + EnableTLS bool + CAFile string + CertFile string + KeyFile string +} + // 读取配置 -func Read(filename string) (*ini.Section,error) { +func Read(filename string) (*Setting,error) { config, err := ini.Load(filename) if err != nil { return nil, err } section := config.Section(DefaultSection) - return section, nil + var s Setting + + s.Db.Engine = section.Key("db.engine").MustString("mysql") + s.Db.Host = section.Key("db.host").MustString("127.0.0.1") + s.Db.Port = section.Key("db.port").MustInt(3306) + s.Db.User = section.Key("db.user").MustString("") + s.Db.Password = section.Key("db.password").MustString("") + s.Db.Database = section.Key("db.database").MustString("gocron") + s.Db.Prefix = section.Key("db.prefix").MustString("") + s.Db.Charset = section.Key("db.charset").MustString("utf8") + s.Db.MaxIdleConns = section.Key("db.max.idle.conns").MustInt(30) + s.Db.MaxOpenConns = section.Key("db.max.open.conns").MustInt(100) + + s.AllowIps = section.Key("allow_ips").MustString("") + s.AppName = section.Key("app.name").MustString("定时任务管理系统") + s.ApiKey = section.Key("api.key").MustString("") + s.ApiSecret = section.Key("api.secret").MustString("") + s.ApiSignEnable = section.Key("api.sign.enable").MustBool(true) + + s.EnableTLS = section.Key("enable_tls").MustBool(false) + s.CAFile = section.Key("ca_file").MustString("") + s.CertFile = section.Key("cert_file").MustString("") + s.KeyFile = section.Key("key_file").MustString("") + + if s.EnableTLS { + if !utils.FileExist(s.CAFile) { + logger.Fatalf("failed to read ca cert file: %s", s.CAFile) + } + + if !utils.FileExist(s.CertFile) { + logger.Fatalf("failed to read client cert file: %s", s.CertFile) + } + + if !utils.FileExist(s.KeyFile) { + logger.Fatalf("failed to read client key file: %s", s.KeyFile) + } + } + + return &s, nil } // 写入配置 diff --git a/routers/install/install.go b/routers/install/install.go index 3fe0244..b3c7458 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -139,14 +139,14 @@ func createAdminUser(form InstallForm) error { // 测试数据库连接 func testDbConnection(form InstallForm) error { - var dbConfig map[string]string = make(map[string]string) - dbConfig["engine"] = form.DbType - dbConfig["host"] = form.DbHost - dbConfig["port"] = strconv.Itoa(form.DbPort) - dbConfig["user"] = form.DbUsername - dbConfig["password"] = form.DbPassword - dbConfig["charset"] = "utf8" - db, err := models.CreateTmpDb(dbConfig) + var s setting.Setting + s.Db.Engine = form.DbType + s.Db.Host = form.DbHost + s.Db.Port = form.DbPort + s.Db.User = form.DbUsername + s.Db.Password = form.DbPassword + s.Db.Charset = "utf8" + db, err := models.CreateTmpDb(&s) if err != nil { return err } diff --git a/routers/routers.go b/routers/routers.go index 79a39a9..4fd7a7f 100644 --- a/routers/routers.go +++ b/routers/routers.go @@ -184,7 +184,7 @@ func checkAppInstall(m *macaron.Macaron) { // IP验证, 通过反向代理访问gocron,需设置Header X-Real-IP才能获取到客户端真实IP func ipAuth(ctx *macaron.Context) { - allowIpsStr := app.Setting.Key("allow_ips").String() + allowIpsStr := app.Setting.AllowIps if allowIpsStr == "" { return } @@ -230,20 +230,16 @@ func setShareData(ctx *macaron.Context, sess session.Store) { } ctx.Data["LoginUsername"] = user.Username(sess) ctx.Data["LoginUid"] = user.Uid(sess) - ctx.Data["AppName"] = app.Setting.Key("app.name").String() + ctx.Data["AppName"] = app.Setting.AppName } /** API接口签名验证 **/ func apiAuth(ctx *macaron.Context) { - apiSignEnable := app.Setting.Key("api.sign.enable").String() - apiSignEnable = strings.TrimSpace(apiSignEnable) - if apiSignEnable == "false" { + if !app.Setting.ApiSignEnable { return } - apiKey := app.Setting.Key("api.key").String() - apiSecret := app.Setting.Key("api.secret").String() - apiKey = strings.TrimSpace(apiKey) - apiSecret = strings.TrimSpace(apiSecret) + apiKey := strings.TrimSpace(app.Setting.ApiKey) + apiSecret := strings.TrimSpace(app.Setting.ApiSecret) json := utils.JsonResponse{} if apiKey == "" || apiSecret == "" { msg := json.CommonFailure("使用API前, 请先配置密钥")