perf: 优化 labels 在 json field 中的筛选 (#12577)

* perf: 优化 labels 在 json field 中的筛选

* perf: 修改 labels 搜索

---------

Co-authored-by: ibuler <ibuler@qq.com>
pull/12583/head
fit2bot 2024-01-22 11:36:18 +08:00 committed by GitHub
parent 3853d0bcc6
commit 0c74e92bfb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 33 deletions

View File

@ -362,11 +362,15 @@ class RelatedManager:
if name is None or val is None: if name is None or val is None:
continue continue
if custom_attr_filter: custom_filter_q = None
spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None)
if spec_attr_filter:
custom_filter_q = spec_attr_filter(val, match)
elif custom_attr_filter:
custom_filter_q = custom_attr_filter(name, val, match) custom_filter_q = custom_attr_filter(name, val, match)
if custom_filter_q: if custom_filter_q:
filters.append(custom_filter_q) filters.append(custom_filter_q)
continue continue
if match == 'ip_in': if match == 'ip_in':
q = cls.get_ip_in_q(name, val) q = cls.get_ip_in_q(name, val)
@ -464,11 +468,15 @@ class JSONManyToManyDescriptor:
rule_value = rule.get('value', '') rule_value = rule.get('value', '')
rule_match = rule.get('match', 'exact') rule_match = rule.get('match', 'exact')
if custom_attr_filter: custom_filter_q = None
q = custom_attr_filter(rule['name'], rule_value, rule_match) spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None)
if q: if spec_attr_filter:
custom_q &= q custom_filter_q = spec_attr_filter(rule_value, rule_match)
continue elif custom_attr_filter:
custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match)
if custom_filter_q:
custom_q &= custom_filter_q
continue
if rule_match == 'in': if rule_match == 'in':
res &= value in rule_value or '*' in rule_value res &= value in rule_value or '*' in rule_value
@ -517,7 +525,6 @@ class JSONManyToManyDescriptor:
res &= rule_value.issubset(value) res &= rule_value.issubset(value)
else: else:
res &= bool(value & rule_value) res &= bool(value & rule_value)
else: else:
logging.error("unknown match: {}".format(rule['match'])) logging.error("unknown match: {}".format(rule['match']))
res &= False res &= False

View File

@ -6,7 +6,7 @@ import logging
from django.core.cache import cache from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Count from django.db.models import Q
from django_filters import rest_framework as drf_filters from django_filters import rest_framework as drf_filters
from rest_framework import filters from rest_framework import filters
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend):
] ]
@staticmethod @staticmethod
def filter_resources(resources, labels_id): def parse_label_ids(labels_id):
from labels.models import Label
label_ids = [i.strip() for i in labels_id.split(',')] label_ids = [i.strip() for i in labels_id.split(',')]
cleaned = []
args = [] args = []
for label_id in label_ids: for label_id in label_ids:
kwargs = {} kwargs = {}
if ':' in label_id: if ':' in label_id:
k, v = label_id.split(':', 1) k, v = label_id.split(':', 1)
kwargs['label__name'] = k.strip() kwargs['name'] = k.strip()
if v != '*': if v != '*':
kwargs['label__value'] = v.strip() kwargs['value'] = v.strip()
args.append(kwargs)
else: else:
kwargs['label_id'] = label_id cleaned.append(label_id)
args.append(kwargs)
if len(args) == 1: if len(args) != 0:
resources = resources.filter(**args[0]) q = Q()
return resources for kwarg in args:
q |= Q(**kwarg)
q = Q() ids = Label.objects.filter(q).values_list('id', flat=True)
for kwarg in args: cleaned.extend(list(ids))
q |= Q(**kwarg) return cleaned
resources = resources.filter(q) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(args))
return resources
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
labels_id = request.query_params.get('labels') labels_id = request.query_params.get('labels')
@ -230,7 +224,8 @@ class LabelFilterBackend(filters.BaseFilterBackend):
resources = labeled_resource_cls.objects.filter( resources = labeled_resource_cls.objects.filter(
res_type__app_label=app_label, res_type__model=model_name, res_type__app_label=app_label, res_type__model=model_name,
) )
resources = self.filter_resources(resources, labels_id) label_ids = self.parse_label_ids(labels_id)
resources = model.filter_resources_by_labels(resources, label_ids)
res_ids = resources.values_list('res_id', flat=True) res_ids = resources.values_list('res_id', flat=True)
queryset = queryset.filter(id__in=set(res_ids)) queryset = queryset.filter(id__in=set(res_ids))
return queryset return queryset

View File

@ -1,6 +1,6 @@
from django.contrib.contenttypes.fields import GenericRelation from django.contrib.contenttypes.fields import GenericRelation
from django.db import models from django.db import models
from django.db.models import OneToOneField from django.db.models import OneToOneField, Count
from common.utils import lazyproperty from common.utils import lazyproperty
from .models import LabeledResource from .models import LabeledResource
@ -36,3 +36,37 @@ class LabeledMixin(models.Model):
@res_labels.setter @res_labels.setter
def res_labels(self, value): def res_labels(self, value):
self.real.labels.set(value, bulk=False) self.real.labels.set(value, bulk=False)
@classmethod
def filter_resources_by_labels(cls, resources, label_ids):
return cls._get_filter_res_by_labels_m2m_all(resources, label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids):
return resources.filter(label_id__in=label_ids)
@classmethod
def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids):
if len(label_ids) == 1:
return cls._get_filter_res_by_labels_m2m_in(resources, label_ids)
resources = resources.filter(label_id__in=label_ids) \
.values('res_id') \
.order_by('res_id') \
.annotate(count=Count('res_id', distinct=True)) \
.values('res_id', 'count') \
.filter(count=len(label_ids))
return resources
@classmethod
def get_labels_filter_attr_q(cls, value, match):
resources = LabeledResource.objects.all()
if not value:
return None
if match != 'm2m_all':
resources = cls._get_filter_res_by_labels_m2m_in(resources, value)
else:
resources = cls._get_filter_res_by_labels_m2m_all(resources, value)
res_ids = set(resources.values_list('res_id', flat=True))
return models.Q(id__in=res_ids)