Merge pull request #5766 from jumpserver/dev

v2.8 发版 (rc4)
pull/5813/head
老广 2021-03-16 20:49:04 +08:00 committed by GitHub
commit 73b57a662e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 181 additions and 178 deletions

View File

@ -1,16 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.apps import AppConfig from django.apps import AppConfig
from django.db.models.signals import post_migrate
def initial_some_nodes():
from .models import Node
Node.initial_some_nodes()
def initial_some_nodes_callback(sender, **kwargs):
initial_some_nodes()
class AssetsConfig(AppConfig): class AssetsConfig(AppConfig):
@ -19,7 +9,3 @@ class AssetsConfig(AppConfig):
def ready(self): def ready(self):
super().ready() super().ready()
from . import signals_handler from . import signals_handler
try:
initial_some_nodes()
except Exception:
post_migrate.connect(initial_some_nodes_callback, sender=self)

View File

@ -465,44 +465,6 @@ class SomeNodesMixin:
empty_key = '-11' empty_key = '-11'
empty_value = _("empty") empty_value = _("empty")
@classmethod
def correct_default_node_if_need(cls):
with tmp_to_root_org():
wrong_default_org = cls.objects.filter(key='1', value='Default').first()
if not wrong_default_org:
return
if wrong_default_org.has_children_or_has_assets():
return
default_org = Organization.default()
right_default_org = cls.objects.filter(value=default_org.name).first()
if not right_default_org:
return
if right_default_org.date_create > wrong_default_org.date_create:
return
with atomic():
logger.warn(f'Correct default node: '
f'old={wrong_default_org.value}-{wrong_default_org.key} '
f'new={right_default_org.value}-{right_default_org.key}')
wrong_default_org.delete()
right_default_org.key = '1'
right_default_org.save()
@classmethod
def default_node(cls):
cls.correct_default_node_if_need()
default_org = Organization.default()
with tmp_to_org(default_org):
defaults = {'value': default_org.name}
obj, created = cls.objects.get_or_create(
defaults=defaults, key=cls.default_key,
)
return obj
def is_default_node(self): def is_default_node(self):
return self.key == self.default_key return self.key == self.default_key
@ -513,15 +475,36 @@ class SomeNodesMixin:
return False return False
@classmethod @classmethod
def get_next_org_root_node_key(cls): def org_root(cls):
with tmp_to_org(Organization.root()): # 如果使用current_org 在set_current_org时会死循环
org_nodes_roots = cls.objects.filter(key__regex=r'^[0-9]+$') ori_org = get_current_org()
org_nodes_roots_keys = org_nodes_roots.values_list('key', flat=True)
if not org_nodes_roots_keys: if ori_org and ori_org.is_default():
org_nodes_roots_keys = ['1'] return cls.default_node()
max_key = max([int(k) for k in org_nodes_roots_keys])
key = str(max_key + 1) if max_key > 0 else '2' if ori_org and ori_org.is_root():
return key return None
org_roots = cls.org_root_nodes()
org_roots_length = len(org_roots)
if org_roots_length == 1:
root = org_roots[0]
return root
elif org_roots_length == 0:
root = cls.create_org_root_node()
return root
else:
error = 'Current org {} root node not 1, get {}'.format(ori_org, org_roots_length)
raise ValueError(error)
@classmethod
def default_node(cls):
default_org = Organization.default()
with tmp_to_org(default_org):
defaults = {'value': default_org.name}
obj, created = cls.objects.get_or_create(defaults=defaults, key=cls.default_key)
return obj
@classmethod @classmethod
def create_org_root_node(cls): def create_org_root_node(cls):
@ -531,68 +514,22 @@ class SomeNodesMixin:
root = cls.objects.create(key=key, value=ori_org.name) root = cls.objects.create(key=key, value=ori_org.name)
return root return root
@classmethod
def get_next_org_root_node_key(cls):
with tmp_to_root_org():
org_nodes_roots = cls.org_root_nodes()
org_nodes_roots_keys = org_nodes_roots.values_list('key', flat=True)
if not org_nodes_roots_keys:
org_nodes_roots_keys = ['1']
max_key = max([int(k) for k in org_nodes_roots_keys])
key = str(max_key + 1) if max_key > 0 else '2'
return key
@classmethod @classmethod
def org_root_nodes(cls): def org_root_nodes(cls):
nodes = cls.objects.filter(parent_key='') \ root_nodes = cls.objects.filter(parent_key='', key__regex=r'^[0-9]+$') \
.filter(key__regex=r'^[0-9]+$') \ .exclude(key__startswith='-').order_by('key')
.exclude(key__startswith='-') \ return root_nodes
.order_by('key')
return nodes
@classmethod
def org_root(cls):
# 如果使用current_org 在set_current_org时会死循环
ori_org = get_current_org()
if ori_org and ori_org.is_default():
return cls.default_node()
if ori_org and ori_org.is_root():
return None
org_roots = cls.org_root_nodes()
org_roots_length = len(org_roots)
if org_roots_length == 1:
return org_roots[0]
elif org_roots_length == 0:
root = cls.create_org_root_node()
return root
else:
raise ValueError('Current org root node not 1, get {}'.format(org_roots_length))
@classmethod
def initial_some_nodes(cls):
cls.default_node()
@classmethod
def modify_other_org_root_node_key(cls):
"""
解决创建 default 节点失败的问题
因为在其他组织下存在 default 节点故在 DEFAULT 组织下 get 不到 create 失败
"""
logger.info("Modify other org root node key")
with tmp_to_org(Organization.root()):
node_key1 = cls.objects.filter(key='1').first()
if not node_key1:
logger.info("Not found node that `key` = 1")
return
if node_key1.org_id == '':
node_key1.org_id = str(Organization.default().id)
node_key1.save()
return
with transaction.atomic():
with tmp_to_org(node_key1.org):
org_root_node_new_key = cls.get_next_org_root_node_key()
for n in cls.objects.all():
old_key = n.key
key_list = n.key.split(':')
key_list[0] = org_root_node_new_key
new_key = ':'.join(key_list)
n.key = new_key
n.save()
logger.info('Modify key ( {} > {} )'.format(old_key, new_key))
class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin): class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):

