refactor($auth): 调度器与任务节点支持HTTPS双向认证

pull/21/merge
ouqiang 2017-09-07 18:01:48 +08:00
parent a966f3aeda
commit 250cbdde7c
10 changed files with 255 additions and 81 deletions

View File

@ -9,29 +9,62 @@ import (
"runtime" "runtime"
"os" "os"
"fmt" "fmt"
"strings"
"github.com/ouqiang/gocron/modules/rpc/auth"
"github.com/ouqiang/gocron/modules/utils"
) )
const AppVersion = "1.2" const AppVersion = "1.2.2"
func main() { func main() {
var serverAddr string var serverAddr string
var allowRoot bool var allowRoot bool
var version bool var version bool
var CAFile string
var certFile string
var keyFile string
var enableTLS bool
flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root") flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root")
flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port") flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port")
flag.BoolVar(&version, "v", false, "./gocron-node -v") flag.BoolVar(&version, "v", false, "./gocron-node -v")
flag.BoolVar(&enableTLS, "enable-tls", false, "./gocron-node -enable-tls")
flag.StringVar(&CAFile, "ca-file", "", "./gocron-node -ca-file path")
flag.StringVar(&certFile, "cert-file", "", "./gocron-node -cert-file path")
flag.StringVar(&keyFile, "key-file", "", "./gocron-node -key-file path")
flag.Parse() flag.Parse()
if version { if version {
fmt.Println(AppVersion) fmt.Println(AppVersion)
os.Exit(0) return
} }
if (enableTLS) {
if !utils.FileExist(CAFile) {
fmt.Printf("failed to read ca cert file: %s", CAFile)
return
}
if !utils.FileExist(certFile) {
fmt.Printf("failed to read server cert file: %s", certFile)
return
}
if !utils.FileExist(keyFile) {
fmt.Printf("failed to read server key file: %s", keyFile)
return
}
}
certificate := auth.Certificate{
CAFile: strings.TrimSpace(CAFile),
CertFile: strings.TrimSpace(certFile),
KeyFile: strings.TrimSpace(keyFile),
}
if runtime.GOOS != "windows" && os.Getuid() == 0 && !allowRoot { if runtime.GOOS != "windows" && os.Getuid() == 0 && !allowRoot {
fmt.Println("Do not run gocron-node as root user") fmt.Println("Do not run gocron-node as root user")
os.Exit(1) return
} }
server.Start(serverAddr) server.Start(serverAddr, enableTLS, certificate)
} }

View File

