250 lines
9.8 KiB
Python
250 lines
9.8 KiB
Python
from types import FunctionType, MethodType
|
|
|
|
# from rest_framework_mongoengine.generics import GenericAPIView as MongoGenericAPIView
|
|
from django.core.exceptions import ValidationError
|
|
from django.http.response import Http404
|
|
from django.shortcuts import get_object_or_404 as _get_object_or_404
|
|
from django_filters.rest_framework import DjangoFilterBackend
|
|
from mongoengine.queryset.base import BaseQuerySet
|
|
from rest_framework.filters import OrderingFilter, SearchFilter
|
|
from rest_framework.request import Request
|
|
from rest_framework.settings import api_settings
|
|
from rest_framework.viewsets import ViewSetMixin
|
|
|
|
from utils.exceptions import APIException
|
|
from . import mixins
|
|
from .filters import MongoSearchFilter, MongoOrderingFilter, AdvancedSearchFilter, MongoAdvancedSearchFilter
|
|
from .generics import GenericAPIView
|
|
from .logging.view_logger import CustomerModelViewLogger
|
|
from .pagination import Pagination
|
|
from .serializers import CustomModelSerializer
|
|
|
|
|
|
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
|
|
try:
|
|
return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
|
|
except (TypeError, ValueError, ValidationError, Http404):
|
|
raise APIException(message='该对象不存在或者无访问权限')
|
|
|
|
|
|
class GenericViewSet(ViewSetMixin, GenericAPIView):
|
|
extra_filter_backends = []
|
|
pagination_class = Pagination
|
|
filter_backends = [DjangoFilterBackend, OrderingFilter, SearchFilter, AdvancedSearchFilter]
|
|
view_logger_classes = (CustomerModelViewLogger,)
|
|
|
|
def handle_logging(self, request: Request, *args, **kwargs):
|
|
view_loggers = self.get_view_loggers(request, *args, **kwargs)
|
|
for view_logger in view_loggers:
|
|
handle_action = getattr(view_logger, f'handle_{self.action}', None)
|
|
if handle_action and isinstance(handle_action, (FunctionType, MethodType)):
|
|
handle_action(request, *args, **kwargs)
|
|
|
|
def get_serializer(self, *args, **kwargs):
|
|
serializer_class = self.get_serializer_class()
|
|
kwargs['context'] = self.get_serializer_context()
|
|
serializer = serializer_class(*args, **kwargs)
|
|
if isinstance(serializer, CustomModelSerializer):
|
|
serializer.request = self.request
|
|
return serializer
|
|
|
|
def filter_queryset(self, queryset):
|
|
for backend in set(set(self.filter_backends) | set(self.extra_filter_backends or [])):
|
|
queryset = backend().filter_queryset(self.request, queryset, self)
|
|
queryset = self.action_extra_filter_queryset(queryset)
|
|
return queryset
|
|
|
|
def action_extra_filter_queryset(self, queryset):
|
|
action__extra_filter_backends = getattr(self, f"{self.action}_extra_filter_backends", None)
|
|
if not action__extra_filter_backends:
|
|
return queryset
|
|
for backend in action__extra_filter_backends:
|
|
queryset = backend().filter_queryset(self.request, queryset, self)
|
|
return queryset
|
|
|
|
def get_serializer_class(self):
|
|
action_serializer_name = f"{self.action}_serializer_class"
|
|
action_serializer_class = getattr(self, action_serializer_name, None)
|
|
if action_serializer_class:
|
|
return action_serializer_class
|
|
return super().get_serializer_class()
|
|
|
|
def reverse_action(self, url_name, *args, **kwargs):
|
|
return super().reverse_action(url_name, *args, **kwargs)
|
|
|
|
def get_action_extra_permissions(self):
|
|
"""
|
|
获取已配置的action权限校验,并且实例化其对象
|
|
:return:
|
|
"""
|
|
action_extra_permission_classes = getattr(self, f"{self.action}_extra_permission_classes", None)
|
|
if not action_extra_permission_classes:
|
|
return []
|
|
return [permission() for permission in action_extra_permission_classes]
|
|
|
|
def check_action_extra_permissions(self, request):
|
|
"""
|
|
逐个校验action权限校验
|
|
:param request:
|
|
:return:
|
|
"""
|
|
for permission in self.get_action_extra_permissions():
|
|
if not permission.has_permission(request, self):
|
|
self.permission_denied(
|
|
request, message=getattr(permission, 'message', None)
|
|
)
|
|
|
|
def check_action_extra_object_permissions(self, request, obj):
|
|
"""
|
|
action方法的专属对象权限校验
|
|
:param request:
|
|
:param obj:
|
|
:return:
|
|
"""
|
|
for permission in self.get_action_extra_permissions():
|
|
if not permission.has_object_permission(request, self, obj):
|
|
self.permission_denied(
|
|
request, message=getattr(permission, 'message', None)
|
|
)
|
|
|
|
def initial(self, request, *args, **kwargs):
|
|
"""
|
|
重写initial方法
|
|
(1)新增action的权限校验
|
|
:param request:
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
super().initial(request, *args, **kwargs)
|
|
self.check_action_extra_permissions(request)
|
|
|
|
def get_object(self):
|
|
queryset = self.filter_queryset(self.get_queryset())
|
|
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
|
|
assert lookup_url_kwarg in self.kwargs, (
|
|
'Expected view %s to be called with a URL keyword argument '
|
|
'named "%s". Fix your URL conf, or set the `.lookup_field` '
|
|
'attribute on the view correctly.' %
|
|
(self.__class__.__name__, lookup_url_kwarg)
|
|
)
|
|
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
|
|
obj = get_object_or_404(queryset, **filter_kwargs)
|
|
self.check_object_permissions(self.request, obj)
|
|
return obj
|
|
|
|
def check_object_permissions(self, request, obj):
|
|
"""
|
|
重新check_object_permissions
|
|
(1)新增action方法的专属对象权限检查入口
|
|
(2)先校验共同的object_permissions, 再校验action的object_permissions
|
|
:param request:
|
|
:param obj:
|
|
:return:
|
|
"""
|
|
super().check_object_permissions(request, obj)
|
|
self.check_action_extra_object_permissions(request, obj)
|
|
|
|
|
|
class MongoGenericAPIView(GenericAPIView):
|
|
""" Adaptation of DRF GenericAPIView """
|
|
lookup_field = 'id'
|
|
|
|
def get_queryset(self):
|
|
queryset = super(MongoGenericAPIView, self).get_queryset()
|
|
if isinstance(queryset, BaseQuerySet):
|
|
queryset = queryset.all()
|
|
return queryset
|
|
|
|
def get_object(self):
|
|
queryset = self.filter_queryset(self.get_queryset())
|
|
# Perform the lookup filtering.
|
|
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
|
|
assert lookup_url_kwarg in self.kwargs, (
|
|
'Expected view %s to be called with a URL keyword argument '
|
|
'named "%s". Fix your URL conf, or set the `.lookup_field` '
|
|
'attribute on the view correctly.' %
|
|
(self.__class__.__name__, lookup_url_kwarg)
|
|
)
|
|
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
|
|
obj = get_object_or_404(queryset, **filter_kwargs)
|
|
self.check_object_permissions(self.request, obj)
|
|
return obj
|
|
|
|
|
|
class MongoGenericViewSet(ViewSetMixin, MongoGenericAPIView):
|
|
pagination_class = Pagination
|
|
pass
|
|
|
|
|
|
class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
|
|
mixins.ListModelMixin,
|
|
GenericViewSet):
|
|
pass
|
|
|
|
|
|
class ModelViewSet(mixins.CreateModelMixin,
|
|
mixins.RetrieveModelMixin,
|
|
mixins.UpdateModelMixin,
|
|
mixins.DestroyModelMixin,
|
|
mixins.ListModelMixin,
|
|
GenericViewSet):
|
|
pass
|
|
|
|
|
|
class MongoModelViewSet(mixins.CreateModelMixin,
|
|
mixins.RetrieveModelMixin,
|
|
mixins.UpdateModelMixin,
|
|
mixins.DestroyModelMixin,
|
|
mixins.ListModelMixin,
|
|
MongoGenericViewSet):
|
|
pass
|
|
|
|
|
|
class CustomModelViewSet(ModelViewSet, mixins.TableSerializerMixin):
|
|
"""
|
|
自定义的ModelViewSet:
|
|
(1)默认分页器就为统一分页器op_drf.pagination.Pagination
|
|
(1)默认使用统一标准返回格式
|
|
(1)默认支持高级搜索
|
|
(1)默认支持生成前端动态table的option
|
|
(1)ORM性能优化, 尽可能使用values_queryset形式
|
|
"""
|
|
values_queryset = None
|
|
ordering_fields = '__all__'
|
|
|
|
def get_queryset(self):
|
|
if getattr(self, 'values_queryset', None):
|
|
return self.values_queryset
|
|
return super().get_queryset()
|
|
|
|
|
|
class CustomMongoModelViewSet(MongoModelViewSet, mixins.TableSerializerMixin):
|
|
filter_backends = (MongoOrderingFilter, MongoSearchFilter, MongoAdvancedSearchFilter)
|
|
# filter_fields = '__all__' # 暂不支持__all__
|
|
filter_fields = ()
|
|
search_fields = ()
|
|
ordering_fields = '__all__'
|
|
view_logger_classes = (CustomerModelViewLogger,)
|
|
|
|
def get_queryset(self):
|
|
queryset = self.queryset
|
|
filtering_kwargs = {}
|
|
for param in self.request.query_params:
|
|
param = param.strip()
|
|
if param in ['pageSize', 'pageNum', 'search', 'ordering', 'as']: continue
|
|
if self.filter_fields == '__all__' or param in self.filter_fields:
|
|
# if param in self.filter_fields:
|
|
filtering_kwargs[param] = self.request.query_params[param]
|
|
queryset = queryset.filter(**filtering_kwargs)
|
|
ordering_params = self.request.query_params.get(api_settings.ORDERING_PARAM, None)
|
|
if ordering_params:
|
|
ordering_fields = [field.strip() for field in ordering_params.split(',')]
|
|
ordering_fields = filter(lambda field: self.ordering_fields == '__all__' or field in self.ordering_fields,
|
|
ordering_fields)
|
|
queryset = queryset.order_by(*ordering_fields)
|
|
return queryset
|
|
|
|
def filter_queryset(self, queryset):
|
|
return super().filter_queryset(queryset)
|