mirror of https://github.com/jumpserver/jumpserver
perf: 修改 m2m json field
parent
ebaa8d2637
commit
a261d69cd2
|
@ -15,7 +15,7 @@ from assets.filters import IpInFilterBackend, LabelFilterBackend, NodeFilterBack
|
|||
from assets.models import Asset, Gateway, Platform
|
||||
from assets.tasks import test_assets_connectivity_manual, update_assets_hardware_info_manual
|
||||
from common.api import SuggestionMixin
|
||||
from common.drf.filters import BaseFilterSet
|
||||
from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend
|
||||
from common.utils import get_logger, is_uuid
|
||||
from orgs.mixins import generics
|
||||
from orgs.mixins.api import OrgBulkModelViewSet
|
||||
|
@ -110,7 +110,10 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
|
|||
("spec_info", "assets.view_asset"),
|
||||
("gathered_info", "assets.view_asset"),
|
||||
)
|
||||
extra_filter_backends = [LabelFilterBackend, IpInFilterBackend, NodeFilterBackend]
|
||||
extra_filter_backends = [
|
||||
LabelFilterBackend, IpInFilterBackend,
|
||||
NodeFilterBackend, AttrRulesFilterBackend
|
||||
]
|
||||
|
||||
def get_serializer_class(self):
|
||||
cls = super().get_serializer_class()
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import Q
|
||||
from django.forms import model_to_dict
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
|
@ -116,7 +117,32 @@ class Protocol(models.Model):
|
|||
return self.asset_platform_protocol.get('public', True)
|
||||
|
||||
|
||||
class Asset(NodesRelationMixin, AbsConnectivity, JMSOrgBaseModel):
|
||||
class JSONFilterMixin:
|
||||
@staticmethod
|
||||
def get_json_filter_attr_q(name, value, match):
|
||||
"""
|
||||
:param name: 属性名称
|
||||
:param value: 定义的结果
|
||||
:param match: 匹配方式
|
||||
:return:
|
||||
"""
|
||||
from ..node import Node
|
||||
if not isinstance(value, (list, tuple)):
|
||||
value = [value]
|
||||
if name == 'nodes':
|
||||
nodes = Node.objects.filter(id__in=value)
|
||||
children = Node.get_nodes_all_children(nodes, with_self=True).values_list('id', flat=True)
|
||||
return Q(nodes__in=children)
|
||||
elif name == 'category':
|
||||
return Q(platform__category__in=value)
|
||||
elif name == 'type':
|
||||
return Q(platform__type__in=value)
|
||||
elif name == 'protocols':
|
||||
return Q(protocols__name__in=value)
|
||||
return None
|
||||
|
||||
|
||||
class Asset(NodesRelationMixin, AbsConnectivity, JSONFilterMixin, JMSOrgBaseModel):
|
||||
Category = const.Category
|
||||
Type = const.AllTypes
|
||||
|
||||
|
|
|
@ -63,6 +63,19 @@ class FamilyMixin:
|
|||
pattern += r'|^{0}$'.format(key)
|
||||
return pattern
|
||||
|
||||
@classmethod
|
||||
def get_nodes_children_key_pattern(cls, nodes, with_self=True):
|
||||
keys = [i.key for i in nodes]
|
||||
keys = cls.clean_children_keys(keys)
|
||||
patterns = [cls.get_node_all_children_key_pattern(key) for key in keys]
|
||||
patterns = '|'.join(patterns)
|
||||
return patterns
|
||||
|
||||
@classmethod
|
||||
def get_nodes_all_children(cls, nodes, with_self=True):
|
||||
pattern = cls.get_nodes_children_key_pattern(nodes, with_self=with_self)
|
||||
return Node.objects.filter(key__iregex=pattern)
|
||||
|
||||
@classmethod
|
||||
def get_node_children_key_pattern(cls, key, with_self=True):
|
||||
pattern = r'^{0}:[0-9]+$'.format(key)
|
||||
|
|
|
@ -315,6 +315,7 @@ class RelatedManager:
|
|||
else:
|
||||
queryset = to_model.objects.all()
|
||||
q = cls.get_filter_q(value, to_model)
|
||||
print("Q: ", q)
|
||||
return queryset.filter(q)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -189,7 +189,7 @@ class NumberInFilter(drf_filters.BaseInFilter, drf_filters.NumberFilter):
|
|||
pass
|
||||
|
||||
|
||||
class AttrRulesFilter(filters.BaseFilterBackend):
|
||||
class AttrRulesFilterBackend(filters.BaseFilterBackend):
|
||||
def get_schema_fields(self, view):
|
||||
return [
|
||||
coreapi.Field(
|
||||
|
@ -204,7 +204,15 @@ class AttrRulesFilter(filters.BaseFilterBackend):
|
|||
if not attr_rules:
|
||||
return queryset
|
||||
|
||||
attr_rules = base64.b64decode(attr_rules.encode('utf-8'))
|
||||
attr_rules = json.loads(attr_rules)
|
||||
try:
|
||||
attr_rules = base64.b64decode(attr_rules.encode('utf-8'))
|
||||
except Exception:
|
||||
raise ValidationError({'attr_rules': 'attr_rules should be base64'})
|
||||
try:
|
||||
attr_rules = json.loads(attr_rules)
|
||||
except Exception:
|
||||
raise ValidationError({'attr_rules': 'attr_rules should be json'})
|
||||
|
||||
logging.debug('attr_rules: %s', attr_rules)
|
||||
q = RelatedManager.get_filter_q(attr_rules, queryset.model)
|
||||
return queryset.filter(q)
|
||||
return queryset.filter(q).distinct()
|
||||
|
|
|
@ -8,7 +8,7 @@ from rest_framework.response import Response
|
|||
from rest_framework_bulk import BulkModelViewSet
|
||||
|
||||
from common.api import CommonApiMixin, SuggestionMixin
|
||||
from common.drf.filters import AttrRulesFilter
|
||||
from common.drf.filters import AttrRulesFilterBackend
|
||||
from common.utils import get_logger
|
||||
from orgs.utils import current_org, tmp_to_root_org
|
||||
from rbac.models import Role, RoleBinding
|
||||
|
@ -30,7 +30,7 @@ __all__ = [
|
|||
|
||||
class UserViewSet(CommonApiMixin, UserQuerysetMixin, SuggestionMixin, BulkModelViewSet):
|
||||
filterset_class = UserFilter
|
||||
extra_filter_backends = [AttrRulesFilter]
|
||||
extra_filter_backends = [AttrRulesFilterBackend]
|
||||
search_fields = ('username', 'email', 'name')
|
||||
serializer_classes = {
|
||||
'default': UserSerializer,
|
||||
|
|
Loading…
Reference in New Issue