@ -10,7 +10,7 @@ import (
"github.com/ouqiang/gocron/cmd" "github.com/ouqiang/gocron/cmd"
) )
const AppVersion = "1.1" const AppVersion = "1.2.2"
func main() { func main() {
app := cli.NewApp() app := cli.NewApp()

View File

@ -9,8 +9,8 @@ import (
"strings" "strings"
"github.com/ouqiang/gocron/modules/logger" "github.com/ouqiang/gocron/modules/logger"
"github.com/ouqiang/gocron/modules/app" "github.com/ouqiang/gocron/modules/app"
"strconv"
"time" "time"
"github.com/ouqiang/gocron/modules/setting"
) )
type Status int8 type Status int8
@ -65,27 +65,18 @@ func (model *BaseModel) pageLimitOffset() int {
// 创建Db // 创建Db
func CreateDb() *xorm.Engine { func CreateDb() *xorm.Engine {
config := getDbConfig() dsn := getDbEngineDSN(app.Setting)
dsn := getDbEngineDSN(config["engine"], config) engine, err := xorm.NewEngine(app.Setting.Db.Engine, dsn)
engine, err := xorm.NewEngine(config["engine"], dsn)
if err != nil { if err != nil {
logger.Fatal("创建xorm引擎失败", err) logger.Fatal("创建xorm引擎失败", err)
} }
maxIdleConns, err := strconv.Atoi(config["max_idle_conns"]) engine.SetMaxIdleConns(app.Setting.Db.MaxIdleConns)
maxOpenConns, err := strconv.Atoi(config["max_open_conns"]) engine.SetMaxOpenConns(app.Setting.Db.MaxOpenConns)
if maxIdleConns <= 0 {
maxIdleConns = 30
}
if maxOpenConns <= 0 {
maxOpenConns = 100
}
engine.SetMaxIdleConns(maxIdleConns)
engine.SetMaxOpenConns(maxOpenConns)
if config["prefix"] != "" { if app.Setting.Db.Prefix != "" {
// 设置表前缀 // 设置表前缀
TablePrefix = config["prefix"] TablePrefix = app.Setting.Db.Prefix
mapper := core.NewPrefixMapper(core.SnakeMapper{}, config["prefix"]) mapper := core.NewPrefixMapper(core.SnakeMapper{}, app.Setting.Db.Prefix)
engine.SetTableMapper(mapper) engine.SetTableMapper(mapper)
} }
// 本地环境开启日志 // 本地环境开启日志
@ -100,48 +91,30 @@ func CreateDb() *xorm.Engine {
} }
// 创建临时数据库连接 // 创建临时数据库连接
func CreateTmpDb(config map[string]string) (*xorm.Engine, error) { func CreateTmpDb(setting *setting.Setting) (*xorm.Engine, error) {
dsn := getDbEngineDSN(config["engine"], config) dsn := getDbEngineDSN(setting)
return xorm.NewEngine(config["engine"], dsn) return xorm.NewEngine(setting.Db.Engine, dsn)
} }
// 获取数据库引擎DSN mysql,sqlite // 获取数据库引擎DSN mysql,sqlite
func getDbEngineDSN(engine string, config map[string]string) string { func getDbEngineDSN(setting *setting.Setting) string {
engine = strings.ToLower(engine) engine := strings.ToLower(setting.Db.Engine)
var dsn string = "" var dsn string = ""
switch engine { switch engine {
case "mysql": case "mysql":
dsn = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s", dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s",
config["user"], setting.Db.User,
config["password"], setting.Db.Password,
config["host"], setting.Db.Host,
config["port"], setting.Db.Port ,
config["database"], setting.Db.Database,
config["charset"]) setting.Db.Charset)
} }
return dsn return dsn
} }
// 获取数据库配置
func getDbConfig() map[string]string {
var db map[string]string = make(map[string]string)
db["user"] = app.Setting.Key("db.user").String()
db["password"] = app.Setting.Key("db.password").String()
db["host"] = app.Setting.Key("db.host").String()
db["port"] = app.Setting.Key("db.port").String()
db["database"] = app.Setting.Key("db.database").String()
db["charset"] = app.Setting.Key("db.charset").String()
db["prefix"] = app.Setting.Key("db.prefix").String()
db["engine"] = app.Setting.Key("db.engine").String()
db["max_idle_conns"] = app.Setting.Key("db.max.idle.conns").String()
db["max_open_conns"] = app.Setting.Key("db.max.open.conns").String()
return db
}
func keepDbAlived(engine *xorm.Engine) { func keepDbAlived(engine *xorm.Engine) {
t := time.Tick(180 * time.Second) t := time.Tick(180 * time.Second)
for { for {

View File

@ -5,10 +5,10 @@ import (
"github.com/ouqiang/gocron/modules/logger" "github.com/ouqiang/gocron/modules/logger"
"github.com/ouqiang/gocron/modules/utils" "github.com/ouqiang/gocron/modules/utils"
"gopkg.in/ini.v1"
"io/ioutil" "io/ioutil"
"strconv" "strconv"
"strings" "strings"
"github.com/ouqiang/gocron/modules/setting"
) )
var ( var (
@ -18,7 +18,7 @@ var (
DataDir string // 存放session等 DataDir string // 存放session等
AppConfig string // 应用配置文件 AppConfig string // 应用配置文件
Installed bool // 应用是否安装过 Installed bool // 应用是否安装过
Setting *ini.Section // 应用配置 Setting *setting.Setting // 应用配置
VersionId int // 版本号 VersionId int // 版本号
VersionFile string // 版本号文件 VersionFile string // 版本号文件
) )

View File

@ -0,0 +1,71 @@
package auth
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"errors"
"fmt"
"google.golang.org/grpc/credentials"
)
type Certificate struct {
CAFile string
CertFile string
KeyFile string
ServerName string
}
func (c Certificate) GetTLSConfigForServer() (*tls.Config, error) {
certificate, err := tls.LoadX509KeyPair(
c.CertFile,
c.KeyFile,
)
certPool := x509.NewCertPool()
bs, err := ioutil.ReadFile(c.CAFile)
if err != nil {
return nil, errors.New(fmt.Sprintf("failed to read client ca cert: %s", err))
}
ok := certPool.AppendCertsFromPEM(bs)
if !ok {
return nil, errors.New("failed to append client certs")
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
ClientCAs: certPool,
}
return tlsConfig, nil
}
func (c Certificate) GetTransportCredsForClient() (credentials.TransportCredentials, error) {
certificate, err := tls.LoadX509KeyPair(
c.CertFile,
c.KeyFile,
)
certPool := x509.NewCertPool()
bs, err := ioutil.ReadFile(c.CAFile)
if err != nil {
return nil, errors.New(fmt.Sprintf("failed to read ca cert: %s", err))
}
ok := certPool.AppendCertsFromPEM(bs)
if !ok {
return nil, errors.New("failed to append certs")
}
transportCreds := credentials.NewTLS(&tls.Config{
ServerName: c.ServerName,
Certificates: []tls.Certificate{certificate},
RootCAs: certPool,
})
return transportCreds, nil
}

View File

@ -6,6 +6,9 @@ import (
"time" "time"
"google.golang.org/grpc" "google.golang.org/grpc"
"errors" "errors"
"github.com/ouqiang/gocron/modules/rpc/auth"
"github.com/ouqiang/gocron/modules/app"
"strings"
) )
@ -97,7 +100,25 @@ func (p *GRPCPool) newCommonPool(addr string) (error) {
InitialCap: 1, InitialCap: 1,
MaxCap: 30, MaxCap: 30,
Factory: func() (interface{}, error) { Factory: func() (interface{}, error) {
if !app.Setting.EnableTLS {
return grpc.Dial(addr, grpc.WithInsecure()) return grpc.Dial(addr, grpc.WithInsecure())
}
server := strings.Split(addr, ":")
certificate := auth.Certificate{
CAFile: app.Setting.CAFile,
CertFile: app.Setting.CertFile,
KeyFile: app.Setting.KeyFile,
ServerName: server[0],
}
transportCreds, err := certificate.GetTransportCredsForClient()
if err != nil {
return nil, err
}
return grpc.Dial(addr, grpc.WithTransportCredentials(transportCreds))
}, },
Close: func(v interface{}) error { Close: func(v interface{}) error {
conn, ok := v.(*grpc.ClientConn) conn, ok := v.(*grpc.ClientConn)

View File

@ -7,6 +7,8 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
pb "github.com/ouqiang/gocron/modules/rpc/proto" pb "github.com/ouqiang/gocron/modules/rpc/proto"
"github.com/ouqiang/gocron/modules/utils" "github.com/ouqiang/gocron/modules/utils"
"github.com/ouqiang/gocron/modules/rpc/auth"
"google.golang.org/grpc/credentials"
) )
type Server struct {} type Server struct {}
@ -29,22 +31,35 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse,
return resp, nil return resp, nil
} }
func Start(addr string) { func Start(addr string, enableTLS bool, certificate auth.Certificate) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
grpclog.Println("panic", err) grpclog.Println("panic", err)
} }
} () } ()
l, err := net.Listen("tcp", addr) l, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
grpclog.Fatal(err) grpclog.Fatal(err)
} }
s := grpc.NewServer()
pb.RegisterTaskServer(s, Server{}) var s *grpc.Server
grpclog.Println("listen ", addr) if enableTLS {
err = s.Serve(l) tlsConfig, err := certificate.GetTLSConfigForServer()
if err != nil { if err != nil {
grpclog.Fatal(err) grpclog.Fatal(err)
} }
opt := grpc.Creds(credentials.NewTLS(tlsConfig))
s = grpc.NewServer(opt)
pb.RegisterTaskServer(s, Server{})
grpclog.Printf("listen %s with TLS", addr)
} else {
s = grpc.NewServer()
pb.RegisterTaskServer(s, Server{})
grpclog.Printf("listen %s", addr)
}
err = s.Serve(l)
grpclog.Fatal(err)
} }

