feat($task): 调度器与任务节点支持HTTPS通信, #14

pull/21/merge
ouqiang 2017-09-06 11:27:54 +08:00
parent d4e0898674
commit 019fee2cce
12 changed files with 99 additions and 24 deletions

View File

@ -9,6 +9,7 @@ import (
"runtime"
"os"
"fmt"
"strings"
)
const AppVersion = "1.1"
@ -17,8 +18,12 @@ func main() {
var serverAddr string
var allowRoot bool
var version bool
var keyFile string
var certFile string
flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root")
flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port")
flag.StringVar(&certFile, "cert-file", "", "./gocron-node -cert-file path")
flag.StringVar(&keyFile, "key-file", "", "./gocron-node -key-file path")
flag.BoolVar(&version, "v", false, "./gocron-node -v")
flag.Parse()
@ -27,11 +32,25 @@ func main() {
os.Exit(0)
}
certFile = strings.TrimSpace(certFile)
keyFile = strings.TrimSpace(keyFile)
if certFile != "" && keyFile == "" {
fmt.Println("missing argument key-file")
return
}
if keyFile != "" && certFile == "" {
fmt.Println("missing argument cert-file")
return
}
if runtime.GOOS != "windows" && os.Getuid() == 0 && !allowRoot {
fmt.Println("Do not run gocron-node as root user")
os.Exit(1)
}
server.Start(serverAddr)
server.Start(serverAddr, certFile, keyFile)
}

View File

