U 优化批量执行

pull/517/head
vapao 2022-07-04 11:25:06 +08:00
parent 3e2357ae50
commit d5b9828564
6 changed files with 87 additions and 83 deletions

View File

@ -19,7 +19,7 @@ class Job:
self.key = key
self.command = self._handle_command(command, interpreter)
self.token = token
self.rds_cli = None
self.rds = get_redis_connection()
self.env = dict(
SPUG_HOST_ID=str(self.key),
SPUG_HOST_NAME=name,
@ -31,12 +31,8 @@ class Job:
if isinstance(params, dict):
self.env.update({f'_SPUG_{k}': str(v) for k, v in params.items()})
def _send(self, message, with_expire=False):
if self.rds_cli is None:
self.rds_cli = get_redis_connection()
self.rds_cli.lpush(self.token, json.dumps(message))
if with_expire:
self.rds_cli.expire(self.token, 300)
def _send(self, message):
self.rds.publish(self.token, json.dumps(message))
def _handle_command(self, command, interpreter):
if interpreter == 'python':
@ -45,12 +41,10 @@ class Job:
return command
def send(self, data):
message = {'key': self.key, 'data': data}
self._send(message)
self._send({'key': self.key, 'data': data})
def send_status(self, code):
message = {'key': self.key, 'status': code}
self._send(message, True)
self._send({'key': self.key, 'status': code})
def run(self):
if not self.token:

View File

@ -40,6 +40,7 @@ class ExecHistory(models.Model, ModelMixin):
digest = models.CharField(max_length=32, db_index=True)
interpreter = models.CharField(max_length=20)
command = models.TextField()
params = models.CharField(max_length=500, default='{}')
host_ids = models.TextField()
updated_at = models.CharField(max_length=20, default=human_datetime)

View File

@ -3,6 +3,7 @@
# Released under the AGPL-3.0 License.
from django.views.generic import View
from django.conf import settings
from django.db import close_old_connections
from django_redis import get_redis_connection
from apps.exec.models import Transfer
from apps.account.utils import has_host_perm
@ -10,6 +11,7 @@ from apps.host.models import Host
from apps.setting.utils import AppSetting
from libs import json_response, JsonParser, Argument, auth
from concurrent import futures
from threading import Thread
import subprocess
import tempfile
import uuid
@ -76,28 +78,33 @@ class TransferView(View):
Argument('token', help='参数错误')
).parse(request.body)
if error is None:
rds = get_redis_connection()
task = Transfer.objects.get(digest=form.token)
threads = []
max_workers = max(10, os.cpu_count() * 5)
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for host in Host.objects.filter(id__in=json.loads(task.host_ids)):
t = executor.submit(_do_sync, rds, task, host)
t.token = task.digest
t.key = host.id
threads.append(t)
for t in futures.as_completed(threads):
exc = t.exception()
if exc:
rds.publish(t.token, json.dumps({'key': t.key, 'status': -1, 'data': f'Exception: {exc}'}))
if task.host_id:
command = f'umount -f {task.src_dir} && rm -rf {task.src_dir}'
else:
command = f'rm -rf {task.src_dir}'
subprocess.run(command, shell=True)
Thread(target=_dispatch_sync, args=(task,)).start()
return json_response(error=error)
def _dispatch_sync(task):
rds = get_redis_connection()
threads = []
max_workers = max(10, os.cpu_count() * 5)
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for host in Host.objects.filter(id__in=json.loads(task.host_ids)):
t = executor.submit(_do_sync, rds, task, host)
t.token = task.digest
t.key = host.id
threads.append(t)
for t in futures.as_completed(threads):
exc = t.exception()
if exc:
rds.publish(t.token, json.dumps({'key': t.key, 'status': -1, 'data': f'Exception: {exc}'}))
if task.host_id:
command = f'umount -f {task.src_dir} && rm -rf {task.src_dir}'
else:
command = f'rm -rf {task.src_dir}'
subprocess.run(command, shell=True)
close_old_connections()
def _do_sync(rds, task, host):
token = task.digest
rds.publish(token, json.dumps({'key': host.id, 'data': '\r\n\x1b[36m### Executing ...\x1b[0m\r\n'}))

View File

@ -8,7 +8,6 @@ from apps.exec.transfer import TransferView
urlpatterns = [
url(r'template/$', TemplateView.as_view()),
url(r'history/$', get_histories),
url(r'do/$', do_task),
url(r'do/$', TaskView.as_view()),
url(r'transfer/$', TransferView.as_view()),
]

