import json import time from celery import signals from django.core.cache import cache from django.db import transaction from django.db.models.signals import pre_save from django.db.utils import ProgrammingError from django.dispatch import receiver from django.utils import translation, timezone from django.utils.functional import LazyObject from rest_framework.utils.encoders import JSONEncoder from common.db.utils import close_old_connections, get_logger from common.signals import django_ready from common.utils.connection import RedisPubSub from jumpserver.utils import get_current_request from orgs.utils import get_current_org_id, set_current_org from .ansible.runner import interface from .celery import app from .models import CeleryTaskExecution, CeleryTask, Job logger = get_logger(__name__) @receiver(pre_save, sender=Job) def on_account_pre_create(sender, instance, **kwargs): # 升级版本号 instance.version += 1 @receiver(signals.worker_ready) def sync_registered_tasks(*args, **kwargs): synced = cache.get('synced_registered_tasks', False) if synced: return cache.set('synced_registered_tasks', True, 60) with transaction.atomic(): try: db_tasks = CeleryTask.objects.all() celery_task_names = [key for key in app.tasks] db_task_names = db_tasks.values_list('name', flat=True) db_tasks.exclude(name__in=celery_task_names).delete() not_in_db_tasks = set(celery_task_names) - set(db_task_names) tasks_to_create = [CeleryTask(name=name) for name in not_in_db_tasks] CeleryTask.objects.bulk_create(tasks_to_create) except ProgrammingError: pass @receiver(django_ready) def check_registered_tasks(*args, **kwargs): attrs = ['verbose_name', 'activity_callback'] ignores = [ 'users.tasks.check_user_expired_periodic', 'ops.tasks.clean_celery_periodic_tasks', 'terminal.tasks.delete_terminal_status_period', 'ops.tasks.check_server_performance_period', 'settings.tasks.ldap.import_ldap_user', 'users.tasks.check_password_expired', 'assets.tasks.nodes_amount.check_node_assets_amount_task', 'notifications.notifications.publish_task', 'perms.tasks.check_asset_permission_will_expired', 'ops.tasks.create_or_update_registered_periodic_tasks', 'perms.tasks.check_asset_permission_expired', 'settings.tasks.ldap.import_ldap_user_periodic', 'users.tasks.check_password_expired_periodic', 'common.utils.verify_code.send_sms_async', 'assets.tasks.nodes_amount.check_node_assets_amount_period_task', 'users.tasks.check_user_expired', 'orgs.tasks.refresh_org_cache_task', 'terminal.tasks.upload_session_replay_to_external_storage', 'terminal.tasks.clean_orphan_session', 'terminal.tasks.upload_session_replay_file_to_external_storage', 'audits.tasks.clean_audits_log_period', 'authentication.tasks.clean_django_sessions' ] for name, task in app.tasks.items(): if name.startswith('celery.'): continue if name in ignores: continue for attr in attrs: if not hasattr(task, attr): # print('>>> Task {} has no attribute {}'.format(name, attr)) pass @signals.before_task_publish.connect def before_task_publish(body=None, **kwargs): current_lang = translation.get_language() current_org_id = get_current_org_id() args, kwargs = body[:2] kwargs['__current_lang'] = current_lang kwargs['__current_org_id'] = current_org_id @signals.task_prerun.connect def on_celery_task_pre_run(task_id='', kwargs=None, **others): count = 0 qs = CeleryTaskExecution.objects.filter(id=task_id) while not qs.exists() and count < 5: count += 1 time.sleep(1) qs = CeleryTaskExecution.objects.filter(id=task_id) # 更新状态 qs.update(state='RUNNING', date_start=timezone.now()) # 关闭之前的数据库连接 close_old_connections() # 设置语言的一些上下文 lang = kwargs.pop('__current_lang', None) org_id = kwargs.pop('__current_org_id', None) if lang: print('>> Set language to {}'.format(lang)) translation.activate(lang) if org_id: print('>> Set org to {}'.format(org_id)) set_current_org(org_id) @signals.task_postrun.connect def on_celery_task_post_run(task_id='', state='', **kwargs): close_old_connections() CeleryTaskExecution.objects.filter(id=task_id).update( state=state, date_finished=timezone.now(), is_finished=True ) @signals.after_task_publish.connect def task_sent_handler(headers=None, body=None, **kwargs): info = headers if 'task' in headers else body task = info.get('task') i = info.get('id') if not i or not task: logger.error("Not found task id or name: {}".format(info)) return args, kwargs, __ = body try: args = json.loads(json.dumps(list(args), cls=JSONEncoder)) kwargs = json.loads(json.dumps(kwargs, cls=JSONEncoder)) except Exception as e: logger.warn('Parse task args or kwargs error (Need handle): {}'.format(e)) args = [] kwargs = {} # 不要保存__current_lang和__current_org_id参数,防止系统任务中点击再次执行报错 kwargs.pop('__current_lang', None) kwargs.pop('__current_org_id', None) data = { 'id': i, 'name': task, 'state': 'PENDING', 'is_finished': False, 'args': args, 'kwargs': kwargs } request = get_current_request() if request and request.user.is_authenticated: data['creator'] = request.user with transaction.atomic(): try: task_execution = CeleryTaskExecution.objects.create(**data) task_execution.set_creator_if_need() except Exception as e: logger.error('Create celery task execution error: {}'.format(e)) CeleryTask.objects.filter(name=task).update(date_last_publish=timezone.now()) @receiver(django_ready) def subscribe_stop_job_execution(sender, **kwargs): logger.info("Start subscribe for stop job execution") def on_stop(pid): logger.info(f"Stop job execution {pid} start") interface.kill_process(pid) job_execution_stop_pub_sub.subscribe(on_stop) class JobExecutionPubSub(LazyObject): def _setup(self): self._wrapped = RedisPubSub('fm.job_execution_stop') job_execution_stop_pub_sub = JobExecutionPubSub()