django-vue-admin/backend/dvadmin/utils/viewset.py

138 lines
5.3 KiB
Python

# -*- coding: utf-8 -*-
"""
@author: 猿小天
@contact: QQ:1638245306
@Created on: 2021/6/1 001 22:57
@Remark: 自定义视图集
"""
import uuid
from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.viewsets import ModelViewSet
from dvadmin.utils.filters import DataLevelPermissionsFilter
from dvadmin.utils.import_export_mixin import (
ExportSerializerMixin,
ImportSerializerMixin,
)
from dvadmin.utils.json_response import SuccessResponse, ErrorResponse, DetailResponse
from dvadmin.utils.permission import CustomPermission
from django_restql.mixins import QueryArgumentsMixin
from treebeard.models import Node
class CustomModelViewSet(
ModelViewSet, ImportSerializerMixin, ExportSerializerMixin, QueryArgumentsMixin
):
"""
自定义的ModelViewSet:
统一标准的返回格式;新增,查询,修改可使用不同序列化器
(1)ORM性能优化, 尽可能使用values_queryset形式
(2)xxx_serializer_class 某个方法下使用的序列化器(xxx=create|update|list|retrieve|destroy)
(3)filter_fields = '__all__' 默认支持全部model中的字段查询(除json字段外)
(4)import_field_dict={} 导入时的字段字典 {model值: model的label}
(5)export_field_label = [] 导出时的字段
"""
values_queryset = None
ordering_fields = "__all__"
create_serializer_class = None
update_serializer_class = None
filter_fields = "__all__"
search_fields = ()
extra_filter_backends = [DataLevelPermissionsFilter]
permission_classes = [CustomPermission]
import_field_dict = {}
export_field_label = []
def filter_queryset(self, queryset):
for backend in set(
set(self.filter_backends) | set(self.extra_filter_backends or [])
):
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
def get_queryset(self):
if getattr(self, "values_queryset", None):
return self.values_queryset
return super().get_queryset()
def get_serializer_class(self):
action_serializer_name = f"{self.action}_serializer_class"
action_serializer_class = getattr(self, action_serializer_name, None)
if action_serializer_class:
return action_serializer_class
return super().get_serializer_class()
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data, request=request)
serializer.is_valid(raise_exception=True)
if Node in self.queryset.model.__mro__:
parent_id = request.data.get("parent")
data = serializer.validated_data
if parent_id is None:
self.queryset.model.add_root(**data)
else:
parent = self.queryset.model.objects.filter(pk=parent_id).first()
parent.add_child(**data)
return DetailResponse(data=data, msg="新增成功")
self.perform_create(serializer)
return DetailResponse(data=serializer.data, msg="新增成功")
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True, request=request)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True, request=request)
return SuccessResponse(data=serializer.data, msg="获取成功")
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return DetailResponse(data=serializer.data, msg="获取成功")
def update(self, request, *args, **kwargs):
partial = kwargs.pop("partial", False)
instance = self.get_object()
serializer = self.get_serializer(
instance, data=request.data, request=request, partial=partial
)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, "_prefetched_objects_cache", None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return DetailResponse(data=serializer.data, msg="更新成功")
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return DetailResponse(data=[], msg="删除成功")
keys = openapi.Schema(
description="主键列表", type=openapi.TYPE_ARRAY, items=openapi.TYPE_STRING
)
@swagger_auto_schema(
request_body=openapi.Schema(
type=openapi.TYPE_OBJECT, required=["keys"], properties={"keys": keys}
),
operation_summary="批量删除",
)
@action(methods=["delete"], detail=False)
def multiple_delete(self, request, *args, **kwargs):
request_data = request.data
keys = request_data.get("keys", None)
if keys:
self.get_queryset().filter(id__in=keys).delete()
return SuccessResponse(data=[], msg="删除成功")
else:
return ErrorResponse(msg="未获取到keys字段")