diff --git a/apps/assets/migrations/0104_auto_20220803_1859.py b/apps/assets/migrations/0104_auto_20220803_1859.py index ecd75743b..f1836f713 100644 --- a/apps/assets/migrations/0104_auto_20220803_1859.py +++ b/apps/assets/migrations/0104_auto_20220803_1859.py @@ -6,10 +6,12 @@ from django.db import migrations def migrate_asset_protocols(apps, schema_editor): asset_model = apps.get_model('assets', 'Asset') protocol_model = apps.get_model('assets', 'Protocol') + asset_protocol_through = asset_model.protocols.through count = 0 bulk_size = 1000 print("\nStart migrate asset protocols") + protocol_map = {} while True: start = time.time() assets = asset_model.objects.all()[count:count+bulk_size] @@ -17,15 +19,24 @@ def migrate_asset_protocols(apps, schema_editor): if not assets: break - protocols = [] + assets_protocols = [] for asset in assets: - for protocol in asset.protocols.all(): - protocols.append(protocol_model( - asset_id=asset.id, - protocol=protocol.protocol, - port=protocol.port, - )) - protocol_model.objects.bulk_create(protocols, ignore_conflicts=True) + old_protocols = asset._protocols + + for name_port in old_protocols.split(','): + name_port_list = name_port.split('/') + if len(name_port_list) != 2: + continue + + name, port = name_port_list + protocol = protocol_map.get(name_port) + if not protocol: + protocol = protocol_model.objects.get_or_create( + defaults={'name': name, 'port': port}, + name=name, port=port + )[0] + assets_protocols.append(asset_protocol_through(asset_id=asset.id, protocol_id=protocol.id)) + asset_model.protocols.through.objects.bulk_create(assets_protocols, ignore_conflicts=True) print("Create asset protocols: {}-{} using: {:.2f}s".format( count - bulk_size, count, time.time()-start )) diff --git a/apps/assets/serializers/asset/common.py b/apps/assets/serializers/asset/common.py index d3ca2b257..0e8acfe12 100644 --- a/apps/assets/serializers/asset/common.py +++ b/apps/assets/serializers/asset/common.py @@ -3,70 +3,30 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ -from orgs.mixins.serializers import BulkOrgResourceModelSerializer from orgs.mixins.serializers import OrgResourceModelSerializerMixin -from ...models import Asset, Node, Platform, SystemUser +from ...models import Asset, Node, Platform, SystemUser, Protocol, Label from ..mixin import CategoryDisplayMixin from ..account import AccountSerializer __all__ = [ 'AssetSerializer', 'AssetSimpleSerializer', 'MiniAssetSerializer', - 'AssetTaskSerializer', 'AssetsTaskSerializer', 'ProtocolsField', + 'AssetTaskSerializer', 'AssetsTaskSerializer', ] -class ProtocolField(serializers.RegexField): - default_error_messages = { - 'invalid': _('Protocol format should {}/{}').format('protocol', '1-65535') - } - regex = r'^(\w+)/(\d{1,5})$' - - def __init__(self, *args, **kwargs): - super().__init__(self.regex, **kwargs) +class AssetProtocolsSerializer(serializers.ModelSerializer): + class Meta: + model = Protocol + fields = ['id', 'name', 'port'] -def validate_duplicate_protocols(values): - errors = [] - names = [] - - print("Value is: ", values) - - for value in values.split(' '): - if not value or '/' not in value: - continue - name = value.split('/')[0] - if name in names: - errors.append(_("Protocol duplicate: {}").format(name)) - names.append(name) - errors.append('') - if any(errors): - raise serializers.ValidationError(errors) - - -class ProtocolsField(serializers.ListField): - default_validators = [validate_duplicate_protocols] - - def __init__(self, *args, **kwargs): - kwargs['child'] = ProtocolField() - kwargs['allow_null'] = True - kwargs['allow_empty'] = True - kwargs['min_length'] = 1 - kwargs['max_length'] = 32 - super().__init__(*args, **kwargs) - - def to_representation(self, value): - if not value: - return [] - if isinstance(value, str): - return value.split(' ') - return value - - def to_internal_value(self, data): - return ' '.join(data) +class AssetLabelSerializer(serializers.ModelSerializer): + class Meta: + model = Label + fields = ['id', 'name', 'value'] class AssetSerializer(CategoryDisplayMixin, OrgResourceModelSerializerMixin): - protocols = ProtocolsField(label=_('Protocols'), required=False, default=['ssh/22']) domain_display = serializers.ReadOnlyField(source='domain.name', label=_('Domain name')) nodes_display = serializers.ListField( child=serializers.CharField(), label=_('Nodes name'), required=False @@ -75,10 +35,12 @@ class AssetSerializer(CategoryDisplayMixin, OrgResourceModelSerializerMixin): child=serializers.CharField(), label=_('Labels name'), required=False, read_only=True ) + labels = AssetLabelSerializer(many=True, required=False) platform_display = serializers.SlugField( source='platform.name', label=_("Platform display"), read_only=True ) accounts = AccountSerializer(many=True, write_only=True, required=False) + protocols = AssetProtocolsSerializer(many=True) """ 资产的数据结构 @@ -87,17 +49,16 @@ class AssetSerializer(CategoryDisplayMixin, OrgResourceModelSerializerMixin): class Meta: model = Asset fields_mini = [ - 'id', 'hostname', 'ip', 'platform', 'protocols' + 'id', 'hostname', 'ip', ] fields_small = fields_mini + [ - 'protocol', 'port', 'is_active', - 'public_ip', 'number', 'comment', + 'is_active', 'number', 'comment', ] fields_fk = [ - 'domain', 'domain_display', 'platform', + 'domain', 'domain_display', 'platform', 'platform', 'platform_display', ] fields_m2m = [ - 'nodes', 'nodes_display', 'labels', 'labels_display', 'accounts' + 'nodes', 'nodes_display', 'labels', 'labels_display', 'accounts', 'protocols', ] read_only_fields = [ 'category', 'category_display', 'type', 'type_display', diff --git a/apps/assets/serializers/platform.py b/apps/assets/serializers/platform.py index 7a68e045d..b7277741d 100644 --- a/apps/assets/serializers/platform.py +++ b/apps/assets/serializers/platform.py @@ -3,8 +3,6 @@ from django.core.validators import RegexValidator from django.utils.translation import gettext_lazy as _ from assets.models import Platform -from assets.serializers.asset import ProtocolsField -from assets.const import Protocol from .mixin import CategoryDisplayMixin __all__ = ['PlatformSerializer'] @@ -12,7 +10,7 @@ __all__ = ['PlatformSerializer'] class PlatformSerializer(CategoryDisplayMixin, serializers.ModelSerializer): meta = serializers.DictField(required=False, allow_null=True, label=_('Meta')) - protocols_default = ProtocolsField(label=_('Protocols'), required=False) + protocols_default = serializers.ListField(label=_('Protocols'), required=False) type_limits = serializers.ReadOnlyField(required=False, read_only=True) def __init__(self, *args, **kwargs): diff --git a/apps/authentication/serializers/connection_token.py b/apps/authentication/serializers/connection_token.py index 1b639bec6..a05d8c0e0 100644 --- a/apps/authentication/serializers/connection_token.py +++ b/apps/authentication/serializers/connection_token.py @@ -8,7 +8,6 @@ from common.utils.random import random_string from assets.models import Asset, SystemUser, Gateway, Domain, CommandFilterRule from users.models import User from applications.models import Application -from assets.serializers import ProtocolsField from perms.serializers.base import ActionsField @@ -122,7 +121,6 @@ class ConnectionTokenUserSerializer(serializers.ModelSerializer): class ConnectionTokenAssetSerializer(serializers.ModelSerializer): - protocols = ProtocolsField(label='Protocols', read_only=True) class Meta: model = Asset diff --git a/apps/perms/serializers/asset/user_permission.py b/apps/perms/serializers/asset/user_permission.py index d2ad97894..26a21d7f0 100644 --- a/apps/perms/serializers/asset/user_permission.py +++ b/apps/perms/serializers/asset/user_permission.py @@ -5,7 +5,6 @@ from rest_framework import serializers from django.utils.translation import ugettext_lazy as _ from assets.models import Node, SystemUser, Asset, Platform -from assets.serializers import ProtocolsField from perms.serializers.base import ActionsField __all__ = [ @@ -38,7 +37,6 @@ class AssetGrantedSerializer(serializers.ModelSerializer): """ 被授权资产的数据结构 """ - protocols = ProtocolsField(label=_('Protocols'), required=False, read_only=True) platform = serializers.SlugRelatedField( slug_field='name', queryset=Platform.objects.all(), label=_("Platform") )