gocron/modules/ssh/ssh.go

137 lines
2.9 KiB
Go
Raw Normal View History

package ssh
import (
"golang.org/x/crypto/ssh"
"fmt"
"net"
"time"
"errors"
)
2017-04-20 01:36:42 +00:00
type HostAuthType int8 // 认证方式
const (
HostPassword = 1 // 密码认证
HostPublicKey = 2 // 公钥认证
)
const SSHConnectTimeout = 10
type SSHConfig struct {
2017-04-20 01:36:42 +00:00
AuthType HostAuthType
User string
Password string
2017-04-20 01:36:42 +00:00
PrivateKey string
Host string
Port int
ExecTimeout int// 执行超时时间
}
type Result struct {
Output string
Err error
}
2017-04-20 01:36:42 +00:00
func parseSSHConfig(sshConfig SSHConfig) (config *ssh.ClientConfig, err error) {
timeout := SSHConnectTimeout * time.Second
// 密码认证
if sshConfig.AuthType == HostPassword {
config = &ssh.ClientConfig{
User: sshConfig.User,
Auth: []ssh.AuthMethod{
ssh.Password(sshConfig.Password),
},
Timeout: timeout,
HostKeyCallback:func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
return
}
signer, err := ssh.ParsePrivateKey([]byte(sshConfig.PrivateKey))
if err != nil {
return
}
// 公钥认证
config = &ssh.ClientConfig{
User: sshConfig.User,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
Timeout: timeout,
HostKeyCallback:func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
return
}
// 执行shell命令
func Exec(sshConfig SSHConfig, cmd string) (output string, err error) {
2017-04-16 18:01:41 +00:00
client, err := getClient(sshConfig)
if err != nil {
return "", err
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
return "", err
}
defer session.Close()
if sshConfig.ExecTimeout <= 0 {
outputByte, execErr := session.CombinedOutput(cmd)
output = string(outputByte)
err = execErr
return
}
var resultChan chan Result = make(chan Result)
var timeoutChan chan bool = make(chan bool)
2017-04-16 18:01:41 +00:00
go func() {
output, err := session.CombinedOutput(cmd)
resultChan <- Result{string(output), err}
}()
go triggerTimeout(timeoutChan, sshConfig.ExecTimeout)
select {
case result := <- resultChan:
output = result.Output
err = result.Err
case <- timeoutChan:
output = ""
err = errors.New("timeout")
}
return
}
func getClient(sshConfig SSHConfig) (*ssh.Client, error) {
2017-04-20 01:36:42 +00:00
config, err := parseSSHConfig(sshConfig)
if err != nil {
return nil, err
}
addr := fmt.Sprintf("%s:%d", sshConfig.Host, sshConfig.Port)
return ssh.Dial("tcp", addr, config)
}
func triggerTimeout(ch chan bool, timeout int){
// 最长执行时间不能超过24小时
if timeout <= 0 || timeout > 86400 {
timeout = 86400
}
time.Sleep(time.Duration(timeout) * time.Second)
2017-04-12 06:12:34 +00:00
close(ch)
2017-04-20 01:36:42 +00:00
}