View File

@ -1,13 +1,17 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
import inspect
from urllib.parse import urlencode from urllib.parse import urlencode
from functools import partial from functools import partial
import time import time
from django.conf import settings from django.conf import settings
from django.contrib.auth import authenticate from django.contrib import auth
from django.contrib.auth import (
BACKEND_SESSION_KEY, _get_backends,
PermissionDenied, user_login_failed, _clean_credentials
)
from django.shortcuts import reverse from django.shortcuts import reverse
from django.contrib.auth import BACKEND_SESSION_KEY
from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get
from users.models import User from users.models import User
@ -22,6 +26,59 @@ from .const import RSA_PRIVATE_KEY
logger = get_logger(__name__) logger = get_logger(__name__)
def check_backend_can_auth(username, backend_path, allowed_auth_backends):
if allowed_auth_backends is not None and backend_path not in allowed_auth_backends:
logger.debug('Skip user auth backend: {}, {} not in'.format(
username, backend_path, ','.join(allowed_auth_backends)
)
)
return False
return True
def authenticate(request=None, **credentials):
"""
If the given credentials are valid, return a User object.
"""
username = credentials.get('username')
allowed_auth_backends = User.get_user_allowed_auth_backends(username)
for backend, backend_path in _get_backends(return_tuples=True):
# 预先检查,不浪费认证时间
if not check_backend_can_auth(username, backend_path, allowed_auth_backends):
continue
backend_signature = inspect.signature(backend.authenticate)
try:
backend_signature.bind(request, **credentials)
except TypeError:
# This backend doesn't accept these credentials as arguments. Try the next one.
continue
try:
user = backend.authenticate(request, **credentials)
except PermissionDenied:
# This backend says to stop in our tracks - this user should not be allowed in at all.
break
if user is None:
continue
# 如果是 None, 证明没有检查过, 需要再次检查
if allowed_auth_backends is None:
# 有些 authentication 参数中不带 username, 之后还要再检查
allowed_auth_backends = user.get_allowed_auth_backends()
if not check_backend_can_auth(user.username, backend_path, allowed_auth_backends):
continue
# Annotate the user object with the path of the backend.
user.backend = backend_path
return user
# The credentials supplied are invalid to all backends, fire signal
user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request)
auth.authenticate = authenticate
class AuthMixin: class AuthMixin:
request = None request = None
partial_credential_error = None partial_credential_error = None
@ -121,13 +178,6 @@ class AuthMixin:
self.raise_credential_error(errors.reason_user_inactive) self.raise_credential_error(errors.reason_user_inactive)
return user return user
def _check_auth_source_is_valid(self, user, auth_backend):
# 限制只能从认证来源登录
if settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
auth_backends_allowed = user.SOURCE_BACKEND_MAPPING.get(user.source)
if auth_backend not in auth_backends_allowed:
self.raise_credential_error(error=errors.reason_backend_not_match)
def _check_login_acl(self, user, ip): def _check_login_acl(self, user, ip):
# ACL 限制用户登录 # ACL 限制用户登录
from acls.models import LoginACL from acls.models import LoginACL
@ -144,9 +194,6 @@ class AuthMixin:
user = self._check_auth_user_is_valid(username, password, public_key) user = self._check_auth_user_is_valid(username, password, public_key)
# 校验login-acl规则 # 校验login-acl规则
self._check_login_acl(user, ip) self._check_login_acl(user, ip)
# 限制只能从认证来源登录
auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend')
self._check_auth_source_is_valid(user, auth_backend)
self._check_password_require_reset_or_not(user) self._check_password_require_reset_or_not(user)
self._check_passwd_is_too_simple(user, password) self._check_passwd_is_too_simple(user, password)
@ -154,7 +201,7 @@ class AuthMixin:
request.session['auth_password'] = 1 request.session['auth_password'] = 1
request.session['user_id'] = str(user.id) request.session['user_id'] = str(user.id)
request.session['auto_login'] = auto_login request.session['auto_login'] = auto_login
request.session['auth_backend'] = auth_backend request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL)
return user return user
@classmethod @classmethod

