mirror of https://github.com/jumpserver/jumpserver
commit
73b57a662e
|
@ -1,16 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.apps import AppConfig
|
||||
from django.db.models.signals import post_migrate
|
||||
|
||||
|
||||
def initial_some_nodes():
|
||||
from .models import Node
|
||||
Node.initial_some_nodes()
|
||||
|
||||
|
||||
def initial_some_nodes_callback(sender, **kwargs):
|
||||
initial_some_nodes()
|
||||
|
||||
|
||||
class AssetsConfig(AppConfig):
|
||||
|
@ -19,7 +9,3 @@ class AssetsConfig(AppConfig):
|
|||
def ready(self):
|
||||
super().ready()
|
||||
from . import signals_handler
|
||||
try:
|
||||
initial_some_nodes()
|
||||
except Exception:
|
||||
post_migrate.connect(initial_some_nodes_callback, sender=self)
|
||||
|
|
|
@ -465,44 +465,6 @@ class SomeNodesMixin:
|
|||
empty_key = '-11'
|
||||
empty_value = _("empty")
|
||||
|
||||
@classmethod
|
||||
def correct_default_node_if_need(cls):
|
||||
with tmp_to_root_org():
|
||||
wrong_default_org = cls.objects.filter(key='1', value='Default').first()
|
||||
if not wrong_default_org:
|
||||
return
|
||||
|
||||
if wrong_default_org.has_children_or_has_assets():
|
||||
return
|
||||
|
||||
default_org = Organization.default()
|
||||
right_default_org = cls.objects.filter(value=default_org.name).first()
|
||||
if not right_default_org:
|
||||
return
|
||||
|
||||
if right_default_org.date_create > wrong_default_org.date_create:
|
||||
return
|
||||
|
||||
with atomic():
|
||||
logger.warn(f'Correct default node: '
|
||||
f'old={wrong_default_org.value}-{wrong_default_org.key} '
|
||||
f'new={right_default_org.value}-{right_default_org.key}')
|
||||
wrong_default_org.delete()
|
||||
right_default_org.key = '1'
|
||||
right_default_org.save()
|
||||
|
||||
@classmethod
|
||||
def default_node(cls):
|
||||
cls.correct_default_node_if_need()
|
||||
|
||||
default_org = Organization.default()
|
||||
with tmp_to_org(default_org):
|
||||
defaults = {'value': default_org.name}
|
||||
obj, created = cls.objects.get_or_create(
|
||||
defaults=defaults, key=cls.default_key,
|
||||
)
|
||||
return obj
|
||||
|
||||
def is_default_node(self):
|
||||
return self.key == self.default_key
|
||||
|
||||
|
@ -513,15 +475,36 @@ class SomeNodesMixin:
|
|||
return False
|
||||
|
||||
@classmethod
|
||||
def get_next_org_root_node_key(cls):
|
||||
with tmp_to_org(Organization.root()):
|
||||
org_nodes_roots = cls.objects.filter(key__regex=r'^[0-9]+$')
|
||||
org_nodes_roots_keys = org_nodes_roots.values_list('key', flat=True)
|
||||
if not org_nodes_roots_keys:
|
||||
org_nodes_roots_keys = ['1']
|
||||
max_key = max([int(k) for k in org_nodes_roots_keys])
|
||||
key = str(max_key + 1) if max_key > 0 else '2'
|
||||
return key
|
||||
def org_root(cls):
|
||||
# 如果使用current_org 在set_current_org时会死循环
|
||||
ori_org = get_current_org()
|
||||
|
||||
if ori_org and ori_org.is_default():
|
||||
return cls.default_node()
|
||||
|
||||
if ori_org and ori_org.is_root():
|
||||
return None
|
||||
|
||||
org_roots = cls.org_root_nodes()
|
||||
org_roots_length = len(org_roots)
|
||||
|
||||
if org_roots_length == 1:
|
||||
root = org_roots[0]
|
||||
return root
|
||||
elif org_roots_length == 0:
|
||||
root = cls.create_org_root_node()
|
||||
return root
|
||||
else:
|
||||
error = 'Current org {} root node not 1, get {}'.format(ori_org, org_roots_length)
|
||||
raise ValueError(error)
|
||||
|
||||
@classmethod
|
||||
def default_node(cls):
|
||||
default_org = Organization.default()
|
||||
with tmp_to_org(default_org):
|
||||
defaults = {'value': default_org.name}
|
||||
obj, created = cls.objects.get_or_create(defaults=defaults, key=cls.default_key)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def create_org_root_node(cls):
|
||||
|
@ -531,68 +514,22 @@ class SomeNodesMixin:
|
|||
root = cls.objects.create(key=key, value=ori_org.name)
|
||||
return root
|
||||
|
||||
@classmethod
|
||||
def get_next_org_root_node_key(cls):
|
||||
with tmp_to_root_org():
|
||||
org_nodes_roots = cls.org_root_nodes()
|
||||
org_nodes_roots_keys = org_nodes_roots.values_list('key', flat=True)
|
||||
if not org_nodes_roots_keys:
|
||||
org_nodes_roots_keys = ['1']
|
||||
max_key = max([int(k) for k in org_nodes_roots_keys])
|
||||
key = str(max_key + 1) if max_key > 0 else '2'
|
||||
return key
|
||||
|
||||
@classmethod
|
||||
def org_root_nodes(cls):
|
||||
nodes = cls.objects.filter(parent_key='') \
|
||||
.filter(key__regex=r'^[0-9]+$') \
|
||||
.exclude(key__startswith='-') \
|
||||
.order_by('key')
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def org_root(cls):
|
||||
# 如果使用current_org 在set_current_org时会死循环
|
||||
ori_org = get_current_org()
|
||||
|
||||
if ori_org and ori_org.is_default():
|
||||
return cls.default_node()
|
||||
if ori_org and ori_org.is_root():
|
||||
return None
|
||||
|
||||
org_roots = cls.org_root_nodes()
|
||||
org_roots_length = len(org_roots)
|
||||
|
||||
if org_roots_length == 1:
|
||||
return org_roots[0]
|
||||
elif org_roots_length == 0:
|
||||
root = cls.create_org_root_node()
|
||||
return root
|
||||
else:
|
||||
raise ValueError('Current org root node not 1, get {}'.format(org_roots_length))
|
||||
|
||||
@classmethod
|
||||
def initial_some_nodes(cls):
|
||||
cls.default_node()
|
||||
|
||||
@classmethod
|
||||
def modify_other_org_root_node_key(cls):
|
||||
"""
|
||||
解决创建 default 节点失败的问题,
|
||||
因为在其他组织下存在 default 节点,故在 DEFAULT 组织下 get 不到 create 失败
|
||||
"""
|
||||
logger.info("Modify other org root node key")
|
||||
|
||||
with tmp_to_org(Organization.root()):
|
||||
node_key1 = cls.objects.filter(key='1').first()
|
||||
if not node_key1:
|
||||
logger.info("Not found node that `key` = 1")
|
||||
return
|
||||
if node_key1.org_id == '':
|
||||
node_key1.org_id = str(Organization.default().id)
|
||||
node_key1.save()
|
||||
return
|
||||
|
||||
with transaction.atomic():
|
||||
with tmp_to_org(node_key1.org):
|
||||
org_root_node_new_key = cls.get_next_org_root_node_key()
|
||||
for n in cls.objects.all():
|
||||
old_key = n.key
|
||||
key_list = n.key.split(':')
|
||||
key_list[0] = org_root_node_new_key
|
||||
new_key = ':'.join(key_list)
|
||||
n.key = new_key
|
||||
n.save()
|
||||
logger.info('Modify key ( {} > {} )'.format(old_key, new_key))
|
||||
root_nodes = cls.objects.filter(parent_key='', key__regex=r'^[0-9]+$') \
|
||||
.exclude(key__startswith='-').order_by('key')
|
||||
return root_nodes
|
||||
|
||||
|
||||
class Node(OrgModelMixin, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
import inspect
|
||||
from urllib.parse import urlencode
|
||||
from functools import partial
|
||||
import time
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.auth import authenticate
|
||||
from django.contrib import auth
|
||||
from django.contrib.auth import (
|
||||
BACKEND_SESSION_KEY, _get_backends,
|
||||
PermissionDenied, user_login_failed, _clean_credentials
|
||||
)
|
||||
from django.shortcuts import reverse
|
||||
from django.contrib.auth import BACKEND_SESSION_KEY
|
||||
|
||||
from common.utils import get_object_or_none, get_request_ip, get_logger, bulk_get
|
||||
from users.models import User
|
||||
|
@ -22,6 +26,59 @@ from .const import RSA_PRIVATE_KEY
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def check_backend_can_auth(username, backend_path, allowed_auth_backends):
|
||||
if allowed_auth_backends is not None and backend_path not in allowed_auth_backends:
|
||||
logger.debug('Skip user auth backend: {}, {} not in'.format(
|
||||
username, backend_path, ','.join(allowed_auth_backends)
|
||||
)
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def authenticate(request=None, **credentials):
|
||||
"""
|
||||
If the given credentials are valid, return a User object.
|
||||
"""
|
||||
username = credentials.get('username')
|
||||
allowed_auth_backends = User.get_user_allowed_auth_backends(username)
|
||||
|
||||
for backend, backend_path in _get_backends(return_tuples=True):
|
||||
# 预先检查,不浪费认证时间
|
||||
if not check_backend_can_auth(username, backend_path, allowed_auth_backends):
|
||||
continue
|
||||
|
||||
backend_signature = inspect.signature(backend.authenticate)
|
||||
try:
|
||||
backend_signature.bind(request, **credentials)
|
||||
except TypeError:
|
||||
# This backend doesn't accept these credentials as arguments. Try the next one.
|
||||
continue
|
||||
try:
|
||||
user = backend.authenticate(request, **credentials)
|
||||
except PermissionDenied:
|
||||
# This backend says to stop in our tracks - this user should not be allowed in at all.
|
||||
break
|
||||
if user is None:
|
||||
continue
|
||||
# 如果是 None, 证明没有检查过, 需要再次检查
|
||||
if allowed_auth_backends is None:
|
||||
# 有些 authentication 参数中不带 username, 之后还要再检查
|
||||
allowed_auth_backends = user.get_allowed_auth_backends()
|
||||
if not check_backend_can_auth(user.username, backend_path, allowed_auth_backends):
|
||||
continue
|
||||
|
||||
# Annotate the user object with the path of the backend.
|
||||
user.backend = backend_path
|
||||
return user
|
||||
|
||||
# The credentials supplied are invalid to all backends, fire signal
|
||||
user_login_failed.send(sender=__name__, credentials=_clean_credentials(credentials), request=request)
|
||||
|
||||
|
||||
auth.authenticate = authenticate
|
||||
|
||||
|
||||
class AuthMixin:
|
||||
request = None
|
||||
partial_credential_error = None
|
||||
|
@ -121,13 +178,6 @@ class AuthMixin:
|
|||
self.raise_credential_error(errors.reason_user_inactive)
|
||||
return user
|
||||
|
||||
def _check_auth_source_is_valid(self, user, auth_backend):
|
||||
# 限制只能从认证来源登录
|
||||
if settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
|
||||
auth_backends_allowed = user.SOURCE_BACKEND_MAPPING.get(user.source)
|
||||
if auth_backend not in auth_backends_allowed:
|
||||
self.raise_credential_error(error=errors.reason_backend_not_match)
|
||||
|
||||
def _check_login_acl(self, user, ip):
|
||||
# ACL 限制用户登录
|
||||
from acls.models import LoginACL
|
||||
|
@ -144,9 +194,6 @@ class AuthMixin:
|
|||
user = self._check_auth_user_is_valid(username, password, public_key)
|
||||
# 校验login-acl规则
|
||||
self._check_login_acl(user, ip)
|
||||
# 限制只能从认证来源登录
|
||||
auth_backend = getattr(user, 'backend', 'django.contrib.auth.backends.ModelBackend')
|
||||
self._check_auth_source_is_valid(user, auth_backend)
|
||||
self._check_password_require_reset_or_not(user)
|
||||
self._check_passwd_is_too_simple(user, password)
|
||||
|
||||
|
@ -154,7 +201,7 @@ class AuthMixin:
|
|||
request.session['auth_password'] = 1
|
||||
request.session['user_id'] = str(user.id)
|
||||
request.session['auto_login'] = auto_login
|
||||
request.session['auth_backend'] = auth_backend
|
||||
request.session['auth_backend'] = getattr(user, 'backend', settings.AUTH_BACKEND_MODEL)
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -17,14 +17,16 @@ from ..serializers import (
|
|||
AdHocDetailSerializer,
|
||||
)
|
||||
from ..tasks import run_ansible_task
|
||||
from orgs.mixins.api import OrgBulkModelViewSet
|
||||
from orgs.utils import current_org
|
||||
|
||||
__all__ = [
|
||||
'TaskViewSet', 'TaskRun', 'AdHocViewSet', 'AdHocRunHistoryViewSet'
|
||||
]
|
||||
|
||||
|
||||
class TaskViewSet(JMSBulkModelViewSet):
|
||||
queryset = Task.objects.all()
|
||||
class TaskViewSet(OrgBulkModelViewSet):
|
||||
model = Task
|
||||
filterset_fields = ("name",)
|
||||
search_fields = filterset_fields
|
||||
serializer_class = TaskSerializer
|
||||
|
|
|
@ -6,7 +6,7 @@ import sys
|
|||
from django.db import migrations
|
||||
|
||||
|
||||
default_id = '00000000-0000-0000-0000-000000000001'
|
||||
default_id = '00000000-0000-0000-0000-000000000002'
|
||||
|
||||
|
||||
def add_default_org(apps, schema_editor):
|
||||
|
|
|
@ -28,8 +28,8 @@ class Organization(models.Model):
|
|||
|
||||
ROOT_ID = '00000000-0000-0000-0000-000000000000'
|
||||
ROOT_NAME = _('GLOBAL')
|
||||
DEFAULT_ID = '00000000-0000-0000-0000-000000000001'
|
||||
DEFAULT_NAME = 'DEFAULT'
|
||||
DEFAULT_ID = '00000000-0000-0000-0000-000000000002'
|
||||
DEFAULT_NAME = 'Default'
|
||||
orgs_mapping = None
|
||||
|
||||
class Meta:
|
||||
|
@ -150,10 +150,7 @@ class Organization(models.Model):
|
|||
|
||||
@classmethod
|
||||
def get_user_all_orgs(cls, user):
|
||||
return [
|
||||
*cls.objects.filter(members=user).distinct(),
|
||||
cls.default()
|
||||
]
|
||||
return cls.objects.filter(members=user).distinct()
|
||||
|
||||
@classmethod
|
||||
def get_user_admin_orgs(cls, user):
|
||||
|
@ -363,13 +360,7 @@ class OrgMemberManager(models.Manager):
|
|||
if role in to_add:
|
||||
to_add[role].add(user)
|
||||
|
||||
self.remove_users_by_role(
|
||||
org,
|
||||
to_remove.users,
|
||||
to_remove.admins,
|
||||
to_remove.auditors
|
||||
)
|
||||
|
||||
# 先添加再移除 (防止用户角色由组织用户->组织管理员时从组织清除用户)
|
||||
self.add_users_by_role(
|
||||
org,
|
||||
to_add.users,
|
||||
|
@ -377,6 +368,13 @@ class OrgMemberManager(models.Manager):
|
|||
to_add.auditors
|
||||
)
|
||||
|
||||
self.remove_users_by_role(
|
||||
org,
|
||||
to_remove.users,
|
||||
to_remove.admins,
|
||||
to_remove.auditors
|
||||
)
|
||||
|
||||
def set_users_by_role(self, org, users=None, admins=None, auditors=None):
|
||||
"""
|
||||
给组织设置带角色的用户
|
||||
|
|
|
@ -89,6 +89,8 @@ class AssetsTreeFormatMixin(SerializeToTreeNodeMixin):
|
|||
"""
|
||||
将 资产 序列化成树的结构返回
|
||||
"""
|
||||
filterset_fields = ['hostname', 'ip', 'id', 'comment']
|
||||
search_fields = ['hostname', 'ip', 'comment']
|
||||
|
||||
def list(self, request: Request, *args, **kwargs):
|
||||
queryset = self.filter_queryset(self.get_queryset())
|
||||
|
@ -99,6 +101,3 @@ class AssetsTreeFormatMixin(SerializeToTreeNodeMixin):
|
|||
queryset = queryset[:999]
|
||||
data = self.serialize_assets(queryset, None)
|
||||
return Response(data=data)
|
||||
|
||||
# def get_serializer_class(self):
|
||||
# return EmptySerializer
|
||||
|
|
|
@ -82,7 +82,7 @@ class MyAllAssetsAsTreeApi(UserAllGrantedAssetsQuerysetMixin,
|
|||
RoleUserMixin,
|
||||
AssetsTreeFormatMixin,
|
||||
ListAPIView):
|
||||
search_fields = ['hostname', 'ip']
|
||||
pass
|
||||
|
||||
|
||||
class UserGrantedNodeAssetsForAdminApi(UserGrantedNodeAssetsMixin,
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
UNGROUPED_NODE_ID = "00000000-0000-0000-0000-000000000002"
|
||||
UNGROUPED_NODE_KEY = '-2'
|
||||
UNGROUPED_NODE_VALUE = _("Ungrouped")
|
||||
EMPTY_NODE_ID = "00000000-0000-0000-0000-000000000003"
|
||||
EMPTY_NODE_KEY = "-3"
|
||||
EMPTY_NODE_VALUE = _("Empty")
|
|
@ -111,6 +111,33 @@ class CommandViewSet(viewsets.ModelViewSet):
|
|||
filterset_class = CommandFilter
|
||||
ordering_fields = ('timestamp', )
|
||||
|
||||
def merge_all_storage_list(self, request, *args, **kwargs):
|
||||
merged_commands = []
|
||||
|
||||
storages = CommandStorage.objects.all()
|
||||
for storage in storages:
|
||||
qs = storage.get_command_queryset()
|
||||
commands = self.filter_queryset(qs)
|
||||
merged_commands.extend(commands)
|
||||
|
||||
merged_commands.sort(key=lambda command: command.timestamp, reverse=True)
|
||||
page = self.paginate_queryset(merged_commands)
|
||||
if page is not None:
|
||||
serializer = self.get_serializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
command_storage_id = self.request.query_params.get('command_storage_id')
|
||||
session_id = self.request.query_params.get('session_id')
|
||||
|
||||
if session_id and not command_storage_id:
|
||||
# 会话里的命令列表肯定会提供 session_id,这里防止 merge 的时候取全量的数据
|
||||
return self.merge_all_storage_list(request, *args, **kwargs)
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
def get_queryset(self):
|
||||
command_storage_id = self.request.query_params.get('command_storage_id')
|
||||
storage = CommandStorage.objects.get(id=command_storage_id)
|
||||
|
|
|
@ -117,11 +117,21 @@ class CommandStore():
|
|||
timestamp_range['lte'] = timestamp__lte
|
||||
|
||||
# 处理组织
|
||||
must_not = []
|
||||
should = []
|
||||
org_id = match.get('org_id')
|
||||
if org_id == '':
|
||||
|
||||
real_default_org_id = '00000000-0000-0000-0000-000000000002'
|
||||
if org_id in (real_default_org_id, ''):
|
||||
match.pop('org_id')
|
||||
must_not.append({'wildcard': {'org_id': '*'}})
|
||||
should.append({
|
||||
'bool':{
|
||||
'must_not': [
|
||||
{
|
||||
'wildcard': {'org_id': '*'}
|
||||
}
|
||||
]}
|
||||
})
|
||||
should.append({'match': {'org_id': real_default_org_id}})
|
||||
|
||||
# 构建 body
|
||||
body = {
|
||||
|
@ -130,7 +140,7 @@ class CommandStore():
|
|||
'must': [
|
||||
{'match': {k: v}} for k, v in match.items()
|
||||
],
|
||||
'must_not': must_not,
|
||||
'should': should,
|
||||
'filter': [
|
||||
{
|
||||
'term': {k: v}
|
||||
|
|
|
@ -48,9 +48,6 @@ class CommandFilter(filters.FilterSet):
|
|||
|
||||
@staticmethod
|
||||
def get_org_id():
|
||||
if current_org.is_default():
|
||||
org_id = ''
|
||||
else:
|
||||
org_id = current_org.id
|
||||
return org_id
|
||||
|
||||
|
|
|
@ -679,6 +679,21 @@ class User(AuthMixin, TokenMixin, RoleMixin, MFAMixin, AbstractUser):
|
|||
return
|
||||
return super(User, self).delete()
|
||||
|
||||
@classmethod
|
||||
def get_user_allowed_auth_backends(cls, username):
|
||||
if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE or not username:
|
||||
# return settings.AUTHENTICATION_BACKENDS
|
||||
return None
|
||||
user = cls.objects.filter(username=username).first()
|
||||
if not user:
|
||||
return None
|
||||
return user.get_allowed_auth_backends()
|
||||
|
||||
def get_allowed_auth_backends(self):
|
||||
if not settings.ONLY_ALLOW_AUTH_FROM_SOURCE:
|
||||
return None
|
||||
return self.SOURCE_BACKEND_MAPPING.get(self.source, [])
|
||||
|
||||
class Meta:
|
||||
ordering = ['username']
|
||||
verbose_name = _("User")
|
||||
|
|
|
@ -15,11 +15,6 @@ class UserOrgSerializer(serializers.Serializer):
|
|||
is_root = serializers.BooleanField(read_only=True)
|
||||
|
||||
|
||||
class UserOrgLabelSerializer(serializers.Serializer):
|
||||
value = serializers.CharField(source='id')
|
||||
label = serializers.CharField(source='name')
|
||||
|
||||
|
||||
class UserUpdatePasswordSerializer(serializers.ModelSerializer):
|
||||
old_password = serializers.CharField(required=True, max_length=128, write_only=True)
|
||||
new_password = serializers.CharField(required=True, max_length=128, write_only=True)
|
||||
|
@ -89,7 +84,7 @@ class UserRoleSerializer(serializers.Serializer):
|
|||
|
||||
class UserProfileSerializer(UserSerializer):
|
||||
admin_or_audit_orgs = UserOrgSerializer(many=True, read_only=True)
|
||||
user_all_orgs = UserOrgLabelSerializer(many=True, read_only=True)
|
||||
user_all_orgs = UserOrgSerializer(many=True, read_only=True)
|
||||
current_org_roles = serializers.ListField(read_only=True)
|
||||
public_key_comment = serializers.CharField(
|
||||
source='get_public_key_comment', required=False, read_only=True, max_length=128
|
||||
|
|
Loading…
Reference in New Issue