# -*- coding: utf-8 -*- # import time from hashlib import md5 from threading import Thread from collections import defaultdict from itertools import chain from django.conf import settings from django.db.models.signals import m2m_changed from django.core.cache import cache from django.http import JsonResponse from django.utils.translation import ugettext as _ from django.contrib.auth import get_user_model from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.decorators import action from rest_framework.request import Request from common.const.http import POST from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter from ..utils import lazyproperty __all__ = [ 'JSONResponseMixin', 'CommonApiMixin', 'AsyncApiMixin', 'RelationMixin', 'QuerySetMixin', 'ExtraFilterFieldsMixin', 'RenderToJsonMixin', 'SerializerMixin', 'AllowBulkDestroyMixin', 'PaginatedResponseMixin' ] UserModel = get_user_model() class JSONResponseMixin(object): """JSON mixin""" @staticmethod def render_json_response(context): return JsonResponse(context) # SerializerMixin # ---------------------- class RenderToJsonMixin: @action(methods=[POST], detail=False, url_path='render-to-json') def render_to_json(self, request: Request): data = { 'title': (), 'data': request.data, } jms_context = getattr(request, 'jms_context', {}) column_title_field_pairs = jms_context.get('column_title_field_pairs', ()) data['title'] = column_title_field_pairs if isinstance(request.data, (list, tuple)) and not any(request.data): error = _("Request file format may be wrong") return Response(data={"error": error}, status=400) return Response(data=data) class SerializerMixin: """ 根据用户请求动作的不同,获取不同的 `serializer_class `""" action: str request: Request serializer_classes = None single_actions = ['put', 'retrieve', 'patch'] def get_serializer_class_by_view_action(self): if not hasattr(self, 'serializer_classes'): return None if not isinstance(self.serializer_classes, dict): return None view_action = self.request.query_params.get('action') or self.action or 'list' serializer_class = self.serializer_classes.get(view_action) if serializer_class is None: view_method = self.request.method.lower() serializer_class = self.serializer_classes.get(view_method) if serializer_class is None and view_action in self.single_actions: serializer_class = self.serializer_classes.get('single') if serializer_class is None: serializer_class = self.serializer_classes.get('display') if serializer_class is None: serializer_class = self.serializer_classes.get('default') return serializer_class def get_serializer_class(self): serializer_class = self.get_serializer_class_by_view_action() if serializer_class is None: serializer_class = super().get_serializer_class() return serializer_class class ExtraFilterFieldsMixin: """ 额外的 api filter """ default_added_filters = [CustomFilter, IDSpmFilter, IDInFilter] filter_backends = api_settings.DEFAULT_FILTER_BACKENDS extra_filter_fields = [] extra_filter_backends = [] def get_filter_backends(self): if self.filter_backends != self.__class__.filter_backends: return self.filter_backends backends = list(chain( self.filter_backends, self.default_added_filters, self.extra_filter_backends)) return backends def filter_queryset(self, queryset): for backend in self.get_filter_backends(): queryset = backend().filter_queryset(self.request, queryset, self) return queryset class PaginatedResponseMixin: def get_paginated_response_with_query_set(self, queryset): page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) return self.get_paginated_response(serializer.data) serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin, RenderToJsonMixin): pass class InterceptMixin: """ Hack默认的dispatch, 让用户可以实现 self.do """ def dispatch(self, request, *args, **kwargs): self.args = args self.kwargs = kwargs request = self.initialize_request(request, *args, **kwargs) self.request = request self.headers = self.default_response_headers # deprecate? try: self.initial(request, *args, **kwargs) # Get the appropriate handler method if request.method.lower() in self.http_method_names: handler = getattr(self, request.method.lower(), self.http_method_not_allowed) else: handler = self.http_method_not_allowed response = self.do(handler, request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) self.response = self.finalize_response(request, response, *args, **kwargs) return self.response class AsyncApiMixin(InterceptMixin): def get_request_user_id(self): user = self.request.user if hasattr(user, 'id'): return str(user.id) return '' @lazyproperty def async_cache_key(self): method = self.request.method path = self.get_request_md5() user = self.get_request_user_id() key = '{}_{}_{}'.format(method, path, user) return key def get_request_md5(self): path = self.request.path query = {k: v for k, v in self.request.GET.items()} query.pop("_", None) query.pop('refresh', None) query = "&".join(["{}={}".format(k, v) for k, v in query.items()]) full_path = "{}?{}".format(path, query) return md5(full_path.encode()).hexdigest() @lazyproperty def initial_data(self): data = { "status": "running", "start_time": time.time(), "key": self.async_cache_key, } return data def get_cache_data(self): key = self.async_cache_key if self.is_need_refresh(): cache.delete(key) return None data = cache.get(key) return data def do(self, handler, *args, **kwargs): if not self.is_need_async(): return handler(*args, **kwargs) resp = self.do_async(handler, *args, **kwargs) return resp def is_need_refresh(self): if self.request.GET.get("refresh"): return True return False def is_need_async(self): return False def do_async(self, handler, *args, **kwargs): data = self.get_cache_data() if not data: t = Thread( target=self.do_in_thread, args=(handler, *args), kwargs=kwargs ) t.start() resp = Response(self.initial_data) return resp status = data.get("status") resp = data.get("resp") if status == "ok" and resp: resp = Response(**resp) else: resp = Response(data) return resp def do_in_thread(self, handler, *args, **kwargs): key = self.async_cache_key data = self.initial_data cache.set(key, data, 600) try: response = handler(*args, **kwargs) data["status"] = "ok" data["resp"] = { "data": response.data, "status": response.status_code } cache.set(key, data, 600) except Exception as e: data["error"] = str(e) data["status"] = "error" cache.set(key, data, 600) class RelationMixin: m2m_field = None from_field = None to_field = None to_model = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.m2m_field is not None, ''' `m2m_field` should not be `None` ''' self.from_field = self.m2m_field.m2m_field_name() self.to_field = self.m2m_field.m2m_reverse_field_name() self.to_model = self.m2m_field.related_model self.through = getattr(self.m2m_field.model, self.m2m_field.attname).through def get_queryset(self): # 注意,此处拦截了 `get_queryset` 没有 `super` queryset = self.through.objects.all() return queryset def send_m2m_changed_signal(self, instances, action): if not isinstance(instances, list): instances = [instances] from_to_mapper = defaultdict(list) for i in instances: to_id = getattr(i, self.to_field).id # TODO 优化,不应该每次都查询数据库 from_obj = getattr(i, self.from_field) from_to_mapper[from_obj].append(to_id) for from_obj, to_ids in from_to_mapper.items(): m2m_changed.send( sender=self.through, instance=from_obj, action=action, reverse=False, model=self.to_model, pk_set=to_ids ) def perform_create(self, serializer): instance = serializer.save() self.send_m2m_changed_signal(instance, 'post_add') def perform_destroy(self, instance): instance.delete() self.send_m2m_changed_signal(instance, 'post_remove') class QuerySetMixin: def get_queryset(self): queryset = super().get_queryset() serializer_class = self.get_serializer_class() if serializer_class and hasattr(serializer_class, 'setup_eager_loading'): queryset = serializer_class.setup_eager_loading(queryset) return queryset class AllowBulkDestroyMixin: def allow_bulk_destroy(self, qs, filtered): """ 我们规定,批量删除的情况必须用 `id` 指定要删除的数据。 """ query = str(filtered.query) return '`id` IN (' in query or '`id` =' in query class RoleAdminMixin: kwargs: dict user_id_url_kwarg = 'pk' @lazyproperty def user(self): user_id = self.kwargs.get(self.user_id_url_kwarg) return UserModel.objects.get(id=user_id) class RoleUserMixin: request: Request @lazyproperty def user(self): return self.request.user