feat($gocron-node): 增加token鉴权, #14

pull/21/merge v1.2.1
ouqiang 2017-09-06 22:46:56 +08:00
parent 601f250882
commit 8ec80a02b6
11 changed files with 101 additions and 18 deletions

View File

@ -74,6 +74,7 @@
* -s ip:port 监听地址 * -s ip:port 监听地址
* -cert-file 证书文件 * -cert-file 证书文件
* -key-file 私钥文件 * -key-file 私钥文件
* -token
* -h 查看帮助 * -h 查看帮助
* -v 查看版本 * -v 查看版本

View File

@ -20,10 +20,12 @@ func main() {
var version bool var version bool
var keyFile string var keyFile string
var certFile string var certFile string
var token string
flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root") flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root")
flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port") flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port")
flag.StringVar(&certFile, "cert-file", "", "./gocron-node -cert-file path") flag.StringVar(&certFile, "cert-file", "", "./gocron-node -cert-file path")
flag.StringVar(&keyFile, "key-file", "", "./gocron-node -key-file path") flag.StringVar(&keyFile, "key-file", "", "./gocron-node -key-file path")
flag.StringVar(&token, "token", "", "./gocron-node -token")
flag.BoolVar(&version, "v", false, "./gocron-node -v") flag.BoolVar(&version, "v", false, "./gocron-node -v")
flag.Parse() flag.Parse()
@ -52,5 +54,5 @@ func main() {
server.Start(serverAddr, certFile, keyFile) server.Start(serverAddr, certFile, keyFile, token)
} }

View File

@ -11,6 +11,7 @@ type Host struct {
Alias string `xorm:"varchar(32) notnull default '' "` // 主机别名 Alias string `xorm:"varchar(32) notnull default '' "` // 主机别名
Port int `xorm:"notnull default 22"` // 主机端口 Port int `xorm:"notnull default 22"` // 主机端口
CertFile string `xorm:"varchar(64) notnull default '' "` CertFile string `xorm:"varchar(64) notnull default '' "`
Token string `xorm:"varchar(128) notnull default '' "`
Remark string `xorm:"varchar(100) notnull default '' "` // 备注 Remark string `xorm:"varchar(100) notnull default '' "` // 备注
BaseModel `xorm:"-"` BaseModel `xorm:"-"`
Selected bool `xorm:"-"` Selected bool `xorm:"-"`
@ -27,7 +28,7 @@ func (host *Host) Create() (insertId int16, err error) {
} }
func (host *Host) UpdateBean(id int16) (int64, error) { func (host *Host) UpdateBean(id int16) (int64, error) {
return Db.ID(id).Cols("name,alias,port,cert_file,remark").Update(host) return Db.ID(id).Cols("name,alias,port,cert_file,token,remark").Update(host)
} }

View File

@ -143,6 +143,11 @@ func (migration *Migration) upgradeFor120(session *xorm.Session) error {
// host表增加cert_file字段 // host表增加cert_file字段
tableName := TablePrefix + "host" tableName := TablePrefix + "host"
_, err := session.Exec(fmt.Sprintf("ALTER TABLE %s Add COLUMN cert_file VARCHAR(64) NOT NULL DEFAULT ''", tableName)) _, err := session.Exec(fmt.Sprintf("ALTER TABLE %s Add COLUMN cert_file VARCHAR(64) NOT NULL DEFAULT ''", tableName))
if err != nil {
return err
}
_, err = session.Exec(fmt.Sprintf("ALTER TABLE %s Add COLUMN token VARCHAR(64) NOT NULL DEFAULT ''", tableName))
return err return err
} }

View File

