diff --git a/apps/common/db/fields.py b/apps/common/db/fields.py index 1691fad1d..0a620e1f0 100644 --- a/apps/common/db/fields.py +++ b/apps/common/db/fields.py @@ -362,11 +362,15 @@ class RelatedManager: if name is None or val is None: continue - if custom_attr_filter: + custom_filter_q = None + spec_attr_filter = getattr(to_model, "get_{}_filter_attr_q".format(name), None) + if spec_attr_filter: + custom_filter_q = spec_attr_filter(val, match) + elif custom_attr_filter: custom_filter_q = custom_attr_filter(name, val, match) - if custom_filter_q: - filters.append(custom_filter_q) - continue + if custom_filter_q: + filters.append(custom_filter_q) + continue if match == 'ip_in': q = cls.get_ip_in_q(name, val) @@ -464,11 +468,15 @@ class JSONManyToManyDescriptor: rule_value = rule.get('value', '') rule_match = rule.get('match', 'exact') - if custom_attr_filter: - q = custom_attr_filter(rule['name'], rule_value, rule_match) - if q: - custom_q &= q - continue + custom_filter_q = None + spec_attr_filter = getattr(to_model, "get_filter_{}_attr_q".format(rule['name']), None) + if spec_attr_filter: + custom_filter_q = spec_attr_filter(rule_value, rule_match) + elif custom_attr_filter: + custom_filter_q = custom_attr_filter(rule['name'], rule_value, rule_match) + if custom_filter_q: + custom_q &= custom_filter_q + continue if rule_match == 'in': res &= value in rule_value or '*' in rule_value @@ -517,7 +525,6 @@ class JSONManyToManyDescriptor: res &= rule_value.issubset(value) else: res &= bool(value & rule_value) - else: logging.error("unknown match: {}".format(rule['match'])) res &= False diff --git a/apps/common/drf/filters.py b/apps/common/drf/filters.py index 85475f479..803849909 100644 --- a/apps/common/drf/filters.py +++ b/apps/common/drf/filters.py @@ -6,7 +6,7 @@ import logging from django.core.cache import cache from django.core.exceptions import ImproperlyConfigured -from django.db.models import Q, Count +from django.db.models import Q from django_filters import rest_framework as drf_filters from rest_framework import filters from rest_framework.compat import coreapi, coreschema @@ -180,36 +180,30 @@ class LabelFilterBackend(filters.BaseFilterBackend): ] @staticmethod - def filter_resources(resources, labels_id): + def parse_label_ids(labels_id): + from labels.models import Label label_ids = [i.strip() for i in labels_id.split(',')] + cleaned = [] args = [] for label_id in label_ids: kwargs = {} if ':' in label_id: k, v = label_id.split(':', 1) - kwargs['label__name'] = k.strip() + kwargs['name'] = k.strip() if v != '*': - kwargs['label__value'] = v.strip() + kwargs['value'] = v.strip() + args.append(kwargs) else: - kwargs['label_id'] = label_id - args.append(kwargs) + cleaned.append(label_id) - if len(args) == 1: - resources = resources.filter(**args[0]) - return resources - - q = Q() - for kwarg in args: - q |= Q(**kwarg) - - resources = resources.filter(q) \ - .values('res_id') \ - .order_by('res_id') \ - .annotate(count=Count('res_id', distinct=True)) \ - .values('res_id', 'count') \ - .filter(count=len(args)) - return resources + if len(args) != 0: + q = Q() + for kwarg in args: + q |= Q(**kwarg) + ids = Label.objects.filter(q).values_list('id', flat=True) + cleaned.extend(list(ids)) + return cleaned def filter_queryset(self, request, queryset, view): labels_id = request.query_params.get('labels') @@ -230,7 +224,8 @@ class LabelFilterBackend(filters.BaseFilterBackend): resources = labeled_resource_cls.objects.filter( res_type__app_label=app_label, res_type__model=model_name, ) - resources = self.filter_resources(resources, labels_id) + label_ids = self.parse_label_ids(labels_id) + resources = model.filter_resources_by_labels(resources, label_ids) res_ids = resources.values_list('res_id', flat=True) queryset = queryset.filter(id__in=set(res_ids)) return queryset diff --git a/apps/labels/mixins.py b/apps/labels/mixins.py index 4b775cde5..33e73b60b 100644 --- a/apps/labels/mixins.py +++ b/apps/labels/mixins.py @@ -1,6 +1,6 @@ from django.contrib.contenttypes.fields import GenericRelation from django.db import models -from django.db.models import OneToOneField +from django.db.models import OneToOneField, Count from common.utils import lazyproperty from .models import LabeledResource @@ -36,3 +36,37 @@ class LabeledMixin(models.Model): @res_labels.setter def res_labels(self, value): self.real.labels.set(value, bulk=False) + + @classmethod + def filter_resources_by_labels(cls, resources, label_ids): + return cls._get_filter_res_by_labels_m2m_all(resources, label_ids) + + @classmethod + def _get_filter_res_by_labels_m2m_in(cls, resources, label_ids): + return resources.filter(label_id__in=label_ids) + + @classmethod + def _get_filter_res_by_labels_m2m_all(cls, resources, label_ids): + if len(label_ids) == 1: + return cls._get_filter_res_by_labels_m2m_in(resources, label_ids) + + resources = resources.filter(label_id__in=label_ids) \ + .values('res_id') \ + .order_by('res_id') \ + .annotate(count=Count('res_id', distinct=True)) \ + .values('res_id', 'count') \ + .filter(count=len(label_ids)) + return resources + + @classmethod + def get_labels_filter_attr_q(cls, value, match): + resources = LabeledResource.objects.all() + if not value: + return None + + if match != 'm2m_all': + resources = cls._get_filter_res_by_labels_m2m_in(resources, value) + else: + resources = cls._get_filter_res_by_labels_m2m_all(resources, value) + res_ids = set(resources.values_list('res_id', flat=True)) + return models.Q(id__in=res_ids)