mirror of https://github.com/jumpserver/jumpserver
				
				
				
			perf: 账号模版信息同步到所关联的账号
							parent
							
								
									3ef8e9603a
								
							
						
					
					
						commit
						ca7d164034
					
				| 
						 | 
					@ -1,10 +1,12 @@
 | 
				
			||||||
from django_filters import rest_framework as drf_filters
 | 
					from django_filters import rest_framework as drf_filters
 | 
				
			||||||
 | 
					from rest_framework import status
 | 
				
			||||||
from rest_framework.decorators import action
 | 
					from rest_framework.decorators import action
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from accounts import serializers
 | 
					from accounts import serializers
 | 
				
			||||||
from accounts.models import AccountTemplate
 | 
					 | 
				
			||||||
from accounts.mixins import AccountRecordViewLogMixin
 | 
					from accounts.mixins import AccountRecordViewLogMixin
 | 
				
			||||||
 | 
					from accounts.models import AccountTemplate
 | 
				
			||||||
 | 
					from accounts.tasks import template_sync_related_accounts
 | 
				
			||||||
from assets.const import Protocol
 | 
					from assets.const import Protocol
 | 
				
			||||||
from common.drf.filters import BaseFilterSet
 | 
					from common.drf.filters import BaseFilterSet
 | 
				
			||||||
from common.permissions import UserConfirmation, ConfirmType
 | 
					from common.permissions import UserConfirmation, ConfirmType
 | 
				
			||||||
