diff --git a/apps/assets/urls/api_urls.py b/apps/assets/urls/api_urls.py index 279bd26be..d70accc23 100644 --- a/apps/assets/urls/api_urls.py +++ b/apps/assets/urls/api_urls.py @@ -1,7 +1,6 @@ # coding:utf-8 from django.urls import path, re_path from rest_framework_nested import routers -# from rest_framework.routers import DefaultRouter from rest_framework_bulk.routes import BulkRouter from common import api as capi diff --git a/apps/common/db/models.py b/apps/common/db/models.py index 5d9827c06..b807b9dd5 100644 --- a/apps/common/db/models.py +++ b/apps/common/db/models.py @@ -12,11 +12,12 @@ import uuid from django.db.models import * +from django.db.models.functions import Concat from django.utils.translation import ugettext_lazy as _ class Choice(str): - def __new__(cls, value, label): + def __new__(cls, value, label=''): # `deepcopy` 的时候不会传 `label` self = super().__new__(cls, value) self.label = label return self @@ -77,3 +78,7 @@ class JMSModel(JMSBaseModel): class Meta: abstract = True + + +def concated_display(name1, name2): + return Concat(F(name1), Value('('), F(name2), Value(')')) diff --git a/apps/common/drf/api.py b/apps/common/drf/api.py index 692d567f5..febd4467e 100644 --- a/apps/common/drf/api.py +++ b/apps/common/drf/api.py @@ -2,7 +2,8 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet from rest_framework_bulk import BulkModelViewSet from ..mixins.api import ( - SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin + SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin, + RelationMixin, AllowBulkDestoryMixin ) @@ -26,5 +27,16 @@ class JMSBulkModelViewSet(SerializerMixin2, QuerySetMixin, ExtraFilterFieldsMixin, PaginatedResponseMixin, + AllowBulkDestoryMixin, BulkModelViewSet): pass + + +class JMSBulkRelationModelViewSet(SerializerMixin2, + QuerySetMixin, + ExtraFilterFieldsMixin, + PaginatedResponseMixin, + RelationMixin, + AllowBulkDestoryMixin, + BulkModelViewSet): + pass diff --git a/apps/common/mixins/api.py b/apps/common/mixins/api.py index 5c17a5cca..3e9aea665 100644 --- a/apps/common/mixins/api.py +++ b/apps/common/mixins/api.py @@ -11,6 +11,8 @@ from django.core.cache import cache from django.http import JsonResponse from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework import status +from rest_framework_bulk.drf3.mixins import BulkDestroyModelMixin from common.drf.filters import IDSpmFilter, CustomFilter, IDInFilter from ..utils import lazyproperty @@ -223,10 +225,11 @@ class RelationMixin: self.through = getattr(self.m2m_field.model, self.m2m_field.attname).through def get_queryset(self): + # 注意,此处拦截了 `get_queryset` 没有 `super` queryset = self.through.objects.all() return queryset - def send_post_add_signal(self, instances): + def send_m2m_changed_signal(self, instances, action): if not isinstance(instances, list): instances = [instances] @@ -239,13 +242,17 @@ class RelationMixin: for from_obj, to_ids in from_to_mapper.items(): m2m_changed.send( - sender=self.through, instance=from_obj, action='post_add', + sender=self.through, instance=from_obj, action=action, reverse=False, model=self.to_model, pk_set=to_ids ) def perform_create(self, serializer): instance = serializer.save() - self.send_post_add_signal(instance) + self.send_m2m_changed_signal(instance, 'post_add') + + def perform_destroy(self, instance): + instance.delete() + self.send_m2m_changed_signal(instance, 'post_remove') class SerializerMixin2: @@ -275,3 +282,12 @@ class QuerySetMixin: queryset = serializer_class.setup_eager_loading(queryset) return queryset + + +class AllowBulkDestoryMixin: + def allow_bulk_destroy(self, qs, filtered): + """ + 我们规定,批量删除的情况必须用 `id` 指定要删除的数据。 + """ + query = str(filtered.query) + return '`id` IN (' in query or '`id` =' in query diff --git a/apps/orgs/api.py b/apps/orgs/api.py index e29a14e22..d283019d3 100644 --- a/apps/orgs/api.py +++ b/apps/orgs/api.py @@ -7,14 +7,18 @@ from rest_framework.views import Response from rest_framework_bulk import BulkModelViewSet from common.permissions import IsSuperUserOrAppUser +from common.drf.api import JMSBulkRelationModelViewSet from .models import Organization, ROLE -from .serializers import OrgSerializer, OrgReadSerializer, \ - OrgAllUserSerializer, OrgRetrieveSerializer +from .serializers import ( + OrgSerializer, OrgReadSerializer, + OrgRetrieveSerializer, OrgMemberSerializer +) from users.models import User, UserGroup from assets.models import Asset, Domain, AdminUser, SystemUser, Label from perms.models import AssetPermission from orgs.utils import current_org from common.utils import get_logger +from .filters import OrgMemberRelationFilterSet logger = get_logger(__file__) @@ -61,15 +65,13 @@ class OrgViewSet(BulkModelViewSet): return Response({'msg': True}, status=status.HTTP_200_OK) -class OrgAllUserListApi(generics.ListAPIView): +class OrgMemberRelationBulkViewSet(JMSBulkRelationModelViewSet): permission_classes = (IsSuperUserOrAppUser,) - serializer_class = OrgAllUserSerializer - filter_fields = ("username", "name") - search_fields = filter_fields + m2m_field = Organization.members.field + serializer_class = OrgMemberSerializer + filterset_class = OrgMemberRelationFilterSet - def get_queryset(self): - pk = self.kwargs.get("pk") - users = User.objects.filter( - orgs=pk, m2m_org_members__role=ROLE.USER - ).only(*self.serializer_class.Meta.only_fields) - return users + def perform_bulk_destroy(self, queryset): + objs = list(queryset.all().prefetch_related('user', 'org')) + queryset.delete() + self.send_m2m_changed_signal(objs, action='post_remove') diff --git a/apps/orgs/filters.py b/apps/orgs/filters.py new file mode 100644 index 000000000..df68e468f --- /dev/null +++ b/apps/orgs/filters.py @@ -0,0 +1,16 @@ +from django_filters.rest_framework import filterset +from django_filters.rest_framework import filters + +from .models import OrganizationMember + + +class UUIDInFilter(filters.BaseInFilter, filters.UUIDFilter): + pass + + +class OrgMemberRelationFilterSet(filterset.FilterSet): + id = UUIDInFilter(field_name='id', lookup_expr='in') + + class Meta: + model = OrganizationMember + fields = ('org_id', 'user_id', 'role', 'id') diff --git a/apps/orgs/mixins/serializers.py b/apps/orgs/mixins/serializers.py index 2b415e31b..fed9d1713 100644 --- a/apps/orgs/mixins/serializers.py +++ b/apps/orgs/mixins/serializers.py @@ -11,8 +11,7 @@ from ..utils import get_current_org_id_for_serializer __all__ = [ "OrgResourceSerializerMixin", "BulkOrgResourceSerializerMixin", - "BulkOrgResourceModelSerializer", "OrgMembershipSerializerMixin", - "OrgResourceModelSerializerMixin", + "BulkOrgResourceModelSerializer", "OrgResourceModelSerializerMixin", ] @@ -53,9 +52,3 @@ class BulkOrgResourceSerializerMixin(BulkSerializerMixin, OrgResourceSerializerM class BulkOrgResourceModelSerializer(BulkOrgResourceSerializerMixin, serializers.ModelSerializer): pass - - -class OrgMembershipSerializerMixin: - def run_validation(self, initial_data=None): - initial_data['organization'] = str(self.context['org'].id) - return super().run_validation(initial_data) diff --git a/apps/orgs/models.py b/apps/orgs/models.py index effabf50f..c72d1ae82 100644 --- a/apps/orgs/models.py +++ b/apps/orgs/models.py @@ -149,6 +149,13 @@ class Organization(models.Model): m2m_org_members__user_id=user.id ).distinct() + @classmethod + def get_user_all_orgs(cls, user): + return [ + *cls.objects.filter(members=user).distinct(), + cls.default() + ] + @classmethod def get_user_admin_orgs(cls, user): if user.is_anonymous: @@ -161,7 +168,10 @@ class Organization(models.Model): def get_user_user_orgs(cls, user): if user.is_anonymous: return cls.objects.none() - return cls.get_user_orgs_by_role(user, ROLE.USER) + return [ + *cls.get_user_orgs_by_role(user, ROLE.USER), + cls.default() + ] @classmethod def get_user_audit_orgs(cls, user): diff --git a/apps/orgs/serializers.py b/apps/orgs/serializers.py index 4cf54f92a..5b20ce47e 100644 --- a/apps/orgs/serializers.py +++ b/apps/orgs/serializers.py @@ -1,11 +1,12 @@ - +from django.db.models import F from rest_framework.serializers import ModelSerializer from rest_framework import serializers from users.models.user import User from common.serializers import AdaptedBulkListSerializer +from common.drf.serializers import BulkModelSerializer +from common.db.models import concated_display as display from .models import Organization, OrganizationMember -from .mixins.serializers import OrgMembershipSerializerMixin class OrgSerializer(ModelSerializer): @@ -50,30 +51,20 @@ class OrgReadSerializer(OrgSerializer): pass -class OrgMembershipAdminSerializer(OrgMembershipSerializerMixin, ModelSerializer): +class OrgMemberSerializer(BulkModelSerializer): + org_display = serializers.CharField() + user_display = serializers.CharField() + class Meta: model = Organization.members.through - list_serializer_class = AdaptedBulkListSerializer - fields = '__all__' + fields = ('id', 'org', 'user', 'role', 'org_display', 'user_display') - -class OrgMembershipUserSerializer(OrgMembershipSerializerMixin, ModelSerializer): - class Meta: - model = Organization.members.through - list_serializer_class = AdaptedBulkListSerializer - fields = '__all__' - - -class OrgAllUserSerializer(serializers.Serializer): - user = serializers.UUIDField(read_only=True, source='id') - user_display = serializers.SerializerMethodField() - - class Meta: - only_fields = ['id', 'username', 'name'] - - @staticmethod - def get_user_display(obj): - return str(obj) + @classmethod + def setup_eager_loading(cls, queryset): + return queryset.annotate( + org_display=F('org__name'), + user_display=display('user__name', 'user__username') + ).distinct() class OrgRetrieveSerializer(OrgReadSerializer): diff --git a/apps/orgs/urls/api_urls.py b/apps/orgs/urls/api_urls.py index e14435868..56a135fcd 100644 --- a/apps/orgs/urls/api_urls.py +++ b/apps/orgs/urls/api_urls.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- # -from django.urls import re_path, path +from django.urls import re_path from rest_framework.routers import DefaultRouter +from rest_framework_bulk.routes import BulkRouter from common import api as capi from .. import api @@ -10,15 +11,13 @@ from .. import api app_name = 'orgs' router = DefaultRouter() +bulk_router = BulkRouter() router.register(r'orgs', api.OrgViewSet, 'org') +bulk_router.register(r'org-memeber-relation', api.OrgMemberRelationBulkViewSet, 'org-memeber-relation') old_version_urlpatterns = [ re_path('(?Porg)/.*', capi.redirect_plural_name_api) ] -urlpatterns = [ - path('/users/all/', api.OrgAllUserListApi.as_view(), name='org-all-users'), -] - -urlpatterns += router.urls + old_version_urlpatterns +urlpatterns = router.urls + bulk_router.urls + old_version_urlpatterns diff --git a/apps/users/serializers/user.py b/apps/users/serializers/user.py index 21c38d7bc..4924aef17 100644 --- a/apps/users/serializers/user.py +++ b/apps/users/serializers/user.py @@ -18,7 +18,7 @@ __all__ = [ 'ChangeUserPasswordSerializer', 'ResetOTPSerializer', 'UserProfileSerializer', 'UserOrgSerializer', 'UserUpdatePasswordSerializer', 'UserUpdatePublicKeySerializer', - 'UserRetrieveSerializer' + 'UserRetrieveSerializer', 'MiniUserSerializer', ]