From ebaa8d26377fca87f21dc22f42fd6b479233f0d8 Mon Sep 17 00:00:00 2001 From: ibuler <ibuler@qq.com> Date: Thu, 18 May 2023 17:31:40 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20json=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/db/fields.py | 50 +++++++++++++++++++++----------------- apps/common/drf/filters.py | 25 +++++++++++++++++++ apps/users/api/user.py | 10 +++----- 3 files changed, 57 insertions(+), 28 deletions(-) diff --git a/apps/common/db/fields.py b/apps/common/db/fields.py index d2087f5d4..6039e943a 100644 --- a/apps/common/db/fields.py +++ b/apps/common/db/fields.py @@ -294,6 +294,29 @@ class RelatedManager: self.value = value self.instance.__dict__[self.field.name] = value + @classmethod + def get_filter_q(cls, value, to_model): + if not value or not isinstance(value, dict): + return Q() + + if value["type"] == "all": + return Q() + elif value["type"] == "ids" and isinstance(value.get("ids"), list): + return Q(id__in=value["ids"]) + elif value["type"] == "attrs" and isinstance(value.get("attrs"), list): + return cls._get_filter_attrs_q(value, to_model) + else: + return Q() + + @classmethod + def filter_queryset_by_model(cls, value, to_model): + if hasattr(to_model, "get_queryset"): + queryset = to_model.get_queryset() + else: + queryset = to_model.objects.all() + q = cls.get_filter_q(value, to_model) + return queryset.filter(q) + @staticmethod def get_ip_in_q(name, val): q = Q() @@ -322,7 +345,8 @@ class RelatedManager: continue return q - def _get_filter_attrs_q(self, value, to_model): + @classmethod + def _get_filter_attrs_q(cls, value, to_model): filters = Q() # 特殊情况有这几种, # 1. 像 资产中的 type 和 category,集成自 Platform。所以不能直接查询 @@ -340,16 +364,14 @@ class RelatedManager: if name is None or val is None: continue - print("Has custom filter: {}".format(custom_attr_filter)) if custom_attr_filter: custom_filter_q = custom_attr_filter(name, val, match) - print("Custom filter: {}".format(custom_filter_q)) if custom_filter_q: filters &= custom_filter_q continue if match == 'ip_in': - q = self.get_ip_in_q(name, val) + q = cls.get_ip_in_q(name, val) elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"): lookup = "{}__{}".format(name, match) q = Q(**{lookup: val}) @@ -377,26 +399,10 @@ class RelatedManager: def _get_queryset(self): to_model = apps.get_model(self.field.to) value = self.value - if hasattr(to_model, "get_queryset"): - queryset = to_model.get_queryset() - else: - queryset = to_model.objects.all() - - if not value or not isinstance(value, dict): - return queryset.none() - - if value["type"] == "all": - return queryset - elif value["type"] == "ids" and isinstance(value.get("ids"), list): - return queryset.filter(id__in=value["ids"]) - elif value["type"] == "attrs" and isinstance(value.get("attrs"), list): - q = self._get_filter_attrs_q(value, to_model) - return queryset.filter(q) - else: - return queryset.none() + return self.filter_queryset_by_model(value, to_model) def get_attr_q(self): - q = self._get_filter_attrs_q(self.value) + q = self._get_filter_attrs_q(self.value, apps.get_model(self.field.to)) return q def all(self): diff --git a/apps/common/drf/filters.py b/apps/common/drf/filters.py index 949260a47..6278efccf 100644 --- a/apps/common/drf/filters.py +++ b/apps/common/drf/filters.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # +import base64 +import json import logging from django.core.cache import cache @@ -18,6 +20,8 @@ __all__ = [ "BaseFilterSet" ] +from common.db.fields import RelatedManager + class BaseFilterSet(drf_filters.FilterSet): def do_nothing(self, queryset, name, value): @@ -183,3 +187,24 @@ class UUIDInFilter(drf_filters.BaseInFilter, drf_filters.UUIDFilter): class NumberInFilter(drf_filters.BaseInFilter, drf_filters.NumberFilter): pass + + +class AttrRulesFilter(filters.BaseFilterBackend): + def get_schema_fields(self, view): + return [ + coreapi.Field( + name='attr_rules', location='query', required=False, + type='string', example='/api/v1/users/users?attr_rules=jsonbase64', + description='Filter by json like {"type": "attrs", "attrs": []} to base64' + ) + ] + + def filter_queryset(self, request, queryset, view): + attr_rules = request.query_params.get('attr_rules') + if not attr_rules: + return queryset + + attr_rules = base64.b64decode(attr_rules.encode('utf-8')) + attr_rules = json.loads(attr_rules) + q = RelatedManager.get_filter_q(attr_rules, queryset.model) + return queryset.filter(q) diff --git a/apps/users/api/user.py b/apps/users/api/user.py index d1eee8083..26a8e9d05 100644 --- a/apps/users/api/user.py +++ b/apps/users/api/user.py @@ -7,8 +7,8 @@ from rest_framework.decorators import action from rest_framework.response import Response from rest_framework_bulk import BulkModelViewSet -from common.api import CommonApiMixin -from common.api import SuggestionMixin +from common.api import CommonApiMixin, SuggestionMixin +from common.drf.filters import AttrRulesFilter from common.utils import get_logger from orgs.utils import current_org, tmp_to_root_org from rbac.models import Role, RoleBinding @@ -18,10 +18,7 @@ from .. import serializers from ..filters import UserFilter from ..models import User from ..notifications import ResetMFAMsg -from ..serializers import ( - UserSerializer, - MiniUserSerializer, InviteSerializer -) +from ..serializers import UserSerializer, MiniUserSerializer, InviteSerializer from ..signals import post_user_create logger = get_logger(__name__) @@ -33,6 +30,7 @@ __all__ = [ class UserViewSet(CommonApiMixin, UserQuerysetMixin, SuggestionMixin, BulkModelViewSet): filterset_class = UserFilter + extra_filter_backends = [AttrRulesFilter] search_fields = ('username', 'email', 'name') serializer_classes = { 'default': UserSerializer,