| 
						 | 
					@ -44,6 +46,7 @@ class AccountTemplateViewSet(OrgBulkModelViewSet):
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    rbac_perms = {
 | 
					    rbac_perms = {
 | 
				
			||||||
        'su_from_account_templates': 'accounts.view_accounttemplate',
 | 
					        'su_from_account_templates': 'accounts.view_accounttemplate',
 | 
				
			||||||
 | 
					        'sync_related_accounts': 'accounts.change_accounttemplate',
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @action(methods=['get'], detail=False, url_path='su-from-account-templates')
 | 
					    @action(methods=['get'], detail=False, url_path='su-from-account-templates')
 | 
				
			||||||
| 
						 | 
					@ -54,6 +57,13 @@ class AccountTemplateViewSet(OrgBulkModelViewSet):
 | 
				
			||||||
        serializer = self.get_serializer(templates, many=True)
 | 
					        serializer = self.get_serializer(templates, many=True)
 | 
				
			||||||
        return Response(data=serializer.data)
 | 
					        return Response(data=serializer.data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @action(methods=['patch'], detail=True, url_path='sync-related-accounts')
 | 
				
			||||||
 | 
					    def sync_related_accounts(self, request, *args, **kwargs):
 | 
				
			||||||
 | 
					        instance = self.get_object()
 | 
				
			||||||
 | 
					        user_id = str(request.user.id)
 | 
				
			||||||
 | 
					        task = template_sync_related_accounts.delay(str(instance.id), user_id)
 | 
				
			||||||
 | 
					        return Response({'task': task.id}, status=status.HTTP_200_OK)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AccountTemplateSecretsViewSet(AccountRecordViewLogMixin, AccountTemplateViewSet):
 | 
					class AccountTemplateSecretsViewSet(AccountRecordViewLogMixin, AccountTemplateViewSet):
 | 
				
			||||||
    serializer_classes = {
 | 
					    serializer_classes = {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,8 +49,7 @@ class AccountTemplate(BaseAccount, SecretWithRandomMixin):
 | 
				
			||||||
            ).first()
 | 
					            ).first()
 | 
				
			||||||
            return account
 | 
					            return account
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    def bulk_update_accounts(self, accounts):
 | 
				
			||||||
    def bulk_update_accounts(accounts, data):
 | 
					 | 
				
			||||||
        history_model = Account.history.model
 | 
					        history_model = Account.history.model
 | 
				
			||||||
        account_ids = accounts.values_list('id', flat=True)
 | 
					        account_ids = accounts.values_list('id', flat=True)
 | 
				
			||||||
        history_accounts = history_model.objects.filter(id__in=account_ids)
 | 
					        history_accounts = history_model.objects.filter(id__in=account_ids)
 | 
				
			||||||
| 
						 | 
					@ -63,8 +62,7 @@ class AccountTemplate(BaseAccount, SecretWithRandomMixin):
 | 
				
			||||||
        for account in accounts:
 | 
					        for account in accounts:
 | 
				
			||||||
            account_id = str(account.id)
 | 
					            account_id = str(account.id)
 | 
				
			||||||
            account.version = account_id_count_map.get(account_id) + 1
 | 
					            account.version = account_id_count_map.get(account_id) + 1
 | 
				
			||||||
            for k, v in data.items():
 | 
					            account.secret = self.get_secret()
 | 
				
			||||||
                setattr(account, k, v)
 | 
					 | 
				
			||||||
        Account.objects.bulk_update(accounts, ['version', 'secret'])
 | 
					        Account.objects.bulk_update(accounts, ['version', 'secret'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
| 
						 | 
					@ -86,7 +84,5 @@ class AccountTemplate(BaseAccount, SecretWithRandomMixin):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def bulk_sync_account_secret(self, accounts, user_id):
 | 
					    def bulk_sync_account_secret(self, accounts, user_id):
 | 
				
			||||||
        """ 批量同步账号密码 """
 | 
					        """ 批量同步账号密码 """
 | 
				
			||||||
        if not accounts:
 | 
					        self.bulk_update_accounts(accounts)
 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        self.bulk_update_accounts(accounts, {'secret': self.secret})
 | 
					 | 
				
			||||||
        self.bulk_create_history_accounts(accounts, user_id)
 | 
					        self.bulk_create_history_accounts(accounts, user_id)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
from rest_framework import serializers
 | 
					from rest_framework import serializers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from accounts.const import SecretStrategy, SecretType
 | 
					from accounts.const import SecretStrategy, SecretType
 | 
				
			||||||
from accounts.models import AccountTemplate, Account
 | 
					from accounts.models import AccountTemplate
 | 
				
			||||||
from accounts.utils import SecretGenerator
 | 
					from accounts.utils import SecretGenerator
 | 
				
			||||||
from common.serializers import SecretReadableMixin
 | 
					from common.serializers import SecretReadableMixin
 | 
				
			||||||
from common.serializers.fields import ObjectRelatedField
 | 
					from common.serializers.fields import ObjectRelatedField
 | 
				
			||||||
| 
						 | 
					@ -18,9 +18,6 @@ class PasswordRulesSerializer(serializers.Serializer):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AccountTemplateSerializer(BaseAccountSerializer):
 | 
					class AccountTemplateSerializer(BaseAccountSerializer):
 | 
				
			||||||
    is_sync_account = serializers.BooleanField(default=False, write_only=True)
 | 
					 | 
				
			||||||
    _is_sync_account = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    password_rules = PasswordRulesSerializer(required=False, label=_('Password rules'))
 | 
					    password_rules = PasswordRulesSerializer(required=False, label=_('Password rules'))
 | 
				
			||||||
    su_from = ObjectRelatedField(
 | 
					    su_from = ObjectRelatedField(
 | 
				
			||||||
        required=False, queryset=AccountTemplate.objects, allow_null=True,
 | 
					        required=False, queryset=AccountTemplate.objects, allow_null=True,
 | 
				
			||||||
| 
						 | 
					@ -32,7 +29,7 @@ class AccountTemplateSerializer(BaseAccountSerializer):
 | 
				
			||||||
        fields = BaseAccountSerializer.Meta.fields + [
 | 
					        fields = BaseAccountSerializer.Meta.fields + [
 | 
				
			||||||
            'secret_strategy', 'password_rules',
 | 
					            'secret_strategy', 'password_rules',
 | 
				
			||||||
            'auto_push', 'push_params', 'platforms',
 | 
					            'auto_push', 'push_params', 'platforms',
 | 
				
			||||||
            'is_sync_account', 'su_from'
 | 
					            'su_from'
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
        extra_kwargs = {
 | 
					        extra_kwargs = {
 | 
				
			||||||
            'secret_strategy': {'help_text': _('Secret generation strategy for account creation')},
 | 
					            'secret_strategy': {'help_text': _('Secret generation strategy for account creation')},
 | 
				
			||||||
| 
						 | 
					@ -46,17 +43,6 @@ class AccountTemplateSerializer(BaseAccountSerializer):
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sync_accounts_secret(self, instance, diff):
 | 
					 | 
				
			||||||
        if not self._is_sync_account or 'secret' not in diff:
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
        query_data = {
 | 
					 | 
				
			||||||
            'source_id': instance.id,
 | 
					 | 
				
			||||||
            'username': instance.username,
 | 
					 | 
				
			||||||
            'secret_type': instance.secret_type
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        accounts = Account.objects.filter(**query_data)
 | 
					 | 
				
			||||||
        instance.bulk_sync_account_secret(accounts, self.context['request'].user.id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def generate_secret(attrs):
 | 
					    def generate_secret(attrs):
 | 
				
			||||||
        secret_type = attrs.get('secret_type', SecretType.PASSWORD)
 | 
					        secret_type = attrs.get('secret_type', SecretType.PASSWORD)
 | 
				
			||||||
| 
						 | 
					@ -68,23 +54,10 @@ class AccountTemplateSerializer(BaseAccountSerializer):
 | 
				
			||||||
        attrs['secret'] = generator.get_secret()
 | 
					        attrs['secret'] = generator.get_secret()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate(self, attrs):
 | 
					    def validate(self, attrs):
 | 
				
			||||||
        self._is_sync_account = attrs.pop('is_sync_account', None)
 | 
					 | 
				
			||||||
        attrs = super().validate(attrs)
 | 
					        attrs = super().validate(attrs)
 | 
				
			||||||
        self.generate_secret(attrs)
 | 
					        self.generate_secret(attrs)
 | 
				
			||||||
        return attrs
 | 
					        return attrs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(self, instance, validated_data):
 | 
					 | 
				
			||||||
        diff = {
 | 
					 | 
				
			||||||
            k: v for k, v in validated_data.items()
 | 
					 | 
				
			||||||
            if getattr(instance, k, None) != v
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        instance = super().update(instance, validated_data)
 | 
					 | 
				
			||||||
        if {'username', 'secret_type'} & set(diff.keys()):
 | 
					 | 
				
			||||||
            Account.objects.filter(source_id=instance.id).update(source_id=None)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            self.sync_accounts_secret(instance, diff)
 | 
					 | 
				
			||||||
        return instance
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AccountTemplateSecretSerializer(SecretReadableMixin, AccountTemplateSerializer):
 | 
					class AccountTemplateSecretSerializer(SecretReadableMixin, AccountTemplateSerializer):
 | 
				
			||||||
    class Meta(AccountTemplateSerializer.Meta):
 | 
					    class Meta(AccountTemplateSerializer.Meta):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,6 @@
 | 
				
			||||||
from .backup_account import *
 | 
					 | 
				
			||||||
from .automation import *
 | 
					from .automation import *
 | 
				
			||||||
from .push_account import *
 | 
					from .backup_account import *
 | 
				
			||||||
from .verify_account import *
 | 
					 | 
				
			||||||
from .gather_accounts import *
 | 
					from .gather_accounts import *
 | 
				
			||||||
 | 
					from .push_account import *
 | 
				
			||||||
 | 
					from .template import *
 | 
				
			||||||
 | 
					from .verify_account import *
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,60 @@
 | 
				
			||||||
 | 
					from datetime import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from celery import shared_task
 | 
				
			||||||
 | 
					from django.shortcuts import get_object_or_404
 | 
				
			||||||
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from orgs.utils import tmp_to_root_org, tmp_to_org
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@shared_task(
 | 
				
			||||||
 | 
					    verbose_name=_('Template sync info to related accounts'),
 | 
				
			||||||
 | 
					    activity_callback=lambda self, template_id, *args, **kwargs: (template_id, None)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def template_sync_related_accounts(template_id, user_id=None):
 | 
				
			||||||
 | 
					    from accounts.models import Account, AccountTemplate
 | 
				
			||||||
 | 
					    with tmp_to_root_org():
 | 
				
			||||||
 | 
					        template = get_object_or_404(AccountTemplate, id=template_id)
 | 
				
			||||||
 | 
					    org_id = template.org_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with tmp_to_org(org_id):
 | 
				
			||||||
 | 
					        accounts = Account.objects.filter(source_id=template_id)
 | 
				
			||||||
 | 
					    if not accounts:
 | 
				
			||||||
 | 
					        print('\033[35m>>> 没有需要同步的账号, 结束任务')
 | 
				
			||||||
 | 
					        print('\033[0m')
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    failed, succeeded = 0, 0
 | 
				
			||||||
 | 
					    succeeded_account_ids = []
 | 
				
			||||||
 | 
					    name = template.name
 | 
				
			||||||
 | 
					    username = template.username
 | 
				
			||||||
 | 
					    secret_type = template.secret_type
 | 
				
			||||||
 | 
					    print(f'\033[32m>>> 开始同步模版名称、用户名、密钥类型到相关联的账号 ({datetime.now().strftime("%Y-%m-%d %H:%M:%S")})')
 | 
				
			||||||
 | 
					    with tmp_to_org(org_id):
 | 
				
			||||||
 | 
					        for account in accounts:
 | 
				
			||||||
 | 
					            account.name = name
 | 
				
			||||||
 | 
					            account.username = username
 | 
				
			||||||
 | 
					            account.secret_type = secret_type
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                account.save(update_fields=['name', 'username', 'secret_type'])
 | 
				
			||||||
 | 
					                succeeded += 1
 | 
				
			||||||
 | 
					                succeeded_account_ids.append(account.id)
 | 
				
			||||||
 | 
					            except Exception as e:
 | 
				
			||||||
 | 
					                account.source_id = None
 | 
				
			||||||
 | 
					                account.save(update_fields=['source_id'])
 | 
				
			||||||
 | 
					                print(f'\033[31m- 同步失败: [{account}] 原因: [{e}]')
 | 
				
			||||||
 | 
					                failed += 1
 | 
				
			||||||
 | 
					        accounts = Account.objects.filter(id__in=succeeded_account_ids)
 | 
				
			||||||
 | 
					        if accounts:
 | 
				
			||||||
 | 
					            print(f'\033[33m>>> 批量更新账号密文 ({datetime.now().strftime("%Y-%m-%d %H:%M:%S")})')
 | 
				
			||||||
 | 
					            template.bulk_sync_account_secret(accounts, user_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    total = succeeded + failed
 | 
				
			||||||
 | 
					    print(
 | 
				
			||||||
 | 
					        f'\033[33m>>> 同步完成:, '
 | 
				
			||||||
 | 
					        f'共计: {total}, '
 | 
				
			||||||
 | 
					        f'成功: {succeeded}, '
 | 
				
			||||||
 | 
					        f'失败: {failed}, '
 | 
				
			||||||
 | 
					        f'({datetime.now().strftime("%Y-%m-%d %H:%M:%S")}) '
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    print('\033[0m')
 | 
				
			||||||
		Loading…
	
		Reference in New Issue