perf: 优化 celery task context

pull/9644/head
ibuler 2023-02-20 15:01:00 +08:00 committed by Jiangjie.Bai
parent 7c3b98cf3b
commit ca6d71f442
1 changed files with 16 additions and 15 deletions

View File

@ -7,18 +7,15 @@ from django.db.models.signals import pre_save
from django.db.utils import ProgrammingError from django.db.utils import ProgrammingError
from django.dispatch import receiver from django.dispatch import receiver
from django.utils import translation, timezone from django.utils import translation, timezone
from django.utils.translation import gettext as _
from common.db.utils import close_old_connections, get_logger from common.db.utils import close_old_connections, get_logger
from common.signals import django_ready from common.signals import django_ready
from orgs.utils import get_current_org_id, set_current_org
from .celery import app from .celery import app
from .models import CeleryTaskExecution, CeleryTask, Job from .models import CeleryTaskExecution, CeleryTask, Job
logger = get_logger(__name__) logger = get_logger(__name__)
TASK_LANG_CACHE_KEY = 'TASK_LANG_{}'
TASK_LANG_CACHE_TTL = 1800
@receiver(pre_save, sender=Job) @receiver(pre_save, sender=Job)
def on_account_pre_create(sender, instance, **kwargs): def on_account_pre_create(sender, instance, **kwargs):
@ -58,32 +55,36 @@ def check_registered_tasks(*args, **kwargs):
@signals.before_task_publish.connect @signals.before_task_publish.connect
def before_task_publish(headers=None, **kwargs): def before_task_publish(body=None, **kwargs):
task_id = headers.get('id')
current_lang = translation.get_language() current_lang = translation.get_language()
key = TASK_LANG_CACHE_KEY.format(task_id) current_org_id = get_current_org_id()
cache.set(key, current_lang, 1800) args, kwargs = body[:2]
kwargs['__current_lang'] = current_lang
kwargs['__current_org_id'] = current_org_id
@signals.task_prerun.connect @signals.task_prerun.connect
def on_celery_task_pre_run(task_id='', **kwargs): def on_celery_task_pre_run(task_id='', kwargs=None, **others):
# 更新状态 # 更新状态
CeleryTaskExecution.objects.filter(id=task_id) \ CeleryTaskExecution.objects.filter(id=task_id) \
.update(state='RUNNING', date_start=timezone.now()) .update(state='RUNNING', date_start=timezone.now())
# 关闭之前的数据库连接 # 关闭之前的数据库连接
close_old_connections() close_old_connections()
# 保存 Lang context # 设置语言的一些上下文
key = TASK_LANG_CACHE_KEY.format(task_id) lang = kwargs.pop('__current_lang', None)
task_lang = cache.get(key) org_id = kwargs.pop('__current_org_id', None)
if task_lang: if lang:
translation.activate(task_lang) print('>> Set language to {}'.format(lang))
translation.activate(lang)
if org_id:
print('>> Set org to {}'.format(org_id))
set_current_org(org_id)
@signals.task_postrun.connect @signals.task_postrun.connect
def on_celery_task_post_run(task_id='', state='', **kwargs): def on_celery_task_post_run(task_id='', state='', **kwargs):
close_old_connections() close_old_connections()
print(_("Task") + ": {} {}".format(task_id, state))
CeleryTaskExecution.objects.filter(id=task_id).update( CeleryTaskExecution.objects.filter(id=task_id).update(
state=state, date_finished=timezone.now(), is_finished=True state=state, date_finished=timezone.now(), is_finished=True