diff --git a/backend/dvadmin/utils/filters.py b/backend/dvadmin/utils/filters.py index 52de84a..13beff2 100644 --- a/backend/dvadmin/utils/filters.py +++ b/backend/dvadmin/utils/filters.py @@ -19,7 +19,7 @@ from django.db.models.constants import LOOKUP_SEP from django_filters import utils from django_filters.conf import settings from django_filters.constants import ALL_FIELDS -from django_filters.filters import CharFilter +from django_filters.filters import CharFilter, BooleanFilter from django_filters.filterset import FilterSet, FilterSetMetaclass from django_filters.rest_framework import DjangoFilterBackend from django_filters.utils import get_model_field @@ -395,8 +395,6 @@ def calculate_execution_time(func): def next_layer_data(qs_filter, qs_node): parent_nodes = set(qs_node.values_list("id", flat=True)) - # print(f"过滤查询集 ==> {qs_filter}", flush=True) - # print(f"待渲染节点的id ==> {parent_nodes=}", flush=True) if set(qs_filter) == set(qs_node): return parent_nodes # qs_filter内所有父级id 去重 @@ -410,10 +408,32 @@ def next_layer_data(qs_filter, qs_node): parent_ids.add(node.parent.id) break node = node.parent + # print(f"过滤查询集 ==> {qs_filter}", flush=True) + # print(f"待渲染节点的id ==> {parent_nodes=}", flush=True) # print(f"过滤查询集的父节点id ==> {parent_ids=}", flush=True) return parent_ids +def construct_data(qs_filter, qs_node, is_parent): + filter_node_ids = set(qs_filter.values_list("id", flat=True)) + render_node_ids = set(qs_node.values_list("id", flat=True)) + + hidden_node_ids = set() + for node in qs_filter: + while node.parent: + if node.parent in qs_filter: + hidden_node_ids.add(node.id) + node = node.parent + on_show = filter_node_ids.difference(hidden_node_ids) + on_expand = hidden_node_ids & render_node_ids + # print(f"完整查询结果 {filter_node_ids}") + # print(f"待展示的节点(未过滤) {render_node_ids}") + # print(f"查询结果中的子节点 {hidden_node_ids}") + # print(f"查询后首先渲染的父节点 {on_show}") + # print(f"展开父节点时要渲染的节点 {on_expand}") + return on_expand if is_parent else on_show + + class FilterSetOptions: def __init__(self, options=None): self.model = getattr(options, "model", None) @@ -430,7 +450,13 @@ class FilterSetOptions: "extra": lambda f: { "lookup_expr": "icontains", }, - } + }, + models.BooleanField: { + "filter_class": BooleanFilter, + "extra": lambda f: { + "widget": forms.RadioSelect, + }, + }, }, ) @@ -449,17 +475,24 @@ class LazyLoadFilterSetMetaclass(FilterSetMetaclass): class LazyLoadFilter(FilterSet, metaclass=LazyLoadFilterSetMetaclass): - # @calculate_execution_time @property + # @calculate_execution_time def qs(self): queryset = self.queryset - filter_params = [k for k, v in self.form.cleaned_data.items() if not v] + # print(self.form.cleaned_data, flush=True) + filter_params = [k for k, v in self.form.cleaned_data.items() if v in [None, ""]] for field in filter_params: self.form.cleaned_data.pop(field) - self.form.cleaned_data.pop("parent", None) + is_parent = self.form.cleaned_data.pop("parent", None) is not None # print(queryset, flush=True) if self.form.cleaned_data: self.queryset = queryset.model.objects.all() - node_ids = next_layer_data(super().qs, queryset) + + # 从根节点开始 + # node_ids = next_layer_data(super().qs, queryset) + + # 按匹配结果显示 + node_ids = construct_data(super().qs, queryset, is_parent) + return queryset.model.objects.filter(id__in=node_ids) return super().qs