View File

@ -4,12 +4,10 @@
from django.views.generic import View
from django_redis import get_redis_connection
from django.conf import settings
from django.db.models import F
from libs import json_response, JsonParser, Argument, human_datetime, auth
from apps.exec.models import ExecTemplate, ExecHistory
from apps.host.models import Host
from apps.account.utils import has_host_perm
import hashlib
import uuid
import json
@ -54,60 +52,66 @@ class TemplateView(View):
return json_response(error=error)
@auth('exec.task.do')
def do_task(request):
form, error = JsonParser(
Argument('host_ids', type=list, filter=lambda x: len(x), help='请选择执行主机'),
Argument('command', help='请输入执行命令内容'),
Argument('interpreter', default='sh'),
Argument('template_id', type=int, required=False),
Argument('params', type=dict, required=False)
).parse(request.body)
if error is None:
if not has_host_perm(request.user, form.host_ids):
return json_response(error='无权访问主机,请联系管理员')
token, rds = uuid.uuid4().hex, get_redis_connection()
for host in Host.objects.filter(id__in=form.host_ids):
data = dict(
key=host.id,
name=host.name,
token=token,
interpreter=form.interpreter,
hostname=host.hostname,
port=host.port,
username=host.username,
command=form.command,
pkey=host.private_key,
params=form.params
)
rds.rpush(settings.EXEC_WORKER_KEY, json.dumps(data))
form.host_ids.sort()
host_ids = json.dumps(form.host_ids)
tmp_str = f'{form.interpreter},{host_ids},{form.command}'
digest = hashlib.md5(tmp_str.encode()).hexdigest()
record = ExecHistory.objects.filter(user=request.user, digest=digest).first()
if form.template_id:
template = ExecTemplate.objects.filter(pk=form.template_id).first()
if not template or template.body != form.command:
form.template_id = None
if record:
record.template_id = form.template_id
record.updated_at = human_datetime()
record.save()
else:
class TaskView(View):
@auth('exec.task.do')
def get(self, request):
records = ExecHistory.objects.filter(user=request.user).select_related('template')
return json_response([x.to_view() for x in records])
@auth('exec.task.do')
def post(self, request):
form, error = JsonParser(
Argument('host_ids', type=list, filter=lambda x: len(x), help='请选择执行主机'),
Argument('command', help='请输入执行命令内容'),
Argument('interpreter', default='sh'),
Argument('template_id', type=int, required=False),
Argument('params', type=dict, handler=json.dumps, default={})
).parse(request.body)
if error is None:
if not has_host_perm(request.user, form.host_ids):
return json_response(error='无权访问主机,请联系管理员')
token, rds = uuid.uuid4().hex, get_redis_connection()
form.host_ids.sort()
if form.template_id:
template = ExecTemplate.objects.filter(pk=form.template_id).first()
if not template or template.body != form.command:
form.template_id = None
ExecHistory.objects.create(
user=request.user,
digest=digest,
digest=token,
interpreter=form.interpreter,
template_id=form.template_id,
command=form.command,
host_ids=json.dumps(form.host_ids),
params=form.params
)
return json_response(token)
return json_response(error=error)
return json_response(token)
return json_response(error=error)
@auth('exec.task.do')
def patch(self, request):
form, error = JsonParser(
Argument('token', help='参数错误')
).parse(request.body)
if error is None:
rds = get_redis_connection()
task = ExecHistory.objects.get(digest=form.token)
for host in Host.objects.filter(id__in=json.loads(task.host_ids)):
data = dict(
key=host.id,
name=host.name,
token=task.digest,
interpreter=task.interpreter,
hostname=host.hostname,
port=host.port,
username=host.username,
command=task.command,
pkey=host.private_key,
params=json.loads(task.params)
)
rds.rpush(settings.EXEC_WORKER_KEY, json.dumps(data))
return json_response(error=error)
@auth('exec.task.do')
def get_histories(request):
records = ExecHistory.objects.filter(user=request.user).select_related('template')
return json_response([x.to_view() for x in records])

View File

@ -89,5 +89,4 @@ class ObjectView(View):
def _compute_progress(self, rds_cli, token, total, value, *args):
percent = '%.1f' % (value / total * 100)
rds_cli.lpush(token, percent)
rds_cli.expire(token, 300)
rds_cli.publish(token, percent)