jumpserver/apps/ops/signal_handlers.py

100 lines
3.0 KiB
Python

import ast
from celery import signals
from django.db import transaction
from django.core.cache import cache
from django.dispatch import receiver
from django.db.utils import ProgrammingError
from django.utils import translation, timezone
from django.utils.translation import gettext as _
from common.signals import django_ready
from common.db.utils import close_old_connections, get_logger
from .celery import app
from .models import CeleryTaskExecution, CeleryTask
logger = get_logger(__name__)
TASK_LANG_CACHE_KEY = 'TASK_LANG_{}'
TASK_LANG_CACHE_TTL = 1800
@receiver(django_ready)
def sync_registered_tasks(*args, **kwargs):
with transaction.atomic():
try:
db_tasks = CeleryTask.objects.all()
celery_task_names = [key for key in app.tasks]
db_task_names = db_tasks.values_list('name', flat=True)
db_tasks.exclude(name__in=celery_task_names).delete()
not_in_db_tasks = set(celery_task_names) - set(db_task_names)
tasks_to_create = [CeleryTask(name=name) for name in not_in_db_tasks]
CeleryTask.objects.bulk_create(tasks_to_create)
except ProgrammingError:
pass
@signals.before_task_publish.connect
def before_task_publish(headers=None, **kwargs):
task_id = headers.get('id')
current_lang = translation.get_language()
key = TASK_LANG_CACHE_KEY.format(task_id)
cache.set(key, current_lang, 1800)
@signals.task_prerun.connect
def on_celery_task_pre_run(task_id='', **kwargs):
# 更新状态
CeleryTaskExecution.objects.filter(id=task_id) \
.update(state='RUNNING', date_start=timezone.now())
# 关闭之前的数据库连接
close_old_connections()
# 保存 Lang context
key = TASK_LANG_CACHE_KEY.format(task_id)
task_lang = cache.get(key)
if task_lang:
translation.activate(task_lang)
@signals.task_postrun.connect
def on_celery_task_post_run(task_id='', state='', **kwargs):
close_old_connections()
print(_("Task") + ": {} {}".format(task_id, state))
CeleryTaskExecution.objects.filter(id=task_id).update(
state=state, date_finished=timezone.now(), is_finished=True
)
@signals.after_task_publish.connect
def task_sent_handler(headers=None, body=None, **kwargs):
info = headers if 'task' in headers else body
task = info.get('task')
i = info.get('id')
if not i or not task:
logger.error("Not found task id or name: {}".format(info))
return
args = info.get('argsrepr', '()')
kwargs = info.get('kwargsrepr', '{}')
try:
args = list(ast.literal_eval(args))
kwargs = ast.literal_eval(kwargs)
except (ValueError, SyntaxError):
args = []
kwargs = {}
data = {
'id': i,
'name': task,
'state': 'PENDING',
'is_finished': False,
'args': args,
'kwargs': kwargs
}
CeleryTaskExecution.objects.create(**data)
CeleryTask.objects.filter(name=task).update(last_published_time=timezone.now())