mirror of https://github.com/jumpserver/jumpserver
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
183 lines
6.3 KiB
183 lines
6.3 KiB
import json |
|
import time |
|
|
|
from celery import signals |
|
from django.core.cache import cache |
|
from django.db import transaction |
|
from django.db.models.signals import pre_save |
|
from django.db.utils import ProgrammingError |
|
from django.dispatch import receiver |
|
from django.utils import translation, timezone |
|
from django.utils.functional import LazyObject |
|
from rest_framework.utils.encoders import JSONEncoder |
|
|
|
from common.db.utils import close_old_connections, get_logger |
|
from common.signals import django_ready |
|
from common.utils.connection import RedisPubSub |
|
from jumpserver.utils import get_current_request |
|
from orgs.utils import get_current_org_id, set_current_org |
|
from .ansible.runner import interface |
|
from .celery import app |
|
from .models import CeleryTaskExecution, CeleryTask, Job |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@receiver(pre_save, sender=Job) |
|
def on_account_pre_create(sender, instance, **kwargs): |
|
# 升级版本号 |
|
instance.version += 1 |
|
|
|
|
|
@receiver(signals.worker_ready) |
|
def sync_registered_tasks(*args, **kwargs): |
|
synced = cache.get('synced_registered_tasks', False) |
|
if synced: |
|
return |
|
cache.set('synced_registered_tasks', True, 60) |
|
with transaction.atomic(): |
|
try: |
|
db_tasks = CeleryTask.objects.all() |
|
celery_task_names = [key for key in app.tasks] |
|
db_task_names = db_tasks.values_list('name', flat=True) |
|
|
|
db_tasks.exclude(name__in=celery_task_names).delete() |
|
not_in_db_tasks = set(celery_task_names) - set(db_task_names) |
|
tasks_to_create = [CeleryTask(name=name) for name in not_in_db_tasks] |
|
CeleryTask.objects.bulk_create(tasks_to_create) |
|
except ProgrammingError: |
|
pass |
|
|
|
|
|
@receiver(django_ready) |
|
def check_registered_tasks(*args, **kwargs): |
|
attrs = ['verbose_name', 'activity_callback'] |
|
ignores = [ |
|
'users.tasks.check_user_expired_periodic', 'ops.tasks.clean_celery_periodic_tasks', |
|
'terminal.tasks.delete_terminal_status_period', 'ops.tasks.check_server_performance_period', |
|
'settings.tasks.ldap.import_ldap_user', 'users.tasks.check_password_expired', |
|
'assets.tasks.nodes_amount.check_node_assets_amount_task', 'notifications.notifications.publish_task', |
|
'perms.tasks.check_asset_permission_will_expired', |
|
'ops.tasks.create_or_update_registered_periodic_tasks', 'perms.tasks.check_asset_permission_expired', |
|
'settings.tasks.ldap.import_ldap_user_periodic', 'users.tasks.check_password_expired_periodic', |
|
'common.utils.verify_code.send_sms_async', 'assets.tasks.nodes_amount.check_node_assets_amount_period_task', |
|
'users.tasks.check_user_expired', 'orgs.tasks.refresh_org_cache_task', |
|
'terminal.tasks.upload_session_replay_to_external_storage', 'terminal.tasks.clean_orphan_session', |
|
'terminal.tasks.upload_session_replay_file_to_external_storage', |
|
'audits.tasks.clean_audits_log_period', 'authentication.tasks.clean_django_sessions' |
|
] |
|
|
|
for name, task in app.tasks.items(): |
|
if name.startswith('celery.'): |
|
continue |
|
if name in ignores: |
|
continue |
|
for attr in attrs: |
|
if not hasattr(task, attr): |
|
# print('>>> Task {} has no attribute {}'.format(name, attr)) |
|
pass |
|
|
|
|
|
@signals.before_task_publish.connect |
|
def before_task_publish(body=None, **kwargs): |
|
current_lang = translation.get_language() |
|
current_org_id = get_current_org_id() |
|
args, kwargs = body[:2] |
|
kwargs['__current_lang'] = current_lang |
|
kwargs['__current_org_id'] = current_org_id |
|
|
|
|
|
@signals.task_prerun.connect |
|
def on_celery_task_pre_run(task_id='', kwargs=None, **others): |
|
count = 0 |
|
qs = CeleryTaskExecution.objects.filter(id=task_id) |
|
while not qs.exists() and count < 5: |
|
count += 1 |
|
time.sleep(1) |
|
qs = CeleryTaskExecution.objects.filter(id=task_id) |
|
|
|
# 更新状态 |
|
qs.update(state='RUNNING', date_start=timezone.now()) |
|
|
|
# 关闭之前的数据库连接 |
|
close_old_connections() |
|
|
|
# 设置语言的一些上下文 |
|
lang = kwargs.pop('__current_lang', None) |
|
org_id = kwargs.pop('__current_org_id', None) |
|
if lang: |
|
print('>> Set language to {}'.format(lang)) |
|
translation.activate(lang) |
|
if org_id: |
|
print('>> Set org to {}'.format(org_id)) |
|
set_current_org(org_id) |
|
|
|
|
|
@signals.task_postrun.connect |
|
def on_celery_task_post_run(task_id='', state='', **kwargs): |
|
close_old_connections() |
|
|
|
CeleryTaskExecution.objects.filter(id=task_id).update( |
|
state=state, date_finished=timezone.now(), is_finished=True |
|
) |
|
|
|
|
|
@signals.after_task_publish.connect |
|
def task_sent_handler(headers=None, body=None, **kwargs): |
|
info = headers if 'task' in headers else body |
|
task = info.get('task') |
|
i = info.get('id') |
|
if not i or not task: |
|
logger.error("Not found task id or name: {}".format(info)) |
|
return |
|
|
|
args, kwargs, __ = body |
|
|
|
try: |
|
args = json.loads(json.dumps(list(args), cls=JSONEncoder)) |
|
kwargs = json.loads(json.dumps(kwargs, cls=JSONEncoder)) |
|
except Exception as e: |
|
logger.warn('Parse task args or kwargs error (Need handle): {}'.format(e)) |
|
args = [] |
|
kwargs = {} |
|
|
|
# 不要保存__current_lang和__current_org_id参数,防止系统任务中点击再次执行报错 |
|
kwargs.pop('__current_lang', None) |
|
kwargs.pop('__current_org_id', None) |
|
data = { |
|
'id': i, |
|
'name': task, |
|
'state': 'PENDING', |
|
'is_finished': False, |
|
'args': args, |
|
'kwargs': kwargs |
|
} |
|
request = get_current_request() |
|
if request and request.user.is_authenticated: |
|
data['creator'] = request.user |
|
|
|
with transaction.atomic(): |
|
try: |
|
CeleryTaskExecution.objects.create(**data) |
|
except Exception as e: |
|
logger.error('Create celery task execution error: {}'.format(e)) |
|
CeleryTask.objects.filter(name=task).update(date_last_publish=timezone.now()) |
|
|
|
|
|
@receiver(django_ready) |
|
def subscribe_stop_job_execution(sender, **kwargs): |
|
logger.info("Start subscribe for stop job execution") |
|
|
|
def on_stop(pid): |
|
logger.info(f"Stop job execution {pid} start") |
|
interface.kill_process(pid) |
|
|
|
job_execution_stop_pub_sub.subscribe(on_stop) |
|
|
|
|
|
class JobExecutionPubSub(LazyObject): |
|
def _setup(self): |
|
self._wrapped = RedisPubSub('fm.job_execution_stop') |
|
|
|
|
|
job_execution_stop_pub_sub = JobExecutionPubSub()
|
|
|