@ -13,6 +13,7 @@ type TaskHostDetail struct {
Port int Port int
Alias string Alias string
CertFile string CertFile string
Token string
} }
func (TaskHostDetail) TableName() string { func (TaskHostDetail) TableName() string {
@ -49,7 +50,7 @@ func (th *TaskHost) Add(taskId int, hostIds []int) error {
func (th *TaskHost) GetHostIdsByTaskId(taskId int) ([]TaskHostDetail, error) { func (th *TaskHost) GetHostIdsByTaskId(taskId int) ([]TaskHostDetail, error) {
list := make([]TaskHostDetail, 0) list := make([]TaskHostDetail, 0)
fields := "th.id,th.host_id,h.alias,h.name,h.port,h.cert_file" fields := "th.id,th.host_id,h.alias,h.name,h.port,h.cert_file,h.token"
err := Db.Alias("th"). err := Db.Alias("th").
Join("LEFT", hostTableName(), "th.host_id=h.id"). Join("LEFT", hostTableName(), "th.host_id=h.id").
Where("th.task_id = ?", taskId). Where("th.task_id = ?", taskId).

View File

@ -16,11 +16,11 @@ var (
errUnavailable = errors.New("无法连接远程服务器") errUnavailable = errors.New("无法连接远程服务器")
) )
func ExecWithRetry(ip string, port int, certFile string,taskReq *pb.TaskRequest) (string, error) { func ExecWithRetry(ip string, port int, certFile string, token string, taskReq *pb.TaskRequest) (string, error) {
tryTimes := 15 tryTimes := 15
i := 0 i := 0
for i < tryTimes { for i < tryTimes {
output, err := Exec(ip, port, certFile, taskReq) output, err := Exec(ip, port, certFile, token, taskReq)
if err != errUnavailable { if err != errUnavailable {
return output, err return output, err
} }
@ -31,14 +31,14 @@ func ExecWithRetry(ip string, port int, certFile string,taskReq *pb.TaskRequest)
return "", errUnavailable return "", errUnavailable
} }
func Exec(ip string, port int, certFile string, taskReq *pb.TaskRequest) (string, error) { func Exec(ip string, port int, certFile string, token string, taskReq *pb.TaskRequest) (string, error) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
logger.Error("panic#rpc/client.go:Exec#", err) logger.Error("panic#rpc/client.go:Exec#", err)
} }
} () } ()
addr := fmt.Sprintf("%s:%d", ip, port) addr := fmt.Sprintf("%s:%d", ip, port)
conn, err := grpcpool.Pool.Get(addr, certFile) conn, err := grpcpool.Pool.Get(addr, certFile, token)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -77,4 +77,4 @@ func parseGRPCError(err error, conn *grpc.ClientConn, connClosed *bool) (string,
return "", errors.New("执行超时, 强制结束") return "", errors.New("执行超时, 强制结束")
} }
return "", err return "", err
} }

View File

@ -7,6 +7,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"errors" "errors"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"golang.org/x/net/context"
"strings" "strings"
) )
@ -32,12 +33,12 @@ type GRPCPool struct {
sync.RWMutex sync.RWMutex
} }
func (p *GRPCPool) Get(addr, certFile string) (*grpc.ClientConn, error) { func (p *GRPCPool) Get(addr, certFile, token string) (*grpc.ClientConn, error) {
p.RLock() p.RLock()
pool, ok := p.conns[addr] pool, ok := p.conns[addr]
p.RUnlock() p.RUnlock()
if !ok { if !ok {
err := p.newCommonPool(addr, certFile) err := p.newCommonPool(addr, certFile, token)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,7 +89,7 @@ func (p *GRPCPool) ReleaseAll() {
} }
// 初始化底层连接池 // 初始化底层连接池
func (p *GRPCPool) newCommonPool(addr, certFile string) (error) { func (p *GRPCPool) newCommonPool(addr, certFile, token string) (error) {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
commonPool, ok := p.conns[addr] commonPool, ok := p.conns[addr]
@ -109,7 +110,13 @@ func (p *GRPCPool) newCommonPool(addr, certFile string) (error) {
return nil, err return nil, err
} }
return grpc.Dial(addr, grpc.WithTransportCredentials(creds)) customCredential := &CustomCredential{Token: token}
return grpc.Dial(addr,
grpc.WithTransportCredentials(creds),
grpc.WithPerRPCCredentials(customCredential),
)
}, },
Close: func(v interface{}) error { Close: func(v interface{}) error {
conn, ok := v.(*grpc.ClientConn) conn, ok := v.(*grpc.ClientConn)
@ -129,4 +136,19 @@ func (p *GRPCPool) newCommonPool(addr, certFile string) (error) {
p.conns[addr] = commonPool p.conns[addr] = commonPool
return nil return nil
}
type CustomCredential struct
{
Token string
}
func (c CustomCredential) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
"token": c.Token,
}, nil
}
func (c CustomCredential) RequireTransportSecurity() bool {
return true
} }

View File

@ -8,9 +8,32 @@ import (
pb "github.com/ouqiang/gocron/modules/rpc/proto" pb "github.com/ouqiang/gocron/modules/rpc/proto"
"github.com/ouqiang/gocron/modules/utils" "github.com/ouqiang/gocron/modules/utils"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"errors"
) )
type Server struct {} type Server struct
{
Token string
}
func (s Server) auth(ctx context.Context) error {
// 验证token是否有效
meta, ok := metadata.FromContext(ctx)
if !ok {
return errors.New("missing metadata")
}
token, ok := meta["token"]
if !ok {
return errors.New("missing param token")
}
if token[0] != s.Token {
return errors.New("invalid token")
}
return nil
}
func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse, error) { func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse, error) {
defer func() { defer func() {
@ -18,6 +41,15 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse,
grpclog.Println(err) grpclog.Println(err)
} }
} () } ()
if s.Token != "" {
err := s.auth(ctx)
if err != nil {
return nil, err
}
}
output, err := utils.ExecShell(ctx, req.Command) output, err := utils.ExecShell(ctx, req.Command)
resp := new(pb.TaskResponse) resp := new(pb.TaskResponse)
resp.Output = output resp.Output = output
@ -30,7 +62,7 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse,
return resp, nil return resp, nil
} }
func Start(addr, certFile, keyFile string) { func Start(addr, certFile, keyFile, token string) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
grpclog.Println("panic", err) grpclog.Println("panic", err)
@ -42,6 +74,7 @@ func Start(addr, certFile, keyFile string) {
} }
var s *grpc.Server var s *grpc.Server
server := Server{Token: token}
if certFile != "" { if certFile != "" {
// TLS认证 // TLS认证
creds, err := credentials.NewServerTLSFromFile(certFile, keyFile) creds, err := credentials.NewServerTLSFromFile(certFile, keyFile)
@ -50,11 +83,11 @@ func Start(addr, certFile, keyFile string) {
} }
s = grpc.NewServer(grpc.Creds(creds)) s = grpc.NewServer(grpc.Creds(creds))
pb.RegisterTaskServer(s, Server{}) pb.RegisterTaskServer(s, server)
grpclog.Printf("listen %s with TLS", addr) grpclog.Printf("listen %s with TLS", addr)
} else { } else {
s = grpc.NewServer() s = grpc.NewServer()
pb.RegisterTaskServer(s, Server{}) pb.RegisterTaskServer(s, server)
grpclog.Println("listen ", addr) grpclog.Println("listen ", addr)
} }
err = s.Serve(l) err = s.Serve(l)

