diff --git a/apps/applications/forms/remote_app.py b/apps/applications/forms/remote_app.py index cfa841f83..ba7661acd 100644 --- a/apps/applications/forms/remote_app.py +++ b/apps/applications/forms/remote_app.py @@ -5,7 +5,7 @@ from django.utils.translation import ugettext as _ from django import forms from orgs.mixins import OrgModelForm -from assets.models import SystemUser, Protocol +from assets.models import SystemUser from ..models import RemoteApp from .. import const @@ -88,9 +88,7 @@ class RemoteAppCreateUpdateForm(RemoteAppTypeForms, OrgModelForm): # 过滤RDP资产和系统用户 super().__init__(*args, **kwargs) field_asset = self.fields['asset'] - field_asset.queryset = field_asset.queryset.filter( - protocols__name=Protocol.PROTOCOL_RDP - ) + field_asset.queryset = field_asset.queryset.has_protocol('rdp') field_system_user = self.fields['system_user'] field_system_user.queryset = field_system_user.queryset.filter( protocol=SystemUser.PROTOCOL_RDP diff --git a/apps/assets/api/asset.py b/apps/assets/api/asset.py index 3e1c10ede..2d22dd41b 100644 --- a/apps/assets/api/asset.py +++ b/apps/assets/api/asset.py @@ -37,7 +37,7 @@ __all__ = [ ] -class AssetViewSet(LabelFilter, ApiMessageMixin, OrgBulkModelViewSet): +class AssetViewSet(LabelFilter, OrgBulkModelViewSet): """ API endpoint that allows Asset to be viewed or edited. """ diff --git a/apps/assets/api/node.py b/apps/assets/api/node.py index a6072806d..51f9a8739 100644 --- a/apps/assets/api/node.py +++ b/apps/assets/api/node.py @@ -130,7 +130,7 @@ class NodeChildrenAsTreeApi(generics.ListAPIView): include_assets = self.request.query_params.get('assets', '0') == '1' if not include_assets: return queryset - assets = self.node.get_assets().prefetch_related("protocols").only( + assets = self.node.get_assets().only( "id", "hostname", "ip", 'platform', "os", "org_id", ) for asset in assets: diff --git a/apps/assets/forms/asset.py b/apps/assets/forms/asset.py index 8f6ffab18..5973f731b 100644 --- a/apps/assets/forms/asset.py +++ b/apps/assets/forms/asset.py @@ -6,33 +6,27 @@ from django.utils.translation import gettext_lazy as _ from common.utils import get_logger from orgs.mixins import OrgModelForm -from ..models import Asset, Protocol, Node +from ..models import Asset, Node logger = get_logger(__file__) __all__ = [ - 'AssetCreateForm', 'AssetUpdateForm', 'AssetBulkUpdateForm', - 'ProtocolForm' + 'AssetCreateForm', 'AssetUpdateForm', 'AssetBulkUpdateForm', 'ProtocolForm', ] -class ProtocolForm(forms.ModelForm): - class Meta: - model = Protocol - fields = ['name', 'port'] - widgets = { - 'name': forms.Select(attrs={ - 'class': 'form-control protocol-name' - }), - 'port': forms.TextInput(attrs={ - 'class': 'form-control protocol-port' - }), - } +class ProtocolForm(forms.Form): + name = forms.ChoiceField( + choices=Asset.PROTOCOL_CHOICES, label=_("Name"), initial='ssh', + widget=forms.Select(attrs={'class': 'form-control protocol-name'}) + ) + port = forms.IntegerField( + max_value=65534, min_value=1, label=_("Port"), initial=22, + widget=forms.TextInput(attrs={'class': 'form-control protocol-port'}) + ) class AssetCreateForm(OrgModelForm): - PROTOCOL_CHOICES = Protocol.PROTOCOL_CHOICES - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self.data: diff --git a/apps/assets/migrations/0029_auto_20190522_1114.py b/apps/assets/migrations/0029_auto_20190522_1114.py index e1d20bb5f..46836c846 100644 --- a/apps/assets/migrations/0029_auto_20190522_1114.py +++ b/apps/assets/migrations/0029_auto_20190522_1114.py @@ -3,14 +3,6 @@ from django.db import migrations -def migrate_assets_protocol(apps, schema_editor): - asset_model = apps.get_model("assets", "Asset") - db_alias = schema_editor.connection.alias - assets = asset_model.objects.using(db_alias).all() - for asset in assets: - asset.protocols.create(name=asset.protocol, port=asset.port) - - class Migration(migrations.Migration): dependencies = [ @@ -18,5 +10,4 @@ class Migration(migrations.Migration): ] operations = [ - migrations.RunPython(migrate_assets_protocol), ] diff --git a/apps/assets/migrations/0034_auto_20190705_1348.py b/apps/assets/migrations/0034_auto_20190705_1348.py new file mode 100644 index 000000000..161ce2b30 --- /dev/null +++ b/apps/assets/migrations/0034_auto_20190705_1348.py @@ -0,0 +1,39 @@ +# Generated by Django 2.1.7 on 2019-07-05 05:48 + +from django.db import migrations +from django.db.models import F +from django.db.models import CharField, Value as V +from django.db.models.functions import Concat + + +def migrate_assets_protocol(apps, schema_editor): + asset_model = apps.get_model("assets", "Asset") + db_alias = schema_editor.connection.alias + assets = asset_model.objects.using(db_alias).all().annotate( + protocols_new=Concat( + 'protocol', V('/'), 'port', + output_field=CharField(), + ), + ) + assets.update(protocols=F('protocols_new')) + + +class Migration(migrations.Migration): + + dependencies = [ + ('assets', '0033_auto_20190624_2108'), + ] + + operations = [ + migrations.RemoveField( + model_name='asset', + name='protocols', + ), + migrations.AddField( + model_name='asset', + name='protocols', + field=CharField(blank=True, default='ssh/22', max_length=128, verbose_name='Protocols'), + ), + migrations.RunPython(migrate_assets_protocol), + migrations.DeleteModel(name='Protocol'), + ] diff --git a/apps/assets/models/asset.py b/apps/assets/models/asset.py index 0e5fc2090..bbf543fcb 100644 --- a/apps/assets/models/asset.py +++ b/apps/assets/models/asset.py @@ -6,16 +6,16 @@ import uuid import logging import random from functools import reduce +from collections import OrderedDict from django.db import models from django.utils.translation import ugettext_lazy as _ -from django.core.validators import MinValueValidator, MaxValueValidator from .user import AdminUser, SystemUser from .utils import Connectivity from orgs.mixins import OrgModelMixin, OrgManager -__all__ = ['Asset', 'Protocol'] +__all__ = ['Asset'] logger = logging.getLogger(__name__) @@ -45,8 +45,12 @@ class AssetQuerySet(models.QuerySet): def valid(self): return self.active() + def has_protocol(self, name): + return self.filter(protocols__contains=name) -class Protocol(models.Model): + +class ProtocolsMixin: + protocols = '' PROTOCOL_SSH = 'ssh' PROTOCOL_RDP = 'rdp' PROTOCOL_TELNET = 'telnet' @@ -57,19 +61,42 @@ class Protocol(models.Model): (PROTOCOL_TELNET, 'telnet (beta)'), (PROTOCOL_VNC, 'vnc'), ) - PORT_VALIDATORS = [MaxValueValidator(65535), MinValueValidator(1)] - id = models.UUIDField(default=uuid.uuid4, primary_key=True) - name = models.CharField(max_length=16, choices=PROTOCOL_CHOICES, - default=PROTOCOL_SSH, verbose_name=_("Name")) - port = models.IntegerField(default=22, verbose_name=_("Port"), - validators=PORT_VALIDATORS) + @property + def protocols_as_list(self): + if not self.protocols: + return [] + return self.protocols.split(' ') - def __str__(self): - return "{}/{}".format(self.name, self.port) + @property + def protocols_as_dict(self): + d = OrderedDict() + protocols = self.protocols_as_list + for i in protocols: + if '/' not in i: + continue + name, port = i.split('/')[:2] + if not all([name, port]): + continue + d[name] = int(port) + return d + + @property + def protocols_as_json(self): + return [ + {"name": name, "port": port} + for name, port in self.protocols_as_dict.items() + ] + + def has_protocol(self, name): + return name in self.protocols_as_dict + + @property + def ssh_port(self): + return self.protocols_as_dict.get("ssh", 22) -class Asset(OrgModelMixin): +class Asset(ProtocolsMixin, OrgModelMixin): # Important PLATFORM_CHOICES = ( ('Linux', 'Linux'), @@ -84,12 +111,12 @@ class Asset(OrgModelMixin): id = models.UUIDField(default=uuid.uuid4, primary_key=True) ip = models.CharField(max_length=128, verbose_name=_('IP'), db_index=True) hostname = models.CharField(max_length=128, verbose_name=_('Hostname')) - protocol = models.CharField(max_length=128, default=Protocol.PROTOCOL_SSH, - choices=Protocol.PROTOCOL_CHOICES, + protocol = models.CharField(max_length=128, default=ProtocolsMixin.PROTOCOL_SSH, + choices=ProtocolsMixin.PROTOCOL_CHOICES, verbose_name=_('Protocol')) port = models.IntegerField(default=22, verbose_name=_('Port')) - protocols = models.ManyToManyField('Protocol', verbose_name=_("Protocol")) + protocols = models.CharField(max_length=128, default='ssh/22', blank=True, verbose_name=_("Protocols")) platform = models.CharField(max_length=128, choices=PLATFORM_CHOICES, default='Linux', verbose_name=_('Platform')) domain = models.ForeignKey("assets.Domain", null=True, blank=True, related_name='assets', verbose_name=_("Domain"), on_delete=models.SET_NULL) nodes = models.ManyToManyField('assets.Node', default=default_node, related_name='assets', verbose_name=_("Nodes")) @@ -136,41 +163,9 @@ class Asset(OrgModelMixin): warning = '' if not self.is_active: warning += ' inactive' - else: - return True, '' - return False, warning - - @property - def protocols_name(self): - names = [] - for protocol in self.protocols.all(): - names.append(protocol.name) - return names - - def has_protocol(self, name): - return name in self.protocols_name - - def get_protocol_by_name(self, name): - for i in self.protocols.all(): - if i.name.lower() == name.lower(): - return i - return None - - @property - def protocol_ssh(self): - return self.get_protocol_by_name("ssh") - - @property - def protocol_rdp(self): - return self.get_protocol_by_name("rdp") - - @property - def ssh_port(self): - if self.protocol_ssh: - port = self.protocol_ssh.port - else: - port = 22 - return port + if warning: + return False, warning + return True, warning def is_windows(self): if self.platform in ("Windows", "Windows2016"): @@ -278,10 +273,7 @@ class Asset(OrgModelMixin): 'id': self.id, 'hostname': self.hostname, 'ip': self.ip, - 'protocols': [ - {"name": p.name, "port": p.port} - for p in self.protocols.all() - ], + 'protocols': self.protocols_as_list, 'platform': self.platform, } } @@ -314,7 +306,7 @@ class Asset(OrgModelMixin): created_by='Fake') try: asset.save() - asset.protocols.create(name="ssh", port=22) + asset.protocols = 'ssh/22' if nodes and len(nodes) > 3: _nodes = random.sample(nodes, 3) else: diff --git a/apps/assets/serializers/asset.py b/apps/assets/serializers/asset.py index 24e19c82d..5fdb3a362 100644 --- a/apps/assets/serializers/asset.py +++ b/apps/assets/serializers/asset.py @@ -1,48 +1,66 @@ # -*- coding: utf-8 -*- # from rest_framework import serializers -from rest_framework.validators import ValidationError from django.db.models import Prefetch from django.utils.translation import ugettext_lazy as _ from orgs.mixins import BulkOrgResourceModelSerializer from common.serializers import AdaptedBulkListSerializer -from ..models import Asset, Protocol, Node, Label +from ..models import Asset, Node, Label from .base import ConnectivitySerializer __all__ = [ 'AssetSerializer', 'AssetSimpleSerializer', - 'ProtocolSerializer', 'ProtocolsRelatedField', + 'ProtocolsField', ] -class ProtocolSerializer(serializers.ModelSerializer): - class Meta: - model = Protocol - fields = ["name", "port"] +class ProtocolField(serializers.RegexField): + protocols = '|'.join(dict(Asset.PROTOCOL_CHOICES).keys()) + default_error_messages = { + 'invalid': _('Protocol format should {}/{}'.format(protocols, '1-65535')) + } + regex = r'^(%s)/(\d{1,5})$' % protocols + + def __init__(self, *args, **kwargs): + super().__init__(self.regex, **kwargs) -class ProtocolsRelatedField(serializers.RelatedField): +def validate_duplicate_protocols(values): + errors = [] + names = [] + + for value in values: + if not value or '/' not in value: + continue + name = value.split('/')[0] + if name in names: + errors.append(_("Protocol duplicate: {}").format(name)) + names.append(name) + errors.append('') + if any(errors): + raise serializers.ValidationError(errors) + + +class ProtocolsField(serializers.ListField): + default_validators = [validate_duplicate_protocols] + + def __init__(self, *args, **kwargs): + kwargs['child'] = ProtocolField() + kwargs['allow_null'] = True + kwargs['allow_empty'] = True + kwargs['min_length'] = 1 + kwargs['max_length'] = 4 + super().__init__(*args, **kwargs) + def to_representation(self, value): - return str(value) - - def to_internal_value(self, data): - if isinstance(data, dict): - return data - if '/' not in data: - raise ValidationError("protocol not contain /: {}".format(data)) - v = data.split("/") - if len(v) != 2: - raise ValidationError("protocol format should be name/port: {}".format(data)) - name, port = v - cleaned_data = {"name": name, "port": port} - return cleaned_data + if not value: + return [] + return value.split(' ') class AssetSerializer(BulkOrgResourceModelSerializer): - protocols = ProtocolsRelatedField( - many=True, queryset=Protocol.objects.all(), label=_("Protocols") - ) + protocols = ProtocolsField(label=_('Protocols'), required=False) connectivity = ConnectivitySerializer(read_only=True, label=_("Connectivity")) """ @@ -79,66 +97,32 @@ class AssetSerializer(BulkOrgResourceModelSerializer): queryset = queryset.prefetch_related( Prefetch('nodes', queryset=Node.objects.all().only('id')), Prefetch('labels', queryset=Label.objects.all().only('id')), - 'protocols' ).select_related('admin_user', 'domain') return queryset - @staticmethod - def validate_protocols(attr): - protocols_serializer = ProtocolSerializer(data=attr, many=True) - protocols_serializer.is_valid(raise_exception=True) - protocols_name = [i.get("name", "ssh") for i in attr] - errors = [{} for i in protocols_name] - for i, name in enumerate(protocols_name): - if name in protocols_name[:i]: - errors[i] = {"name": _("Protocol duplicate: {}").format(name)} - if any(errors): - raise ValidationError(errors) - return attr - - def create(self, validated_data): + def compatible_with_old_protocol(self, validated_data): protocols_data = validated_data.pop("protocols", []) # 兼容老的api - protocol = validated_data.get("protocol") + name = validated_data.get("protocol") port = validated_data.get("port") - if not protocols_data and protocol and port: - protocols_data = [{"name": protocol, "port": port}] + if not protocols_data and name and port: + protocols_data.insert(0, '/'.join([name, str(port)])) + elif not name and not port and protocols_data: + protocol = protocols_data[0].split('/') + validated_data["protocol"] = protocol[0] + validated_data["port"] = int(protocol[1]) + if validated_data: + validated_data["protocols"] = ' '.join(protocols_data) - if not protocol and not port and protocols_data: - validated_data["protocol"] = protocols_data[0]["name"] - validated_data["port"] = protocols_data[0]["port"] - - protocols_serializer = ProtocolSerializer(data=protocols_data, many=True) - protocols_serializer.is_valid(raise_exception=True) - protocols = protocols_serializer.save() + def create(self, validated_data): + self.compatible_with_old_protocol(validated_data) instance = super().create(validated_data) - instance.protocols.set(protocols) return instance def update(self, instance, validated_data): - protocols_data = validated_data.pop("protocols", []) - - # 兼容老的api - protocol = validated_data.get("protocol") - port = validated_data.get("port") - if not protocols_data and protocol and port: - protocols_data = [{"name": protocol, "port": port}] - - if not protocol and not port and protocols_data: - validated_data["protocol"] = protocols_data[0]["name"] - validated_data["port"] = protocols_data[0]["port"] - protocols = None - if protocols_data: - protocols_serializer = ProtocolSerializer(data=protocols_data, many=True) - protocols_serializer.is_valid(raise_exception=True) - protocols = protocols_serializer.save() - - instance = super().update(instance, validated_data) - if protocols: - instance.protocols.all().delete() - instance.protocols.set(protocols) - return instance + self.compatible_with_old_protocol(validated_data) + return super().update(instance, validated_data) class AssetSimpleSerializer(serializers.ModelSerializer): diff --git a/apps/assets/templates/assets/asset_detail.html b/apps/assets/templates/assets/asset_detail.html index 2dcb569ac..ad635ca5b 100644 --- a/apps/assets/templates/assets/asset_detail.html +++ b/apps/assets/templates/assets/asset_detail.html @@ -70,11 +70,7 @@ {% trans 'Protocol' %} - - {% for protocol in asset.protocols.all %} - {{ protocol }} - {% endfor %} - + {{ asset.protocols }} {% trans 'Admin user' %}: diff --git a/apps/assets/views/asset.py b/apps/assets/views/asset.py index ed0fe7efc..ce701b129 100644 --- a/apps/assets/views/asset.py +++ b/apps/assets/views/asset.py @@ -1,37 +1,25 @@ # coding:utf-8 from __future__ import absolute_import, unicode_literals -import csv -import json -import uuid -import codecs -import chardet -from io import StringIO - -from django.db import transaction from django.contrib import messages from django.utils.translation import ugettext_lazy as _ -from django.views.generic import TemplateView, ListView, View -from django.views.generic.edit import CreateView, DeleteView, FormView, UpdateView +from django.views.generic import TemplateView, ListView +from django.views.generic.edit import FormMixin +from django.views.generic.edit import CreateView, DeleteView, UpdateView from django.urls import reverse_lazy from django.views.generic.detail import DetailView -from django.http import HttpResponse, JsonResponse -from django.views.decorators.csrf import csrf_exempt -from django.utils.decorators import method_decorator from django.core.cache import cache -from django.utils import timezone from django.shortcuts import redirect from django.contrib.messages.views import SuccessMessageMixin from django.forms.formsets import formset_factory -from common.mixins import JSONResponseMixin from common.utils import get_object_or_none, get_logger from common.permissions import PermissionsMixin, IsOrgAdmin, IsValidUser from common.const import ( create_success_msg, update_success_msg, KEY_CACHE_RESOURCES_ID ) from .. import forms -from ..models import Asset, AdminUser, SystemUser, Label, Node, Domain +from ..models import Asset, SystemUser, Label, Node __all__ = [ @@ -87,7 +75,7 @@ class UserAssetListView(PermissionsMixin, TemplateView): return super().get_context_data(**kwargs) -class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView): +class AssetCreateView(PermissionsMixin, FormMixin, TemplateView): model = Asset form_class = forms.AssetCreateForm template_name = 'assets/asset_create.html' @@ -112,16 +100,6 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView): formset = ProtocolFormset() return formset - def form_valid(self, form): - formset = self.get_protocol_formset() - valid = formset.is_valid() - if not valid: - return self.form_invalid(form) - protocols = formset.save() - instance = super().form_valid(form) - instance.protocols.set(protocols) - return instance - def get_context_data(self, **kwargs): formset = self.get_protocol_formset() context = { @@ -132,8 +110,32 @@ class AssetCreateView(PermissionsMixin, SuccessMessageMixin, CreateView): kwargs.update(context) return super().get_context_data(**kwargs) - def get_success_message(self, cleaned_data): - return create_success_msg % ({"name": cleaned_data["hostname"]}) + +class AssetUpdateView(PermissionsMixin, UpdateView): + model = Asset + form_class = forms.AssetUpdateForm + template_name = 'assets/asset_update.html' + success_url = reverse_lazy('assets:asset-list') + permission_classes = [IsOrgAdmin] + + def get_protocol_formset(self): + ProtocolFormset = formset_factory(forms.ProtocolForm, extra=0, min_num=1, max_num=5) + if self.request.method == "POST": + formset = ProtocolFormset(self.request.POST) + else: + initial_data = self.object.protocols_as_json + formset = ProtocolFormset(initial=initial_data) + return formset + + def get_context_data(self, **kwargs): + formset = self.get_protocol_formset() + context = { + 'app': _('Assets'), + 'action': _('Update asset'), + 'formset': formset, + } + kwargs.update(context) + return super().get_context_data(**kwargs) class AssetBulkUpdateView(PermissionsMixin, ListView): @@ -177,36 +179,6 @@ class AssetBulkUpdateView(PermissionsMixin, ListView): return super().get_context_data(**kwargs) -class AssetUpdateView(PermissionsMixin, SuccessMessageMixin, UpdateView): - model = Asset - form_class = forms.AssetUpdateForm - template_name = 'assets/asset_update.html' - success_url = reverse_lazy('assets:asset-list') - permission_classes = [IsOrgAdmin] - - def get_protocol_formset(self): - ProtocolFormset = formset_factory(forms.ProtocolForm, extra=0, min_num=1, max_num=5) - if self.request.method == "POST": - formset = ProtocolFormset(self.request.POST) - else: - initial_data = [{"name": p.name, "port": p.port} for p in self.object.protocols.all()] - formset = ProtocolFormset(initial=initial_data) - return formset - - def get_context_data(self, **kwargs): - formset = self.get_protocol_formset() - context = { - 'app': _('Assets'), - 'action': _('Update asset'), - 'formset': formset, - } - kwargs.update(context) - return super().get_context_data(**kwargs) - - def get_success_message(self, cleaned_data): - return update_success_msg % ({"name": cleaned_data["hostname"]}) - - class AssetDeleteView(PermissionsMixin, DeleteView): model = Asset template_name = 'delete_confirm.html' @@ -222,7 +194,7 @@ class AssetDetailView(PermissionsMixin, DetailView): def get_queryset(self): return super().get_queryset().prefetch_related( - "nodes", "labels", "protocols" + "nodes", "labels", ).select_related('admin_user', 'domain') def get_context_data(self, **kwargs): diff --git a/apps/ops/templates/ops/command_execution_create.html b/apps/ops/templates/ops/command_execution_create.html index 230b4e6a1..a9cce3d5a 100644 --- a/apps/ops/templates/ops/command_execution_create.html +++ b/apps/ops/templates/ops/command_execution_create.html @@ -135,9 +135,7 @@ function getSelectedAssetsNode() { var assetsNode = []; nodes.forEach(function (node) { if (node.meta.type === 'asset' && !node.isHidden) { - var protocols = $.map(node.meta.asset.protocols, function (v) { - return v.name - }); + var protocols = node.meta.asset.protocols; if (assetsNodeId.indexOf(node.id) === -1 && protocols.indexOf("ssh") > -1) { assetsNodeId.push(node.id); assetsNode.push(node) diff --git a/apps/perms/utils/asset_permission.py b/apps/perms/utils/asset_permission.py index 4273e197e..4ce485d71 100644 --- a/apps/perms/utils/asset_permission.py +++ b/apps/perms/utils/asset_permission.py @@ -126,7 +126,7 @@ class GenerateTree: for asset, system_users in assets.items(): self.add_asset(asset, system_users) - #@timeit + # #@timeit def add_asset(self, asset, system_users=None): nodes = asset.nodes.all() nodes = self.node_util.get_nodes_by_queryset(nodes) @@ -493,12 +493,13 @@ class AssetPermissionUtil(AssetPermissionCacheMixin): pattern.add(r'^{0}$|^{0}:'.format(node.key)) pattern = '|'.join(list(pattern)) if pattern: - assets = Asset.objects.filter(nodes__key__regex=pattern)\ - .prefetch_related('nodes', "protocols")\ + assets = Asset.objects.filter(nodes__key__regex=pattern) \ + .prefetch_related('nodes')\ .only(*self.assets_only)\ .distinct() else: assets = [] + assets = list(assets) self.tree.add_assets_without_system_users(assets) assets = self.tree.get_assets() self._assets = assets @@ -598,7 +599,7 @@ def parse_asset_to_tree_node(node, asset, system_users): 'id': asset.id, 'hostname': asset.hostname, 'ip': asset.ip, - 'protocols': [str(p) for p in asset.protocols.all()], + 'protocols': asset.protocols_as_list, 'platform': asset.platform, 'domain': None if not asset.domain else asset.domain.id, 'is_active': asset.is_active,