U 优化批量执行

pull/517/head
vapao 2022-07-04 11:25:06 +08:00
parent 3e2357ae50
commit d5b9828564
6 changed files with 87 additions and 83 deletions

View File

@ -19,7 +19,7 @@ class Job:
self.key = key self.key = key
self.command = self._handle_command(command, interpreter) self.command = self._handle_command(command, interpreter)
self.token = token self.token = token
self.rds_cli = None self.rds = get_redis_connection()
self.env = dict( self.env = dict(
SPUG_HOST_ID=str(self.key), SPUG_HOST_ID=str(self.key),
SPUG_HOST_NAME=name, SPUG_HOST_NAME=name,
@ -31,12 +31,8 @@ class Job:
if isinstance(params, dict): if isinstance(params, dict):
self.env.update({f'_SPUG_{k}': str(v) for k, v in params.items()}) self.env.update({f'_SPUG_{k}': str(v) for k, v in params.items()})
def _send(self, message, with_expire=False): def _send(self, message):
if self.rds_cli is None: self.rds.publish(self.token, json.dumps(message))
self.rds_cli = get_redis_connection()
self.rds_cli.lpush(self.token, json.dumps(message))
if with_expire:
self.rds_cli.expire(self.token, 300)
def _handle_command(self, command, interpreter): def _handle_command(self, command, interpreter):
if interpreter == 'python': if interpreter == 'python':
@ -45,12 +41,10 @@ class Job:
return command return command
def send(self, data): def send(self, data):
message = {'key': self.key, 'data': data} self._send({'key': self.key, 'data': data})
self._send(message)
def send_status(self, code): def send_status(self, code):
message = {'key': self.key, 'status': code} self._send({'key': self.key, 'status': code})
self._send(message, True)
def run(self): def run(self):
if not self.token: if not self.token:

View File

@ -40,6 +40,7 @@ class ExecHistory(models.Model, ModelMixin):
digest = models.CharField(max_length=32, db_index=True) digest = models.CharField(max_length=32, db_index=True)
interpreter = models.CharField(max_length=20) interpreter = models.CharField(max_length=20)
command = models.TextField() command = models.TextField()
params = models.CharField(max_length=500, default='{}')
host_ids = models.TextField() host_ids = models.TextField()
updated_at = models.CharField(max_length=20, default=human_datetime) updated_at = models.CharField(max_length=20, default=human_datetime)

View File

@ -3,6 +3,7 @@
# Released under the AGPL-3.0 License. # Released under the AGPL-3.0 License.
from django.views.generic import View from django.views.generic import View
from django.conf import settings from django.conf import settings
from django.db import close_old_connections
from django_redis import get_redis_connection from django_redis import get_redis_connection
from apps.exec.models import Transfer from apps.exec.models import Transfer
from apps.account.utils import has_host_perm from apps.account.utils import has_host_perm
@ -10,6 +11,7 @@ from apps.host.models import Host
from apps.setting.utils import AppSetting from apps.setting.utils import AppSetting
from libs import json_response, JsonParser, Argument, auth from libs import json_response, JsonParser, Argument, auth
from concurrent import futures from concurrent import futures
from threading import Thread
import subprocess import subprocess
import tempfile import tempfile
import uuid import uuid
@ -76,28 +78,33 @@ class TransferView(View):
Argument('token', help='参数错误') Argument('token', help='参数错误')
).parse(request.body) ).parse(request.body)
if error is None: if error is None:
rds = get_redis_connection()
task = Transfer.objects.get(digest=form.token) task = Transfer.objects.get(digest=form.token)
threads = [] Thread(target=_dispatch_sync, args=(task,)).start()
max_workers = max(10, os.cpu_count() * 5)
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for host in Host.objects.filter(id__in=json.loads(task.host_ids)):
t = executor.submit(_do_sync, rds, task, host)
t.token = task.digest
t.key = host.id
threads.append(t)
for t in futures.as_completed(threads):
exc = t.exception()
if exc:
rds.publish(t.token, json.dumps({'key': t.key, 'status': -1, 'data': f'Exception: {exc}'}))
if task.host_id:
command = f'umount -f {task.src_dir} && rm -rf {task.src_dir}'
else:
command = f'rm -rf {task.src_dir}'
subprocess.run(command, shell=True)
return json_response(error=error) return json_response(error=error)
def _dispatch_sync(task):
rds = get_redis_connection()
threads = []
max_workers = max(10, os.cpu_count() * 5)
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
for host in Host.objects.filter(id__in=json.loads(task.host_ids)):
t = executor.submit(_do_sync, rds, task, host)
t.token = task.digest
t.key = host.id
threads.append(t)
for t in futures.as_completed(threads):
exc = t.exception()
if exc:
rds.publish(t.token, json.dumps({'key': t.key, 'status': -1, 'data': f'Exception: {exc}'}))
if task.host_id:
command = f'umount -f {task.src_dir} && rm -rf {task.src_dir}'
else:
command = f'rm -rf {task.src_dir}'
subprocess.run(command, shell=True)
close_old_connections()
def _do_sync(rds, task, host): def _do_sync(rds, task, host):
token = task.digest token = task.digest
rds.publish(token, json.dumps({'key': host.id, 'data': '\r\n\x1b[36m### Executing ...\x1b[0m\r\n'})) rds.publish(token, json.dumps({'key': host.id, 'data': '\r\n\x1b[36m### Executing ...\x1b[0m\r\n'}))

