perf: 账号模版更新 (#10184)

Co-authored-by: feng <1304903146@qq.com>
pull/10189/head
fit2bot 2023-04-12 17:59:13 +08:00 committed by GitHub
parent b0365838fb
commit 30b89e5cc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 36 deletions

View File

@ -18,6 +18,7 @@ class AliasAccount(TextChoices):
class Source(TextChoices): class Source(TextChoices):
LOCAL = 'local', _('Local') LOCAL = 'local', _('Local')
COLLECTED = 'collected', _('Collected') COLLECTED = 'collected', _('Collected')
TEMPLATE = 'template', _('Template')
class AccountInvalidPolicy(TextChoices): class AccountInvalidPolicy(TextChoices):

View File

@ -1,5 +1,4 @@
import uuid import uuid
from collections import defaultdict
from django.db import IntegrityError from django.db import IntegrityError
from django.db.models import Q from django.db.models import Q
@ -10,7 +9,7 @@ from rest_framework.validators import UniqueTogetherValidator
from accounts.const import SecretType, Source, AccountInvalidPolicy from accounts.const import SecretType, Source, AccountInvalidPolicy
from accounts.models import Account, AccountTemplate from accounts.models import Account, AccountTemplate
from accounts.tasks import push_accounts_to_assets_task from accounts.tasks import push_accounts_to_assets_task
from assets.const import Category, AllTypes, Protocol from assets.const import Category, AllTypes
from assets.models import Asset from assets.models import Asset
from common.serializers import SecretReadableMixin from common.serializers import SecretReadableMixin
from common.serializers.fields import ObjectRelatedField, LabeledChoiceField from common.serializers.fields import ObjectRelatedField, LabeledChoiceField
@ -80,12 +79,12 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
raise serializers.ValidationError({'template': 'Template not found'}) raise serializers.ValidationError({'template': 'Template not found'})
# Set initial data from template # Set initial data from template
ignore_fields = ['id', 'name', 'date_created', 'date_updated', 'org_id'] ignore_fields = ['id', 'date_created', 'date_updated', 'org_id']
field_names = [ field_names = [
field.name for field in template._meta.fields field.name for field in template._meta.fields
if field.name not in ignore_fields if field.name not in ignore_fields
] ]
attrs = {'source': 'template', 'source_id': template.id} attrs = {'source': Source.TEMPLATE, 'source_id': str(template.id)}
for name in field_names: for name in field_names:
value = getattr(template, name, None) value = getattr(template, name, None)
if value is None: if value is None:
@ -135,6 +134,16 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
else: else:
raise serializers.ValidationError('Account already exists') raise serializers.ValidationError('Account already exists')
def validate(self, attrs):
attrs = super().validate(attrs)
if self.instance:
return attrs
if 'source' in self.initial_data:
attrs['source'] = self.initial_data['source']
attrs['source_id'] = self.initial_data['source_id']
return attrs
def create(self, validated_data): def create(self, validated_data):
push_now = validated_data.pop('push_now', None) push_now = validated_data.pop('push_now', None)
instance, stat = self.do_create(validated_data) instance, stat = self.do_create(validated_data)
@ -146,6 +155,7 @@ class AccountCreateUpdateSerializerMixin(serializers.Serializer):
validated_data.pop('username', None) validated_data.pop('username', None)
validated_data.pop('on_invalid', None) validated_data.pop('on_invalid', None)
push_now = validated_data.pop('push_now', None) push_now = validated_data.pop('push_now', None)
validated_data['source_id'] = None
instance = super().update(instance, validated_data) instance = super().update(instance, validated_data)
self.push_account_if_need(instance, push_now, 'updated') self.push_account_if_need(instance, push_now, 'updated')
return instance return instance
@ -233,25 +243,6 @@ class AssetAccountBulkSerializer(AccountCreateUpdateSerializerMixin, serializers
initial_data = self.initial_data initial_data = self.initial_data
self.from_template_if_need(initial_data) self.from_template_if_need(initial_data)
@staticmethod
def _get_valid_secret_type_assets(assets, secret_type):
if isinstance(assets, list):
asset_ids = [a.id for a in assets]
assets = Asset.objects.filter(id__in=asset_ids)
asset_protocol = assets.prefetch_related('protocols').values_list('id', 'protocols__name')
protocol_secret_types_map = Protocol.protocol_secret_types()
asset_secret_types_mapp = defaultdict(set)
for asset_id, protocol in asset_protocol:
secret_types = set(protocol_secret_types_map.get(protocol, []))
asset_secret_types_mapp[asset_id].update(secret_types)
return [
asset for asset in assets
if secret_type in asset_secret_types_mapp.get(asset.id, [])
]
@staticmethod @staticmethod
def get_filter_lookup(vd): def get_filter_lookup(vd):
return { return {
@ -314,7 +305,8 @@ class AssetAccountBulkSerializer(AccountCreateUpdateSerializerMixin, serializers
vd['name'] = vd.get('username') vd['name'] = vd.get('username')
create_handler = self.get_create_handler(on_invalid) create_handler = self.get_create_handler(on_invalid)
secret_type_supports = self._get_valid_secret_type_assets(assets, secret_type) asset_ids = [asset.id for asset in assets]
secret_type_supports = Asset.get_secret_type_assets(asset_ids, secret_type)
_results = {} _results = {}
for asset in assets: for asset in assets:

View File

@ -1,4 +1,5 @@
from accounts.models import AccountTemplate from accounts.models import AccountTemplate, Account
from assets.models import Asset
from common.serializers import SecretReadableMixin from common.serializers import SecretReadableMixin
from .base import BaseAccountSerializer from .base import BaseAccountSerializer
@ -7,17 +8,47 @@ class AccountTemplateSerializer(BaseAccountSerializer):
class Meta(BaseAccountSerializer.Meta): class Meta(BaseAccountSerializer.Meta):
model = AccountTemplate model = AccountTemplate
# @classmethod @staticmethod
# def validate_required(cls, attrs): def bulk_update_accounts(instance, diff):
# # TODO 选择模版后检查一些必填项 accounts = Account.objects.filter(source_id=instance.id)
# required_field_dict = {} if not accounts:
# error = _('This field is required.') return
# for k, v in cls().fields.items():
# if v.required and k not in attrs: secret_type = diff.pop('secret_type', None)
# required_field_dict[k] = error diff.pop('secret', None)
# if not required_field_dict: update_accounts = []
# return for account in accounts:
# raise serializers.ValidationError(required_field_dict) for field, value in diff.items():
setattr(account, field, value)
update_accounts.append(account)
if update_accounts:
Account.objects.bulk_update(update_accounts, diff.keys())
if secret_type is None:
return
update_accounts = []
asset_ids = accounts.values_list('asset_id', flat=True)
secret_type_supports = Asset.get_secret_type_assets(asset_ids, secret_type)
asset_ids_supports = [asset.id for asset in secret_type_supports]
for account in accounts:
asset_id = account.asset_id
if asset_id not in asset_ids_supports:
continue
account.secret_type = secret_type
account.secret = instance.secret
update_accounts.append(account)
if update_accounts:
Account.objects.bulk_update(update_accounts, ['secret', 'secret_type'])
def update(self, instance, validated_data):
diff = {
k: v for k, v in validated_data.items()
if getattr(instance, k) != v
}
instance = super().update(instance, validated_data)
self.bulk_update_accounts(instance, diff)
return instance
class AccountTemplateSecretSerializer(SecretReadableMixin, AccountTemplateSerializer): class AccountTemplateSecretSerializer(SecretReadableMixin, AccountTemplateSerializer):

View File

@ -271,6 +271,22 @@ class Asset(NodesRelationMixin, AbsConnectivity, JMSOrgBaseModel):
tree_node = TreeNode(**data) tree_node = TreeNode(**data)
return tree_node return tree_node
@staticmethod
def get_secret_type_assets(asset_ids, secret_type):
assets = Asset.objects.filter(id__in=asset_ids)
asset_protocol = assets.prefetch_related('protocols').values_list('id', 'protocols__name')
protocol_secret_types_map = const.Protocol.protocol_secret_types()
asset_secret_types_mapp = defaultdict(set)
for asset_id, protocol in asset_protocol:
secret_types = set(protocol_secret_types_map.get(protocol, []))
asset_secret_types_mapp[asset_id].update(secret_types)
return [
asset for asset in assets
if secret_type in asset_secret_types_mapp.get(asset.id, [])
]
class Meta: class Meta:
unique_together = [('org_id', 'name')] unique_together = [('org_id', 'name')]
verbose_name = _("Asset") verbose_name = _("Asset")