@ -10,7 +10,7 @@ import (
"github.com/ouqiang/gocron/cmd"
)
const AppVersion = "1.1"
const AppVersion = "1.2"
func main() {
app := cli.NewApp()

View File

@ -10,6 +10,7 @@ type Host struct {
Name string `xorm:"varchar(64) notnull"` // 主机名称
Alias string `xorm:"varchar(32) notnull default '' "` // 主机别名
Port int `xorm:"notnull default 22"` // 主机端口
CertFile string `xorm:"varchar(64) notnull default '' "`
Remark string `xorm:"varchar(100) notnull default '' "` // 备注
BaseModel `xorm:"-"`
Selected bool `xorm:"-"`
@ -26,7 +27,7 @@ func (host *Host) Create() (insertId int16, err error) {
}
func (host *Host) UpdateBean(id int16) (int64, error) {
return Db.ID(id).Cols("name,alias,port,remark").Update(host)
return Db.ID(id).Cols("name,alias,port,cert_file,remark").Update(host)
}

View File

@ -49,9 +49,10 @@ func isDatabaseExist(name string) bool {
// 迭代升级数据库, 新建表、新增字段等
func (migration *Migration) Upgrade(oldVersionId int) {
versionIds := []int{110}
versionIds := []int{110, 120}
upgradeFuncs := []func(*xorm.Session) error {
migration.upgradeFor110,
migration.upgradeFor120,
}
// 默认当前版本为v1.0
@ -135,3 +136,13 @@ func (migration *Migration) upgradeFor110(session *xorm.Session) error {
return err
}
// 升级到v1.2版本
func (migration *Migration) upgradeFor120(session *xorm.Session) error {
// host表增加cert_file字段
tableName := TablePrefix + "host"
_, err := session.Exec(fmt.Sprintf("ALTER TABLE %s Add COLUMN cert_file VARCHAR(64) NOT NULL DEFAULT ''", tableName))
return err
}

View File

@ -12,6 +12,7 @@ type TaskHostDetail struct {
Name string
Port int
Alias string
CertFile string
}
func (TaskHostDetail) TableName() string {
@ -48,7 +49,7 @@ func (th *TaskHost) Add(taskId int, hostIds []int) error {
func (th *TaskHost) GetHostIdsByTaskId(taskId int) ([]TaskHostDetail, error) {
list := make([]TaskHostDetail, 0)
fields := "th.id,th.host_id,h.alias,h.name,h.port"
fields := "th.id,th.host_id,h.alias,h.name,h.port,h.cert_file"
err := Db.Alias("th").
Join("LEFT", hostTableName(), "th.host_id=h.id").
Where("th.task_id = ?", taskId).

View File

@ -16,11 +16,11 @@ var (
errUnavailable = errors.New("无法连接远程服务器")
)
func ExecWithRetry(ip string, port int, taskReq *pb.TaskRequest) (string, error) {
tryTimes := 60
func ExecWithRetry(ip string, port int, certFile string,taskReq *pb.TaskRequest) (string, error) {
tryTimes := 15
i := 0
for i < tryTimes {
output, err := Exec(ip, port, taskReq)
output, err := Exec(ip, port, certFile, taskReq)
if err != errUnavailable {
return output, err
}
@ -31,14 +31,14 @@ func ExecWithRetry(ip string, port int, taskReq *pb.TaskRequest) (string, error)
return "", errUnavailable
}
func Exec(ip string, port int, taskReq *pb.TaskRequest) (string, error) {
func Exec(ip string, port int, certFile string, taskReq *pb.TaskRequest) (string, error) {
defer func() {
if err := recover(); err != nil {
logger.Error("panic#rpc/client.go:Exec#", err)
}
} ()
addr := fmt.Sprintf("%s:%d", ip, port)
conn, err := grpcpool.Pool.Get(addr)
conn, err := grpcpool.Pool.Get(addr, certFile)
if err != nil {
return "", err
}

View File

@ -6,6 +6,8 @@ import (
"time"
"google.golang.org/grpc"
"errors"
"google.golang.org/grpc/credentials"
"strings"
)
@ -30,12 +32,12 @@ type GRPCPool struct {
sync.RWMutex
}
func (p *GRPCPool) Get(addr string) (*grpc.ClientConn, error) {
func (p *GRPCPool) Get(addr, certFile string) (*grpc.ClientConn, error) {
p.RLock()
pool, ok := p.conns[addr]
p.RUnlock()
if !ok {
err := p.newCommonPool(addr)
err := p.newCommonPool(addr, certFile)
if err != nil {
return nil, err
}
@ -86,7 +88,7 @@ func (p *GRPCPool) ReleaseAll() {
}
// 初始化底层连接池
func (p *GRPCPool) newCommonPool(addr string) (error) {
func (p *GRPCPool) newCommonPool(addr, certFile string) (error) {
p.Lock()
defer p.Unlock()
commonPool, ok := p.conns[addr]
@ -97,7 +99,17 @@ func (p *GRPCPool) newCommonPool(addr string) (error) {
InitialCap: 1,
MaxCap: 30,
Factory: func() (interface{}, error) {
if certFile == "" {
return grpc.Dial(addr, grpc.WithInsecure())
}
server := strings.Split(addr, ":")
creds, err := credentials.NewClientTLSFromFile(certFile, server[0])
if err != nil {
return nil, err
}
return grpc.Dial(addr, grpc.WithTransportCredentials(creds))
},
Close: func(v interface{}) error {
conn, ok := v.(*grpc.ClientConn)

View File

@ -7,6 +7,7 @@ import (
"google.golang.org/grpc"
pb "github.com/ouqiang/gocron/modules/rpc/proto"
"github.com/ouqiang/gocron/modules/utils"
"google.golang.org/grpc/credentials"
)
type Server struct {}
@ -29,7 +30,7 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse,
return resp, nil
}
func Start(addr string) {
func Start(addr, certFile, keyFile string) {
defer func() {
if err := recover(); err != nil {
grpclog.Println("panic", err)
@ -39,9 +40,23 @@ func Start(addr string) {
if err != nil {
grpclog.Fatal(err)
}
s := grpc.NewServer()
var s *grpc.Server
if certFile != "" {
// TLS认证
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
if err != nil {
grpclog.Fatalf("Failed to generate credentials %v", err)
}
s = grpc.NewServer(grpc.Creds(creds))
pb.RegisterTaskServer(s, Server{})
grpclog.Printf("listen %s with TLS", addr)
} else {
s = grpc.NewServer()
pb.RegisterTaskServer(s, Server{})
grpclog.Println("listen ", addr)
}
err = s.Serve(l)
if err != nil {
grpclog.Fatal(err)

View File

@ -64,6 +64,7 @@ type HostForm struct {
Name string `binding:"Required;MaxSize(64)"`
Alias string `binding:"Required;MaxSize(32)"`
Port int `binding:"Required;Range(1-65535)"`
CertFile string
Remark string
}
@ -93,6 +94,12 @@ func Store(ctx *macaron.Context, form HostForm) string {
hostModel.Alias = strings.TrimSpace(form.Alias)
hostModel.Port = form.Port
hostModel.Remark = strings.TrimSpace(form.Remark)
hostModel.CertFile = strings.TrimSpace(form.CertFile)
if hostModel.CertFile != "" && !utils.FileExist(hostModel.CertFile) {
return json.CommonFailure("证书文件不存在或无权限访问")
}
isCreate := false
oldHostModel := new(models.Host)
err = oldHostModel.Find(int(id))
@ -100,6 +107,7 @@ func Store(ctx *macaron.Context, form HostForm) string {
return json.CommonFailure("主机不存在")
}
if id > 0 {
_, err = hostModel.UpdateBean(id)
} else {
@ -112,10 +120,7 @@ func Store(ctx *macaron.Context, form HostForm) string {
if !isCreate {
oldAddr := fmt.Sprintf("%s:%d", oldHostModel.Name, oldHostModel.Port)
newAddr := fmt.Sprintf("%s:%d", hostModel.Name, hostModel.Port)
if oldAddr != newAddr {
grpcpool.Pool.Release(oldAddr)
}
taskModel := new(models.Task)
tasks, err := taskModel.ActiveListByHostId(id)
@ -175,7 +180,7 @@ func Ping(ctx *macaron.Context) string {
taskReq := &rpc.TaskRequest{}
taskReq.Command = "echo hello"
taskReq.Timeout = 10
output, err := client.Exec(hostModel.Name, hostModel.Port, taskReq)
output, err := client.Exec(hostModel.Name, hostModel.Port, hostModel.CertFile, taskReq)
if err != nil {
return json.CommonFailure("连接失败-" + err.Error() + " " + output, err)
}

View File

@ -181,7 +181,7 @@ func (h *RPCHandler) Run(taskModel models.Task) (result string, err error) {
var resultChan chan TaskResult = make(chan TaskResult, len(taskModel.Hosts))
for _, taskHost := range taskModel.Hosts {
go func(th models.TaskHostDetail) {
output, err := rpcClient.ExecWithRetry(th.Name, th.Port, taskRequest)
output, err := rpcClient.ExecWithRetry(th.Name, th.Port, th.CertFile, taskRequest)
var errorMessage string = ""
if err != nil {
errorMessage = err.Error()

View File

@ -36,6 +36,15 @@
</div>
</div>
</div>
<div class="two fields">
<div class="field">
<label></label>
<div class="ui small input">
<input type="text" name="cert_file" value="{{{.Host.CertFile}}}"
placeholder="data/certs/server.pem">
</div>
</div>
</div>
<div class="two fields">
<div class="field">
<label></label>

View File

@ -36,6 +36,7 @@
<th></th>
<th></th>
<th></th>
<th></th>
<th></th>
<th></th>
</tr>
@ -47,6 +48,7 @@
<td>{{{.Name}}}</td>
<td>{{{.Alias}}}</td>
<td>{{{.Port}}}</td>
<td>{{{.CertFile}}}</td>
<td>{{{.Remark}}}</td>
<td class="operation">
<a class="ui purple button" href="/host/edit/{{{.Id}}}"></a>