You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

183 lines
6.3 KiB

import json
import time
from celery import signals
from django.core.cache import cache
from django.db import transaction
from django.db.models.signals import pre_save
from django.db.utils import ProgrammingError
from django.dispatch import receiver
from django.utils import translation, timezone
from django.utils.functional import LazyObject
from rest_framework.utils.encoders import JSONEncoder
from common.db.utils import close_old_connections, get_logger
from common.signals import django_ready
from common.utils.connection import RedisPubSub
from jumpserver.utils import get_current_request
from orgs.utils import get_current_org_id, set_current_org
from .ansible.runner import interface
from .celery import app
from .models import CeleryTaskExecution, CeleryTask, Job
logger = get_logger(__name__)
@receiver(pre_save, sender=Job)
def on_account_pre_create(sender, instance, **kwargs):
# 升级版本号
instance.version += 1
@receiver(signals.worker_ready)
def sync_registered_tasks(*args, **kwargs):
synced = cache.get('synced_registered_tasks', False)
if synced:
return
cache.set('synced_registered_tasks', True, 60)
with transaction.atomic():
try:
db_tasks = CeleryTask.objects.all()
celery_task_names = [key for key in app.tasks]
db_task_names = db_tasks.values_list('name', flat=True)
db_tasks.exclude(name__in=celery_task_names).delete()
not_in_db_tasks = set(celery_task_names) - set(db_task_names)
tasks_to_create = [CeleryTask(name=name) for name in not_in_db_tasks]
CeleryTask.objects.bulk_create(tasks_to_create)
except ProgrammingError:
pass
@receiver(django_ready)
def check_registered_tasks(*args, **kwargs):
attrs = ['verbose_name', 'activity_callback']
ignores = [
'users.tasks.check_user_expired_periodic', 'ops.tasks.clean_celery_periodic_tasks',
'terminal.tasks.delete_terminal_status_period', 'ops.tasks.check_server_performance_period',
'settings.tasks.ldap.import_ldap_user', 'users.tasks.check_password_expired',
'assets.tasks.nodes_amount.check_node_assets_amount_task', 'notifications.notifications.publish_task',
'perms.tasks.check_asset_permission_will_expired',
'ops.tasks.create_or_update_registered_periodic_tasks', 'perms.tasks.check_asset_permission_expired',
'settings.tasks.ldap.import_ldap_user_periodic', 'users.tasks.check_password_expired_periodic',
'common.utils.verify_code.send_sms_async', 'assets.tasks.nodes_amount.check_node_assets_amount_period_task',
'users.tasks.check_user_expired', 'orgs.tasks.refresh_org_cache_task',
'terminal.tasks.upload_session_replay_to_external_storage', 'terminal.tasks.clean_orphan_session',
'terminal.tasks.upload_session_replay_file_to_external_storage',
'audits.tasks.clean_audits_log_period', 'authentication.tasks.clean_django_sessions'
]
for name, task in app.tasks.items():
if name.startswith('celery.'):
continue
if name in ignores:
continue
for attr in attrs:
if not hasattr(task, attr):
# print('>>> Task {} has no attribute {}'.format(name, attr))
pass
@signals.before_task_publish.connect
def before_task_publish(body=None, **kwargs):
current_lang = translation.get_language()
current_org_id = get_current_org_id()
args, kwargs = body[:2]
kwargs['__current_lang'] = current_lang
kwargs['__current_org_id'] = current_org_id
@signals.task_prerun.connect
def on_celery_task_pre_run(task_id='', kwargs=None, **others):
count = 0
qs = CeleryTaskExecution.objects.filter(id=task_id)
while not qs.exists() and count < 5:
count += 1
time.sleep(1)
qs = CeleryTaskExecution.objects.filter(id=task_id)
# 更新状态
qs.update(state='RUNNING', date_start=timezone.now())
# 关闭之前的数据库连接
close_old_connections()
# 设置语言的一些上下文
lang = kwargs.pop('__current_lang', None)
org_id = kwargs.pop('__current_org_id', None)
if lang:
print('>> Set language to {}'.format(lang))
translation.activate(lang)
if org_id:
print('>> Set org to {}'.format(org_id))
set_current_org(org_id)
@signals.task_postrun.connect
def on_celery_task_post_run(task_id='', state='', **kwargs):
close_old_connections()
CeleryTaskExecution.objects.filter(id=task_id).update(
state=state, date_finished=timezone.now(), is_finished=True
)
@signals.after_task_publish.connect
def task_sent_handler(headers=None, body=None, **kwargs):
info = headers if 'task' in headers else body
task = info.get('task')
i = info.get('id')
if not i or not task:
logger.error("Not found task id or name: {}".format(info))
return
args, kwargs, __ = body
try:
args = json.loads(json.dumps(list(args), cls=JSONEncoder))
kwargs = json.loads(json.dumps(kwargs, cls=JSONEncoder))
except Exception as e:
logger.warn('Parse task args or kwargs error (Need handle): {}'.format(e))
args = []
kwargs = {}
# 不要保存__current_lang和__current_org_id参数,防止系统任务中点击再次执行报错
kwargs.pop('__current_lang', None)
kwargs.pop('__current_org_id', None)
data = {
'id': i,
'name': task,
'state': 'PENDING',
'is_finished': False,
'args': args,
'kwargs': kwargs
}
request = get_current_request()
if request and request.user.is_authenticated:
data['creator'] = request.user
with transaction.atomic():
try:
CeleryTaskExecution.objects.create(**data)
except Exception as e:
logger.error('Create celery task execution error: {}'.format(e))
CeleryTask.objects.filter(name=task).update(date_last_publish=timezone.now())
@receiver(django_ready)
def subscribe_stop_job_execution(sender, **kwargs):
logger.info("Start subscribe for stop job execution")
def on_stop(pid):
logger.info(f"Stop job execution {pid} start")
interface.kill_process(pid)
job_execution_stop_pub_sub.subscribe(on_stop)
class JobExecutionPubSub(LazyObject):
def _setup(self):
self._wrapped = RedisPubSub('fm.job_execution_stop')
job_execution_stop_pub_sub = JobExecutionPubSub()