From 1e85805ea3503b32b2eb99b33d4dd41421012543 Mon Sep 17 00:00:00 2001 From: xinwen Date: Tue, 16 Mar 2021 15:34:40 +0800 Subject: [PATCH 1/7] =?UTF-8?q?fix:=20=E7=94=A8=E6=88=B7=E6=8E=88=E6=9D=83?= =?UTF-8?q?=E8=B5=84=E4=BA=A7=E8=BF=87=E6=BB=A4=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../asset/user_permission/user_permission_assets/mixin.py | 5 ++--- .../asset/user_permission/user_permission_assets/views.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py index d7a5c23dc..2464787d6 100644 --- a/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py +++ b/apps/perms/api/asset/user_permission/user_permission_assets/mixin.py @@ -89,6 +89,8 @@ class AssetsTreeFormatMixin(SerializeToTreeNodeMixin): """ 将 资产 序列化成树的结构返回 """ + filterset_fields = ['hostname', 'ip', 'id', 'comment'] + search_fields = ['hostname', 'ip', 'comment'] def list(self, request: Request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) @@ -99,6 +101,3 @@ class AssetsTreeFormatMixin(SerializeToTreeNodeMixin): queryset = queryset[:999] data = self.serialize_assets(queryset, None) return Response(data=data) - - # def get_serializer_class(self): - # return EmptySerializer diff --git a/apps/perms/api/asset/user_permission/user_permission_assets/views.py b/apps/perms/api/asset/user_permission/user_permission_assets/views.py index 05b09442a..8a9690e12 100644 --- a/apps/perms/api/asset/user_permission/user_permission_assets/views.py +++ b/apps/perms/api/asset/user_permission/user_permission_assets/views.py @@ -82,7 +82,7 @@ class MyAllAssetsAsTreeApi(UserAllGrantedAssetsQuerysetMixin, RoleUserMixin, AssetsTreeFormatMixin, ListAPIView): - search_fields = ['hostname', 'ip'] + pass class UserGrantedNodeAssetsForAdminApi(UserGrantedNodeAssetsMixin, From adc607dafe27c83de51216ca59e900b25c967676 Mon Sep 17 00:00:00 2001 From: Bai Date: Tue, 16 Mar 2021 15:15:06 +0800 Subject: [PATCH 2/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E8=A7=92=E8=89=B2=E7=94=B1=E7=BB=84=E7=BB=87=E7=94=A8=E6=88=B7?= =?UTF-8?q?->=E7=BB=84=E7=BB=87=E7=AE=A1=E7=90=86=E5=91=98=E6=97=B6?= =?UTF-8?q?=E4=BB=8E=E7=BB=84=E7=BB=87=E6=B8=85=E9=99=A4=E7=94=A8=E6=88=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/orgs/models.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/apps/orgs/models.py b/apps/orgs/models.py index 72231f3c3..d762e5eba 100644 --- a/apps/orgs/models.py +++ b/apps/orgs/models.py @@ -363,13 +363,7 @@ class OrgMemberManager(models.Manager): if role in to_add: to_add[role].add(user) - self.remove_users_by_role( - org, - to_remove.users, - to_remove.admins, - to_remove.auditors - ) - + # 先添加再移除 (防止用户角色由组织用户->组织管理员时从组织清除用户) self.add_users_by_role( org, to_add.users, @@ -377,6 +371,13 @@ class OrgMemberManager(models.Manager): to_add.auditors ) + self.remove_users_by_role( + org, + to_remove.users, + to_remove.admins, + to_remove.auditors + ) + def set_users_by_role(self, org, users=None, admins=None, auditors=None): """ 给组织设置带角色的用户 From 98c6a936584a1361c47993ab588995d8b4876348 Mon Sep 17 00:00:00 2001 From: xinwen Date: Tue, 16 Mar 2021 11:08:15 +0800 Subject: [PATCH 3/7] =?UTF-8?q?fix:=20=E4=BB=BB=E5=8A=A1=E5=BA=94=E8=AF=A5?= =?UTF-8?q?=E5=88=86=E7=BB=84=E7=BB=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/ops/api/adhoc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/apps/ops/api/adhoc.py b/apps/ops/api/adhoc.py index 4845d49dc..f7b32e4fd 100644 --- a/apps/ops/api/adhoc.py +++ b/apps/ops/api/adhoc.py @@ -17,14 +17,16 @@ from ..serializers import ( AdHocDetailSerializer, ) from ..tasks import run_ansible_task +from orgs.mixins.api import OrgBulkModelViewSet +from orgs.utils import current_org __all__ = [ 'TaskViewSet', 'TaskRun', 'AdHocViewSet', 'AdHocRunHistoryViewSet' ] -class TaskViewSet(JMSBulkModelViewSet): - queryset = Task.objects.all() +class TaskViewSet(OrgBulkModelViewSet): + model = Task filterset_fields = ("name",) search_fields = filterset_fields serializer_class = TaskSerializer From 36c083f67457c97658bb17ad67936447877239e4 Mon Sep 17 00:00:00 2001 From: xinwen Date: Tue, 16 Mar 2021 19:42:22 +0800 Subject: [PATCH 4/7] =?UTF-8?q?fix:=20=E4=BC=9A=E8=AF=9D=E9=87=8C=E6=9F=A5?= =?UTF-8?q?=E4=B8=8D=E5=88=B0=E5=91=BD=E4=BB=A4=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/terminal/api/command.py | 27 +++++++++++++++++++++++++++ apps/terminal/backends/command/es.py | 18 ++++++++++++++---- apps/terminal/filters.py | 5 +---- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/apps/terminal/api/command.py b/apps/terminal/api/command.py index e75ce491e..4d7aea9c5 100644 --- a/apps/terminal/api/command.py +++ b/apps/terminal/api/command.py @@ -111,6 +111,33 @@ class CommandViewSet(viewsets.ModelViewSet): filterset_class = CommandFilter ordering_fields = ('timestamp', ) + def merge_all_storage_list(self, request, *args, **kwargs): + merged_commands = [] + + storages = CommandStorage.objects.all() + for storage in storages: + qs = storage.get_command_queryset() + commands = self.filter_queryset(qs) + merged_commands.extend(commands) + + merged_commands.sort(key=lambda command: command.timestamp, reverse=True) + page = self.paginate_queryset(merged_commands) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + def list(self, request, *args, **kwargs): + command_storage_id = self.request.query_params.get('command_storage_id') + session_id = self.request.query_params.get('session_id') + + if session_id and not command_storage_id: + # 会话里的命令列表肯定会提供 session_id,这里防止 merge 的时候取全量的数据 + return self.merge_all_storage_list(request, *args, **kwargs) + return super().list(request, *args, **kwargs) + def get_queryset(self): command_storage_id = self.request.query_params.get('command_storage_id') storage = CommandStorage.objects.get(id=command_storage_id) diff --git a/apps/terminal/backends/command/es.py b/apps/terminal/backends/command/es.py index 1137f6ec8..da51c44e2 100644 --- a/apps/terminal/backends/command/es.py +++ b/apps/terminal/backends/command/es.py @@ -117,11 +117,21 @@ class CommandStore(): timestamp_range['lte'] = timestamp__lte # 处理组织 - must_not = [] + should = [] org_id = match.get('org_id') - if org_id == '': + + real_default_org_id = '00000000-0000-0000-0000-000000000002' + if org_id in (real_default_org_id, ''): match.pop('org_id') - must_not.append({'wildcard': {'org_id': '*'}}) + should.append({ + 'bool':{ + 'must_not': [ + { + 'wildcard': {'org_id': '*'} + } + ]} + }) + should.append({'match': {'org_id': real_default_org_id}}) # 构建 body body = { @@ -130,7 +140,7 @@ class CommandStore(): 'must': [ {'match': {k: v}} for k, v in match.items() ], - 'must_not': must_not, + 'should': should, 'filter': [ { 'term': {k: v} diff --git a/apps/terminal/filters.py b/apps/terminal/filters.py index caed19a9c..a102c149c 100644 --- a/apps/terminal/filters.py +++ b/apps/terminal/filters.py @@ -48,10 +48,7 @@ class CommandFilter(filters.FilterSet): @staticmethod def get_org_id(): - if current_org.is_default(): - org_id = '' - else: - org_id = current_org.id + org_id = current_org.id return org_id From cc3911d2f1179eec35cc9e7dfb2ddb578511039d Mon Sep 17 00:00:00 2001 From: ibuler Date: Tue, 16 Mar 2021 19:35:56 +0800 Subject: [PATCH 5/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20user=20profile?= =?UTF-8?q?=20all=20orgs=20=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/orgs/models.py | 5 +---- apps/users/serializers/profile.py | 7 +------ 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/apps/orgs/models.py b/apps/orgs/models.py index d762e5eba..ac0830592 100644 --- a/apps/orgs/models.py +++ b/apps/orgs/models.py @@ -150,10 +150,7 @@ class Organization(models.Model): @classmethod def get_user_all_orgs(cls, user): - return [ - *cls.objects.filter(members=user).distinct(), - cls.default() - ] + return cls.objects.filter(members=user).distinct() @classmethod def get_user_admin_orgs(cls, user): diff --git a/apps/users/serializers/profile.py b/apps/users/serializers/profile.py index c89b173ff..1c5e99873 100644 --- a/apps/users/serializers/profile.py +++ b/apps/users/serializers/profile.py @@ -15,11 +15,6 @@ class UserOrgSerializer(serializers.Serializer): is_root = serializers.BooleanField(read_only=True) -class UserOrgLabelSerializer(serializers.Serializer): - value = serializers.CharField(source='id') - label = serializers.CharField(source='name') - - class UserUpdatePasswordSerializer(serializers.ModelSerializer): old_password = serializers.CharField(required=True, max_length=128, write_only=True) new_password = serializers.CharField(required=True, max_length=128, write_only=True) @@ -89,7 +84,7 @@ class UserRoleSerializer(serializers.Serializer): class UserProfileSerializer(UserSerializer): admin_or_audit_orgs = UserOrgSerializer(many=True, read_only=True) - user_all_orgs = UserOrgLabelSerializer(many=True, read_only=True) + user_all_orgs = UserOrgSerializer(many=True, read_only=True) current_org_roles = serializers.ListField(read_only=True) public_key_comment = serializers.CharField( source='get_public_key_comment', required=False, read_only=True, max_length=128 From 1216f15e45734f4cc9c6bb3ce69ddd4979a2caef Mon Sep 17 00:00:00 2001 From: Bai Date: Tue, 16 Mar 2021 20:17:13 +0800 Subject: [PATCH 6/7] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=96=B0=E6=97=A7?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=AF=B9=E4=BA=8Edefault=5Fnode=E8=8A=82?= =?UTF-8?q?=E7=82=B9=E5=8F=98=E6=9B=B4=E5=86=B2=E7=AA=81=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98(=E6=97=A7=E7=89=88=E6=9C=AC=E4=BC=9A=E5=B0=86?= =?UTF-8?q?=E6=96=B0=E7=89=88=E6=9C=AC=E8=BF=81=E7=A7=BB=E5=90=8E=E7=9A=84?= =?UTF-8?q?default=5Fnode=E8=8A=82=E7=82=B9=E7=9A=84key=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E9=9D=9E1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/assets/apps.py | 14 -- apps/assets/models/node.py | 151 +++++------------- .../migrations/0010_auto_20210219_1241.py | 2 +- apps/orgs/models.py | 4 +- apps/perms/const.py | 10 -- 5 files changed, 47 insertions(+), 134 deletions(-) delete mode 100644 apps/perms/const.py diff --git a/apps/assets/apps.py b/apps/assets/apps.py index a7267c7b6..04ed9fded 100644 --- a/apps/assets/apps.py +++ b/apps/assets/apps.py @@ -1,16 +1,6 @@ from __future__ import unicode_literals from django.apps import AppConfig -from django.db.models.signals import post_migrate - - -def initial_some_nodes(): - from .models import Node - Node.initial_some_nodes() - - -def initial_some_nodes_callback(sender, **kwargs): - initial_some_nodes() class AssetsConfig(AppConfig): @@ -19,7 +9,3 @@ class AssetsConfig(AppConfig): def ready(self): super().ready() from . import signals_handler - try: - initial_some_nodes() - except Exception: - post_migrate.connect(initial_some_nodes_callback, sender=self) diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index d2a8c3ccf..ad17a8be9 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -465,44 +465,6 @@ class SomeNodesMixin: empty_key = '-11' empty_value = _("empty") - @classmethod - def correct_default_node_if_need(cls): - with tmp_to_root_org(): - wrong_default_org = cls.objects.filter(key='1', value='Default').first() - if not wrong_default_org: - return - - if wrong_default_org.has_children_or_has_assets(): - return - - default_org = Organization.default() - right_default_org = cls.objects.filter(value=default_org.name).first() - if not right_default_org: - return - - if right_default_org.date_create > wrong_default_org.date_create: - return - - with atomic(): - logger.warn(f'Correct default node: ' - f'old={wrong_default_org.value}-{wrong_default_org.key} ' - f'new={right_default_org.value}-{right_default_org.key}') - wrong_default_org.delete() - right_default_org.key = '1' - right_default_org.save() - - @classmethod - def default_node(cls): - cls.correct_default_node_if_need() - - default_org = Organization.default() - with tmp_to_org(default_org): - defaults = {'value': default_org.name} - obj, created = cls.objects.get_or_create( - defaults=defaults, key=cls.default_key, - ) - return obj - def is_default_node(self): return self.key == self.default_key @@ -513,15 +475,36 @@ class SomeNodesMixin: return False @classmethod - def get_next_org_root_node_key(cls): - with tmp_to_org(Organization.root()): - org_nodes_roots = cls.objects.filter(key__regex=r'^[0-9]+$') - org_nodes_roots_keys = org_nodes_roots.values_list('key', flat=True) - if not org_nodes_roots_keys: - org_nodes_roots_keys = ['1'] - max_key = max([int(k) for k in org_nodes_roots_keys]) - key = str(max_key + 1) if max_key > 0 else '2' - return key + def org_root(cls): + # 如果使用current_org 在set_current_org时会死循环 + ori_org = get_current_org() + + if ori_org and ori_org.is_default(): + return cls.default_node() + + if ori_org and ori_org.is_root(): + return None + + org_roots = cls.org_root_nodes() + org_roots_length = len(org_roots) + + if org_roots_length == 1: + root = org_roots[0] + return root + elif org_roots_length == 0: + root = cls.create_org_root_node() + return root + else: + error = 'Current org {} root node not 1, get {}'.format(ori_org, org_roots_length) + raise ValueError(error) + + @classmethod + def default_node(cls): + default_org = Organization.default() + with tmp_to_org(default_org): + defaults = {'value': default_org.name} + obj, created = cls.objects.get_or_create(defaults=defaults, key=cls.default_key) + return obj @classmethod def create_org_root_node(cls): @@ -531,68 +514,22 @@ class SomeNodesMixin: root = cls.objects.create(key=key, value=ori_org.name) return root + @classmethod + def get_next_org_root_node_key(cls): + with tmp_to_root_org(): + org_nodes_roots = cls.org_root_nodes() + org_nodes_roots_keys = org_nodes_roots.values_list('key', flat=True) + if not org_nodes_roots_keys: + org_nodes_roots_keys = ['1'] + max_key = max([int(k) for k in org_nodes_roots_keys]) + key = str(max_key + 1) if max_key > 0 else '2' + return key + @classmethod def org_root_nodes(cls): - nodes = cls.objects.filter(parent_key='') \ - .filter(key__regex=r'^[0-9]+$') \ - .exclude(key__startswith='-') \ - .order_by('key') - return nodes - - @classmethod - def org_root(cls): - # 如果使用current_org 在set_current_org时会死循环 - ori_org = get_current_org() - - if ori_org and ori_org.is_default(): - return cls.default_node() - if ori_org and ori_org.is_root(): - return None - - org_roots = cls.org_root_nodes() - org_roots_length = len(org_roots) - - if org_roots_length == 1: - return org_roots[0] - elif org_roots_length == 0: - root = cls.create_org_root_node() - return root - else: - raise ValueError('Current org root node not 1, get {}'.format(org_roots_length)) - - @classmethod - def initial_some_nodes(cls): - cls.default_node() - - @classmethod - def modify_other_org_root_node_key(cls): - """ - 解决创建 default 节点失败的问题, - 因为在其他组织下存在 default 节点,故在 DEFAULT 组织下 get 不到 create 失败 - """ - logger.info("Modify other org root node key") - - with tmp_to_org(Organization.root()): - node_key1 = cls.objects.filter(key='1').first() - if not node_key1: - logger.info("Not found node that `key` = 1") - return - if node_key1.org_id == '': - node_key1.org_id = str(Organization.default().id) - node_key1.save() - return - - with transaction.atomic(): - with tmp_to_org(node_key1.org): - org_root_node_new_key = cls.get_next_org_root_node_key() - for n in cls.objects.all(): - old_key = n.key - key_list = n.key.split(':') - key_list[0] = org_root_node_new_key - new_key = ':'.join(key_list) - n.key = new_key - n.save() - logger.info('Modify key ( {} > {} )'.format(old_key, new_key)) + root_nodes = cls.objects.filter(parent_key='', key__regex=r'^[0-9]+$') \ + .exclude(key__startswith='-').order_by('key') + return root_nodes class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin): diff --git a/apps/orgs/migrations/0010_auto_20210219_1241.py b/apps/orgs/migrations/0010_auto_20210219_1241.py index 9a9043f94..37b40e350 100644 --- a/apps/orgs/migrations/0010_auto_20210219_1241.py +++ b/apps/orgs/migrations/0010_auto_20210219_1241.py @@ -6,7 +6,7 @@ import sys from django.db import migrations -default_id = '00000000-0000-0000-0000-000000000001' +default_id = '00000000-0000-0000-0000-000000000002' def add_default_org(apps, schema_editor): diff --git a/apps/orgs/models.py b/apps/orgs/models.py index ac0830592..26f39c231 100644 --- a/apps/orgs/models.py +++ b/apps/orgs/models.py @@ -28,8 +28,8 @@ class Organization(models.Model): ROOT_ID = '00000000-0000-0000-0000-000000000000' ROOT_NAME = _('GLOBAL') - DEFAULT_ID = '00000000-0000-0000-0000-000000000001' - DEFAULT_NAME = 'DEFAULT' + DEFAULT_ID = '00000000-0000-0000-0000-000000000002' + DEFAULT_NAME = 'Default' orgs_mapping = None class Meta: diff --git a/apps/perms/const.py b/apps/perms/const.py deleted file mode 100644 index 476fc9d3f..000000000 --- a/apps/perms/const.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- -# -from django.utils.translation import ugettext_lazy as _ - -UNGROUPED_NODE_ID = "00000000-0000-0000-0000-000000000002" -UNGROUPED_NODE_KEY = '-2' -UNGROUPED_NODE_VALUE = _("Ungrouped") -EMPTY_NODE_ID = "00000000-0000-0000-0000-000000000003" -EMPTY_NODE_KEY = "-3" -EMPTY_NODE_VALUE = _("Empty") From ea325f6e529d7a292e4030df1811e1ae3883c299 Mon Sep 17 00:00:00 2001 From: ibuler Date: Tue, 16 Mar 2021 16:57:56 +0800 Subject: [PATCH 7/7] =?UTF-8?q?perf(users):=20=E4=BC=98=E5=8C=96=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=AE=A4=E8=AF=81=E6=9D=A5=E6=BA=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/authentication/mixins.py | 73 ++++++++++++++++++++++++++++------- apps/users/models/user.py | 15 +++++++ 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index f89938e64..97adb0e67 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -1,13 +1,17 @@ # -*- coding: utf-8 -*- # +import inspect from urllib.parse import urlencode from functools import partial import time from django.conf import settings -from django.contrib.auth import authenticate +from django.contrib import auth +from django.contrib.auth import ( + BACKEND_SESSION_KEY, _get_backends, + PermissionDenied, user_login_failed, _clean_credentials +) from django.shortcuts import reverse -from django.contrib.auth import BACKEND_SESSION_KEY from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get from users.models import User @@ -22,6 +26,59 @@ from .const import RSA_PRIVATE_KEY logger = get_logger(__name__) +def check_backend_can_auth(username, backend_path, allowed_auth_backends): + if allowed_auth_backends is not None and backend_path not in allowed_auth_backends: + logger.debug('Skip user auth backend: {}, {} not in'.format( + username, backend_path, ','.join(allowed_auth_backends) + ) + ) + return False + return True + + +def authenticate(request=None, **credentials): + """ + If the given credentials are valid, return a User object. + """ + username = credentials.get('username') + allowed_auth_backends = User.get_user_allowed_auth_backends(username) + + for backend, backend_path in _get_backends(return_tuples=True): + # 预先检查,不浪费认证时间 + if not check_backend_can_auth(username, backend_path, allowed_auth_backends): + continue + + backend_signature = inspect.signature(backend.authenticate) + try: + backend_signature.bind(request, **credentials) + except TypeError: + # This backend doesn't accept these credentials as arguments. Try the next one. + continue + try: + user = backend.authenticate(request, **credentials) + except PermissionDenied: + # This backend says to stop in our tracks - this user should not be allowed in at all. + break + if user is None: + continue + # 如果是 None, 证明没有检查过, 需要再次检查 + if allowed_auth_backends is None: + # 有些 authentication 参数中不带 username, 之后还要再检查 + allowed_auth_backends = user.get_allowed_auth_backends() + if not check_backend_can_auth(user.username, backend_path, allowed_auth_backends): + continue + + # Annotate the user object with the path of the backend. + user.backend = backend_path + return user + + # The credentials supplied are invalid to all backends, fire signal + user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request) + + +auth.authenticate = authenticate + + class AuthMixin: request = None partial_credential_error = None @@ -121,13 +178,6 @@ class AuthMixin: self.raise_credential_error(errors.reason_user_inactive) return user - def _check_auth_source_is_valid(self, user, auth_backend): - # 限制只能从认证来源登录 - if settings.ONLY_ALLOW_AUTH_FROM_SOURCE: - auth_backends_allowed = user.SOURCE_BACKEND_MAPPING.get(user.source) - if auth_backend not in auth_backends_allowed: - self.raise_credential_error(error=errors.reason_backend_not_match) - def _check_login_acl(self, user, ip): # ACL 限制用户登录 from acls.models import LoginACL @@ -144,9 +194,6 @@ class AuthMixin: user = self._check_auth_user_is_valid(username, password, public_key) # 校验login-acl规则 self._check_login_acl(user, ip) - # 限制只能从认证来源登录 - auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend') - self._check_auth_source_is_valid(user, auth_backend) self._check_password_require_reset_or_not(user) self._check_passwd_is_too_simple(user, password) @@ -154,7 +201,7 @@ class AuthMixin: request.session['auth_password'] = 1 request.session['user_id'] = str(user.id) request.session['auto_login'] = auto_login - request.session['auth_backend'] = auth_backend + request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL) return user @classmethod diff --git a/apps/users/models/user.py b/apps/users/models/user.py index 53665b832..096ac260d 100644 --- a/apps/users/models/user.py +++ b/apps/users/models/user.py @@ -679,6 +679,21 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser): return return super(User, self).delete() + @classmethod + def get_user_allowed_auth_backends(cls, username): + if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE or not username: + # return settings.AUTHENTICATION_BACKENDS + return None + user = cls.objects.filter(username=username).first() + if not user: + return None + return user.get_allowed_auth_backends() + + def get_allowed_auth_backends(self): + if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE: + return None + return self.SOURCE_BACKEND_MAPPING.get(self.source, []) + class Meta: ordering = ['username'] verbose_name = _("User")