diff --git a/README.md b/README.md index 8b5b384..819c4a7 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ * -s ip:port 监听地址 * -cert-file 证书文件 * -key-file 私钥文件 + * -token * -h 查看帮助 * -v 查看版本 diff --git a/gocron-node.go b/gocron-node.go index 84eda20..df5258f 100644 --- a/gocron-node.go +++ b/gocron-node.go @@ -20,10 +20,12 @@ func main() { var version bool var keyFile string var certFile string + var token string flag.BoolVar(&allowRoot, "allow-root", false, "./gocron-node -allow-root") flag.StringVar(&serverAddr, "s", "0.0.0.0:5921", "./gocron-node -s ip:port") flag.StringVar(&certFile, "cert-file", "", "./gocron-node -cert-file path") flag.StringVar(&keyFile, "key-file", "", "./gocron-node -key-file path") + flag.StringVar(&token, "token", "", "./gocron-node -token") flag.BoolVar(&version, "v", false, "./gocron-node -v") flag.Parse() @@ -52,5 +54,5 @@ func main() { - server.Start(serverAddr, certFile, keyFile) + server.Start(serverAddr, certFile, keyFile, token) } \ No newline at end of file diff --git a/models/host.go b/models/host.go index 0066607..00f280f 100644 --- a/models/host.go +++ b/models/host.go @@ -11,6 +11,7 @@ type Host struct { Alias string `xorm:"varchar(32) notnull default '' "` // 主机别名 Port int `xorm:"notnull default 22"` // 主机端口 CertFile string `xorm:"varchar(64) notnull default '' "` + Token string `xorm:"varchar(128) notnull default '' "` Remark string `xorm:"varchar(100) notnull default '' "` // 备注 BaseModel `xorm:"-"` Selected bool `xorm:"-"` @@ -27,7 +28,7 @@ func (host *Host) Create() (insertId int16, err error) { } func (host *Host) UpdateBean(id int16) (int64, error) { - return Db.ID(id).Cols("name,alias,port,cert_file,remark").Update(host) + return Db.ID(id).Cols("name,alias,port,cert_file,token,remark").Update(host) } diff --git a/models/migration.go b/models/migration.go index c9d085c..ebd6ee8 100644 --- a/models/migration.go +++ b/models/migration.go @@ -143,6 +143,11 @@ func (migration *Migration) upgradeFor120(session *xorm.Session) error { // host表增加cert_file字段 tableName := TablePrefix + "host" _, err := session.Exec(fmt.Sprintf("ALTER TABLE %s Add COLUMN cert_file VARCHAR(64) NOT NULL DEFAULT ''", tableName)) + if err != nil { + return err + } + + _, err = session.Exec(fmt.Sprintf("ALTER TABLE %s Add COLUMN token VARCHAR(64) NOT NULL DEFAULT ''", tableName)) return err } \ No newline at end of file diff --git a/models/task_host.go b/models/task_host.go index 10611fe..670c1c1 100644 --- a/models/task_host.go +++ b/models/task_host.go @@ -13,6 +13,7 @@ type TaskHostDetail struct { Port int Alias string CertFile string + Token string } func (TaskHostDetail) TableName() string { @@ -49,7 +50,7 @@ func (th *TaskHost) Add(taskId int, hostIds []int) error { func (th *TaskHost) GetHostIdsByTaskId(taskId int) ([]TaskHostDetail, error) { list := make([]TaskHostDetail, 0) - fields := "th.id,th.host_id,h.alias,h.name,h.port,h.cert_file" + fields := "th.id,th.host_id,h.alias,h.name,h.port,h.cert_file,h.token" err := Db.Alias("th"). Join("LEFT", hostTableName(), "th.host_id=h.id"). Where("th.task_id = ?", taskId). diff --git a/modules/rpc/client/client.go b/modules/rpc/client/client.go index ca34e40..063e275 100644 --- a/modules/rpc/client/client.go +++ b/modules/rpc/client/client.go @@ -16,11 +16,11 @@ var ( errUnavailable = errors.New("无法连接远程服务器") ) -func ExecWithRetry(ip string, port int, certFile string,taskReq *pb.TaskRequest) (string, error) { +func ExecWithRetry(ip string, port int, certFile string, token string, taskReq *pb.TaskRequest) (string, error) { tryTimes := 15 i := 0 for i < tryTimes { - output, err := Exec(ip, port, certFile, taskReq) + output, err := Exec(ip, port, certFile, token, taskReq) if err != errUnavailable { return output, err } @@ -31,14 +31,14 @@ func ExecWithRetry(ip string, port int, certFile string,taskReq *pb.TaskRequest) return "", errUnavailable } -func Exec(ip string, port int, certFile string, taskReq *pb.TaskRequest) (string, error) { +func Exec(ip string, port int, certFile string, token string, taskReq *pb.TaskRequest) (string, error) { defer func() { if err := recover(); err != nil { logger.Error("panic#rpc/client.go:Exec#", err) } } () addr := fmt.Sprintf("%s:%d", ip, port) - conn, err := grpcpool.Pool.Get(addr, certFile) + conn, err := grpcpool.Pool.Get(addr, certFile, token) if err != nil { return "", err } @@ -77,4 +77,4 @@ func parseGRPCError(err error, conn *grpc.ClientConn, connClosed *bool) (string, return "", errors.New("执行超时, 强制结束") } return "", err -} +} \ No newline at end of file diff --git a/modules/rpc/grpcpool/grpc_pool.go b/modules/rpc/grpcpool/grpc_pool.go index 7b73c18..b2ff67e 100644 --- a/modules/rpc/grpcpool/grpc_pool.go +++ b/modules/rpc/grpcpool/grpc_pool.go @@ -7,6 +7,7 @@ import ( "google.golang.org/grpc" "errors" "google.golang.org/grpc/credentials" + "golang.org/x/net/context" "strings" ) @@ -32,12 +33,12 @@ type GRPCPool struct { sync.RWMutex } -func (p *GRPCPool) Get(addr, certFile string) (*grpc.ClientConn, error) { +func (p *GRPCPool) Get(addr, certFile, token string) (*grpc.ClientConn, error) { p.RLock() pool, ok := p.conns[addr] p.RUnlock() if !ok { - err := p.newCommonPool(addr, certFile) + err := p.newCommonPool(addr, certFile, token) if err != nil { return nil, err } @@ -88,7 +89,7 @@ func (p *GRPCPool) ReleaseAll() { } // 初始化底层连接池 -func (p *GRPCPool) newCommonPool(addr, certFile string) (error) { +func (p *GRPCPool) newCommonPool(addr, certFile, token string) (error) { p.Lock() defer p.Unlock() commonPool, ok := p.conns[addr] @@ -109,7 +110,13 @@ func (p *GRPCPool) newCommonPool(addr, certFile string) (error) { return nil, err } - return grpc.Dial(addr, grpc.WithTransportCredentials(creds)) + customCredential := &CustomCredential{Token: token} + + + return grpc.Dial(addr, + grpc.WithTransportCredentials(creds), + grpc.WithPerRPCCredentials(customCredential), + ) }, Close: func(v interface{}) error { conn, ok := v.(*grpc.ClientConn) @@ -129,4 +136,19 @@ func (p *GRPCPool) newCommonPool(addr, certFile string) (error) { p.conns[addr] = commonPool return nil +} + +type CustomCredential struct +{ + Token string +} + +func (c CustomCredential) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return map[string]string{ + "token": c.Token, + }, nil +} + +func (c CustomCredential) RequireTransportSecurity() bool { + return true } \ No newline at end of file diff --git a/modules/rpc/server/server.go b/modules/rpc/server/server.go index 1bd9076..3896133 100644 --- a/modules/rpc/server/server.go +++ b/modules/rpc/server/server.go @@ -8,9 +8,32 @@ import ( pb "github.com/ouqiang/gocron/modules/rpc/proto" "github.com/ouqiang/gocron/modules/utils" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "errors" ) -type Server struct {} +type Server struct +{ + Token string +} + +func (s Server) auth(ctx context.Context) error { + // 验证token是否有效 + meta, ok := metadata.FromContext(ctx) + if !ok { + return errors.New("missing metadata") + } + + token, ok := meta["token"] + if !ok { + return errors.New("missing param token") + } + if token[0] != s.Token { + return errors.New("invalid token") + } + + return nil +} func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse, error) { defer func() { @@ -18,6 +41,15 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse, grpclog.Println(err) } } () + + + if s.Token != "" { + err := s.auth(ctx) + if err != nil { + return nil, err + } + } + output, err := utils.ExecShell(ctx, req.Command) resp := new(pb.TaskResponse) resp.Output = output @@ -30,7 +62,7 @@ func (s Server) Run(ctx context.Context, req *pb.TaskRequest) (*pb.TaskResponse, return resp, nil } -func Start(addr, certFile, keyFile string) { +func Start(addr, certFile, keyFile, token string) { defer func() { if err := recover(); err != nil { grpclog.Println("panic", err) @@ -42,6 +74,7 @@ func Start(addr, certFile, keyFile string) { } var s *grpc.Server + server := Server{Token: token} if certFile != "" { // TLS认证 creds, err := credentials.NewServerTLSFromFile(certFile, keyFile) @@ -50,11 +83,11 @@ func Start(addr, certFile, keyFile string) { } s = grpc.NewServer(grpc.Creds(creds)) - pb.RegisterTaskServer(s, Server{}) + pb.RegisterTaskServer(s, server) grpclog.Printf("listen %s with TLS", addr) } else { s = grpc.NewServer() - pb.RegisterTaskServer(s, Server{}) + pb.RegisterTaskServer(s, server) grpclog.Println("listen ", addr) } err = s.Serve(l) diff --git a/routers/host/host.go b/routers/host/host.go index f1dba67..8da5a34 100644 --- a/routers/host/host.go +++ b/routers/host/host.go @@ -65,6 +65,7 @@ type HostForm struct { Alias string `binding:"Required;MaxSize(32)"` Port int `binding:"Required;Range(1-65535)"` CertFile string + Token string Remark string } @@ -95,6 +96,7 @@ func Store(ctx *macaron.Context, form HostForm) string { hostModel.Port = form.Port hostModel.Remark = strings.TrimSpace(form.Remark) hostModel.CertFile = strings.TrimSpace(form.CertFile) + hostModel.Token = strings.TrimSpace(form.Token) if hostModel.CertFile != "" && !utils.FileExist(hostModel.CertFile) { return json.CommonFailure("证书文件不存在或无权限访问") @@ -180,7 +182,11 @@ func Ping(ctx *macaron.Context) string { taskReq := &rpc.TaskRequest{} taskReq.Command = "echo hello" taskReq.Timeout = 10 - output, err := client.Exec(hostModel.Name, hostModel.Port, hostModel.CertFile, taskReq) + output, err := client.Exec(hostModel.Name, + hostModel.Port, + hostModel.CertFile, + hostModel.Token, + taskReq) if err != nil { return json.CommonFailure("连接失败-" + err.Error() + " " + output, err) } diff --git a/service/task.go b/service/task.go index df3a966..41766f8 100644 --- a/service/task.go +++ b/service/task.go @@ -181,7 +181,11 @@ func (h *RPCHandler) Run(taskModel models.Task) (result string, err error) { var resultChan chan TaskResult = make(chan TaskResult, len(taskModel.Hosts)) for _, taskHost := range taskModel.Hosts { go func(th models.TaskHostDetail) { - output, err := rpcClient.ExecWithRetry(th.Name, th.Port, th.CertFile, taskRequest) + output, err := rpcClient.ExecWithRetry(th.Name, + th.Port, + th.CertFile, + th.Token, + taskRequest) var errorMessage string = "" if err != nil { errorMessage = err.Error() diff --git a/templates/host/host_form.html b/templates/host/host_form.html index 5c547f3..805a3c7 100644 --- a/templates/host/host_form.html +++ b/templates/host/host_form.html @@ -45,6 +45,14 @@ +