View File

@ -17,14 +17,16 @@ from ..serializers import (
AdHocDetailSerializer, AdHocDetailSerializer,
) )
from ..tasks import run_ansible_task from ..tasks import run_ansible_task
from orgs.mixins.api import OrgBulkModelViewSet
from orgs.utils import current_org
__all__ = [ __all__ = [
'TaskViewSet', 'TaskRun', 'AdHocViewSet', 'AdHocRunHistoryViewSet' 'TaskViewSet', 'TaskRun', 'AdHocViewSet', 'AdHocRunHistoryViewSet'
] ]
class TaskViewSet(JMSBulkModelViewSet): class TaskViewSet(OrgBulkModelViewSet):
queryset = Task.objects.all() model = Task
filterset_fields = ("name",) filterset_fields = ("name",)
search_fields = filterset_fields search_fields = filterset_fields
serializer_class = TaskSerializer serializer_class = TaskSerializer

View File

@ -6,7 +6,7 @@ import sys
from django.db import migrations from django.db import migrations
default_id = '00000000-0000-0000-0000-000000000001' default_id = '00000000-0000-0000-0000-000000000002'
def add_default_org(apps, schema_editor): def add_default_org(apps, schema_editor):

View File

@ -28,8 +28,8 @@ class Organization(models.Model):
ROOT_ID = '00000000-0000-0000-0000-000000000000' ROOT_ID = '00000000-0000-0000-0000-000000000000'
ROOT_NAME = _('GLOBAL') ROOT_NAME = _('GLOBAL')
DEFAULT_ID = '00000000-0000-0000-0000-000000000001' DEFAULT_ID = '00000000-0000-0000-0000-000000000002'
DEFAULT_NAME = 'DEFAULT' DEFAULT_NAME = 'Default'
orgs_mapping = None orgs_mapping = None
class Meta: class Meta:
@ -150,10 +150,7 @@ class Organization(models.Model):
@classmethod @classmethod
def get_user_all_orgs(cls, user): def get_user_all_orgs(cls, user):
return [ return cls.objects.filter(members=user).distinct()
*cls.objects.filter(members=user).distinct(),
cls.default()
]
@classmethod @classmethod
def get_user_admin_orgs(cls, user): def get_user_admin_orgs(cls, user):
@ -363,13 +360,7 @@ class OrgMemberManager(models.Manager):
if role in to_add: if role in to_add:
to_add[role].add(user) to_add[role].add(user)
self.remove_users_by_role( # 先添加再移除 (防止用户角色由组织用户->组织管理员时从组织清除用户)
org,
to_remove.users,
to_remove.admins,
to_remove.auditors
)
self.add_users_by_role( self.add_users_by_role(
org, org,
to_add.users, to_add.users,
@ -377,6 +368,13 @@ class OrgMemberManager(models.Manager):
to_add.auditors to_add.auditors
) )
self.remove_users_by_role(
org,
to_remove.users,
to_remove.admins,
to_remove.auditors
)
def set_users_by_role(self, org, users=None, admins=None, auditors=None): def set_users_by_role(self, org, users=None, admins=None, auditors=None):
""" """
给组织设置带角色的用户 给组织设置带角色的用户

