From 487c945d1d5fec5571524047a69e5b1ccf594b0e Mon Sep 17 00:00:00 2001 From: ibuler Date: Thu, 21 Oct 2021 16:16:50 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BF=AE=E6=94=B9=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E4=BD=8D=E7=BD=AE=EF=BC=8C=E7=94=A8=E6=88=B7sugestion=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=88=B06=E4=B8=AA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/applications/api/application.py | 2 +- apps/assets/api/asset.py | 2 +- apps/assets/api/system_user.py | 2 +- apps/common/mixins/api.py | 347 --------------------------- apps/common/mixins/api/__init__.py | 7 + apps/common/mixins/api/action.py | 55 +++++ apps/common/mixins/api/common.py | 30 +++ apps/common/mixins/api/filter.py | 35 +++ apps/common/mixins/api/patch.py | 136 +++++++++++ apps/common/mixins/api/permission.py | 37 +++ apps/common/mixins/api/queryset.py | 14 ++ apps/common/mixins/api/serializer.py | 95 ++++++++ apps/common/mixins/views.py | 53 +--- apps/users/api/user.py | 9 +- 14 files changed, 420 insertions(+), 404 deletions(-) delete mode 100644 apps/common/mixins/api.py create mode 100644 apps/common/mixins/api/__init__.py create mode 100644 apps/common/mixins/api/action.py create mode 100644 apps/common/mixins/api/common.py create mode 100644 apps/common/mixins/api/filter.py create mode 100644 apps/common/mixins/api/patch.py create mode 100644 apps/common/mixins/api/permission.py create mode 100644 apps/common/mixins/api/queryset.py create mode 100644 apps/common/mixins/api/serializer.py diff --git a/apps/applications/api/application.py b/apps/applications/api/application.py index 245379630..9428df39f 100644 --- a/apps/applications/api/application.py +++ b/apps/applications/api/application.py @@ -6,7 +6,7 @@ from rest_framework.decorators import action from rest_framework.response import Response from common.tree import TreeNodeSerializer -from common.mixins.views import SuggestionMixin +from common.mixins.api import SuggestionMixin from ..hands import IsOrgAdminOrAppUser from .. import serializers from ..models import Application diff --git a/apps/assets/api/asset.py b/apps/assets/api/asset.py index 79229e832..410020d39 100644 --- a/apps/assets/api/asset.py +++ b/apps/assets/api/asset.py @@ -8,7 +8,7 @@ from django.db.models import Q from common.utils import get_logger, get_object_or_none from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsSuperUser -from common.mixins.views import SuggestionMixin +from common.mixins.api import SuggestionMixin from users.models import User, UserGroup from users.serializers import UserSerializer, UserGroupSerializer from users.filters import UserFilter diff --git a/apps/assets/api/system_user.py b/apps/assets/api/system_user.py index 908ab87b9..f03f2e7d4 100644 --- a/apps/assets/api/system_user.py +++ b/apps/assets/api/system_user.py @@ -6,7 +6,7 @@ from common.utils import get_logger from common.permissions import IsOrgAdmin, IsOrgAdminOrAppUser, IsValidUser from orgs.mixins.api import OrgBulkModelViewSet from orgs.mixins import generics -from common.mixins.views import SuggestionMixin +from common.mixins.api import SuggestionMixin from orgs.utils import tmp_to_root_org from ..models import SystemUser, Asset from .. import serializers diff --git a/apps/common/mixins/api.py b/apps/common/mixins/api.py deleted file mode 100644 index caef88127..000000000 --- a/apps/common/mixins/api.py +++ /dev/null @@ -1,347 +0,0 @@ -# -*- coding: utf-8 -*- -# -import time -from hashlib import md5 -from threading import Thread -from collections import defaultdict -from itertools import chain - -from django.conf import settings -from django.db.models.signals import m2m_changed -from django.core.cache import cache -from django.http import JsonResponse -from django.utils.translation import ugettext as _ -from django.contrib.auth import get_user_model -from rest_framework.response import Response -from rest_framework.settings import api_settings -from rest_framework.decorators import action -from rest_framework.request import Request - -from common.const.http import POST -from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter -from ..utils import lazyproperty - -__all__ = [ - 'JSONResponseMixin', 'CommonApiMixin', 'AsyncApiMixin', 'RelationMixin', - 'QuerySetMixin', 'ExtraFilterFieldsMixin', 'RenderToJsonMixin', - 'SerializerMixin', 'AllowBulkDestroyMixin', 'PaginatedResponseMixin' -] - - -UserModel = get_user_model() - - -class JSONResponseMixin(object): - """JSON mixin""" - @staticmethod - def render_json_response(context): - return JsonResponse(context) - - -# SerializerMixin -# ---------------------- - - -class RenderToJsonMixin: - @action(methods=[POST], detail=False, url_path='render-to-json') - def render_to_json(self, request: Request): - data = { - 'title': (), - 'data': request.data, - } - - jms_context = getattr(request, 'jms_context', {}) - column_title_field_pairs = jms_context.get('column_title_field_pairs', ()) - data['title'] = column_title_field_pairs - - if isinstance(request.data, (list, tuple)) and not any(request.data): - error = _("Request file format may be wrong") - return Response(data={"error": error}, status=400) - return Response(data=data) - - -class SerializerMixin: - """ 根据用户请求动作的不同,获取不同的 `serializer_class `""" - - action: str - request: Request - - serializer_classes = None - single_actions = ['put', 'retrieve', 'patch'] - - def get_serializer_class_by_view_action(self): - if not hasattr(self, 'serializer_classes'): - return None - if not isinstance(self.serializer_classes, dict): - return None - - view_action = self.request.query_params.get('action') or self.action or 'list' - serializer_class = self.serializer_classes.get(view_action) - - if serializer_class is None: - view_method = self.request.method.lower() - serializer_class = self.serializer_classes.get(view_method) - - if serializer_class is None and view_action in self.single_actions: - serializer_class = self.serializer_classes.get('single') - if serializer_class is None: - serializer_class = self.serializer_classes.get('display') - if serializer_class is None: - serializer_class = self.serializer_classes.get('default') - return serializer_class - - def get_serializer_class(self): - serializer_class = self.get_serializer_class_by_view_action() - if serializer_class is None: - serializer_class = super().get_serializer_class() - return serializer_class - - -class ExtraFilterFieldsMixin: - """ - 额外的 api filter - """ - default_added_filters = [CustomFilter, IDSpmFilter, IDInFilter] - filter_backends = api_settings.DEFAULT_FILTER_BACKENDS - extra_filter_fields = [] - extra_filter_backends = [] - - def get_filter_backends(self): - if self.filter_backends != self.__class__.filter_backends: - return self.filter_backends - backends = list(chain( - self.filter_backends, - self.default_added_filters, - self.extra_filter_backends - )) - return backends - - def filter_queryset(self, queryset): - for backend in self.get_filter_backends(): - queryset = backend().filter_queryset(self.request, queryset, self) - return queryset - - -class PaginatedResponseMixin: - def get_paginated_response_with_query_set(self, queryset): - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - return self.get_paginated_response(serializer.data) - - serializer = self.get_serializer(queryset, many=True) - return Response(serializer.data) - - -class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin, RenderToJsonMixin): - pass - - -class InterceptMixin: - """ - Hack默认的dispatch, 让用户可以实现 self.do - """ - def dispatch(self, request, *args, **kwargs): - self.args = args - self.kwargs = kwargs - request = self.initialize_request(request, *args, **kwargs) - self.request = request - self.headers = self.default_response_headers # deprecate? - - try: - self.initial(request, *args, **kwargs) - - # Get the appropriate handler method - if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), - self.http_method_not_allowed) - else: - handler = self.http_method_not_allowed - - response = self.do(handler, request, *args, **kwargs) - - except Exception as exc: - response = self.handle_exception(exc) - - self.response = self.finalize_response(request, response, *args, **kwargs) - return self.response - - -class AsyncApiMixin(InterceptMixin): - def get_request_user_id(self): - user = self.request.user - if hasattr(user, 'id'): - return str(user.id) - return '' - - @lazyproperty - def async_cache_key(self): - method = self.request.method - path = self.get_request_md5() - user = self.get_request_user_id() - key = '{}_{}_{}'.format(method, path, user) - return key - - def get_request_md5(self): - path = self.request.path - query = {k: v for k, v in self.request.GET.items()} - query.pop("_", None) - query.pop('refresh', None) - query = "&".join(["{}={}".format(k, v) for k, v in query.items()]) - full_path = "{}?{}".format(path, query) - return md5(full_path.encode()).hexdigest() - - @lazyproperty - def initial_data(self): - data = { - "status": "running", - "start_time": time.time(), - "key": self.async_cache_key, - } - return data - - def get_cache_data(self): - key = self.async_cache_key - if self.is_need_refresh(): - cache.delete(key) - return None - data = cache.get(key) - return data - - def do(self, handler, *args, **kwargs): - if not self.is_need_async(): - return handler(*args, **kwargs) - resp = self.do_async(handler, *args, **kwargs) - return resp - - def is_need_refresh(self): - if self.request.GET.get("refresh"): - return True - return False - - def is_need_async(self): - return False - - def do_async(self, handler, *args, **kwargs): - data = self.get_cache_data() - if not data: - t = Thread( - target=self.do_in_thread, - args=(handler, *args), - kwargs=kwargs - ) - t.start() - resp = Response(self.initial_data) - return resp - status = data.get("status") - resp = data.get("resp") - if status == "ok" and resp: - resp = Response(**resp) - else: - resp = Response(data) - return resp - - def do_in_thread(self, handler, *args, **kwargs): - key = self.async_cache_key - data = self.initial_data - cache.set(key, data, 600) - try: - response = handler(*args, **kwargs) - data["status"] = "ok" - data["resp"] = { - "data": response.data, - "status": response.status_code - } - cache.set(key, data, 600) - except Exception as e: - data["error"] = str(e) - data["status"] = "error" - cache.set(key, data, 600) - - -class RelationMixin: - m2m_field = None - from_field = None - to_field = None - to_model = None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - assert self.m2m_field is not None, ''' - `m2m_field` should not be `None` - ''' - - self.from_field = self.m2m_field.m2m_field_name() - self.to_field = self.m2m_field.m2m_reverse_field_name() - self.to_model = self.m2m_field.related_model - self.through = getattr(self.m2m_field.model, self.m2m_field.attname).through - - def get_queryset(self): - # 注意,此处拦截了 `get_queryset` 没有 `super` - queryset = self.through.objects.all() - return queryset - - def send_m2m_changed_signal(self, instances, action): - if not isinstance(instances, list): - instances = [instances] - - from_to_mapper = defaultdict(list) - - for i in instances: - to_id = getattr(i, self.to_field).id - # TODO 优化,不应该每次都查询数据库 - from_obj = getattr(i, self.from_field) - from_to_mapper[from_obj].append(to_id) - - for from_obj, to_ids in from_to_mapper.items(): - m2m_changed.send( - sender=self.through, instance=from_obj, action=action, - reverse=False, model=self.to_model, pk_set=to_ids - ) - - def perform_create(self, serializer): - instance = serializer.save() - self.send_m2m_changed_signal(instance, 'post_add') - - def perform_destroy(self, instance): - instance.delete() - self.send_m2m_changed_signal(instance, 'post_remove') - - -class QuerySetMixin: - def get_queryset(self): - queryset = super().get_queryset() - serializer_class = self.get_serializer_class() - - if serializer_class and hasattr(serializer_class, 'setup_eager_loading'): - queryset = serializer_class.setup_eager_loading(queryset) - - return queryset - - -class AllowBulkDestroyMixin: - def allow_bulk_destroy(self, qs, filtered): - """ - 我们规定,批量删除的情况必须用 `id` 指定要删除的数据。 - """ - query = str(filtered.query) - return '`id` IN (' in query or '`id` =' in query - - -class RoleAdminMixin: - kwargs: dict - user_id_url_kwarg = 'pk' - - @lazyproperty - def user(self): - user_id = self.kwargs.get(self.user_id_url_kwarg) - return UserModel.objects.get(id=user_id) - - -class RoleUserMixin: - request: Request - - @lazyproperty - def user(self): - return self.request.user diff --git a/apps/common/mixins/api/__init__.py b/apps/common/mixins/api/__init__.py new file mode 100644 index 000000000..a5827ffef --- /dev/null +++ b/apps/common/mixins/api/__init__.py @@ -0,0 +1,7 @@ +from .common import * +from .action import * +from .patch import * +from .filter import * +from .permission import * +from .queryset import * +from .serializer import * diff --git a/apps/common/mixins/api/action.py b/apps/common/mixins/api/action.py new file mode 100644 index 000000000..994ade06b --- /dev/null +++ b/apps/common/mixins/api/action.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +from typing import Callable + +from django.utils.translation import ugettext as _ +from rest_framework.response import Response +from rest_framework.decorators import action +from rest_framework.request import Request + +from common.const.http import POST +from common.permissions import IsValidUser + + +__all__ = ['SuggestionMixin', 'RenderToJsonMixin'] + + +class SuggestionMixin: + suggestion_limit = 10 + + filter_queryset: Callable + get_queryset: Callable + paginate_queryset: Callable + get_serializer: Callable + get_paginated_response: Callable + + @action(methods=['get'], detail=False, permission_classes=(IsValidUser,)) + def suggestions(self, request, *args, **kwargs): + queryset = self.filter_queryset(self.get_queryset()) + queryset = queryset[:self.suggestion_limit] + page = self.paginate_queryset(queryset) + + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + +class RenderToJsonMixin: + @action(methods=[POST], detail=False, url_path='render-to-json') + def render_to_json(self, request: Request): + data = { + 'title': (), + 'data': request.data, + } + + jms_context = getattr(request, 'jms_context', {}) + column_title_field_pairs = jms_context.get('column_title_field_pairs', ()) + data['title'] = column_title_field_pairs + + if isinstance(request.data, (list, tuple)) and not any(request.data): + error = _("Request file format may be wrong") + return Response(data={"error": error}, status=400) + return Response(data=data) diff --git a/apps/common/mixins/api/common.py b/apps/common/mixins/api/common.py new file mode 100644 index 000000000..ba3895356 --- /dev/null +++ b/apps/common/mixins/api/common.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# +from rest_framework.response import Response + +from .serializer import SerializerMixin +from .filter import ExtraFilterFieldsMixin +from .action import RenderToJsonMixin + +__all__ = [ + 'CommonApiMixin', 'PaginatedResponseMixin', +] + + +class PaginatedResponseMixin: + def get_paginated_response_with_query_set(self, queryset): + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + +class CommonApiMixin(SerializerMixin, ExtraFilterFieldsMixin, RenderToJsonMixin): + pass + + + + diff --git a/apps/common/mixins/api/filter.py b/apps/common/mixins/api/filter.py new file mode 100644 index 000000000..1d1451b66 --- /dev/null +++ b/apps/common/mixins/api/filter.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# +from itertools import chain + +from rest_framework.settings import api_settings + +from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter + + +__all__ = ['ExtraFilterFieldsMixin'] + + +class ExtraFilterFieldsMixin: + """ + 额外的 api filter + """ + default_added_filters = [CustomFilter, IDSpmFilter, IDInFilter] + filter_backends = api_settings.DEFAULT_FILTER_BACKENDS + extra_filter_fields = [] + extra_filter_backends = [] + + def get_filter_backends(self): + if self.filter_backends != self.__class__.filter_backends: + return self.filter_backends + backends = list(chain( + self.filter_backends, + self.default_added_filters, + self.extra_filter_backends + )) + return backends + + def filter_queryset(self, queryset): + for backend in self.get_filter_backends(): + queryset = backend().filter_queryset(self.request, queryset, self) + return queryset diff --git a/apps/common/mixins/api/patch.py b/apps/common/mixins/api/patch.py new file mode 100644 index 000000000..f79957546 --- /dev/null +++ b/apps/common/mixins/api/patch.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +# +import time +from hashlib import md5 +from threading import Thread + +from django.core.cache import cache +from rest_framework.response import Response + +from common.utils import lazyproperty + + +__all__ = ['InterceptMixin', 'AsyncApiMixin'] + + +class InterceptMixin: + """ + Hack默认的dispatch, 让用户可以实现 self.do + """ + def dispatch(self, request, *args, **kwargs): + self.args = args + self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request + self.headers = self.default_response_headers # deprecate? + + try: + self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + response = self.do(handler, request, *args, **kwargs) + + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + return self.response + + +class AsyncApiMixin(InterceptMixin): + def get_request_user_id(self): + user = self.request.user + if hasattr(user, 'id'): + return str(user.id) + return '' + + @lazyproperty + def async_cache_key(self): + method = self.request.method + path = self.get_request_md5() + user = self.get_request_user_id() + key = '{}_{}_{}'.format(method, path, user) + return key + + def get_request_md5(self): + path = self.request.path + query = {k: v for k, v in self.request.GET.items()} + query.pop("_", None) + query.pop('refresh', None) + query = "&".join(["{}={}".format(k, v) for k, v in query.items()]) + full_path = "{}?{}".format(path, query) + return md5(full_path.encode()).hexdigest() + + @lazyproperty + def initial_data(self): + data = { + "status": "running", + "start_time": time.time(), + "key": self.async_cache_key, + } + return data + + def get_cache_data(self): + key = self.async_cache_key + if self.is_need_refresh(): + cache.delete(key) + return None + data = cache.get(key) + return data + + def do(self, handler, *args, **kwargs): + if not self.is_need_async(): + return handler(*args, **kwargs) + resp = self.do_async(handler, *args, **kwargs) + return resp + + def is_need_refresh(self): + if self.request.GET.get("refresh"): + return True + return False + + def is_need_async(self): + return False + + def do_async(self, handler, *args, **kwargs): + data = self.get_cache_data() + if not data: + t = Thread( + target=self.do_in_thread, + args=(handler, *args), + kwargs=kwargs + ) + t.start() + resp = Response(self.initial_data) + return resp + status = data.get("status") + resp = data.get("resp") + if status == "ok" and resp: + resp = Response(**resp) + else: + resp = Response(data) + return resp + + def do_in_thread(self, handler, *args, **kwargs): + key = self.async_cache_key + data = self.initial_data + cache.set(key, data, 600) + try: + response = handler(*args, **kwargs) + data["status"] = "ok" + data["resp"] = { + "data": response.data, + "status": response.status_code + } + cache.set(key, data, 600) + except Exception as e: + data["error"] = str(e) + data["status"] = "error" + cache.set(key, data, 600) + diff --git a/apps/common/mixins/api/permission.py b/apps/common/mixins/api/permission.py new file mode 100644 index 000000000..6ffa15e53 --- /dev/null +++ b/apps/common/mixins/api/permission.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# +from django.contrib.auth import get_user_model +from rest_framework.request import Request + +from common.utils import lazyproperty + + +__all__ = ['AllowBulkDestroyMixin', 'RoleAdminMixin', 'RoleUserMixin'] + + +class AllowBulkDestroyMixin: + def allow_bulk_destroy(self, qs, filtered): + """ + 我们规定,批量删除的情况必须用 `id` 指定要删除的数据。 + """ + query = str(filtered.query) + return '`id` IN (' in query or '`id` =' in query + + +class RoleAdminMixin: + kwargs: dict + user_id_url_kwarg = 'pk' + + @lazyproperty + def user(self): + user_id = self.kwargs.get(self.user_id_url_kwarg) + user_model = get_user_model() + return user_model.objects.get(id=user_id) + + +class RoleUserMixin: + request: Request + + @lazyproperty + def user(self): + return self.request.user \ No newline at end of file diff --git a/apps/common/mixins/api/queryset.py b/apps/common/mixins/api/queryset.py new file mode 100644 index 000000000..4f56e8a51 --- /dev/null +++ b/apps/common/mixins/api/queryset.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# + +__all__ = ['QuerySetMixin'] + + +class QuerySetMixin: + def get_queryset(self): + queryset = super().get_queryset() + serializer_class = self.get_serializer_class() + + if serializer_class and hasattr(serializer_class, 'setup_eager_loading'): + queryset = serializer_class.setup_eager_loading(queryset) + return queryset diff --git a/apps/common/mixins/api/serializer.py b/apps/common/mixins/api/serializer.py new file mode 100644 index 000000000..c5c9b4737 --- /dev/null +++ b/apps/common/mixins/api/serializer.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# +from collections import defaultdict + +from django.db.models.signals import m2m_changed +from rest_framework.request import Request + +__all__ = ['SerializerMixin', 'RelationMixin'] + + +class SerializerMixin: + """ 根据用户请求动作的不同,获取不同的 `serializer_class `""" + + action: str + request: Request + + serializer_classes = None + single_actions = ['put', 'retrieve', 'patch'] + + def get_serializer_class_by_view_action(self): + if not hasattr(self, 'serializer_classes'): + return None + if not isinstance(self.serializer_classes, dict): + return None + + view_action = self.request.query_params.get('action') or self.action or 'list' + serializer_class = self.serializer_classes.get(view_action) + + if serializer_class is None: + view_method = self.request.method.lower() + serializer_class = self.serializer_classes.get(view_method) + + if serializer_class is None and view_action in self.single_actions: + serializer_class = self.serializer_classes.get('single') + if serializer_class is None: + serializer_class = self.serializer_classes.get('display') + if serializer_class is None: + serializer_class = self.serializer_classes.get('default') + return serializer_class + + def get_serializer_class(self): + serializer_class = self.get_serializer_class_by_view_action() + if serializer_class is None: + serializer_class = super().get_serializer_class() + return serializer_class + + +class RelationMixin: + m2m_field = None + from_field = None + to_field = None + to_model = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.m2m_field is not None, ''' + `m2m_field` should not be `None` + ''' + + self.from_field = self.m2m_field.m2m_field_name() + self.to_field = self.m2m_field.m2m_reverse_field_name() + self.to_model = self.m2m_field.related_model + self.through = getattr(self.m2m_field.model, self.m2m_field.attname).through + + def get_queryset(self): + # 注意,此处拦截了 `get_queryset` 没有 `super` + queryset = self.through.objects.all() + return queryset + + def send_m2m_changed_signal(self, instances, action): + if not isinstance(instances, list): + instances = [instances] + + from_to_mapper = defaultdict(list) + + for i in instances: + to_id = getattr(i, self.to_field).id + # TODO 优化,不应该每次都查询数据库 + from_obj = getattr(i, self.from_field) + from_to_mapper[from_obj].append(to_id) + + for from_obj, to_ids in from_to_mapper.items(): + m2m_changed.send( + sender=self.through, instance=from_obj, action=action, + reverse=False, model=self.to_model, pk_set=to_ids + ) + + def perform_create(self, serializer): + instance = serializer.save() + self.send_m2m_changed_signal(instance, 'post_add') + + def perform_destroy(self, instance): + instance.delete() + self.send_m2m_changed_signal(instance, 'post_remove') diff --git a/apps/common/mixins/views.py b/apps/common/mixins/views.py index a4bc32b76..f167d2001 100644 --- a/apps/common/mixins/views.py +++ b/apps/common/mixins/views.py @@ -1,49 +1,16 @@ # -*- coding: utf-8 -*- # -# coding: utf-8 from django.contrib.auth.mixins import UserPassesTestMixin -from django.utils import timezone -from rest_framework.decorators import action from rest_framework import permissions -from rest_framework.response import Response - -from common.permissions import IsValidUser - -__all__ = ["DatetimeSearchMixin", "PermissionsMixin"] +from rest_framework.request import Request -class DatetimeSearchMixin: - date_format = '%Y-%m-%d' - date_from = date_to = None - - def get_date_range(self): - date_from_s = self.request.GET.get('date_from') - date_to_s = self.request.GET.get('date_to') - - if date_from_s: - date_from = timezone.datetime.strptime(date_from_s, self.date_format) - tz = timezone.get_current_timezone() - self.date_from = tz.localize(date_from) - else: - self.date_from = timezone.now() - timezone.timedelta(7) - - if date_to_s: - date_to = timezone.datetime.strptime( - date_to_s + ' 23:59:59', self.date_format + ' %H:%M:%S' - ) - self.date_to = date_to.replace( - tzinfo=timezone.get_current_timezone() - ) - else: - self.date_to = timezone.now() - - def get(self, request, *args, **kwargs): - self.get_date_range() - return super().get(request, *args, **kwargs) +__all__ = ["PermissionsMixin"] class PermissionsMixin(UserPassesTestMixin): permission_classes = [permissions.IsAuthenticated] + request: Request def get_permissions(self): return self.permission_classes @@ -56,17 +23,3 @@ class PermissionsMixin(UserPassesTestMixin): return True -class SuggestionMixin: - suggestion_mini_count = 10 - - @action(methods=['get'], detail=False, permission_classes=(IsValidUser,)) - def suggestions(self, request, *args, **kwargs): - queryset = self.filter_queryset(self.get_queryset()) - queryset = queryset[:self.suggestion_mini_count] - page = self.paginate_queryset(queryset) - if page is not None: - serializer = self.get_serializer(page, many=True) - return self.get_paginated_response(serializer.data) - - serializer = self.get_serializer(queryset, many=True) - return Response(serializer.data) diff --git a/apps/users/api/user.py b/apps/users/api/user.py index d72f04f73..a1b808a22 100644 --- a/apps/users/api/user.py +++ b/apps/users/api/user.py @@ -8,7 +8,6 @@ from rest_framework.response import Response from rest_framework_bulk import BulkModelViewSet from django.db.models import Prefetch -from users.notifications import ResetMFAMsg from common.permissions import ( IsOrgAdmin, IsOrgAdminOrAppUser, CanUpdateDeleteUser, IsSuperUser @@ -18,9 +17,10 @@ from common.utils import get_logger from orgs.utils import current_org from orgs.models import ROLE as ORG_ROLE, OrganizationMember from users.utils import LoginBlockUtil, MFABlockUtils +from .mixins import UserQuerysetMixin +from ..notifications import ResetMFAMsg from .. import serializers from ..serializers import UserSerializer, MiniUserSerializer, InviteSerializer -from .mixins import UserQuerysetMixin from ..models import User from ..signals import post_user_create from ..filters import OrgRoleUserFilterBackend, UserFilter @@ -128,9 +128,9 @@ class UserViewSet(CommonApiMixin, UserQuerysetMixin, BulkModelViewSet): return super().perform_bulk_update(serializer) @action(methods=['get'], detail=False, permission_classes=(IsOrgAdmin,)) - def suggestion(self, request): + def suggestion(self, *args, **kwargs): queryset = User.objects.exclude(role=User.ROLE.APP) - queryset = self.filter_queryset(queryset)[:3] + queryset = self.filter_queryset(queryset)[:6] serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) @@ -206,6 +206,7 @@ class UserResetOTPApi(UserQuerysetMixin, generics.RetrieveAPIView): if user == request.user: msg = _("Could not reset self otp, use profile reset instead") return Response({"error": msg}, status=401) + if user.mfa_enabled: user.reset_mfa() user.save()