View File

@ -8,7 +8,6 @@ from apps.exec.transfer import TransferView
urlpatterns = [ urlpatterns = [
url(r'template/$', TemplateView.as_view()), url(r'template/$', TemplateView.as_view()),
url(r'history/$', get_histories), url(r'do/$', TaskView.as_view()),
url(r'do/$', do_task),
url(r'transfer/$', TransferView.as_view()), url(r'transfer/$', TransferView.as_view()),
] ]

View File

@ -4,12 +4,10 @@
from django.views.generic import View from django.views.generic import View
from django_redis import get_redis_connection from django_redis import get_redis_connection
from django.conf import settings from django.conf import settings
from django.db.models import F
from libs import json_response, JsonParser, Argument, human_datetime, auth from libs import json_response, JsonParser, Argument, human_datetime, auth
from apps.exec.models import ExecTemplate, ExecHistory from apps.exec.models import ExecTemplate, ExecHistory
from apps.host.models import Host from apps.host.models import Host
from apps.account.utils import has_host_perm from apps.account.utils import has_host_perm
import hashlib
import uuid import uuid
import json import json
@ -54,60 +52,66 @@ class TemplateView(View):
return json_response(error=error) return json_response(error=error)
@auth('exec.task.do') class TaskView(View):
def do_task(request): @auth('exec.task.do')
form, error = JsonParser( def get(self, request):
Argument('host_ids', type=list, filter=lambda x: len(x), help='请选择执行主机'), records = ExecHistory.objects.filter(user=request.user).select_related('template')
Argument('command', help='请输入执行命令内容'), return json_response([x.to_view() for x in records])
Argument('interpreter', default='sh'),
Argument('template_id', type=int, required=False), @auth('exec.task.do')
Argument('params', type=dict, required=False) def post(self, request):
).parse(request.body) form, error = JsonParser(
if error is None: Argument('host_ids', type=list, filter=lambda x: len(x), help='请选择执行主机'),
if not has_host_perm(request.user, form.host_ids): Argument('command', help='请输入执行命令内容'),
return json_response(error='无权访问主机,请联系管理员') Argument('interpreter', default='sh'),
token, rds = uuid.uuid4().hex, get_redis_connection() Argument('template_id', type=int, required=False),
for host in Host.objects.filter(id__in=form.host_ids): Argument('params', type=dict, handler=json.dumps, default={})
data = dict( ).parse(request.body)
key=host.id, if error is None:
name=host.name, if not has_host_perm(request.user, form.host_ids):
token=token, return json_response(error='无权访问主机,请联系管理员')
interpreter=form.interpreter, token, rds = uuid.uuid4().hex, get_redis_connection()
hostname=host.hostname, form.host_ids.sort()
port=host.port, if form.template_id:
username=host.username, template = ExecTemplate.objects.filter(pk=form.template_id).first()
command=form.command, if not template or template.body != form.command:
pkey=host.private_key, form.template_id = None
params=form.params
)
rds.rpush(settings.EXEC_WORKER_KEY, json.dumps(data))
form.host_ids.sort()
host_ids = json.dumps(form.host_ids)
tmp_str = f'{form.interpreter},{host_ids},{form.command}'
digest = hashlib.md5(tmp_str.encode()).hexdigest()
record = ExecHistory.objects.filter(user=request.user, digest=digest).first()
if form.template_id:
template = ExecTemplate.objects.filter(pk=form.template_id).first()
if not template or template.body != form.command:
form.template_id = None
if record:
record.template_id = form.template_id
record.updated_at = human_datetime()
record.save()
else:
ExecHistory.objects.create( ExecHistory.objects.create(
user=request.user, user=request.user,
digest=digest, digest=token,
interpreter=form.interpreter, interpreter=form.interpreter,
template_id=form.template_id, template_id=form.template_id,
command=form.command, command=form.command,
host_ids=json.dumps(form.host_ids), host_ids=json.dumps(form.host_ids),
params=form.params
) )
return json_response(token) return json_response(token)
return json_response(error=error) return json_response(error=error)
@auth('exec.task.do')
def patch(self, request):
form, error = JsonParser(
Argument('token', help='参数错误')
).parse(request.body)
if error is None:
rds = get_redis_connection()
task = ExecHistory.objects.get(digest=form.token)
for host in Host.objects.filter(id__in=json.loads(task.host_ids)):
data = dict(
key=host.id,
name=host.name,
token=task.digest,
interpreter=task.interpreter,
hostname=host.hostname,
port=host.port,
username=host.username,
command=task.command,
pkey=host.private_key,
params=json.loads(task.params)
)
rds.rpush(settings.EXEC_WORKER_KEY, json.dumps(data))
return json_response(error=error)
@auth('exec.task.do')
def get_histories(request):
records = ExecHistory.objects.filter(user=request.user).select_related('template')
return json_response([x.to_view() for x in records])

View File

@ -89,5 +89,4 @@ class ObjectView(View):
def _compute_progress(self, rds_cli, token, total, value, *args): def _compute_progress(self, rds_cli, token, total, value, *args):
percent = '%.1f' % (value / total * 100) percent = '%.1f' % (value / total * 100)
rds_cli.lpush(token, percent) rds_cli.publish(token, percent)
rds_cli.expire(token, 300)