View File

@ -89,6 +89,8 @@ class AssetsTreeFormatMixin(SerializeToTreeNodeMixin):
""" """
资产 序列化成树的结构返回 资产 序列化成树的结构返回
""" """
filterset_fields = ['hostname', 'ip', 'id', 'comment']
search_fields = ['hostname', 'ip', 'comment']
def list(self, request: Request, *args, **kwargs): def list(self, request: Request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
@ -99,6 +101,3 @@ class AssetsTreeFormatMixin(SerializeToTreeNodeMixin):
queryset = queryset[:999] queryset = queryset[:999]
data = self.serialize_assets(queryset, None) data = self.serialize_assets(queryset, None)
return Response(data=data) return Response(data=data)
# def get_serializer_class(self):
# return EmptySerializer

View File

@ -82,7 +82,7 @@ class MyAllAssetsAsTreeApi(UserAllGrantedAssetsQuerysetMixin,
RoleUserMixin, RoleUserMixin,
AssetsTreeFormatMixin, AssetsTreeFormatMixin,
ListAPIView): ListAPIView):
search_fields = ['hostname', 'ip'] pass
class UserGrantedNodeAssetsForAdminApi(UserGrantedNodeAssetsMixin, class UserGrantedNodeAssetsForAdminApi(UserGrantedNodeAssetsMixin,

View File

@ -1,10 +0,0 @@
# -*- coding: utf-8 -*-
#
from django.utils.translation import ugettext_lazy as _
UNGROUPED_NODE_ID = "00000000-0000-0000-0000-000000000002"
UNGROUPED_NODE_KEY = '-2'
UNGROUPED_NODE_VALUE = _("Ungrouped")
EMPTY_NODE_ID = "00000000-0000-0000-0000-000000000003"
EMPTY_NODE_KEY = "-3"
EMPTY_NODE_VALUE = _("Empty")

View File

@ -111,6 +111,33 @@ class CommandViewSet(viewsets.ModelViewSet):
filterset_class = CommandFilter filterset_class = CommandFilter
ordering_fields = ('timestamp', ) ordering_fields = ('timestamp', )
def merge_all_storage_list(self, request, *args, **kwargs):
merged_commands = []
storages = CommandStorage.objects.all()
for storage in storages:
qs = storage.get_command_queryset()
commands = self.filter_queryset(qs)
merged_commands.extend(commands)
merged_commands.sort(key=lambda command: command.timestamp, reverse=True)
page = self.paginate_queryset(merged_commands)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
def list(self, request, *args, **kwargs):
command_storage_id = self.request.query_params.get('command_storage_id')
session_id = self.request.query_params.get('session_id')
if session_id and not command_storage_id:
# 会话里的命令列表肯定会提供 session_id这里防止 merge 的时候取全量的数据
return self.merge_all_storage_list(request, *args, **kwargs)
return super().list(request, *args, **kwargs)
def get_queryset(self): def get_queryset(self):
command_storage_id = self.request.query_params.get('command_storage_id') command_storage_id = self.request.query_params.get('command_storage_id')
storage = CommandStorage.objects.get(id=command_storage_id) storage = CommandStorage.objects.get(id=command_storage_id)

View File

@ -117,11 +117,21 @@ class CommandStore():
timestamp_range['lte'] = timestamp__lte timestamp_range['lte'] = timestamp__lte
# 处理组织 # 处理组织
must_not = [] should = []
org_id = match.get('org_id') org_id = match.get('org_id')
if org_id == '':
real_default_org_id = '00000000-0000-0000-0000-000000000002'
if org_id in (real_default_org_id, ''):
match.pop('org_id') match.pop('org_id')
must_not.append({'wildcard': {'org_id': '*'}}) should.append({
'bool':{
'must_not': [
{
'wildcard': {'org_id': '*'}
}
]}
})
should.append({'match': {'org_id': real_default_org_id}})
# 构建 body # 构建 body
body = { body = {
@ -130,7 +140,7 @@ class CommandStore():
'must': [ 'must': [
{'match': {k: v}} for k, v in match.items() {'match': {k: v}} for k, v in match.items()
], ],
'must_not': must_not, 'should': should,
'filter': [ 'filter': [
{ {
'term': {k: v} 'term': {k: v}

View File

@ -48,10 +48,7 @@ class CommandFilter(filters.FilterSet):
@staticmethod @staticmethod
def get_org_id(): def get_org_id():
if current_org.is_default(): org_id = current_org.id
org_id = ''
else:
org_id = current_org.id
return org_id return org_id

View File

@ -679,6 +679,21 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
return return
return super(User, self).delete() return super(User, self).delete()
@classmethod
def get_user_allowed_auth_backends(cls, username):
if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE or not username:
# return settings.AUTHENTICATION_BACKENDS
return None
user = cls.objects.filter(username=username).first()
if not user:
return None
return user.get_allowed_auth_backends()
def get_allowed_auth_backends(self):
if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
return None
return self.SOURCE_BACKEND_MAPPING.get(self.source, [])
class Meta: class Meta:
ordering = ['username'] ordering = ['username']
verbose_name = _("User") verbose_name = _("User")

View File

@ -15,11 +15,6 @@ class UserOrgSerializer(serializers.Serializer):
is_root = serializers.BooleanField(read_only=True) is_root = serializers.BooleanField(read_only=True)
class UserOrgLabelSerializer(serializers.Serializer):
value = serializers.CharField(source='id')
label = serializers.CharField(source='name')
class UserUpdatePasswordSerializer(serializers.ModelSerializer): class UserUpdatePasswordSerializer(serializers.ModelSerializer):
old_password = serializers.CharField(required=True, max_length=128, write_only=True) old_password = serializers.CharField(required=True, max_length=128, write_only=True)
new_password = serializers.CharField(required=True, max_length=128, write_only=True) new_password = serializers.CharField(required=True, max_length=128, write_only=True)
@ -89,7 +84,7 @@ class UserRoleSerializer(serializers.Serializer):
class UserProfileSerializer(UserSerializer): class UserProfileSerializer(UserSerializer):
admin_or_audit_orgs = UserOrgSerializer(many=True, read_only=True) admin_or_audit_orgs = UserOrgSerializer(many=True, read_only=True)
user_all_orgs = UserOrgLabelSerializer(many=True, read_only=True) user_all_orgs = UserOrgSerializer(many=True, read_only=True)
current_org_roles = serializers.ListField(read_only=True) current_org_roles = serializers.ListField(read_only=True)
public_key_comment = serializers.CharField( public_key_comment = serializers.CharField(
source='get_public_key_comment', required=False, read_only=True, max_length=128 source='get_public_key_comment', required=False, read_only=True, max_length=128