perf: 优化 json error

pull/10327/head
ibuler 2023-05-18 17:31:40 +08:00
parent 4e5ab5a605
commit ebaa8d2637
3 changed files with 57 additions and 28 deletions

View File

@ -294,6 +294,29 @@ class RelatedManager:
self.value = value self.value = value
self.instance.__dict__[self.field.name] = value self.instance.__dict__[self.field.name] = value
@classmethod
def get_filter_q(cls, value, to_model):
if not value or not isinstance(value, dict):
return Q()
if value["type"] == "all":
return Q()
elif value["type"] == "ids" and isinstance(value.get("ids"), list):
return Q(id__in=value["ids"])
elif value["type"] == "attrs" and isinstance(value.get("attrs"), list):
return cls._get_filter_attrs_q(value, to_model)
else:
return Q()
@classmethod
def filter_queryset_by_model(cls, value, to_model):
if hasattr(to_model, "get_queryset"):
queryset = to_model.get_queryset()
else:
queryset = to_model.objects.all()
q = cls.get_filter_q(value, to_model)
return queryset.filter(q)
@staticmethod @staticmethod
def get_ip_in_q(name, val): def get_ip_in_q(name, val):
q = Q() q = Q()
@ -322,7 +345,8 @@ class RelatedManager:
continue continue
return q return q
def _get_filter_attrs_q(self, value, to_model): @classmethod
def _get_filter_attrs_q(cls, value, to_model):
filters = Q() filters = Q()
# 特殊情况有这几种, # 特殊情况有这几种,
# 1. 像 资产中的 type 和 category集成自 Platform。所以不能直接查询 # 1. 像 资产中的 type 和 category集成自 Platform。所以不能直接查询
@ -340,16 +364,14 @@ class RelatedManager:
if name is None or val is None: if name is None or val is None:
continue continue
print("Has custom filter: {}".format(custom_attr_filter))
if custom_attr_filter: if custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match) custom_filter_q = custom_attr_filter(name, val, match)
print("Custom filter: {}".format(custom_filter_q))
if custom_filter_q: if custom_filter_q:
filters &= custom_filter_q filters &= custom_filter_q
continue continue
if match == 'ip_in': if match == 'ip_in':
q = self.get_ip_in_q(name, val) q = cls.get_ip_in_q(name, val)
elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"): elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"):
lookup = "{}__{}".format(name, match) lookup = "{}__{}".format(name, match)
q = Q(**{lookup: val}) q = Q(**{lookup: val})
@ -377,26 +399,10 @@ class RelatedManager:
def _get_queryset(self): def _get_queryset(self):
to_model = apps.get_model(self.field.to) to_model = apps.get_model(self.field.to)
value = self.value value = self.value
if hasattr(to_model, "get_queryset"): return self.filter_queryset_by_model(value, to_model)
queryset = to_model.get_queryset()
else:
queryset = to_model.objects.all()
if not value or not isinstance(value, dict):
return queryset.none()
if value["type"] == "all":
return queryset
elif value["type"] == "ids" and isinstance(value.get("ids"), list):
return queryset.filter(id__in=value["ids"])
elif value["type"] == "attrs" and isinstance(value.get("attrs"), list):
q = self._get_filter_attrs_q(value, to_model)
return queryset.filter(q)
else:
return queryset.none()
def get_attr_q(self): def get_attr_q(self):
q = self._get_filter_attrs_q(self.value) q = self._get_filter_attrs_q(self.value, apps.get_model(self.field.to))
return q return q
def all(self): def all(self):

View File

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import base64
import json
import logging import logging
from django.core.cache import cache from django.core.cache import cache
@ -18,6 +20,8 @@ __all__ = [
"BaseFilterSet" "BaseFilterSet"
] ]
from common.db.fields import RelatedManager
class BaseFilterSet(drf_filters.FilterSet): class BaseFilterSet(drf_filters.FilterSet):
def do_nothing(self, queryset, name, value): def do_nothing(self, queryset, name, value):
@ -183,3 +187,24 @@ class UUIDInFilter(drf_filters.BaseInFilter, drf_filters.UUIDFilter):
class NumberInFilter(drf_filters.BaseInFilter, drf_filters.NumberFilter): class NumberInFilter(drf_filters.BaseInFilter, drf_filters.NumberFilter):
pass pass
class AttrRulesFilter(filters.BaseFilterBackend):
def get_schema_fields(self, view):
return [
coreapi.Field(
name='attr_rules', location='query', required=False,
type='string', example='/api/v1/users/users?attr_rules=jsonbase64',
description='Filter by json like {"type": "attrs", "attrs": []} to base64'
)
]
def filter_queryset(self, request, queryset, view):
attr_rules = request.query_params.get('attr_rules')
if not attr_rules:
return queryset
attr_rules = base64.b64decode(attr_rules.encode('utf-8'))
attr_rules = json.loads(attr_rules)
q = RelatedManager.get_filter_q(attr_rules, queryset.model)
return queryset.filter(q)

View File

@ -7,8 +7,8 @@ from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework_bulk import BulkModelViewSet from rest_framework_bulk import BulkModelViewSet
from common.api import CommonApiMixin from common.api import CommonApiMixin, SuggestionMixin
from common.api import SuggestionMixin from common.drf.filters import AttrRulesFilter
from common.utils import get_logger from common.utils import get_logger
from orgs.utils import current_org, tmp_to_root_org from orgs.utils import current_org, tmp_to_root_org
from rbac.models import Role, RoleBinding from rbac.models import Role, RoleBinding
@ -18,10 +18,7 @@ from .. import serializers
from ..filters import UserFilter from ..filters import UserFilter
from ..models import User from ..models import User
from ..notifications import ResetMFAMsg from ..notifications import ResetMFAMsg
from ..serializers import ( from ..serializers import UserSerializer, MiniUserSerializer, InviteSerializer
UserSerializer,
MiniUserSerializer, InviteSerializer
)
from ..signals import post_user_create from ..signals import post_user_create
logger = get_logger(__name__) logger = get_logger(__name__)
@ -33,6 +30,7 @@ __all__ = [
class UserViewSet(CommonApiMixin, UserQuerysetMixin, SuggestionMixin, BulkModelViewSet): class UserViewSet(CommonApiMixin, UserQuerysetMixin, SuggestionMixin, BulkModelViewSet):
filterset_class = UserFilter filterset_class = UserFilter
extra_filter_backends = [AttrRulesFilter]
search_fields = ('username', 'email', 'name') search_fields = ('username', 'email', 'name')
serializer_classes = { serializer_classes = {
'default': UserSerializer, 'default': UserSerializer,