diff --git a/spug_api/apps/exec/executors.py b/spug_api/apps/exec/executors.py index a92602b..2af5297 100644 --- a/spug_api/apps/exec/executors.py +++ b/spug_api/apps/exec/executors.py @@ -15,7 +15,7 @@ def exec_worker_handler(job): class Job: def __init__(self, hostname, port, username, pkey, command, token=None): - self.ssh_cli = SSH(hostname, port, username, pkey) + self.ssh = SSH(hostname, port, username, pkey) self.key = f'{hostname}:{port}' self.command = command self.token = token @@ -29,15 +29,7 @@ class Job: self.rds_cli.expire(self.token, 300) def send(self, data): - message = {'key': self.key, 'type': 'info', 'data': data} - self._send(message) - - def send_system(self, data): - message = {'key': self.key, 'type': 'system', 'data': data} - self._send(message) - - def send_error(self, data): - message = {'key': self.key, 'type': 'error', 'data': data} + message = {'key': self.key, 'data': data} self._send(message) def send_status(self, code): @@ -46,17 +38,19 @@ class Job: def run(self): if not self.token: - return self.ssh_cli.exec_command(self.command) - self.send_system('### Executing') + with self.ssh: + return self.ssh.exec_command_raw(self.command) + self.send('\x1b[36m### Executing ...\x1b[0m\r') code = -1 try: - for code, out in self.ssh_cli.exec_command_with_stream(self.command): - self.send(out) + with self.ssh: + for code, out in self.ssh.exec_command_with_stream(self.command): + self.send(out) except socket.timeout: code = 130 - self.send_error('### Time out') + self.send('\r\n\x1b[31m### Time out\x1b[0m') except Exception as e: code = 131 - self.send_error(f'{e}') + self.send(f'\r\n\x1b[31m### Exception {e}\x1b[0m') finally: self.send_status(code) diff --git a/spug_api/apps/file/views.py b/spug_api/apps/file/views.py index 2f33b06..aa405f6 100644 --- a/spug_api/apps/file/views.py +++ b/spug_api/apps/file/views.py @@ -5,7 +5,7 @@ from django.views.generic import View from django_redis import get_redis_connection from apps.host.models import Host from apps.account.utils import has_host_perm -from apps.file.utils import FileResponseAfter, parse_sftp_attr +from apps.file.utils import FileResponseAfter, FileResponse, parse_sftp_attr from libs import json_response, JsonParser, Argument from functools import partial import os @@ -23,8 +23,8 @@ class FileView(View): host = Host.objects.get(pk=form.id) if not host: return json_response(error='未找到指定主机') - cli = host.get_ssh() - objects = cli.list_dir_attr(form.path) + with host.get_ssh() as ssh: + objects = ssh.list_dir_attr(form.path) return json_response([parse_sftp_attr(x) for x in objects]) return json_response(error=error) @@ -42,10 +42,10 @@ class ObjectView(View): if not host: return json_response(error='未找到指定主机') filename = os.path.basename(form.file) - cli = host.get_ssh().get_client() - sftp = cli.open_sftp() + ssh_cli = host.get_ssh().get_client() + sftp = ssh_cli.open_sftp() f = sftp.open(form.file) - return FileResponseAfter(cli.close, f, as_attachment=True, filename=filename) + return FileResponseAfter(ssh_cli.close, f, as_attachment=True, filename=filename) return json_response(error=error) def post(self, request): @@ -63,10 +63,10 @@ class ObjectView(View): host = Host.objects.get(pk=form.id) if not host: return json_response(error='未找到指定主机') - cli = host.get_ssh() rds_cli = get_redis_connection() callback = partial(self._compute_progress, rds_cli, form.token, file.size) - cli.put_file_by_fl(file, f'{form.path}/{file.name}', callback=callback) + with host.get_ssh() as ssh: + ssh.put_file_by_fl(file, f'{form.path}/{file.name}', callback=callback) return json_response(error=error) def delete(self, request): @@ -80,8 +80,8 @@ class ObjectView(View): host = Host.objects.get(pk=form.id) if not host: return json_response(error='未找到指定主机') - cli = host.get_ssh() - cli.remove_file(form.file) + with host.get_ssh() as ssh: + ssh.remove_file(form.file) return json_response(error=error) def _compute_progress(self, rds_cli, token, total, value, *args): diff --git a/spug_api/apps/host/views.py b/spug_api/apps/host/views.py index 2d2c2c0..e2cf399 100644 --- a/spug_api/apps/host/views.py +++ b/spug_api/apps/host/views.py @@ -48,11 +48,11 @@ class HostView(View): if form.pkey: private_key = form.pkey elif password: - ssh = SSH(form.hostname, form.port, form.username, password=password) - ssh.add_public_key(public_key) + with SSH(form.hostname, form.port, form.username, password=password) as ssh: + ssh.add_public_key(public_key) - ssh = SSH(form.hostname, form.port, form.username, private_key) - ssh.ping() + with SSH(form.hostname, form.port, form.username, private_key) as ssh: + ssh.ping() except BadAuthenticationType: return json_response(error='该主机不支持密钥认证,请参考官方文档,错误代码:E01') except AuthenticationException: diff --git a/spug_api/libs/ssh.py b/spug_api/libs/ssh.py index 4dc7071..9f4d013 100644 --- a/spug_api/libs/ssh.py +++ b/spug_api/libs/ssh.py @@ -2,17 +2,21 @@ # Copyright: (c) # Released under the AGPL-3.0 License. from paramiko.client import SSHClient, AutoAddPolicy -from paramiko.config import SSH_PORT from paramiko.rsakey import RSAKey from paramiko.ssh_exception import AuthenticationException from io import StringIO +import time +import re class SSH: - def __init__(self, hostname, port=SSH_PORT, username='root', pkey=None, password=None, connect_timeout=10): - if pkey is None and password is None: - raise Exception('public key and password must have one is not None') + def __init__(self, hostname, port=22, username='root', pkey=None, password=None, connect_timeout=10): + self.stdout = None self.client = None + self.channel = None + self.sftp = None + self.eof = 'Spug EOF 2108111926' + self.regex = re.compile(r'Spug EOF 2108111926 \d+[\r\n]?$') self.arguments = { 'hostname': hostname, 'port': port, @@ -29,94 +33,141 @@ class SSH: key.write_private_key(key_obj) return key_obj.getvalue(), 'ssh-rsa ' + key.get_base64() - def add_public_key(self, public_key): - command = f'mkdir -p -m 700 ~/.ssh && \ - echo {public_key!r} >> ~/.ssh/authorized_keys && \ - chmod 600 ~/.ssh/authorized_keys' - code, out = self.exec_command(command) - if code != 0: - raise Exception(f'add public key error: {out}') - - def ping(self): - with self: - return True - def get_client(self): if self.client is not None: return self.client + print('\n~~ ssh start ~~') self.client = SSHClient() self.client.set_missing_host_key_policy(AutoAddPolicy) self.client.connect(**self.arguments) return self.client + def ping(self): + self.get_client() + return True + + def add_public_key(self, public_key): + command = f'mkdir -p -m 700 ~/.ssh && \ + echo {public_key!r} >> ~/.ssh/authorized_keys && \ + chmod 600 ~/.ssh/authorized_keys' + _, out, _ = self.client.exec_command(command) + if out.channel.recv_exit_status() != 0: + raise Exception(f'add public key error: {out}') + + def exec_command_raw(self, command): + channel = self.client.get_transport().open_session() + try: + channel.set_combine_stderr(True) + channel.exec_command(command) + return channel.recv_exit_status(), channel.recv(-1).decode() + finally: + channel.close() + + def exec_command(self, command, environment=None): + channel = self._get_channel() + command = self._handle_command(command, environment) + channel.send(command) + out, exit_code = '', -1 + for line in self.stdout: + if line.startswith(self.eof): + exit_code = int(line.rsplit()[-1]) + break + out += line + return exit_code, out + + def exec_command_with_stream(self, command, environment=None): + channel = self._get_channel() + command = self._handle_command(command, environment) + channel.send(command) + exit_code, line = -1, '' + while True: + line = channel.recv(8196).decode() + print(repr(line)) + match = self.regex.search(line) + if match: + exit_code = int(line.rsplit()[-1]) + line = line[:match.start()] + break + yield exit_code, line + yield exit_code, line + + def get_file(self, file): + sftp = self._get_sftp() + return sftp.open(file) + def put_file(self, local_path, remote_path): - with self as cli: - sftp = cli.open_sftp() - sftp.put(local_path, remote_path) - sftp.close() - - def exec_command(self, command, timeout=1800, environment=None): - command = 'set -e\n' + command - with self as cli: - chan = cli.get_transport().open_session() - chan.settimeout(timeout) - chan.set_combine_stderr(True) - if environment: - str_env = ' '.join(self._handle_env(k, v) for k, v in environment.items()) - command = f'export {str_env} && {command}' - chan.exec_command(command) - stdout = chan.makefile("rb", -1) - return chan.recv_exit_status(), self._decode(stdout.read()) - - def exec_command_with_stream(self, command, timeout=1800, environment=None): - command = 'set -e\n' + command - with self as cli: - chan = cli.get_transport().open_session() - chan.settimeout(timeout) - chan.set_combine_stderr(True) - if environment: - str_env = ' '.join(self._handle_env(k, v) for k, v in environment.items()) - command = f'export {str_env} && {command}' - chan.exec_command(command) - stdout = chan.makefile("rb", -1) - out = stdout.readline() - while out: - yield chan.exit_status, self._decode(out) - out = stdout.readline() - yield chan.recv_exit_status(), self._decode(out) + sftp = self._get_sftp() + sftp.put(local_path, remote_path) def put_file_by_fl(self, fl, remote_path, callback=None): - with self as cli: - sftp = cli.open_sftp() - sftp.putfo(fl, remote_path, callback=callback) + sftp = self._get_sftp() + sftp.putfo(fl, remote_path, callback=callback) def list_dir_attr(self, path): - with self as cli: - sftp = cli.open_sftp() - return sftp.listdir_attr(path) + sftp = self._get_sftp() + return sftp.listdir_attr(path) def remove_file(self, path): - with self as cli: - sftp = cli.open_sftp() - sftp.remove(path) + sftp = self._get_sftp() + sftp.remove(path) - def _decode(self, out: bytes): - try: - return out.decode() - except UnicodeDecodeError: - return out.decode('GBK') + def _get_channel(self): + if self.channel: + return self.channel - def _handle_env(self, key, value): - key = key.replace('-', '_') - if isinstance(value, str): - value = value.replace("'", "'\"'\"'") - return f"{key}='{value}'" + counter, data = 0, '' + self.channel = self.client.invoke_shell() + self.channel.send(b'export PS1= && stty -echo && echo Spug execute start\n') + while True: + if self.channel.recv_ready(): + data += self.channel.recv(8196).decode() + if 'Spug execute start\r\n' in data: + self.stdout = self.channel.makefile('r') + break + elif counter >= 100: + self.client.close() + raise Exception('Wait spug response timeout') + else: + counter += 1 + time.sleep(0.1) + return self.channel + + def _get_sftp(self): + if self.sftp: + return self.sftp + + self.sftp = self.client.open_sftp() + return self.sftp + + def _break(self): + time.sleep(5) + command = f'\x03 echo {self.eof} -1\n' + self.channel.send(command.encode()) + + def _make_env_command(self, environment): + if not environment: + return None + str_envs = [] + for k, v in environment.items(): + k = k.replace('-', '_') + if isinstance(v, str): + v = v.replace("'", "'\"'\"'") + str_envs.append(f"{k}='{v}'") + str_envs = ' '.join(str_envs) + return f'export {str_envs}' + + def _handle_command(self, command, environment): + commands = list() + commands.append(self._make_env_command(environment)) + commands.append(command.strip('\n')) + commands.append(f'echo {self.eof} $?\n') + return '\n'.join(x for x in commands if x).encode() def __enter__(self): - if self.client is not None: - raise RuntimeError('Already connected') - return self.get_client() + self.get_client() + return self def __exit__(self, exc_type, exc_val, exc_tb): + print('close √') self.client.close() self.client = None