View File

@ -3,19 +3,84 @@ package setting
import ( import (
"errors" "errors"
"gopkg.in/ini.v1" "gopkg.in/ini.v1"
"github.com/ouqiang/gocron/modules/utils"
"github.com/ouqiang/gocron/modules/logger"
) )
const DefaultSection = "default" const DefaultSection = "default"
type Setting struct {
Db struct{
Engine string
Host string
Port int
User string
Password string
Database string
Prefix string
Charset string
MaxIdleConns int
MaxOpenConns int
}
AllowIps string
AppName string
ApiKey string
ApiSecret string
ApiSignEnable bool
EnableTLS bool
CAFile string
CertFile string
KeyFile string
}
// 读取配置 // 读取配置
func Read(filename string) (*ini.Section,error) { func Read(filename string) (*Setting,error) {
config, err := ini.Load(filename) config, err := ini.Load(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
section := config.Section(DefaultSection) section := config.Section(DefaultSection)
return section, nil var s Setting
s.Db.Engine = section.Key("db.engine").MustString("mysql")
s.Db.Host = section.Key("db.host").MustString("127.0.0.1")
s.Db.Port = section.Key("db.port").MustInt(3306)
s.Db.User = section.Key("db.user").MustString("")
s.Db.Password = section.Key("db.password").MustString("")
s.Db.Database = section.Key("db.database").MustString("gocron")
s.Db.Prefix = section.Key("db.prefix").MustString("")
s.Db.Charset = section.Key("db.charset").MustString("utf8")
s.Db.MaxIdleConns = section.Key("db.max.idle.conns").MustInt(30)
s.Db.MaxOpenConns = section.Key("db.max.open.conns").MustInt(100)
s.AllowIps = section.Key("allow_ips").MustString("")
s.AppName = section.Key("app.name").MustString("定时任务管理系统")
s.ApiKey = section.Key("api.key").MustString("")
s.ApiSecret = section.Key("api.secret").MustString("")
s.ApiSignEnable = section.Key("api.sign.enable").MustBool(true)
s.EnableTLS = section.Key("enable_tls").MustBool(false)
s.CAFile = section.Key("ca_file").MustString("")
s.CertFile = section.Key("cert_file").MustString("")
s.KeyFile = section.Key("key_file").MustString("")
if s.EnableTLS {
if !utils.FileExist(s.CAFile) {
logger.Fatalf("failed to read ca cert file: %s", s.CAFile)
}
if !utils.FileExist(s.CertFile) {
logger.Fatalf("failed to read client cert file: %s", s.CertFile)
}
if !utils.FileExist(s.KeyFile) {
logger.Fatalf("failed to read client key file: %s", s.KeyFile)
}
}
return &s, nil
} }
// 写入配置 // 写入配置

View File

@ -139,14 +139,14 @@ func createAdminUser(form InstallForm) error {
// 测试数据库连接 // 测试数据库连接
func testDbConnection(form InstallForm) error { func testDbConnection(form InstallForm) error {
var dbConfig map[string]string = make(map[string]string) var s setting.Setting
dbConfig["engine"] = form.DbType s.Db.Engine = form.DbType
dbConfig["host"] = form.DbHost s.Db.Host = form.DbHost
dbConfig["port"] = strconv.Itoa(form.DbPort) s.Db.Port = form.DbPort
dbConfig["user"] = form.DbUsername s.Db.User = form.DbUsername
dbConfig["password"] = form.DbPassword s.Db.Password = form.DbPassword
dbConfig["charset"] = "utf8" s.Db.Charset = "utf8"
db, err := models.CreateTmpDb(dbConfig) db, err := models.CreateTmpDb(&s)
if err != nil { if err != nil {
return err return err
} }

View File

@ -184,7 +184,7 @@ func checkAppInstall(m *macaron.Macaron) {
// IP验证, 通过反向代理访问gocron需设置Header X-Real-IP才能获取到客户端真实IP // IP验证, 通过反向代理访问gocron需设置Header X-Real-IP才能获取到客户端真实IP
func ipAuth(ctx *macaron.Context) { func ipAuth(ctx *macaron.Context) {
allowIpsStr := app.Setting.Key("allow_ips").String() allowIpsStr := app.Setting.AllowIps
if allowIpsStr == "" { if allowIpsStr == "" {
return return
} }
@ -230,20 +230,16 @@ func setShareData(ctx *macaron.Context, sess session.Store) {
} }
ctx.Data["LoginUsername"] = user.Username(sess) ctx.Data["LoginUsername"] = user.Username(sess)
ctx.Data["LoginUid"] = user.Uid(sess) ctx.Data["LoginUid"] = user.Uid(sess)
ctx.Data["AppName"] = app.Setting.Key("app.name").String() ctx.Data["AppName"] = app.Setting.AppName
} }
/** API接口签名验证 **/ /** API接口签名验证 **/
func apiAuth(ctx *macaron.Context) { func apiAuth(ctx *macaron.Context) {
apiSignEnable := app.Setting.Key("api.sign.enable").String() if !app.Setting.ApiSignEnable {
apiSignEnable = strings.TrimSpace(apiSignEnable)
if apiSignEnable == "false" {
return return
} }
apiKey := app.Setting.Key("api.key").String() apiKey := strings.TrimSpace(app.Setting.ApiKey)
apiSecret := app.Setting.Key("api.secret").String() apiSecret := strings.TrimSpace(app.Setting.ApiSecret)
apiKey = strings.TrimSpace(apiKey)
apiSecret = strings.TrimSpace(apiSecret)
json := utils.JsonResponse{} json := utils.JsonResponse{}
if apiKey == "" || apiSecret == "" { if apiKey == "" || apiSecret == "" {
msg := json.CommonFailure("使用API前, 请先配置密钥") msg := json.CommonFailure("使用API前, 请先配置密钥")