From c0b301d52b6dcb739550a4a8a7d0b39d1ff0c291 Mon Sep 17 00:00:00 2001 From: wangruidong <940853815@qq.com> Date: Wed, 8 Jan 2025 16:33:19 +0800 Subject: [PATCH] fix: ldap ha periodic task did not execute as expected --- apps/settings/serializers/auth/ldap.py | 14 ++++++++------ apps/settings/serializers/auth/ldap_ha.py | 15 +++++++++------ apps/settings/serializers/auth/mixin.py | 21 +++++++++++++++++++++ apps/settings/tasks/ldap.py | 11 +++++++---- 4 files changed, 45 insertions(+), 16 deletions(-) create mode 100644 apps/settings/serializers/auth/mixin.py diff --git a/apps/settings/serializers/auth/ldap.py b/apps/settings/serializers/auth/ldap.py index 7bfd78e55..767bbdeea 100644 --- a/apps/settings/serializers/auth/ldap.py +++ b/apps/settings/serializers/auth/ldap.py @@ -3,6 +3,7 @@ from rest_framework import serializers from common.serializers.fields import EncryptedField from .base import OrgListField +from .mixin import LDAPSerializerMixin __all__ = [ 'LDAPTestConfigSerializer', 'LDAPUserSerializer', 'LDAPTestLoginSerializer', @@ -36,7 +37,7 @@ class LDAPUserSerializer(serializers.Serializer): status = serializers.JSONField(read_only=True) -class LDAPSettingSerializer(serializers.Serializer): +class LDAPSettingSerializer(LDAPSerializerMixin, serializers.Serializer): # encrypt_fields 现在使用 write_only 来判断了 PREFIX_TITLE = _('LDAP') @@ -103,10 +104,11 @@ class LDAPSettingSerializer(serializers.Serializer): AUTH_LDAP = serializers.BooleanField(required=False, label=_('LDAP')) AUTH_LDAP_SYNC_ORG_IDS = OrgListField() - def post_save(self): - keys = ['AUTH_LDAP_SYNC_IS_PERIODIC', 'AUTH_LDAP_SYNC_INTERVAL', 'AUTH_LDAP_SYNC_CRONTAB'] - kwargs = {k: self.validated_data[k] for k in keys if k in self.validated_data} - if not kwargs: - return + periodic_key = 'AUTH_LDAP_SYNC_IS_PERIODIC' + interval_key = 'AUTH_LDAP_SYNC_INTERVAL' + crontab_key = 'AUTH_LDAP_SYNC_CRONTAB' + + @staticmethod + def import_task_function(**kwargs): from settings.tasks import import_ldap_user_periodic import_ldap_user_periodic(**kwargs) diff --git a/apps/settings/serializers/auth/ldap_ha.py b/apps/settings/serializers/auth/ldap_ha.py index 6b78682b2..95cf3c5b9 100644 --- a/apps/settings/serializers/auth/ldap_ha.py +++ b/apps/settings/serializers/auth/ldap_ha.py @@ -3,6 +3,8 @@ from rest_framework import serializers from common.serializers.fields import EncryptedField from .base import OrgListField +from .mixin import LDAPSerializerMixin +from ops.mixin import PeriodTaskSerializerMixin __all__ = ['LDAPHATestConfigSerializer', 'LDAPHASettingSerializer'] @@ -18,7 +20,7 @@ class LDAPHATestConfigSerializer(serializers.Serializer): AUTH_LDAP_HA = serializers.BooleanField(required=False) -class LDAPHASettingSerializer(serializers.Serializer): +class LDAPHASettingSerializer(LDAPSerializerMixin, serializers.Serializer): # encrypt_fields 现在使用 write_only 来判断了 PREFIX_TITLE = _('LDAP HA') @@ -85,10 +87,11 @@ class LDAPHASettingSerializer(serializers.Serializer): AUTH_LDAP_HA = serializers.BooleanField(required=False, label=_('LDAP HA')) AUTH_LDAP_HA_SYNC_ORG_IDS = OrgListField() - def post_save(self): - keys = ['AUTH_LDAP_HA_SYNC_IS_PERIODIC', 'AUTH_LDAP_HA_SYNC_INTERVAL', 'AUTH_LDAP_HA_SYNC_CRONTAB'] - kwargs = {k: self.validated_data[k] for k in keys if k in self.validated_data} - if not kwargs: - return + periodic_key = 'AUTH_LDAP_HA_SYNC_IS_PERIODIC' + interval_key = 'AUTH_LDAP_HA_SYNC_INTERVAL' + crontab_key = 'AUTH_LDAP_HA_SYNC_CRONTAB' + + @staticmethod + def import_task_function(**kwargs): from settings.tasks import import_ldap_ha_user_periodic import_ldap_ha_user_periodic(**kwargs) diff --git a/apps/settings/serializers/auth/mixin.py b/apps/settings/serializers/auth/mixin.py new file mode 100644 index 000000000..2eed91289 --- /dev/null +++ b/apps/settings/serializers/auth/mixin.py @@ -0,0 +1,21 @@ +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers +from ops.mixin import PeriodTaskSerializerMixin + + +class LDAPSerializerMixin: + def validate(self, attrs): + is_periodic = attrs.get(self.periodic_key) + crontab = attrs.get(self.crontab_key) + interval = attrs.get(self.interval_key) + if is_periodic and not any([crontab, interval]): + msg = _("Require interval or crontab setting") + raise serializers.ValidationError(msg) + return super().validate(attrs) + + def post_save(self): + keys = [self.periodic_key, self.interval_key, self.crontab_key] + kwargs = {k: self.validated_data[k] for k in keys if k in self.validated_data} + if not kwargs: + return + self.import_task_function(**kwargs) diff --git a/apps/settings/tasks/ldap.py b/apps/settings/tasks/ldap.py index 6694cf54c..b90394814 100644 --- a/apps/settings/tasks/ldap.py +++ b/apps/settings/tasks/ldap.py @@ -9,7 +9,7 @@ from common.utils import get_logger from common.utils.timezone import local_now_display from ops.celery.decorator import after_app_ready_start from ops.celery.utils import ( - create_or_update_celery_periodic_tasks + create_or_update_celery_periodic_tasks, disable_celery_periodic_task ) from orgs.models import Organization from settings.notifications import LDAPImportMessage @@ -90,9 +90,12 @@ def import_ldap_ha_user(): def register_periodic_task(task_name, task_func, interval_key, enabled_key, crontab_key, **kwargs): - interval = kwargs.get(interval_key, settings.AUTH_LDAP_SYNC_INTERVAL) - enabled = kwargs.get(enabled_key, settings.AUTH_LDAP_SYNC_IS_PERIODIC) - crontab = kwargs.get(crontab_key, settings.AUTH_LDAP_SYNC_CRONTAB) + interval = kwargs.get(interval_key, getattr(settings, interval_key)) + enabled = kwargs.get(enabled_key, getattr(settings, enabled_key)) + crontab = kwargs.get(crontab_key, getattr(settings, crontab_key)) + + if not enabled: + disable_celery_periodic_task(task_name) if isinstance(interval, int): interval = interval * 3600