perf: 账号模版更新 (#10184)

Co-authored-by: feng <1304903146@qq.com>
pull/10189/head
fit2bot 2023-04-12 17:59:13 +08:00 committed by GitHub
parent b0365838fb
commit 30b89e5cc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 36 deletions

View File

@ -18,6 +18,7 @@ class AliasAccount(TextChoices):
class Source(TextChoices):
LOCAL = 'local', _('Local')
COLLECTED = 'collected', _('Collected')
TEMPLATE = 'template', _('Template')
class AccountInvalidPolicy(TextChoices):

View File

@ -1,5 +1,4 @@
import uuid
from collections import defaultdict
from django.db import IntegrityError
from django.db.models import Q
@ -10,7 +9,7 @@ from rest_framework.validators import UniqueTogetherValidator
from accounts.const import SecretType, Source, AccountInvalidPolicy
from accounts.models import Account, AccountTemplate
from accounts.tasks import push_accounts_to_assets_task
from assets.const import Category, AllTypes, Protocol
from assets.const import Category, AllTypes
from assets.models import Asset
from common.serializers import SecretReadableMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
@ -80,12 +79,12 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
raise serializers.ValidationError({'template': 'Template not found'})
# Set initial data from template
ignore_fields = ['id', 'name', 'date_created', 'date_updated', 'org_id']
ignore_fields = ['id', 'date_created', 'date_updated', 'org_id']
field_names = [
field.name for field in template._meta.fields
if field.name not in ignore_fields
]
attrs = {'source': 'template', 'source_id': template.id}
attrs = {'source': Source.TEMPLATE, 'source_id': str(template.id)}
for name in field_names:
value = getattr(template, name, None)
if value is None:
@ -135,6 +134,16 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
else:
raise serializers.ValidationError('Account already exists')
def validate(self, attrs):
attrs = super().validate(attrs)
if self.instance:
return attrs
if 'source' in self.initial_data:
attrs['source'] = self.initial_data['source']
attrs['source_id'] = self.initial_data['source_id']
return attrs
def create(self, validated_data):
push_now = validated_data.pop('push_now', None)
instance, stat = self.do_create(validated_data)
@ -146,6 +155,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
validated_data.pop('username', None)
validated_data.pop('on_invalid', None)
push_now = validated_data.pop('push_now', None)
validated_data['source_id'] = None
instance = super().update(instance, validated_data)
self.push_account_if_need(instance, push_now, 'updated')
return instance
@ -233,25 +243,6 @@ class AssetAccountBulkSerializer(AccountCreateUpdateSerializerMixin, serializers
initial_data = self.initial_data
self.from_template_if_need(initial_data)
@staticmethod
def _get_valid_secret_type_assets(assets, secret_type):
if isinstance(assets, list):
asset_ids = [a.id for a in assets]
assets = Asset.objects.filter(id__in=asset_ids)
asset_protocol = assets.prefetch_related('protocols').values_list('id', 'protocols__name')
protocol_secret_types_map = Protocol.protocol_secret_types()
asset_secret_types_mapp = defaultdict(set)
for asset_id, protocol in asset_protocol:
secret_types = set(protocol_secret_types_map.get(protocol, []))
asset_secret_types_mapp[asset_id].update(secret_types)
return [
asset for asset in assets
if secret_type in asset_secret_types_mapp.get(asset.id, [])
]
@staticmethod
def get_filter_lookup(vd):
return {
@ -314,7 +305,8 @@ class AssetAccountBulkSerializer(AccountCreateUpdateSerializerMixin, serializers
vd['name'] = vd.get('username')
create_handler = self.get_create_handler(on_invalid)
secret_type_supports = self._get_valid_secret_type_assets(assets, secret_type)
asset_ids = [asset.id for asset in assets]
secret_type_supports = Asset.get_secret_type_assets(asset_ids, secret_type)
_results = {}
for asset in assets:

View File

@ -1,4 +1,5 @@
from accounts.models import AccountTemplate
from accounts.models import AccountTemplate, Account
from assets.models import Asset
from common.serializers import SecretReadableMixin
from .base import BaseAccountSerializer
@ -7,17 +8,47 @@ class AccountTemplateSerializer(BaseAccountSerializer):
class Meta(BaseAccountSerializer.Meta):
model = AccountTemplate
# @classmethod
# def validate_required(cls, attrs):
# # TODO 选择模版后检查一些必填项
# required_field_dict = {}
# error = _('This field is required.')
# for k, v in cls().fields.items():
# if v.required and k not in attrs:
# required_field_dict[k] = error
# if not required_field_dict:
# return
# raise serializers.ValidationError(required_field_dict)
@staticmethod
def bulk_update_accounts(instance, diff):
accounts = Account.objects.filter(source_id=instance.id)
if not accounts:
return
secret_type = diff.pop('secret_type', None)
diff.pop('secret', None)
update_accounts = []
for account in accounts:
for field, value in diff.items():
setattr(account, field, value)
update_accounts.append(account)
if update_accounts:
Account.objects.bulk_update(update_accounts, diff.keys())
if secret_type is None:
return
update_accounts = []
asset_ids = accounts.values_list('asset_id', flat=True)
secret_type_supports = Asset.get_secret_type_assets(asset_ids, secret_type)
asset_ids_supports = [asset.id for asset in secret_type_supports]
for account in accounts:
asset_id = account.asset_id
if asset_id not in asset_ids_supports:
continue
account.secret_type = secret_type
account.secret = instance.secret
update_accounts.append(account)
if update_accounts:
Account.objects.bulk_update(update_accounts, ['secret', 'secret_type'])
def update(self, instance, validated_data):
diff = {
k: v for k, v in validated_data.items()
if getattr(instance, k) != v
}
instance = super().update(instance, validated_data)
self.bulk_update_accounts(instance, diff)
return instance
class AccountTemplateSecretSerializer(SecretReadableMixin, AccountTemplateSerializer):

View File

@ -271,6 +271,22 @@ class Asset(NodesRelationMixin, AbsConnectivity, JMSOrgBaseModel):
tree_node = TreeNode(**data)
return tree_node
@staticmethod
def get_secret_type_assets(asset_ids, secret_type):
assets = Asset.objects.filter(id__in=asset_ids)
asset_protocol = assets.prefetch_related('protocols').values_list('id', 'protocols__name')
protocol_secret_types_map = const.Protocol.protocol_secret_types()
asset_secret_types_mapp = defaultdict(set)
for asset_id, protocol in asset_protocol:
secret_types = set(protocol_secret_types_map.get(protocol, []))
asset_secret_types_mapp[asset_id].update(secret_types)
return [
asset for asset in assets
if secret_type in asset_secret_types_mapp.get(asset.id, [])
]
class Meta:
unique_together = [('org_id', 'name')]
verbose_name = _("Asset")