View File

@ -65,6 +65,7 @@ type HostForm struct {
Alias string `binding:"Required;MaxSize(32)"` Alias string `binding:"Required;MaxSize(32)"`
Port int `binding:"Required;Range(1-65535)"` Port int `binding:"Required;Range(1-65535)"`
CertFile string CertFile string
Token string
Remark string Remark string
} }
@ -95,6 +96,7 @@ func Store(ctx *macaron.Context, form HostForm) string {
hostModel.Port = form.Port hostModel.Port = form.Port
hostModel.Remark = strings.TrimSpace(form.Remark) hostModel.Remark = strings.TrimSpace(form.Remark)
hostModel.CertFile = strings.TrimSpace(form.CertFile) hostModel.CertFile = strings.TrimSpace(form.CertFile)
hostModel.Token = strings.TrimSpace(form.Token)
if hostModel.CertFile != "" && !utils.FileExist(hostModel.CertFile) { if hostModel.CertFile != "" && !utils.FileExist(hostModel.CertFile) {
return json.CommonFailure("证书文件不存在或无权限访问") return json.CommonFailure("证书文件不存在或无权限访问")
@ -180,7 +182,11 @@ func Ping(ctx *macaron.Context) string {
taskReq := &rpc.TaskRequest{} taskReq := &rpc.TaskRequest{}
taskReq.Command = "echo hello" taskReq.Command = "echo hello"
taskReq.Timeout = 10 taskReq.Timeout = 10
output, err := client.Exec(hostModel.Name, hostModel.Port, hostModel.CertFile, taskReq) output, err := client.Exec(hostModel.Name,
hostModel.Port,
hostModel.CertFile,
hostModel.Token,
taskReq)
if err != nil { if err != nil {
return json.CommonFailure("连接失败-" + err.Error() + " " + output, err) return json.CommonFailure("连接失败-" + err.Error() + " " + output, err)
} }

View File

@ -181,7 +181,11 @@ func (h *RPCHandler) Run(taskModel models.Task) (result string, err error) {
var resultChan chan TaskResult = make(chan TaskResult, len(taskModel.Hosts)) var resultChan chan TaskResult = make(chan TaskResult, len(taskModel.Hosts))
for _, taskHost := range taskModel.Hosts { for _, taskHost := range taskModel.Hosts {
go func(th models.TaskHostDetail) { go func(th models.TaskHostDetail) {
output, err := rpcClient.ExecWithRetry(th.Name, th.Port, th.CertFile, taskRequest) output, err := rpcClient.ExecWithRetry(th.Name,
th.Port,
th.CertFile,
th.Token,
taskRequest)
var errorMessage string = "" var errorMessage string = ""
if err != nil { if err != nil {
errorMessage = err.Error() errorMessage = err.Error()

View File

@ -45,6 +45,14 @@
</div> </div>
</div> </div>
</div> </div>
<div class="two fields">
<div class="field">
<label>Token</label>
<div class="ui small input">
<textarea rows="4" name="token" placeholder="gocron-node中配置的token">{{{.Host.Token}}}</textarea>
</div>
</div>
</div>
<div class="two fields"> <div class="two fields">
<div class="field"> <div class="field">
<label></label> <label></label>