From ebaa8d26377fca87f21dc22f42fd6b479233f0d8 Mon Sep 17 00:00:00 2001
From: ibuler <ibuler@qq.com>
Date: Thu, 18 May 2023 17:31:40 +0800
Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20json=20error?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 apps/common/db/fields.py   | 50 +++++++++++++++++++++-----------------
 apps/common/drf/filters.py | 25 +++++++++++++++++++
 apps/users/api/user.py     | 10 +++-----
 3 files changed, 57 insertions(+), 28 deletions(-)

diff --git a/apps/common/db/fields.py b/apps/common/db/fields.py
index d2087f5d4..6039e943a 100644
--- a/apps/common/db/fields.py
+++ b/apps/common/db/fields.py
@@ -294,6 +294,29 @@ class RelatedManager:
         self.value = value
         self.instance.__dict__[self.field.name] = value
 
+    @classmethod
+    def get_filter_q(cls, value, to_model):
+        if not value or not isinstance(value, dict):
+            return Q()
+
+        if value["type"] == "all":
+            return Q()
+        elif value["type"] == "ids" and isinstance(value.get("ids"), list):
+            return Q(id__in=value["ids"])
+        elif value["type"] == "attrs" and isinstance(value.get("attrs"), list):
+            return cls._get_filter_attrs_q(value, to_model)
+        else:
+            return Q()
+
+    @classmethod
+    def filter_queryset_by_model(cls, value, to_model):
+        if hasattr(to_model, "get_queryset"):
+            queryset = to_model.get_queryset()
+        else:
+            queryset = to_model.objects.all()
+        q = cls.get_filter_q(value, to_model)
+        return queryset.filter(q)
+
     @staticmethod
     def get_ip_in_q(name, val):
         q = Q()
@@ -322,7 +345,8 @@ class RelatedManager:
                 continue
         return q
 
-    def _get_filter_attrs_q(self, value, to_model):
+    @classmethod
+    def _get_filter_attrs_q(cls, value, to_model):
         filters = Q()
         # 特殊情况有这几种,
         # 1. 像 资产中的 type 和 category,集成自 Platform。所以不能直接查询
@@ -340,16 +364,14 @@ class RelatedManager:
             if name is None or val is None:
                 continue
 
-            print("Has custom filter: {}".format(custom_attr_filter))
             if custom_attr_filter:
                 custom_filter_q = custom_attr_filter(name, val, match)
-                print("Custom filter: {}".format(custom_filter_q))
                 if custom_filter_q:
                     filters &= custom_filter_q
                     continue
 
             if match == 'ip_in':
-                q = self.get_ip_in_q(name, val)
+                q = cls.get_ip_in_q(name, val)
             elif match in ("exact", "contains", "startswith", "endswith", "regex", "gte", "lte", "gt", "lt"):
                 lookup = "{}__{}".format(name, match)
                 q = Q(**{lookup: val})
@@ -377,26 +399,10 @@ class RelatedManager:
     def _get_queryset(self):
         to_model = apps.get_model(self.field.to)
         value = self.value
-        if hasattr(to_model, "get_queryset"):
-            queryset = to_model.get_queryset()
-        else:
-            queryset = to_model.objects.all()
-
-        if not value or not isinstance(value, dict):
-            return queryset.none()
-
-        if value["type"] == "all":
-            return queryset
-        elif value["type"] == "ids" and isinstance(value.get("ids"), list):
-            return queryset.filter(id__in=value["ids"])
-        elif value["type"] == "attrs" and isinstance(value.get("attrs"), list):
-            q = self._get_filter_attrs_q(value, to_model)
-            return queryset.filter(q)
-        else:
-            return queryset.none()
+        return self.filter_queryset_by_model(value, to_model)
 
     def get_attr_q(self):
-        q = self._get_filter_attrs_q(self.value)
+        q = self._get_filter_attrs_q(self.value, apps.get_model(self.field.to))
         return q
 
     def all(self):
diff --git a/apps/common/drf/filters.py b/apps/common/drf/filters.py
index 949260a47..6278efccf 100644
--- a/apps/common/drf/filters.py
+++ b/apps/common/drf/filters.py
@@ -1,5 +1,7 @@
 # -*- coding: utf-8 -*-
 #
+import base64
+import json
 import logging
 
 from django.core.cache import cache
@@ -18,6 +20,8 @@ __all__ = [
     "BaseFilterSet"
 ]
 
+from common.db.fields import RelatedManager
+
 
 class BaseFilterSet(drf_filters.FilterSet):
     def do_nothing(self, queryset, name, value):
@@ -183,3 +187,24 @@ class UUIDInFilter(drf_filters.BaseInFilter, drf_filters.UUIDFilter):
 
 class NumberInFilter(drf_filters.BaseInFilter, drf_filters.NumberFilter):
     pass
+
+
+class AttrRulesFilter(filters.BaseFilterBackend):
+    def get_schema_fields(self, view):
+        return [
+            coreapi.Field(
+                name='attr_rules', location='query', required=False,
+                type='string', example='/api/v1/users/users?attr_rules=jsonbase64',
+                description='Filter by json like {"type": "attrs", "attrs": []} to base64'
+            )
+        ]
+
+    def filter_queryset(self, request, queryset, view):
+        attr_rules = request.query_params.get('attr_rules')
+        if not attr_rules:
+            return queryset
+
+        attr_rules = base64.b64decode(attr_rules.encode('utf-8'))
+        attr_rules = json.loads(attr_rules)
+        q = RelatedManager.get_filter_q(attr_rules, queryset.model)
+        return queryset.filter(q)
diff --git a/apps/users/api/user.py b/apps/users/api/user.py
index d1eee8083..26a8e9d05 100644
--- a/apps/users/api/user.py
+++ b/apps/users/api/user.py
@@ -7,8 +7,8 @@ from rest_framework.decorators import action
 from rest_framework.response import Response
 from rest_framework_bulk import BulkModelViewSet
 
-from common.api import CommonApiMixin
-from common.api import SuggestionMixin
+from common.api import CommonApiMixin, SuggestionMixin
+from common.drf.filters import AttrRulesFilter
 from common.utils import get_logger
 from orgs.utils import current_org, tmp_to_root_org
 from rbac.models import Role, RoleBinding
@@ -18,10 +18,7 @@ from .. import serializers
 from ..filters import UserFilter
 from ..models import User
 from ..notifications import ResetMFAMsg
-from ..serializers import (
-    UserSerializer,
-    MiniUserSerializer, InviteSerializer
-)
+from ..serializers import UserSerializer, MiniUserSerializer, InviteSerializer
 from ..signals import post_user_create
 
 logger = get_logger(__name__)
@@ -33,6 +30,7 @@ __all__ = [
 
 class UserViewSet(CommonApiMixin, UserQuerysetMixin, SuggestionMixin, BulkModelViewSet):
     filterset_class = UserFilter
+    extra_filter_backends = [AttrRulesFilter]
     search_fields = ('username', 'email', 'name')
     serializer_classes = {
         'default': UserSerializer,