diff --git a/spug_api/apps/exec/executors.py b/spug_api/apps/exec/executors.py index 0122dcd..f85e754 100644 --- a/spug_api/apps/exec/executors.py +++ b/spug_api/apps/exec/executors.py @@ -19,7 +19,7 @@ class Job: self.key = key self.command = self._handle_command(command, interpreter) self.token = token - self.rds_cli = None + self.rds = get_redis_connection() self.env = dict( SPUG_HOST_ID=str(self.key), SPUG_HOST_NAME=name, @@ -31,12 +31,8 @@ class Job: if isinstance(params, dict): self.env.update({f'_SPUG_{k}': str(v) for k, v in params.items()}) - def _send(self, message, with_expire=False): - if self.rds_cli is None: - self.rds_cli = get_redis_connection() - self.rds_cli.lpush(self.token, json.dumps(message)) - if with_expire: - self.rds_cli.expire(self.token, 300) + def _send(self, message): + self.rds.publish(self.token, json.dumps(message)) def _handle_command(self, command, interpreter): if interpreter == 'python': @@ -45,12 +41,10 @@ class Job: return command def send(self, data): - message = {'key': self.key, 'data': data} - self._send(message) + self._send({'key': self.key, 'data': data}) def send_status(self, code): - message = {'key': self.key, 'status': code} - self._send(message, True) + self._send({'key': self.key, 'status': code}) def run(self): if not self.token: diff --git a/spug_api/apps/exec/models.py b/spug_api/apps/exec/models.py index 990e397..5ec8eec 100644 --- a/spug_api/apps/exec/models.py +++ b/spug_api/apps/exec/models.py @@ -40,6 +40,7 @@ class ExecHistory(models.Model, ModelMixin): digest = models.CharField(max_length=32, db_index=True) interpreter = models.CharField(max_length=20) command = models.TextField() + params = models.CharField(max_length=500, default='{}') host_ids = models.TextField() updated_at = models.CharField(max_length=20, default=human_datetime) diff --git a/spug_api/apps/exec/transfer.py b/spug_api/apps/exec/transfer.py index 5e4b64a..0b1537d 100644 --- a/spug_api/apps/exec/transfer.py +++ b/spug_api/apps/exec/transfer.py @@ -3,6 +3,7 @@ # Released under the AGPL-3.0 License. from django.views.generic import View from django.conf import settings +from django.db import close_old_connections from django_redis import get_redis_connection from apps.exec.models import Transfer from apps.account.utils import has_host_perm @@ -10,6 +11,7 @@ from apps.host.models import Host from apps.setting.utils import AppSetting from libs import json_response, JsonParser, Argument, auth from concurrent import futures +from threading import Thread import subprocess import tempfile import uuid @@ -76,28 +78,33 @@ class TransferView(View): Argument('token', help='参数错误') ).parse(request.body) if error is None: - rds = get_redis_connection() task = Transfer.objects.get(digest=form.token) - threads = [] - max_workers = max(10, os.cpu_count() * 5) - with futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - for host in Host.objects.filter(id__in=json.loads(task.host_ids)): - t = executor.submit(_do_sync, rds, task, host) - t.token = task.digest - t.key = host.id - threads.append(t) - for t in futures.as_completed(threads): - exc = t.exception() - if exc: - rds.publish(t.token, json.dumps({'key': t.key, 'status': -1, 'data': f'Exception: {exc}'})) - if task.host_id: - command = f'umount -f {task.src_dir} && rm -rf {task.src_dir}' - else: - command = f'rm -rf {task.src_dir}' - subprocess.run(command, shell=True) + Thread(target=_dispatch_sync, args=(task,)).start() return json_response(error=error) +def _dispatch_sync(task): + rds = get_redis_connection() + threads = [] + max_workers = max(10, os.cpu_count() * 5) + with futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + for host in Host.objects.filter(id__in=json.loads(task.host_ids)): + t = executor.submit(_do_sync, rds, task, host) + t.token = task.digest + t.key = host.id + threads.append(t) + for t in futures.as_completed(threads): + exc = t.exception() + if exc: + rds.publish(t.token, json.dumps({'key': t.key, 'status': -1, 'data': f'Exception: {exc}'})) + if task.host_id: + command = f'umount -f {task.src_dir} && rm -rf {task.src_dir}' + else: + command = f'rm -rf {task.src_dir}' + subprocess.run(command, shell=True) + close_old_connections() + + def _do_sync(rds, task, host): token = task.digest rds.publish(token, json.dumps({'key': host.id, 'data': '\r\n\x1b[36m### Executing ...\x1b[0m\r\n'})) diff --git a/spug_api/apps/exec/urls.py b/spug_api/apps/exec/urls.py index 3c99a19..ffce12f 100644 --- a/spug_api/apps/exec/urls.py +++ b/spug_api/apps/exec/urls.py @@ -8,7 +8,6 @@ from apps.exec.transfer import TransferView urlpatterns = [ url(r'template/$', TemplateView.as_view()), - url(r'history/$', get_histories), - url(r'do/$', do_task), + url(r'do/$', TaskView.as_view()), url(r'transfer/$', TransferView.as_view()), ] diff --git a/spug_api/apps/exec/views.py b/spug_api/apps/exec/views.py index 06a8cf4..03bb8e4 100644 --- a/spug_api/apps/exec/views.py +++ b/spug_api/apps/exec/views.py @@ -4,12 +4,10 @@ from django.views.generic import View from django_redis import get_redis_connection from django.conf import settings -from django.db.models import F from libs import json_response, JsonParser, Argument, human_datetime, auth from apps.exec.models import ExecTemplate, ExecHistory from apps.host.models import Host from apps.account.utils import has_host_perm -import hashlib import uuid import json @@ -54,60 +52,66 @@ class TemplateView(View): return json_response(error=error) -@auth('exec.task.do') -def do_task(request): - form, error = JsonParser( - Argument('host_ids', type=list, filter=lambda x: len(x), help='请选择执行主机'), - Argument('command', help='请输入执行命令内容'), - Argument('interpreter', default='sh'), - Argument('template_id', type=int, required=False), - Argument('params', type=dict, required=False) - ).parse(request.body) - if error is None: - if not has_host_perm(request.user, form.host_ids): - return json_response(error='无权访问主机,请联系管理员') - token, rds = uuid.uuid4().hex, get_redis_connection() - for host in Host.objects.filter(id__in=form.host_ids): - data = dict( - key=host.id, - name=host.name, - token=token, - interpreter=form.interpreter, - hostname=host.hostname, - port=host.port, - username=host.username, - command=form.command, - pkey=host.private_key, - params=form.params - ) - rds.rpush(settings.EXEC_WORKER_KEY, json.dumps(data)) - form.host_ids.sort() - host_ids = json.dumps(form.host_ids) - tmp_str = f'{form.interpreter},{host_ids},{form.command}' - digest = hashlib.md5(tmp_str.encode()).hexdigest() - record = ExecHistory.objects.filter(user=request.user, digest=digest).first() - if form.template_id: - template = ExecTemplate.objects.filter(pk=form.template_id).first() - if not template or template.body != form.command: - form.template_id = None - if record: - record.template_id = form.template_id - record.updated_at = human_datetime() - record.save() - else: +class TaskView(View): + @auth('exec.task.do') + def get(self, request): + records = ExecHistory.objects.filter(user=request.user).select_related('template') + return json_response([x.to_view() for x in records]) + + @auth('exec.task.do') + def post(self, request): + form, error = JsonParser( + Argument('host_ids', type=list, filter=lambda x: len(x), help='请选择执行主机'), + Argument('command', help='请输入执行命令内容'), + Argument('interpreter', default='sh'), + Argument('template_id', type=int, required=False), + Argument('params', type=dict, handler=json.dumps, default={}) + ).parse(request.body) + if error is None: + if not has_host_perm(request.user, form.host_ids): + return json_response(error='无权访问主机,请联系管理员') + token, rds = uuid.uuid4().hex, get_redis_connection() + form.host_ids.sort() + if form.template_id: + template = ExecTemplate.objects.filter(pk=form.template_id).first() + if not template or template.body != form.command: + form.template_id = None + ExecHistory.objects.create( user=request.user, - digest=digest, + digest=token, interpreter=form.interpreter, template_id=form.template_id, command=form.command, host_ids=json.dumps(form.host_ids), + params=form.params ) - return json_response(token) - return json_response(error=error) + return json_response(token) + return json_response(error=error) + + @auth('exec.task.do') + def patch(self, request): + form, error = JsonParser( + Argument('token', help='参数错误') + ).parse(request.body) + if error is None: + rds = get_redis_connection() + task = ExecHistory.objects.get(digest=form.token) + for host in Host.objects.filter(id__in=json.loads(task.host_ids)): + data = dict( + key=host.id, + name=host.name, + token=task.digest, + interpreter=task.interpreter, + hostname=host.hostname, + port=host.port, + username=host.username, + command=task.command, + pkey=host.private_key, + params=json.loads(task.params) + ) + rds.rpush(settings.EXEC_WORKER_KEY, json.dumps(data)) + return json_response(error=error) + -@auth('exec.task.do') -def get_histories(request): - records = ExecHistory.objects.filter(user=request.user).select_related('template') - return json_response([x.to_view() for x in records]) diff --git a/spug_api/apps/file/views.py b/spug_api/apps/file/views.py index 53004bf..fdadd02 100644 --- a/spug_api/apps/file/views.py +++ b/spug_api/apps/file/views.py @@ -89,5 +89,4 @@ class ObjectView(View): def _compute_progress(self, rds_cli, token, total, value, *args): percent = '%.1f' % (value / total * 100) - rds_cli.lpush(token, percent) - rds_cli.expire(token, 300) + rds_cli.publish(token, percent)