Merge pull request #12566 from jumpserver/master

v3.10.2
pull/12580/head
Bryan 2024-01-17 07:34:28 -04:00 committed by GitHub
commit baa75dc735
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
85 changed files with 2305 additions and 1409 deletions

View File

@ -113,7 +113,7 @@ JumpServer是一款安全产品请参考 [基本安全建议](https://docs.ju
## License & Copyright
Copyright (c) 2014-2023 飞致云 FIT2CLOUD, All rights reserved.
Copyright (c) 2014-2024 飞致云 FIT2CLOUD, All rights reserved.
Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in
compliance with the License. You may obtain a copy of the License at

View File

@ -145,9 +145,9 @@ class AccountBackupHandler:
wb = Workbook(filename)
for sheet, data in data_map.items():
ws = wb.add_worksheet(str(sheet))
for row in data:
for col, _data in enumerate(row):
ws.write_string(0, col, _data)
for row_index, row_data in enumerate(data):
for col_index, col_data in enumerate(row_data):
ws.write_string(row_index, col_index, col_data)
wb.close()
files.append(filename)
timedelta = round((time.time() - time_start), 2)

View File

@ -161,7 +161,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
print("Account not found, deleted ?")
return
account.secret = recorder.new_secret
account.save(update_fields=['secret'])
account.date_updated = timezone.now()
account.save(update_fields=['secret', 'date_updated'])
def on_host_error(self, host, error, result):
recorder = self.name_recorder_mapper.get(host)
@ -228,8 +229,8 @@ class ChangeSecretManager(AccountBasePlaybookManager):
rows.insert(0, header)
wb = Workbook(filename)
ws = wb.add_worksheet('Sheet1')
for row in rows:
for col, data in enumerate(row):
ws.write_string(0, col, data)
for row_index, row_data in enumerate(rows):
for col_index, col_data in enumerate(row_data):
ws.write_string(row_index, col_index, col_data)
wb.close()
return True

View File

@ -21,7 +21,8 @@ def on_account_pre_save(sender, instance, **kwargs):
if instance.version == 0:
instance.version = 1
else:
instance.version = instance.history.count()
history_account = instance.history.first()
instance.version = history_account.version + 1 if history_account else 0
@merge_delay_run(ttl=5)

View File

@ -1,9 +1,19 @@
from celery import shared_task
import uuid
from collections import defaultdict
from celery import shared_task, current_task
from django.conf import settings
from django.db.models import Count
from django.utils.translation import gettext_noop, gettext_lazy as _
from accounts.const import AutomationTypes
from accounts.models import Account
from accounts.tasks.common import quickstart_automation_by_snapshot
from audits.const import ActivityChoices
from common.const.crontab import CRONTAB_AT_AM_TWO
from common.utils import get_logger
from ops.celery.decorator import register_as_period_task
from orgs.utils import tmp_to_root_org
logger = get_logger(__file__)
@ -29,3 +39,39 @@ def remove_accounts_task(gather_account_ids):
tp = AutomationTypes.remove_account
quickstart_automation_by_snapshot(task_name, tp, task_snapshot)
@shared_task(verbose_name=_('Clean historical accounts'))
@register_as_period_task(crontab=CRONTAB_AT_AM_TWO)
@tmp_to_root_org()
def clean_historical_accounts():
from audits.signal_handlers import create_activities
print("Clean historical accounts start.")
if settings.HISTORY_ACCOUNT_CLEAN_LIMIT >= 999:
return
limit = settings.HISTORY_ACCOUNT_CLEAN_LIMIT
history_ids_to_be_deleted = []
history_model = Account.history.model
history_id_mapper = defaultdict(list)
ids = history_model.objects.values('id').annotate(count=Count('id')) \
.filter(count__gte=limit).values_list('id', flat=True)
if not ids:
return
for i in history_model.objects.filter(id__in=ids):
_id = str(i.id)
history_id_mapper[_id].append(i.history_id)
for history_ids in history_id_mapper.values():
history_ids_to_be_deleted.extend(history_ids[limit:])
history_qs = history_model.objects.filter(history_id__in=history_ids_to_be_deleted)
resource_ids = list(history_qs.values_list('history_id', flat=True))
history_qs.delete()
task_id = current_task.request.id if current_task else str(uuid.uuid4())
detail = gettext_noop('Remove historical accounts that are out of range.')
create_activities(resource_ids, detail, task_id, action=ActivityChoices.task, org_id='')

View File

@ -21,7 +21,6 @@ from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend
from common.utils import get_logger, is_uuid
from orgs.mixins import generics
from orgs.mixins.api import OrgBulkModelViewSet
from ..mixin import NodeFilterMixin
from ...notifications import BulkUpdatePlatformSkipAssetUserMsg
logger = get_logger(__file__)
@ -86,7 +85,7 @@ class AssetFilterSet(BaseFilterSet):
return queryset.filter(protocols__name__in=value).distinct()
class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet):
"""
API endpoint that allows Asset to be viewed or edited.
"""
@ -114,9 +113,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet):
]
def get_queryset(self):
queryset = super().get_queryset() \
.prefetch_related('nodes', 'protocols') \
.select_related('platform', 'domain')
queryset = super().get_queryset()
if queryset.model is not Asset:
queryset = queryset.select_related('asset_ptr')
return queryset

View File

@ -20,14 +20,15 @@ class DomainViewSet(OrgBulkModelViewSet):
filterset_fields = ("name",)
search_fields = filterset_fields
ordering = ('name',)
serializer_classes = {
'default': serializers.DomainSerializer,
'list': serializers.DomainListSerializer,
}
def get_serializer_class(self):
if self.request.query_params.get('gateway'):
return serializers.DomainWithGatewaySerializer
return serializers.DomainSerializer
def get_queryset(self):
return super().get_queryset().prefetch_related('assets')
return super().get_serializer_class()
class GatewayViewSet(HostViewSet):

View File

@ -2,7 +2,7 @@ from typing import List
from rest_framework.request import Request
from assets.models import Node, Protocol
from assets.models import Node, Platform, Protocol
from assets.utils import get_node_from_request, is_query_node_all_assets
from common.utils import lazyproperty, timeit
@ -71,37 +71,49 @@ class SerializeToTreeNodeMixin:
return 'file'
@timeit
def serialize_assets(self, assets, node_key=None, pid=None):
if node_key is None:
get_pid = lambda asset: getattr(asset, 'parent_key', '')
else:
get_pid = lambda asset: node_key
def serialize_assets(self, assets, node_key=None, get_pid=None):
if not get_pid and not node_key:
get_pid = lambda asset, platform: getattr(asset, 'parent_key', '')
sftp_asset_ids = Protocol.objects.filter(name='sftp') \
.values_list('asset_id', flat=True)
sftp_asset_ids = list(sftp_asset_ids)
data = [
{
sftp_asset_ids = set(sftp_asset_ids)
platform_map = {p.id: p for p in Platform.objects.all()}
data = []
root_assets_count = 0
for asset in assets:
platform = platform_map.get(asset.platform_id)
if not platform:
continue
pid = node_key or get_pid(asset, platform)
if not pid:
continue
# 根节点最多显示 1000 个资产
if pid.isdigit():
if root_assets_count > 1000:
continue
root_assets_count += 1
data.append({
'id': str(asset.id),
'name': asset.name,
'title': f'{asset.address}\n{asset.comment}',
'pId': pid or get_pid(asset),
'title': f'{asset.address}\n{asset.comment}'.strip(),
'pId': pid,
'isParent': False,
'open': False,
'iconSkin': self.get_icon(asset),
'iconSkin': self.get_icon(platform),
'chkDisabled': not asset.is_active,
'meta': {
'type': 'asset',
'data': {
'platform_type': asset.platform.type,
'platform_type': platform.type,
'org_name': asset.org_name,
'sftp': asset.id in sftp_asset_ids,
'name': asset.name,
'address': asset.address
},
}
}
for asset in assets
]
})
return data

View File

@ -29,7 +29,10 @@ class AssetPlatformViewSet(JMSModelViewSet):
}
def get_queryset(self):
queryset = super().get_queryset()
# 因为没有走分页逻辑,所以需要这里 prefetch
queryset = super().get_queryset().prefetch_related(
'protocols', 'automation', 'labels', 'labels__label',
)
queryset = queryset.filter(type__in=AllTypes.get_types_values())
return queryset

View File

@ -126,6 +126,8 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi):
include_assets = self.request.query_params.get('assets', '0') == '1'
if not self.instance or not include_assets:
return Asset.objects.none()
if self.instance.is_org_root():
return Asset.objects.none()
if query_all:
assets = self.instance.get_all_assets()
else:

View File

@ -268,7 +268,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'category', 'category': category.value, '_type': category.value}
category_node = cls.choice_to_node(category, 'ROOT', meta=meta)
category_count = category_type_mapper.get(category, 0)
category_node['name'] += f'({category_count})'
category_node['name'] += f' ({category_count})'
nodes.append(category_node)
# Type 格式化
@ -277,7 +277,7 @@ class AllTypes(ChoicesMixin):
meta = {'type': 'type', 'category': category.value, '_type': tp.value}
tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta)
tp_count = category_type_mapper.get(category + '_' + tp, 0)
tp_node['name'] += f'({tp_count})'
tp_node['name'] += f' ({tp_count})'
platforms = tp_platforms.get(category + '_' + tp, [])
if not platforms:
tp_node['isParent'] = False
@ -286,7 +286,7 @@ class AllTypes(ChoicesMixin):
# Platform 格式化
for p in platforms:
platform_node = cls.platform_to_node(p, tp_node['id'], include_asset)
platform_node['name'] += f'({platform_count.get(p.id, 0)})'
platform_node['name'] += f' ({platform_count.get(p.id, 0)})'
nodes.append(platform_node)
return nodes

View File

@ -63,11 +63,10 @@ class NodeFilterBackend(filters.BaseFilterBackend):
query_all = is_query_node_all_assets(request)
if query_all:
return queryset.filter(
Q(nodes__key__istartswith=f'{node.key}:') |
Q(nodes__key__startswith=f'{node.key}:') |
Q(nodes__key=node.key)
).distinct()
else:
print("Query query origin: ", queryset.count())
return queryset.filter(nodes__key=node.key).distinct()

View File

@ -13,7 +13,7 @@ from django.db.transaction import atomic
from django.utils.translation import gettext_lazy as _, gettext
from common.db.models import output_as_string
from common.utils import get_logger
from common.utils import get_logger, timeit
from common.utils.lock import DistributedLock
from orgs.mixins.models import OrgManager, JMSOrgBaseModel
from orgs.models import Organization
@ -195,11 +195,6 @@ class FamilyMixin:
ancestor_keys = self.get_ancestor_keys(with_self=with_self)
return self.__class__.objects.filter(key__in=ancestor_keys)
# @property
# def parent_key(self):
# parent_key = ":".join(self.key.split(":")[:-1])
# return parent_key
def compute_parent_key(self):
return compute_parent_key(self.key)
@ -349,29 +344,26 @@ class NodeAllAssetsMappingMixin:
return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id)
@classmethod
@timeit
def generate_node_all_asset_ids_mapping(cls, org_id):
from .asset import Asset
logger.info(f'Generate node asset mapping: '
f'thread={threading.get_ident()} '
f'org_id={org_id}')
logger.info(f'Generate node asset mapping: org_id={org_id}')
t1 = time.time()
with tmp_to_org(org_id):
node_ids_key = Node.objects.annotate(
char_id=output_as_string('id')
).values_list('char_id', 'key')
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = Asset.nodes.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
node_id_ancestor_keys_mapping = {
node_id: cls.get_node_ancestor_keys(node_key, with_self=True)
for node_id, node_key in node_ids_key
}
# * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢)
nodes_asset_ids = cls.assets.through.objects.all() \
.annotate(char_node_id=output_as_string('node_id')) \
.annotate(char_asset_id=output_as_string('asset_id')) \
.values_list('char_node_id', 'char_asset_id')
nodeid_assetsid_mapping = defaultdict(set)
for node_id, asset_id in nodes_asset_ids:
nodeid_assetsid_mapping[node_id].add(asset_id)
@ -386,7 +378,7 @@ class NodeAllAssetsMappingMixin:
mapping[ancestor_key].update(asset_ids)
t3 = time.time()
logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2 - t1, t3 - t2))
logger.info('Generate asset nodes mapping, DB query: {:.2f}s, mapping: {:.2f}s'.format(t2 - t1, t3 - t2))
return mapping
@ -436,6 +428,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin):
return asset_ids
@classmethod
@timeit
def get_nodes_all_assets(cls, *nodes):
from .asset import Asset
node_ids = set()
@ -559,11 +552,6 @@ class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin):
def __str__(self):
return self.full_value
# def __eq__(self, other):
# if not other:
# return False
# return self.id == other.id
#
def __gt__(self, other):
self_key = [int(k) for k in self.key.split(':')]
other_key = [int(k) for k in other.key.split(':')]

View File

@ -1,8 +1,8 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.request import Request
from common.utils import get_logger
from assets.models import Node
from common.utils import get_logger
logger = get_logger(__name__)
@ -28,6 +28,7 @@ class AssetPaginationBase(LimitOffsetPagination):
'key', 'all', 'show_current_asset',
'cache_policy', 'display', 'draw',
'order', 'node', 'node_id', 'fields_size',
'asset'
}
for k, v in self._request.query_params.items():
if k not in exclude_query_params and v is not None:

View File

@ -100,7 +100,10 @@ class AssetAccountSerializer(AccountSerializer):
class Meta(AccountSerializer.Meta):
fields = [
f for f in AccountSerializer.Meta.fields
if f not in ['spec_info']
if f not in [
'spec_info', 'connectivity', 'labels', 'created_by',
'date_update', 'date_created'
]
]
extra_kwargs = {
**AccountSerializer.Meta.extra_kwargs,
@ -203,9 +206,12 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('domain', 'nodes', 'protocols', ) \
.prefetch_related('platform', 'platform__automation') \
.prefetch_related('labels', 'labels__label') \
.annotate(category=F("platform__category")) \
.annotate(type=F("platform__type"))
if queryset.model is Asset:
queryset = queryset.prefetch_related('labels__label', 'labels')
else:
queryset = queryset.prefetch_related('asset_ptr__labels__label', 'asset_ptr__labels')
return queryset
@staticmethod
@ -375,7 +381,6 @@ class AssetSerializer(BulkOrgResourceModelSerializer, ResourceLabelsMixin, Writa
class DetailMixin(serializers.Serializer):
accounts = AssetAccountSerializer(many=True, required=False, label=_('Accounts'))
spec_info = MethodSerializer(label=_('Spec info'), read_only=True)
gathered_info = MethodSerializer(label=_('Gathered info'), read_only=True)
auto_config = serializers.DictField(read_only=True, label=_('Auto info'))
@ -390,8 +395,7 @@ class DetailMixin(serializers.Serializer):
def get_field_names(self, declared_fields, info):
names = super().get_field_names(declared_fields, info)
names.extend([
'accounts', 'gathered_info', 'spec_info',
'auto_config',
'gathered_info', 'spec_info', 'auto_config',
])
return names

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Count
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
@ -7,18 +8,15 @@ from common.serializers import ResourceLabelsMixin
from common.serializers.fields import ObjectRelatedField
from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from .gateway import GatewayWithAccountSecretSerializer
from ..models import Domain, Asset
from ..models import Domain
__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer']
__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer', 'DomainListSerializer']
class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
gateways = ObjectRelatedField(
many=True, required=False, label=_('Gateway'), read_only=True,
)
assets = ObjectRelatedField(
many=True, required=False, queryset=Asset.objects, label=_('Asset')
)
class Meta:
model = Domain
@ -30,7 +28,9 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
def to_representation(self, instance):
data = super().to_representation(instance)
assets = data['assets']
assets = data.get('assets')
if assets is None:
return data
gateway_ids = [str(i['id']) for i in data['gateways']]
data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids]
return data
@ -49,6 +49,20 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
return queryset
class DomainListSerializer(DomainSerializer):
assets_amount = serializers.IntegerField(label=_('Assets amount'), read_only=True)
class Meta(DomainSerializer.Meta):
fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'})
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.annotate(
assets_amount=Count('assets'),
)
return queryset
class DomainWithGatewaySerializer(serializers.ModelSerializer):
gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True)

View File

@ -191,7 +191,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
def add_type_choices(self, name, label):
tp = self.fields['type']
tp.choices[name] = label
tp.choice_mapper[name] = label
tp.choice_strings_to_values[name] = label
@lazyproperty
@ -200,12 +199,6 @@ class PlatformSerializer(ResourceLabelsMixin, WritableNestedModelSerializer):
constraints = AllTypes.get_constraints(category, tp)
return constraints
@classmethod
def setup_eager_loading(cls, queryset):
queryset = queryset.prefetch_related('protocols', 'automation') \
.prefetch_related('labels', 'labels__label')
return queryset
def validate_protocols(self, protocols):
if not protocols:
raise serializers.ValidationError(_("Protocols is required"))

View File

@ -80,10 +80,11 @@ RELATED_NODE_IDS = '_related_node_ids'
@receiver(pre_delete, sender=Asset)
def on_asset_delete(instance: Asset, using, **kwargs):
logger.debug("Asset pre delete signal recv: {}".format(instance))
node_ids = Node.objects.filter(assets=instance) \
.distinct().values_list('id', flat=True)
setattr(instance, RELATED_NODE_IDS, node_ids)
node_ids = list(node_ids)
logger.debug("Asset pre delete signal recv: {}, node_ids: {}".format(instance, node_ids))
setattr(instance, RELATED_NODE_IDS, list(node_ids))
m2m_changed.send(
sender=Asset.nodes.through, instance=instance,
reverse=False, model=Node, pk_set=node_ids,
@ -93,8 +94,8 @@ def on_asset_delete(instance: Asset, using, **kwargs):
@receiver(post_delete, sender=Asset)
def on_asset_post_delete(instance: Asset, using, **kwargs):
logger.debug("Asset post delete signal recv: {}".format(instance))
node_ids = getattr(instance, RELATED_NODE_IDS, [])
logger.debug("Asset post delete signal recv: {}, node_ids: {}".format(instance, node_ids))
if node_ids:
m2m_changed.send(
sender=Asset.nodes.through, instance=instance, reverse=False,

View File

@ -15,8 +15,8 @@ from ..tasks import check_node_assets_amount_task
logger = get_logger(__file__)
@on_transaction_commit
@receiver(m2m_changed, sender=Asset.nodes.through)
@on_transaction_commit
def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
# 不允许 `pre_clear` ,因为该信号没有 `pk_set`
# [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed)
@ -37,7 +37,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs):
update_nodes_assets_amount(node_ids=node_ids)
@merge_delay_run(ttl=5)
@merge_delay_run(ttl=30)
def update_nodes_assets_amount(node_ids=()):
nodes = Node.objects.filter(id__in=node_ids)
nodes = Node.get_ancestor_queryset(nodes)

View File

@ -21,7 +21,7 @@ logger = get_logger(__name__)
node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)()
@merge_delay_run(ttl=5)
@merge_delay_run(ttl=30)
def expire_node_assets_mapping(org_ids=()):
logger.debug("Recv asset nodes changed signal, expire memery node asset mapping")
# 所有进程清除(自己的 memory 数据)
@ -53,8 +53,9 @@ def on_node_post_delete(sender, instance, **kwargs):
@receiver(m2m_changed, sender=Asset.nodes.through)
def on_node_asset_change(sender, instance, **kwargs):
expire_node_assets_mapping(org_ids=(instance.org_id,))
def on_node_asset_change(sender, instance, action='pre_remove', **kwargs):
if action.startswith('post'):
expire_node_assets_mapping(org_ids=(instance.org_id,))
@receiver(django_ready)

View File

@ -2,6 +2,7 @@
from django.urls import path
from rest_framework_bulk.routes import BulkRouter
from labels.api import LabelViewSet
from .. import api
app_name = 'assets'
@ -22,6 +23,7 @@ router.register(r'domains', api.DomainViewSet, 'domain')
router.register(r'gateways', api.GatewayViewSet, 'gateway')
router.register(r'favorite-assets', api.FavoriteAssetViewSet, 'favorite-asset')
router.register(r'protocol-settings', api.PlatformProtocolViewSet, 'protocol-setting')
router.register(r'labels', LabelViewSet, 'label')
urlpatterns = [
# path('assets/<uuid:pk>/gateways/', api.AssetGatewayListApi.as_view(), name='asset-gateway-list'),

View File

@ -4,7 +4,6 @@ from urllib.parse import urlencode, urlparse
from kubernetes import client
from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api
from kubernetes.client.exceptions import ApiException
from sshtunnel import SSHTunnelForwarder, BaseSSHTunnelForwarderError
from common.utils import get_logger
@ -88,8 +87,9 @@ class KubernetesClient:
if hasattr(self, func_name):
try:
data = getattr(self, func_name)(*args)
except ApiException as e:
logger.error(e.reason)
except Exception as e:
logger.error(e)
raise e
if self.server:
self.server.stop()

View File

@ -5,6 +5,7 @@ from importlib import import_module
from django.conf import settings
from django.db.models import F, Value, CharField, Q
from django.db.models.functions import Cast
from django.http import HttpResponse, FileResponse
from django.utils.encoding import escape_uri_path
from rest_framework import generics
@ -40,6 +41,7 @@ from .serializers import (
PasswordChangeLogSerializer, ActivityUnionLogSerializer,
FileSerializer, UserSessionSerializer
)
from .utils import construct_userlogin_usernames
logger = get_logger(__name__)
@ -125,15 +127,16 @@ class UserLoginCommonMixin:
class UserLoginLogViewSet(UserLoginCommonMixin, OrgReadonlyModelViewSet):
@staticmethod
def get_org_members():
users = current_org.get_members().values_list('username', flat=True)
def get_org_member_usernames():
user_queryset = current_org.get_members()
users = construct_userlogin_usernames(user_queryset)
return users
def get_queryset(self):
queryset = super().get_queryset()
if current_org.is_root():
return queryset
users = self.get_org_members()
users = self.get_org_member_usernames()
queryset = queryset.filter(username__in=users)
return queryset
@ -163,7 +166,7 @@ class ResourceActivityAPIView(generics.ListAPIView):
q |= Q(user=str(user))
queryset = OperateLog.objects.filter(q, org_q).annotate(
r_type=Value(ActivityChoices.operate_log, CharField()),
r_detail_id=F('id'), r_detail=Value(None, CharField()),
r_detail_id=Cast(F('id'), CharField()), r_detail=Value(None, CharField()),
r_user=F('user'), r_action=F('action'),
).values(*fields)[:limit]
return queryset

View File

@ -4,6 +4,8 @@ from itertools import chain
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.db import models
from django.db.models import F, Value, CharField
from django.db.models.functions import Concat
from common.db.fields import RelatedManager
from common.utils import validate_ip, get_ip_city, get_logger
@ -115,3 +117,12 @@ def model_to_dict_for_operate_log(
get_related_values(f)
return data
def construct_userlogin_usernames(user_queryset):
usernames_original = user_queryset.values_list('username', flat=True)
usernames_combined = user_queryset.annotate(
usernames_combined_field=Concat(F('name'), Value('('), F('username'), Value(')'), output_field=CharField())
).values_list("usernames_combined_field", flat=True)
usernames = list(chain(usernames_original, usernames_combined))
return usernames

View File

@ -90,6 +90,6 @@ class MFAChallengeVerifyApi(AuthMixin, CreateAPIView):
return Response({'msg': 'ok'})
except errors.AuthFailedError as e:
data = {"error": e.error, "msg": e.msg}
raise ValidationError(data)
return Response(data, status=401)
except errors.NeedMoreInfoError as e:
return Response(e.as_data(), status=200)

View File

@ -15,12 +15,11 @@ from authentication.mixins import authenticate
from authentication.serializers import (
PasswordVerifySerializer, ResetPasswordCodeSerializer
)
from authentication.utils import check_user_property_is_correct
from common.permissions import IsValidUser
from common.utils import get_object_or_none
from common.utils.random import random_string
from common.utils.verify_code import SendAndVerifyCodeUtil
from settings.utils import get_login_title
from users.models import User
class UserResetPasswordSendCodeApi(CreateAPIView):
@ -28,8 +27,8 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
serializer_class = ResetPasswordCodeSerializer
@staticmethod
def is_valid_user(**kwargs):
user = get_object_or_none(User, **kwargs)
def is_valid_user(username, **properties):
user = check_user_property_is_correct(username, **properties)
if not user:
err_msg = _('User does not exist: {}').format(_("No user matched"))
return None, err_msg
@ -56,7 +55,6 @@ class UserResetPasswordSendCodeApi(CreateAPIView):
target = serializer.validated_data[form_type]
if form_type == 'sms':
query_key = 'phone'
target = target.lstrip('+')
else:
query_key = form_type
user, err = self.is_valid_user(username=username, **{query_key: target})

View File

@ -7,8 +7,9 @@ from django.conf import settings
from django.utils.translation import gettext_lazy as _
from audits.const import DEFAULT_CITY
from users.models import User
from audits.models import UserLoginLog
from common.utils import get_logger
from common.utils import get_logger, get_object_or_none
from common.utils import validate_ip, get_ip_city, get_request_ip
from .notifications import DifferentCityLoginMessage
@ -24,9 +25,10 @@ def check_different_city_login_if_need(user, request):
is_private = ipaddress.ip_address(ip).is_private
if is_private:
return
usernames = [user.username, f"{user.name}({user.username})"]
last_user_login = UserLoginLog.objects.exclude(
city__in=city_white
).filter(username=user.username, status=True).first()
).filter(username__in=usernames, status=True).first()
if not last_user_login:
return
@ -59,3 +61,12 @@ def build_absolute_uri_for_oidc(request, path=None):
redirect_uri = urljoin(settings.BASE_SITE_URL, path)
return redirect_uri
return build_absolute_uri(request, path=path)
def check_user_property_is_correct(username, **properties):
user = get_object_or_none(User, username=username)
for attr, value in properties.items():
if getattr(user, attr, None) != value:
user = None
break
return user

View File

@ -98,12 +98,19 @@ class QuerySetMixin:
return queryset
if self.action == 'metadata':
queryset = queryset.none()
if self.action in ['list', 'metadata']:
serializer_class = self.get_serializer_class()
if serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
queryset = serializer_class.setup_eager_loading(queryset)
return queryset
def paginate_queryset(self, queryset):
page = super().paginate_queryset(queryset)
serializer_class = self.get_serializer_class()
if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'):
ids = [str(obj.id) for obj in page]
page = self.get_queryset().filter(id__in=ids)
page = serializer_class.setup_eager_loading(page)
page_mapper = {str(obj.id): obj for obj in page}
page = [page_mapper.get(_id) for _id in ids if _id in page_mapper]
return page
class ExtraFilterFieldsMixin:
"""

View File

@ -65,7 +65,7 @@ class EventLoopThread(threading.Thread):
_loop_thread = EventLoopThread()
_loop_thread.setDaemon(True)
_loop_thread.daemon = True
_loop_thread.start()
executor = ThreadPoolExecutor(
max_workers=10,

View File

@ -219,11 +219,11 @@ class LabelFilterBackend(filters.BaseFilterBackend):
if not hasattr(queryset, 'model'):
return queryset
if not hasattr(queryset.model, 'labels'):
if not hasattr(queryset.model, 'label_model'):
return queryset
model = queryset.model
labeled_resource_cls = model._labels.field.related_model
model = queryset.model.label_model()
labeled_resource_cls = model.labels.field.related_model
app_label = model._meta.app_label
model_name = model._meta.model_name

View File

@ -14,6 +14,7 @@ class CeleryBaseService(BaseService):
print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize()))
ansible_config_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'ansible.cfg')
ansible_modules_path = os.path.join(settings.APPS_DIR, 'ops', 'ansible', 'modules')
os.environ.setdefault('LC_ALL', 'C.UTF-8')
os.environ.setdefault('PYTHONOPTIMIZE', '1')
os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True')
os.environ.setdefault('ANSIBLE_CONFIG', ansible_config_path)

View File

@ -1,16 +1,14 @@
import requests
import mistune
from rest_framework.exceptions import APIException
import requests
from django.utils.translation import gettext_lazy as _
from rest_framework.exceptions import APIException
from users.utils import construct_user_email
from common.utils.common import get_logger
from jumpserver.utils import get_current_request
from users.utils import construct_user_email
logger = get_logger(__name__)
SLACK_REDIRECT_URI_SESSION_KEY = '_slack_redirect_uri'
@ -22,15 +20,15 @@ class URL:
AUTH_TEST = 'https://slack.com/api/auth.test'
class SlackRenderer(mistune.Renderer):
def header(self, text, level, raw=None):
class SlackRenderer(mistune.HTMLRenderer):
def heading(self, text, level):
return '*' + text + '*\n'
def double_emphasis(self, text):
def strong(self, text):
return '*' + text + '*'
def list(self, body, ordered=True):
lines = body.split('\n')
def list(self, text, **kwargs):
lines = text.split('\n')
for i, line in enumerate(lines):
if not line:
continue
@ -41,9 +39,9 @@ class SlackRenderer(mistune.Renderer):
def block_code(self, code, lang=None):
return f'`{code}`'
def link(self, link, title, content):
if title or content:
label = str(title or content).strip()
def link(self, link, text=None, title=None):
if title or text:
label = str(title or text).strip()
return f'<{link}|{label}>'
return f'<{link}>'

View File

@ -394,20 +394,20 @@ class CommonBulkModelSerializer(CommonBulkSerializerMixin, serializers.ModelSeri
class ResourceLabelsMixin(serializers.Serializer):
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True)
labels = LabelRelatedField(many=True, label=_('Labels'), required=False, allow_null=True, source='res_labels')
def update(self, instance, validated_data):
labels = validated_data.pop('labels', None)
labels = validated_data.pop('res_labels', None)
res = super().update(instance, validated_data)
if labels is not None:
instance.labels.set(labels, bulk=False)
instance.res_labels.set(labels, bulk=False)
return res
def create(self, validated_data):
labels = validated_data.pop('labels', None)
labels = validated_data.pop('res_labels', None)
instance = super().create(validated_data)
if labels is not None:
instance.labels.set(labels, bulk=False)
instance.res_labels.set(labels, bulk=False)
return instance
@classmethod

View File

@ -62,14 +62,14 @@ def digest_sql_query():
method = current_request.method
path = current_request.get_full_path()
print(">>> [{}] {}".format(method, path))
print(">>>. [{}] {}".format(method, path))
for table_name, queries in table_queries.items():
if table_name.startswith('rbac_') or table_name.startswith('auth_permission'):
continue
for query in queries:
sql = query['sql']
print(" # {}: {}".format(query['time'], sql))
print(" # {}: {}".format(query['time'], sql[:1000]))
if len(queries) < 3:
continue
print("- Table: {}".format(table_name))
@ -77,9 +77,9 @@ def digest_sql_query():
sql = query['sql']
if not sql or not sql.startswith('SELECT'):
continue
print('\t{}. {}'.format(i, sql))
print('\t{}.[{}s] {}'.format(i, round(float(query['time']), 2), sql[:1000]))
logger.debug(">>> [{}] {}".format(method, path))
# logger.debug(">>> [{}] {}".format(method, path))
for name, counter in counters:
logger.debug("Query {:3} times using {:.2f}s {}".format(
counter.counter, counter.time, name)

View File

@ -2,7 +2,7 @@ import os
from celery import shared_task
from django.conf import settings
from django.core.mail import send_mail, EmailMultiAlternatives
from django.core.mail import send_mail, EmailMultiAlternatives, get_connection
from django.utils.translation import gettext_lazy as _
import jms_storage
@ -11,6 +11,16 @@ from .utils import get_logger
logger = get_logger(__file__)
def get_email_connection(**kwargs):
email_backend_map = {
'smtp': 'django.core.mail.backends.smtp.EmailBackend',
'exchange': 'jumpserver.rewriting.exchange.EmailBackend'
}
return get_connection(
backend=email_backend_map.get(settings.EMAIL_PROTOCOL), **kwargs
)
def task_activity_callback(self, subject, message, recipient_list, *args, **kwargs):
from users.models import User
email_list = recipient_list
@ -40,7 +50,7 @@ def send_mail_async(*args, **kwargs):
args = tuple(args)
try:
return send_mail(*args, **kwargs)
return send_mail(connection=get_email_connection(), *args, **kwargs)
except Exception as e:
logger.error("Sending mail error: {}".format(e))
@ -55,7 +65,8 @@ def send_mail_attachment_async(subject, message, recipient_list, attachment_list
subject=subject,
body=message,
from_email=from_email,
to=recipient_list
to=recipient_list,
connection=get_email_connection(),
)
for attachment in attachment_list:
email.attach_file(attachment)

View File

@ -220,7 +220,7 @@ def timeit(func):
now = time.time()
result = func(*args, **kwargs)
using = (time.time() - now) * 1000
msg = "End call {}, using: {:.1f}ms".format(name, using)
msg = "Ends call: {}, using: {:.1f}ms".format(name, using)
logger.debug(msg)
return result

View File

@ -1,18 +1,16 @@
from functools import wraps
import threading
from functools import wraps
from django.db import transaction
from redis_lock import (
Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT,
EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT
)
from redis import Redis
from django.db import transaction
from common.utils import get_logger
from common.utils.inspect import copy_function_args
from common.utils.connection import get_redis_client
from jumpserver.const import CONFIG
from common.local import thread_local
from common.utils import get_logger
from common.utils.connection import get_redis_client
from common.utils.inspect import copy_function_args
logger = get_logger(__file__)
@ -76,6 +74,7 @@ class DistributedLock(RedisLock):
# 要创建一个新的锁对象
with self.__class__(**self.kwargs_copy):
return func(*args, **kwds)
return inner
@classmethod
@ -95,7 +94,6 @@ class DistributedLock(RedisLock):
if self.locked():
owner_id = self.get_owner_id()
local_owner_id = getattr(thread_local, self.name, None)
if local_owner_id and owner_id == local_owner_id:
return True
return False
@ -140,14 +138,16 @@ class DistributedLock(RedisLock):
logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
return
else:
self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
self._raise_exc_with_log(
f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
def _release_on_reentrant_locked_by_me(self):
logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}')
id = getattr(thread_local, self.name, None)
if id != self.id:
raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
raise PermissionError(
f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}')
try:
# 这里要保证先删除 thread_local 的标记,
delattr(thread_local, self.name)
@ -191,7 +191,7 @@ class DistributedLock(RedisLock):
# 处理是否在事务提交时才释放锁
if self._release_on_transaction_commit:
logger.debug(
f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name}')
f'Release lock on transaction commit:lock_id={self.id} lock={self.name}')
transaction.on_commit(_release)
else:
_release()

View File

@ -17,6 +17,7 @@ from assets.models import Asset
from audits.api import OperateLogViewSet
from audits.const import LoginStatusChoices
from audits.models import UserLoginLog, PasswordChangeLog, OperateLog, FTPLog, JobLog
from audits.utils import construct_userlogin_usernames
from common.utils import lazyproperty
from common.utils.timezone import local_now, local_zero_hour
from ops.const import JobStatus
@ -79,7 +80,7 @@ class DateTimeMixin:
if not self.org.is_root():
if query_params == 'username':
query = {
f'{query_params}__in': users.values_list('username', flat=True)
f'{query_params}__in': construct_userlogin_usernames(users)
}
else:
query = {

View File

@ -17,7 +17,7 @@ import re
import sys
import types
from importlib import import_module
from urllib.parse import urljoin, urlparse
from urllib.parse import urljoin, urlparse, quote
import yaml
from django.urls import reverse_lazy
@ -261,6 +261,8 @@ class Config(dict):
'VAULT_HCP_TOKEN': '',
'VAULT_HCP_MOUNT_POINT': 'jumpserver',
'HISTORY_ACCOUNT_CLEAN_LIMIT': 999,
# Cache login password
'CACHE_LOGIN_PASSWORD_ENABLED': False,
'CACHE_LOGIN_PASSWORD_TTL': 60 * 60 * 24,
@ -280,6 +282,7 @@ class Config(dict):
'AUTH_LDAP_SYNC_INTERVAL': None,
'AUTH_LDAP_SYNC_CRONTAB': None,
'AUTH_LDAP_SYNC_ORG_IDS': ['00000000-0000-0000-0000-000000000002'],
'AUTH_LDAP_SYNC_RECEIVERS': [],
'AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS': False,
'AUTH_LDAP_OPTIONS_OPT_REFERRALS': -1,
@ -325,6 +328,7 @@ class Config(dict):
'RADIUS_SERVER': 'localhost',
'RADIUS_PORT': 1812,
'RADIUS_SECRET': '',
'RADIUS_ATTRIBUTES': {},
'RADIUS_ENCRYPT_PASSWORD': True,
'OTP_IN_RADIUS': False,
@ -451,6 +455,7 @@ class Config(dict):
'CUSTOM_SMS_REQUEST_METHOD': 'get',
# Email
'EMAIL_PROTOCOL': 'smtp',
'EMAIL_CUSTOM_USER_CREATED_SUBJECT': _('Create account successfully'),
'EMAIL_CUSTOM_USER_CREATED_HONORIFIC': _('Hello'),
'EMAIL_CUSTOM_USER_CREATED_BODY': _('Your account has been created successfully'),
@ -531,6 +536,7 @@ class Config(dict):
'SYSLOG_SOCKTYPE': 2,
'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60,
'PERM_TREE_REGEN_INTERVAL': 1,
'FLOWER_URL': "127.0.0.1:5555",
'LANGUAGE_CODE': 'zh',
'TIME_ZONE': 'Asia/Shanghai',
@ -693,6 +699,13 @@ class Config(dict):
if openid_config:
self.set_openid_config(openid_config)
def compatible_redis(self):
redis_config = {
'REDIS_PASSWORD': quote(str(self.REDIS_PASSWORD)),
}
for key, value in redis_config.items():
self[key] = value
def compatible(self):
"""
对配置做兼容处理
@ -704,6 +717,8 @@ class Config(dict):
"""
# 兼容 OpenID 配置
self.compatible_auth_openid()
# 兼容 Redis 配置
self.compatible_redis()
def convert_type(self, k, v):
default_value = self.defaults.get(k)

View File

@ -0,0 +1,104 @@
import urllib3
from urllib3.exceptions import InsecureRequestWarning
from django.core.mail.backends.base import BaseEmailBackend
from django.core.mail.message import sanitize_address
from django.conf import settings
from exchangelib import Account, Credentials, Configuration, DELEGATE
from exchangelib import Mailbox, Message, HTMLBody, FileAttachment
from exchangelib import BaseProtocol, NoVerifyHTTPAdapter
from exchangelib.errors import TransportError
urllib3.disable_warnings(InsecureRequestWarning)
BaseProtocol.HTTP_ADAPTER_CLS = NoVerifyHTTPAdapter
class EmailBackend(BaseEmailBackend):
def __init__(
self,
service_endpoint=None,
username=None,
password=None,
fail_silently=False,
**kwargs,
):
super().__init__(fail_silently=fail_silently)
self.service_endpoint = service_endpoint or settings.EMAIL_HOST
self.username = settings.EMAIL_HOST_USER if username is None else username
self.password = settings.EMAIL_HOST_PASSWORD if password is None else password
self._connection = None
def open(self):
if self._connection:
return False
try:
config = Configuration(
service_endpoint=self.service_endpoint, credentials=Credentials(
username=self.username, password=self.password
)
)
self._connection = Account(self.username, config=config, access_type=DELEGATE)
return True
except TransportError:
if not self.fail_silently:
raise
def close(self):
self._connection = None
def send_messages(self, email_messages):
if not email_messages:
return 0
new_conn_created = self.open()
if not self._connection or new_conn_created is None:
return 0
num_sent = 0
for message in email_messages:
sent = self._send(message)
if sent:
num_sent += 1
if new_conn_created:
self.close()
return num_sent
def _send(self, email_message):
if not email_message.recipients():
return False
encoding = settings.DEFAULT_CHARSET
from_email = sanitize_address(email_message.from_email, encoding)
recipients = [
Mailbox(email_address=sanitize_address(addr, encoding)) for addr in email_message.recipients()
]
try:
message_body = email_message.body
alternatives = email_message.alternatives or []
attachments = []
for attachment in email_message.attachments or []:
name, content, mimetype = attachment
if isinstance(content, str):
content = content.encode(encoding)
attachments.append(
FileAttachment(name=name, content=content, content_type=mimetype)
)
for alternative in alternatives:
if alternative[1] == 'text/html':
message_body = HTMLBody(alternative[0])
break
email_message = Message(
account=self._connection, subject=email_message.subject,
body=message_body, to_recipients=recipients, sender=from_email,
attachments=[]
)
email_message.attach(attachments)
email_message.send_and_save()
except Exception as error:
if not self.fail_silently:
raise error
return False
return True

View File

@ -0,0 +1,14 @@
from private_storage.servers import NginxXAccelRedirectServer, DjangoServer
class StaticFileServer(object):
@staticmethod
def serve(private_file):
full_path = private_file.full_path
# todo: gzip 文件录像 nginx 处理后,浏览器无法正常解析内容
# 造成在线播放失败,暂时仅使用 nginx 处理 mp4 录像文件
if full_path.endswith('.mp4'):
return NginxXAccelRedirectServer.serve(private_file)
else:
return DjangoServer.serve(private_file)

View File

@ -50,6 +50,7 @@ AUTH_LDAP_SYNC_IS_PERIODIC = CONFIG.AUTH_LDAP_SYNC_IS_PERIODIC
AUTH_LDAP_SYNC_INTERVAL = CONFIG.AUTH_LDAP_SYNC_INTERVAL
AUTH_LDAP_SYNC_CRONTAB = CONFIG.AUTH_LDAP_SYNC_CRONTAB
AUTH_LDAP_SYNC_ORG_IDS = CONFIG.AUTH_LDAP_SYNC_ORG_IDS
AUTH_LDAP_SYNC_RECEIVERS = CONFIG.AUTH_LDAP_SYNC_RECEIVERS
AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS = CONFIG.AUTH_LDAP_USER_LOGIN_ONLY_IN_USERS
# ==============================================================================
@ -99,6 +100,8 @@ AUTH_RADIUS_BACKEND = 'authentication.backends.radius.RadiusBackend'
RADIUS_SERVER = CONFIG.RADIUS_SERVER
RADIUS_PORT = CONFIG.RADIUS_PORT
RADIUS_SECRET = CONFIG.RADIUS_SECRET
# https://github.com/robgolding/django-radius/blob/develop/radiusauth/backends/radius.py#L15-L52
RADIUS_ATTRIBUTES = CONFIG.RADIUS_ATTRIBUTES
# CAS Auth
AUTH_CAS = CONFIG.AUTH_CAS
@ -190,6 +193,8 @@ VAULT_HCP_HOST = CONFIG.VAULT_HCP_HOST
VAULT_HCP_TOKEN = CONFIG.VAULT_HCP_TOKEN
VAULT_HCP_MOUNT_POINT = CONFIG.VAULT_HCP_MOUNT_POINT
HISTORY_ACCOUNT_CLEAN_LIMIT = CONFIG.HISTORY_ACCOUNT_CLEAN_LIMIT
# Other setting
# 这个是 User Login Private Token
TOKEN_EXPIRATION = CONFIG.TOKEN_EXPIRATION

View File

@ -312,12 +312,15 @@ STATICFILES_DIRS = (
os.path.join(BASE_DIR, "static"),
)
# Media files (File, ImageField) will be save these
# Media files (File, ImageField) will be safe these
MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(PROJECT_DIR, 'data', 'media').replace('\\', '/') + '/'
PRIVATE_STORAGE_ROOT = MEDIA_ROOT
PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_access'
PRIVATE_STORAGE_INTERNAL_URL = '/private-media/'
PRIVATE_STORAGE_SERVER = 'jumpserver.rewriting.storage.servers.StaticFileServer'
# Use django-bootstrap-form to format template, input max width arg
# BOOTSTRAP_COLUMN_COUNT = 11
@ -326,6 +329,7 @@ PRIVATE_STORAGE_AUTH_FUNCTION = 'jumpserver.rewriting.storage.permissions.allow_
FIXTURE_DIRS = [os.path.join(BASE_DIR, 'fixtures'), ]
# Email config
EMAIL_PROTOCOL = CONFIG.EMAIL_PROTOCOL
EMAIL_HOST = CONFIG.EMAIL_HOST
EMAIL_PORT = CONFIG.EMAIL_PORT
EMAIL_HOST_USER = CONFIG.EMAIL_HOST_USER

View File

@ -208,6 +208,7 @@ OPERATE_LOG_ELASTICSEARCH_CONFIG = CONFIG.OPERATE_LOG_ELASTICSEARCH_CONFIG
MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE
DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE
PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL
# Magnus DB Port
MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS

View File

@ -21,7 +21,7 @@ LOGGING = {
},
'main': {
'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s',
'format': '%(asctime)s [%(levelname).4s] %(message)s',
},
'exception': {
'datefmt': '%Y-%m-%d %H:%M:%S',

View File

@ -73,7 +73,7 @@ class LabelContentTypeResourceViewSet(JMSModelViewSet):
queryset = model.objects.all()
if bound == '1':
queryset = queryset.filter(id__in=list(res_ids))
elif bound == '0':
else:
queryset = queryset.exclude(id__in=list(res_ids))
keyword = self.request.query_params.get('search')
if keyword:
@ -90,9 +90,10 @@ class LabelContentTypeResourceViewSet(JMSModelViewSet):
LabeledResource.objects \
.filter(res_type=content_type, label=label) \
.exclude(res_id__in=res_ids).delete()
resources = []
for res_id in res_ids:
resources.append(LabeledResource(res_type=content_type, res_id=res_id, label=label, org_id=current_org.id))
resources = [
LabeledResource(res_type=content_type, res_id=res_id, label=label, org_id=current_org.id)
for res_id in res_ids
]
LabeledResource.objects.bulk_create(resources, ignore_conflicts=True)
return Response({"total": len(res_ids)})
@ -129,15 +130,22 @@ class LabeledResourceViewSet(OrgBulkModelViewSet):
}
ordering_fields = ('res_type', 'date_created')
# Todo: 这里需要优化,查询 sql 太多
def filter_search(self, queryset):
keyword = self.request.query_params.get('search')
if not keyword:
return queryset
keyword = keyword.strip().lower()
matched = []
for instance in queryset:
if keyword.lower() in str(instance.resource).lower():
matched.append(instance.id)
offset = 0
limit = 10000
while True:
page = queryset[offset:offset + limit]
if not page:
break
offset += limit
for instance in page:
if keyword in str(instance.resource).lower():
matched.append(instance.id)
return queryset.filter(id__in=matched)
def get_queryset(self):

View File

@ -1,21 +1,38 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from django.db.models import OneToOneField
from common.utils import lazyproperty
from .models import LabeledResource
__all__ = ['LabeledMixin']
class LabeledMixin(models.Model):
_labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
labels = GenericRelation(LabeledResource, object_id_field='res_id', content_type_field='res_type')
class Meta:
abstract = True
@property
def labels(self):
return self._labels
@classmethod
def label_model(cls):
pk_field = cls._meta.pk
model = cls
if isinstance(pk_field, OneToOneField):
model = pk_field.related_model
return model
@labels.setter
def labels(self, value):
self._labels.set(value, bulk=False)
@lazyproperty
def real(self):
pk_field = self._meta.pk
if isinstance(pk_field, OneToOneField):
return getattr(self, pk_field.name)
return self
@property
def res_labels(self):
return self.real.labels
@res_labels.setter
def res_labels(self, value):
self.real.labels.set(value, bulk=False)

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:71d292647cf751c002b459449c7bebf4d2bf5a3933748387e7c2f80a7111302e
size 169602
oid sha256:7879f4eeb499e920ad6c4bfdb0b1f334936ca344c275be056f12fcf7485f2bf6
size 170948

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:80dd11dde678e4f9b64df18906175125218fd9f719bfe9aaa667ad6e2d055d40
size 139012
oid sha256:19d3a111cc245f9a9d36b860fd95447df916ad66c918bef672bacdad6bc77a8f
size 140119

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,21 @@ import time
import paramiko
from sshtunnel import SSHTunnelForwarder
from packaging import version
if version.parse(paramiko.__version__) > version.parse("2.8.1"):
_preferred_pubkeys = (
"ssh-ed25519",
"ecdsa-sha2-nistp256",
"ecdsa-sha2-nistp384",
"ecdsa-sha2-nistp521",
"ssh-rsa",
"rsa-sha2-256",
"rsa-sha2-512",
"ssh-dss",
)
paramiko.transport.Transport._preferred_pubkeys = _preferred_pubkeys
def common_argument_spec():
options = dict(

View File

@ -75,7 +75,7 @@ model_cache_field_mapper = {
class OrgResourceStatisticsRefreshUtil:
@staticmethod
@merge_delay_run(ttl=5)
@merge_delay_run(ttl=30)
def refresh_org_fields(org_fields=()):
for org, cache_field_name in org_fields:
OrgResourceStatisticsCache(org).expire(*cache_field_name)
@ -104,7 +104,7 @@ def on_post_delete_refresh_org_resource_statistics_cache(sender, instance, **kwa
def _refresh_session_org_resource_statistics_cache(instance: Session):
cache_field_name = [
'total_count_online_users', 'total_count_online_sessions',
'total_count_today_active_assets','total_count_today_failed_sessions'
'total_count_today_active_assets', 'total_count_today_failed_sessions'
]
org_cache = OrgResourceStatisticsCache(instance.org)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
#
from orgs.mixins.api import OrgBulkModelViewSet
from perms import serializers
from perms.filters import AssetPermissionFilter
@ -13,7 +14,10 @@ class AssetPermissionViewSet(OrgBulkModelViewSet):
资产授权列表的增删改查api
"""
model = AssetPermission
serializer_class = serializers.AssetPermissionSerializer
serializer_classes = {
'default': serializers.AssetPermissionSerializer,
'list': serializers.AssetPermissionListSerializer,
}
filterset_class = AssetPermissionFilter
search_fields = ('name',)
ordering = ('name',)

View File

@ -7,8 +7,7 @@ from assets.models import Asset, Node
from common.utils import get_logger, lazyproperty, is_uuid
from orgs.utils import tmp_to_root_org
from perms import serializers
from perms.pagination import AllPermedAssetPagination
from perms.pagination import NodePermedAssetPagination
from perms.pagination import NodePermedAssetPagination, AllPermedAssetPagination
from perms.utils import UserPermAssetUtil, PermAssetDetailUtil
from .mixin import (
SelfOrPKUserMixin

View File

@ -1,16 +1,14 @@
from django.conf import settings
from rest_framework.response import Response
from assets.models import Asset
from assets.api import SerializeToTreeNodeMixin
from assets.models import Asset
from common.utils import get_logger
from ..assets import UserAllPermedAssetsApi
from .mixin import RebuildTreeMixin
from ..assets import UserAllPermedAssetsApi
logger = get_logger(__name__)
__all__ = [
'UserAllPermedAssetsAsTreeApi',
'UserUngroupAssetsAsTreeApi',
@ -31,7 +29,7 @@ class AssetTreeMixin(RebuildTreeMixin, SerializeToTreeNodeMixin):
if request.query_params.get('search'):
""" 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """
assets = assets[:999]
data = self.serialize_assets(assets, None)
data = self.serialize_assets(assets, 'root')
return Response(data=data)
@ -42,6 +40,7 @@ class UserAllPermedAssetsAsTreeApi(AssetTreeMixin, UserAllPermedAssetsApi):
class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi):
""" 用户 '未分组节点的资产(直接授权的资产)' 作为树 """
def get_assets(self):
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return super().get_assets()

View File

@ -1,6 +1,4 @@
import abc
import re
from collections import defaultdict
from urllib.parse import parse_qsl
from django.conf import settings
@ -13,10 +11,10 @@ from rest_framework.response import Response
from accounts.const import AliasAccount
from assets.api import SerializeToTreeNodeMixin
from assets.const import AllTypes
from assets.models import Asset
from assets.utils import KubernetesTree
from authentication.models import ConnectionToken
from common.exceptions import JMSException
from common.utils import get_object_or_none, lazyproperty
from common.utils.common import timeit
from perms.hands import Node
@ -38,21 +36,36 @@ class BaseUserNodeWithAssetAsTreeApi(
SelfOrPKUserMixin, RebuildTreeMixin,
SerializeToTreeNodeMixin, ListAPIView
):
page_limit = 10000
def list(self, request, *args, **kwargs):
nodes, assets = self.get_nodes_assets()
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, node_key=self.node_key_for_serialize_assets)
data = list(tree_nodes) + list(tree_assets)
return Response(data=data)
offset = int(request.query_params.get('offset', 0))
page_assets = self.get_page_assets()
if not offset:
nodes, assets = self.get_nodes_assets()
page = page_assets[:self.page_limit]
assets = [*assets, *page]
tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True)
tree_assets = self.serialize_assets(assets, **self.serialize_asset_kwargs)
data = list(tree_nodes) + list(tree_assets)
else:
page = page_assets[offset:(offset + self.page_limit)]
data = self.serialize_assets(page, **self.serialize_asset_kwargs) if page else []
offset += len(page)
headers = {'X-JMS-TREE-OFFSET': offset} if offset else {}
return Response(data=data, headers=headers)
@abc.abstractmethod
def get_nodes_assets(self):
return [], []
@lazyproperty
def node_key_for_serialize_assets(self):
return None
def get_page_assets(self):
return []
@property
def serialize_asset_kwargs(self):
return {}
class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
@ -61,7 +74,6 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
def get_nodes_assets(self):
self.query_node_util = UserPermNodeUtil(self.request.user)
self.query_asset_util = UserPermAssetUtil(self.request.user)
ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped()
fav_nodes, fav_assets = self._get_nodes_assets_for_favorite()
all_nodes, all_assets = self._get_nodes_assets_for_all()
@ -69,31 +81,37 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
assets = list(ung_assets) + list(fav_assets) + list(all_assets)
return nodes, assets
def get_page_assets(self):
return self.query_asset_util.get_all_assets().annotate(parent_key=F('nodes__key'))
@timeit
def _get_nodes_assets_for_ungrouped(self):
if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
return [], []
node = self.query_node_util.get_ungrouped_node()
assets = self.query_asset_util.get_ungroup_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \
.prefetch_related('platform')
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
return [node], assets
@lazyproperty
def query_asset_util(self):
return UserPermAssetUtil(self.user)
@timeit
def _get_nodes_assets_for_favorite(self):
node = self.query_node_util.get_favorite_node()
assets = self.query_asset_util.get_favorite_assets()
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \
.prefetch_related('platform')
assets = assets.annotate(parent_key=Value(node.key, output_field=CharField()))
return [node], assets
@timeit
def _get_nodes_assets_for_all(self):
nodes = self.query_node_util.get_whole_tree_nodes(with_special=False)
if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE:
assets = self.query_asset_util.get_perm_nodes_assets()
else:
assets = self.query_asset_util.get_all_assets()
assets = assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform')
assets = Asset.objects.none()
assets = assets.annotate(parent_key=F('nodes__key'))
return nodes, assets
@ -103,6 +121,7 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
# 默认展开的节点key
default_unfolded_node_key = None
@timeit
def get_nodes_assets(self):
query_node_util = UserPermNodeUtil(self.user)
query_asset_util = UserPermAssetUtil(self.user)
@ -136,14 +155,14 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi):
node_key = getattr(node, 'key', None)
return node_key
@lazyproperty
def node_key_for_serialize_assets(self):
return self.query_node_key or self.default_unfolded_node_key
@property
def serialize_asset_kwargs(self):
return {
'node_key': self.query_node_key or self.default_unfolded_node_key
}
class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(
SelfOrPKUserMixin, SerializeToTreeNodeMixin, ListAPIView
):
class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsTreeApi):
@property
def is_sync(self):
sync = self.request.query_params.get('sync', 0)
@ -151,66 +170,54 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(
@property
def tp(self):
return self.request.query_params.get('type')
def get_assets(self):
query_asset_util = UserPermAssetUtil(self.user)
node = PermNode.objects.filter(
granted_node_rels__user=self.user, parent_key='').first()
if node:
__, assets = query_asset_util.get_node_all_assets(node.id)
else:
assets = Asset.objects.none()
return assets
def to_tree_nodes(self, assets):
if not assets:
return []
assets = assets.annotate(tp=F('platform__type'))
asset_type_map = defaultdict(list)
for asset in assets:
asset_type_map[asset.tp].append(asset)
tp = self.tp
if tp:
assets = asset_type_map.get(tp, [])
if not assets:
return []
pid = f'ROOT_{str(assets[0].category).upper()}_{tp}'
return self.serialize_assets(assets, pid=pid)
params = self.request.query_params
get_root = not list(filter(lambda x: params.get(x), ('type', 'n')))
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=get_root)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
return [params.get('category'), params.get('type')]
if not self.is_sync:
return nodes
@lazyproperty
def query_asset_util(self):
return UserPermAssetUtil(self.user)
asset_nodes = []
for node in nodes:
node['open'] = True
tp = node.get('meta', {}).get('_type')
if not tp:
continue
assets = asset_type_map.get(tp, [])
asset_nodes += self.serialize_assets(assets, pid=node['id'])
return nodes + asset_nodes
@timeit
def get_assets(self):
return self.query_asset_util.get_all_assets()
def list(self, request, *args, **kwargs):
assets = self.get_assets()
nodes = self.to_tree_nodes(assets)
return Response(data=nodes)
def _get_tree_nodes_async(self):
if self.request.query_params.get('lv') == '0':
return [], []
if not self.tp or not all(self.tp):
nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user)
return nodes, []
category, tp = self.tp
assets = self.get_assets().filter(platform__type=tp, platform__category=category)
return [], assets
def _get_tree_nodes_sync(self):
if self.request.query_params.get('lv'):
return []
nodes = self.query_asset_util.get_type_nodes_tree()
return nodes, []
@property
def serialize_asset_kwargs(self):
return {
'get_pid': lambda asset, platform: 'ROOT_{}_{}'.format(platform.category.upper(), platform.type),
}
def serialize_nodes(self, nodes, with_asset_amount=False):
return nodes
def get_nodes_assets(self):
if self.is_sync:
return self._get_tree_nodes_sync()
else:
return self._get_tree_nodes_async()
def get_page_assets(self):
if self.is_sync:
return self.get_assets()
else:
return []
class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
@ -258,5 +265,8 @@ class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView):
if not any([namespace, pod]) and not key:
asset_node = k8s_tree_instance.as_asset_tree_node()
tree.append(asset_node)
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
return Response(data=tree)
try:
tree.extend(k8s_tree_instance.async_tree_node(namespace, pod))
return Response(data=tree)
except Exception as e:
raise JMSException(e)

View File

@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _
from accounts.const import AliasAccount
from accounts.models import Account
from assets.models import Asset
from common.utils import date_expired_default
from common.utils import date_expired_default, lazyproperty
from common.utils.timezone import local_now
from labels.mixins import LabeledMixin
from orgs.mixins.models import JMSOrgBaseModel
@ -105,6 +105,22 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
return True
return False
@lazyproperty
def users_amount(self):
return self.users.count()
@lazyproperty
def user_groups_amount(self):
return self.user_groups.count()
@lazyproperty
def assets_amount(self):
return self.assets.count()
@lazyproperty
def nodes_amount(self):
return self.nodes.count()
def get_all_users(self):
from users.models import User
user_ids = self.users.all().values_list('id', flat=True)
@ -114,7 +130,7 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
qs1_ids = User.objects.filter(id__in=user_ids).distinct().values_list('id', flat=True)
qs2_ids = User.objects.filter(groups__id__in=group_ids).distinct().values_list('id', flat=True)
qs_ids = list(qs1_ids) + list(qs2_ids)
qs = User.objects.filter(id__in=qs_ids)
qs = User.objects.filter(id__in=qs_ids, is_service_account=False)
return qs
def get_all_assets(self, flat=False):
@ -143,11 +159,14 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel):
@classmethod
def get_all_users_for_perms(cls, perm_ids, flat=False):
user_ids = cls.users.through.objects.filter(assetpermission_id__in=perm_ids) \
user_ids = cls.users.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('user_id', flat=True).distinct()
group_ids = cls.user_groups.through.objects.filter(assetpermission_id__in=perm_ids) \
group_ids = cls.user_groups.through.objects \
.filter(assetpermission_id__in=perm_ids) \
.values_list('usergroup_id', flat=True).distinct()
group_user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \
group_user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct()
user_ids = set(user_ids) | set(group_user_ids)
if flat:

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Q
from django.db.models import Q, Count
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
@ -14,7 +14,7 @@ from orgs.mixins.serializers import BulkOrgResourceModelSerializer
from perms.models import ActionChoices, AssetPermission
from users.models import User, UserGroup
__all__ = ["AssetPermissionSerializer", "ActionChoicesField"]
__all__ = ["AssetPermissionSerializer", "ActionChoicesField", "AssetPermissionListSerializer"]
class ActionChoicesField(BitChoicesField):
@ -142,8 +142,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
def perform_display_create(instance, **kwargs):
# 用户
users_to_set = User.objects.filter(
Q(name__in=kwargs.get("users_display"))
| Q(username__in=kwargs.get("users_display"))
Q(name__in=kwargs.get("users_display")) |
Q(username__in=kwargs.get("users_display"))
).distinct()
instance.users.add(*users_to_set)
# 用户组
@ -153,8 +153,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance.user_groups.add(*user_groups_to_set)
# 资产
assets_to_set = Asset.objects.filter(
Q(address__in=kwargs.get("assets_display"))
| Q(name__in=kwargs.get("assets_display"))
Q(address__in=kwargs.get("assets_display")) |
Q(name__in=kwargs.get("assets_display"))
).distinct()
instance.assets.add(*assets_to_set)
# 节点
@ -180,3 +180,27 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali
instance = super().create(validated_data)
self.perform_display_create(instance, **display)
return instance
class AssetPermissionListSerializer(AssetPermissionSerializer):
users_amount = serializers.IntegerField(read_only=True, label=_("Users amount"))
user_groups_amount = serializers.IntegerField(read_only=True, label=_("User groups amount"))
assets_amount = serializers.IntegerField(read_only=True, label=_("Assets amount"))
nodes_amount = serializers.IntegerField(read_only=True, label=_("Nodes amount"))
class Meta(AssetPermissionSerializer.Meta):
amount_fields = ["users_amount", "user_groups_amount", "assets_amount", "nodes_amount"]
remove_fields = {"users", "assets", "nodes", "user_groups"}
fields = list(set(AssetPermissionSerializer.Meta.fields + amount_fields) - remove_fields)
@classmethod
def setup_eager_loading(cls, queryset):
"""Perform necessary eager loading of data."""
queryset = queryset \
.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count("users"),
user_groups_amount=Count("user_groups"),
assets_amount=Count("assets"),
nodes_amount=Count("nodes"),
)
return queryset

View File

@ -3,15 +3,13 @@
from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save
from django.dispatch import receiver
from users.models import User, UserGroup
from assets.models import Asset
from common.utils import get_logger, get_object_or_none
from common.exceptions import M2MReverseNotAllowed
from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR
from common.exceptions import M2MReverseNotAllowed
from common.utils import get_logger, get_object_or_none
from perms.models import AssetPermission
from perms.utils import UserPermTreeExpireUtil
from users.models import User, UserGroup
logger = get_logger(__file__)
@ -38,7 +36,7 @@ def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs):
group = UserGroup.objects.get(id=list(group_ids)[0])
org_id = group.org_id
has_group_perm = AssetPermission.user_groups.through.objects\
has_group_perm = AssetPermission.user_groups.through.objects \
.filter(usergroup_id__in=group_ids).exists()
if not has_group_perm:
return
@ -115,6 +113,7 @@ def on_asset_permission_user_groups_changed(sender, instance, action, pk_set, re
def on_node_asset_change(action, instance, reverse, pk_set, **kwargs):
if not need_rebuild_mapping_node(action):
return
print("Asset node changed: ", action)
if reverse:
asset_ids = pk_set
node_ids = [instance.id]

View File

@ -1,8 +1,7 @@
from django.db.models import QuerySet
from assets.models import Node, Asset
from common.utils import get_logger
from common.utils import get_logger, timeit
from perms.models import AssetPermission
logger = get_logger(__file__)
@ -13,6 +12,7 @@ __all__ = ['AssetPermissionUtil']
class AssetPermissionUtil(object):
""" 资产授权相关的方法工具 """
@timeit
def get_permissions_for_user(self, user, with_group=True, flat=False):
""" 获取用户的授权规则 """
perm_ids = set()

View File

@ -1,13 +1,22 @@
from django.conf import settings
from django.db.models import Q
import json
import re
from django.conf import settings
from django.core.cache import cache
from django.db.models import Q
from rest_framework.utils.encoders import JSONEncoder
from assets.const import AllTypes
from assets.models import FavoriteAsset, Asset
from common.utils.common import timeit
from common.utils.common import timeit, get_logger
from orgs.utils import current_org, tmp_to_root_org
from perms.models import PermNode, UserAssetGrantedTreeNodeRelation
from .permission import AssetPermissionUtil
__all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil']
logger = get_logger(__name__)
class AssetPermissionPermAssetUtil:
@ -15,30 +24,35 @@ class AssetPermissionPermAssetUtil:
self.perm_ids = perm_ids
def get_all_assets(self):
""" 获取所有授权的资产 """
node_asset_ids = self.get_perm_nodes_assets(flat=True)
direct_asset_ids = self.get_direct_assets(flat=True)
asset_ids = list(node_asset_ids) + list(direct_asset_ids)
assets = Asset.objects.filter(id__in=asset_ids)
return assets
node_assets = self.get_perm_nodes_assets()
direct_assets = self.get_direct_assets()
# 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢
return (node_assets | direct_assets).distinct()
@timeit
def get_perm_nodes_assets(self, flat=False):
""" 获取所有授权节点下的资产 """
from assets.models import Node
nodes = Node.objects.prefetch_related('granted_by_permissions').filter(
granted_by_permissions__in=self.perm_ids).only('id', 'key')
from ..models import AssetPermission
nodes_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('nodes', flat=True)
nodes = Node.objects.filter(id__in=nodes_ids).only('id', 'key')
assets = PermNode.get_nodes_all_assets(*nodes)
if flat:
return assets.values_list('id', flat=True)
return set(assets.values_list('id', flat=True))
return assets
@timeit
def get_direct_assets(self, flat=False):
""" 获取直接授权的资产 """
assets = Asset.objects.order_by() \
.filter(granted_by_permissions__id__in=self.perm_ids) \
.distinct()
from ..models import AssetPermission
asset_ids = AssetPermission.objects \
.filter(id__in=self.perm_ids) \
.values_list('assets', flat=True)
assets = Asset.objects.filter(id__in=asset_ids).distinct()
if flat:
return assets.values_list('id', flat=True)
return set(assets.values_list('id', flat=True))
return assets
@ -52,12 +66,62 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
def get_ungroup_assets(self):
return self.get_direct_assets()
@timeit
def get_favorite_assets(self):
assets = self.get_all_assets()
assets = Asset.objects.all().valid()
asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
assets = assets.filter(id__in=list(asset_ids))
return assets
def get_type_nodes_tree(self):
assets = self.get_all_assets()
resource_platforms = assets.order_by('id').values_list('platform_id', flat=True)
node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=True)
pattern = re.compile(r'\(0\)?')
nodes = []
for node in node_all:
meta = node.get('meta', {})
if pattern.search(node['name']) or meta.get('type') == 'platform':
continue
_type = meta.get('_type')
if _type:
node['type'] = _type
node['category'] = meta.get('category')
meta.setdefault('data', {})
node['meta'] = meta
nodes.append(node)
return nodes
@classmethod
def get_type_nodes_tree_or_cached(cls, user):
key = f'perms:type-nodes-tree:{user.id}:{current_org.id}'
nodes = cache.get(key)
if nodes is None:
nodes = cls(user).get_type_nodes_tree()
nodes_json = json.dumps(nodes, cls=JSONEncoder)
cache.set(key, nodes_json, 60 * 60 * 24)
else:
nodes = json.loads(nodes)
return nodes
def refresh_type_nodes_tree_cache(self):
logger.debug("Refresh type nodes tree cache")
key = f'perms:type-nodes-tree:{self.user.id}:{current_org.id}'
cache.delete(key)
def refresh_favorite_assets(self):
favor_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True)
favor_ids = set(favor_ids)
with tmp_to_root_org():
valid_ids = self.get_all_assets() \
.filter(id__in=favor_ids) \
.values_list('id', flat=True)
valid_ids = set(valid_ids)
invalid_ids = favor_ids - valid_ids
FavoriteAsset.objects.filter(user=self.user, asset_id__in=invalid_ids).delete()
def get_node_assets(self, key):
node = PermNode.objects.get(key=key)
node.compute_node_from_and_assets_amount(self.user)
@ -90,6 +154,7 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil):
assets = assets.filter(nodes__id=node.id).order_by().distinct()
return assets
@timeit
def _get_indirect_perm_node_all_assets(self, node):
""" 获取间接授权节点下的所有资产
此算法依据 `UserAssetGrantedTreeNodeRelation` 的数据查询
@ -134,7 +199,11 @@ class UserPermNodeUtil:
self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True)
def get_favorite_node(self):
assets_amount = UserPermAssetUtil(self.user).get_favorite_assets().count()
favor_ids = FavoriteAsset.objects \
.filter(user=self.user) \
.values_list('asset_id') \
.distinct()
assets_amount = Asset.objects.all().valid().filter(id__in=favor_ids).count()
return PermNode.get_favorite_node(assets_amount)
def get_ungrouped_node(self):

View File

@ -3,11 +3,12 @@ from collections import defaultdict
from django.conf import settings
from django.core.cache import cache
from django.db import transaction
from assets.models import Asset
from assets.utils import NodeAssetsUtil
from common.db.models import output_as_string
from common.decorators import on_transaction_commit
from common.decorators import on_transaction_commit, merge_delay_run
from common.utils import get_logger
from common.utils.common import lazyproperty, timeit
from orgs.models import Organization
@ -23,6 +24,7 @@ from perms.models import (
PermNode
)
from users.models import User
from . import UserPermAssetUtil
from .permission import AssetPermissionUtil
logger = get_logger(__name__)
@ -50,24 +52,74 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
def __init__(self, user):
self.user = user
self.orgs = self.user.orgs.distinct()
self.org_ids = [str(o.id) for o in self.orgs]
@lazyproperty
def orgs(self):
return self.user.orgs.distinct()
@lazyproperty
def org_ids(self):
return [str(o.id) for o in self.orgs]
@lazyproperty
def cache_key_user(self):
return self.get_cache_key(self.user.id)
@lazyproperty
def cache_key_time(self):
key = 'perms.user.node_tree.built_time.{}'.format(self.user.id)
return key
@timeit
def refresh_if_need(self, force=False):
self._clean_user_perm_tree_for_legacy_org()
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs()
if not to_refresh_orgs:
logger.info('Not have to refresh orgs')
return
with UserGrantedTreeRebuildLock(self.user.id):
logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs]))
refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),))
refresh_user_favorite_assets(users=(self.user,))
@timeit
def refresh_tree_manual(self):
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh just now, pass: {}'.format(built_just_now))
return
to_refresh_orgs = self._get_user_need_refresh_orgs()
if not to_refresh_orgs:
logger.info('Not have to refresh orgs for user: {}'.format(self.user))
return
self.perform_refresh_user_tree(to_refresh_orgs)
@timeit
def perform_refresh_user_tree(self, to_refresh_orgs):
# 再判断一次,毕竟构建树比较慢
built_just_now = cache.get(self.cache_key_time)
if built_just_now:
logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now))
return
self._clean_user_perm_tree_for_legacy_org()
ttl = settings.PERM_TREE_REGEN_INTERVAL
cache.set(self.cache_key_time, int(time.time()), ttl)
lock = UserGrantedTreeRebuildLock(self.user.id)
got = lock.acquire(blocking=False)
if not got:
logger.info('User perm tree rebuild lock not acquired, pass')
return
try:
for org in to_refresh_orgs:
self._rebuild_user_perm_tree_for_org(org)
self._mark_user_orgs_refresh_finished(to_refresh_orgs)
self._mark_user_orgs_refresh_finished(to_refresh_orgs)
finally:
lock.release()
def _rebuild_user_perm_tree_for_org(self, org):
with tmp_to_org(org):
@ -75,7 +127,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree()
end = time.time()
logger.info(
'Refresh user [{user}] org [{org}] perm tree, user {use_time:.2f}s'
'Refresh user perm tree: [{user}] org [{org}] {use_time:.2f}s'
''.format(user=self.user, org=org, use_time=end - start)
)
@ -90,7 +142,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin):
cached_org_ids = self.client.smembers(self.cache_key_user)
cached_org_ids = {oid.decode() for oid in cached_org_ids}
to_refresh_org_ids = set(self.org_ids) - cached_org_ids
to_refresh_orgs = Organization.objects.filter(id__in=to_refresh_org_ids)
to_refresh_orgs = list(Organization.objects.filter(id__in=to_refresh_org_ids))
logger.info(f'Need to refresh orgs: {to_refresh_orgs}')
return to_refresh_orgs
@ -128,7 +180,8 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids)
def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids):
user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \
user_ids = User.groups.through.objects \
.filter(usergroup_id__in=group_ids) \
.values_list('user_id', flat=True).distinct()
self.expire_perm_tree_for_users_orgs(user_ids, org_ids)
@ -151,6 +204,21 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin):
logger.info('Expire all user perm tree')
@merge_delay_run(ttl=20)
def refresh_user_orgs_perm_tree(user_orgs=()):
for user, orgs in user_orgs:
util = UserPermTreeRefreshUtil(user)
util.perform_refresh_user_tree(orgs)
@merge_delay_run(ttl=20)
def refresh_user_favorite_assets(users=()):
for user in users:
util = UserPermAssetUtil(user)
util.refresh_favorite_assets()
util.refresh_type_nodes_tree_cache()
class UserPermTreeBuildUtil(object):
node_only_fields = ('id', 'key', 'parent_key', 'org_id')
@ -161,13 +229,14 @@ class UserPermTreeBuildUtil(object):
self._perm_nodes_key_node_mapper = {}
def rebuild_user_perm_tree(self):
self.clean_user_perm_tree()
if not self.user_perm_ids:
logger.info('User({}) not have permissions'.format(self.user))
return
self.compute_perm_nodes()
self.compute_perm_nodes_asset_amount()
self.create_mapping_nodes()
with transaction.atomic():
self.clean_user_perm_tree()
if not self.user_perm_ids:
logger.info('User({}) not have permissions'.format(self.user))
return
self.compute_perm_nodes()
self.compute_perm_nodes_asset_amount()
self.create_mapping_nodes()
def clean_user_perm_tree(self):
UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete()

View File

@ -139,7 +139,7 @@ class RBACPermission(permissions.DjangoModelPermissions):
if isinstance(perms, str):
perms = [perms]
has = request.user.has_perms(perms)
logger.debug('View require perms: {}, result: {}'.format(perms, has))
logger.debug('Api require perms: {}, result: {}'.format(perms, has))
return has
def has_object_permission(self, request, view, obj):

View File

@ -4,11 +4,12 @@
from smtplib import SMTPSenderRefused
from django.conf import settings
from django.core.mail import send_mail, get_connection
from django.core.mail import send_mail
from django.utils.translation import gettext_lazy as _
from rest_framework.views import Response, APIView
from common.utils import get_logger
from common.tasks import get_email_connection as get_connection
from .. import serializers
logger = get_logger(__file__)

View File

@ -137,7 +137,7 @@ class LDAPUserImportAPI(APIView):
return Response({'msg': _('Get ldap users is None')}, status=400)
orgs = self.get_orgs()
errors = LDAPImportUtil().perform_import(users, orgs)
new_users, errors = LDAPImportUtil().perform_import(users, orgs)
if errors:
return Response({'errors': errors}, status=400)

View File

@ -0,0 +1,36 @@
from django.template.loader import render_to_string
from django.utils.translation import gettext as _
from common.utils import get_logger
from common.utils.timezone import local_now_display
from notifications.notifications import UserMessage
logger = get_logger(__file__)
class LDAPImportMessage(UserMessage):
def __init__(self, user, extra_kwargs):
super().__init__(user)
self.orgs = extra_kwargs.pop('orgs', [])
self.end_time = extra_kwargs.pop('end_time', '')
self.start_time = extra_kwargs.pop('start_time', '')
self.time_start_display = extra_kwargs.pop('time_start_display', '')
self.new_users = extra_kwargs.pop('new_users', [])
self.errors = extra_kwargs.pop('errors', [])
self.cost_time = extra_kwargs.pop('cost_time', '')
def get_html_msg(self) -> dict:
subject = _('Notification of Synchronized LDAP User Task Results')
context = {
'orgs': self.orgs,
'start_time': self.time_start_display,
'end_time': local_now_display(),
'cost_time': self.cost_time,
'users': self.new_users,
'errors': self.errors
}
message = render_to_string('ldap/_msg_import_ldap_user.html', context)
return {
'subject': subject,
'message': message
}

View File

@ -77,6 +77,9 @@ class LDAPSettingSerializer(serializers.Serializer):
required=False, label=_('Connect timeout (s)'),
)
AUTH_LDAP_SEARCH_PAGED_SIZE = serializers.IntegerField(required=False, label=_('Search paged size (piece)'))
AUTH_LDAP_SYNC_RECEIVERS = serializers.ListField(
required=False, label=_('Recipient'), max_length=36
)
AUTH_LDAP = serializers.BooleanField(required=False, label=_('Enable LDAP auth'))

View File

@ -55,6 +55,17 @@ class VaultSettingSerializer(serializers.Serializer):
max_length=256, allow_blank=True, required=False, label=_('Mount Point')
)
HISTORY_ACCOUNT_CLEAN_LIMIT = serializers.IntegerField(
default=999, max_value=999, min_value=1,
required=False, label=_('Historical accounts retained count'),
help_text=_(
'If the specific value is less than 999, '
'the system will automatically perform a task every night: '
'check and delete historical accounts that exceed the predetermined number. '
'If the value reaches or exceeds 999, no historical account deletion will be performed.'
)
)
class ChatAISettingSerializer(serializers.Serializer):
PREFIX_TITLE = _('Chat AI')

View File

@ -1,11 +1,12 @@
# coding: utf-8
#
from django.db import models
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.serializers.fields import EncryptedField
__all__ = [
'MailTestSerializer', 'EmailSettingSerializer',
'EmailContentSettingSerializer', 'SMSBackendSerializer',
@ -18,14 +19,20 @@ class MailTestSerializer(serializers.Serializer):
class EmailSettingSerializer(serializers.Serializer):
# encrypt_fields 现在使用 write_only 来判断了
PREFIX_TITLE = _('Email')
EMAIL_HOST = serializers.CharField(max_length=1024, required=True, label=_("SMTP host"))
EMAIL_PORT = serializers.CharField(max_length=5, required=True, label=_("SMTP port"))
EMAIL_HOST_USER = serializers.CharField(max_length=128, required=True, label=_("SMTP account"))
class EmailProtocol(models.TextChoices):
smtp = 'smtp', _('SMTP')
exchange = 'exchange', _('EXCHANGE')
EMAIL_PROTOCOL = serializers.ChoiceField(
choices=EmailProtocol.choices, label=_("Protocol"), default=EmailProtocol.smtp
)
EMAIL_HOST = serializers.CharField(max_length=1024, required=True, label=_("Host"))
EMAIL_PORT = serializers.CharField(max_length=5, required=True, label=_("Port"))
EMAIL_HOST_USER = serializers.CharField(max_length=128, required=True, label=_("Account"))
EMAIL_HOST_PASSWORD = EncryptedField(
max_length=1024, required=False, label=_("SMTP password"),
max_length=1024, required=False, label=_("Password"),
help_text=_("Tips: Some provider use token except password")
)
EMAIL_FROM = serializers.CharField(

View File

@ -1,15 +1,19 @@
# coding: utf-8
#
import time
from celery import shared_task
from django.conf import settings
from django.utils.translation import gettext_lazy as _
from common.utils import get_logger
from common.utils.timezone import local_now_display
from ops.celery.decorator import after_app_ready_start
from ops.celery.utils import (
create_or_update_celery_periodic_tasks, disable_celery_periodic_task
)
from orgs.models import Organization
from settings.notifications import LDAPImportMessage
from users.models import User
from ..utils import LDAPSyncUtil, LDAPServerUtil, LDAPImportUtil
__all__ = ['sync_ldap_user', 'import_ldap_user_periodic', 'import_ldap_user']
@ -23,6 +27,8 @@ def sync_ldap_user():
@shared_task(verbose_name=_('Periodic import ldap user'))
def import_ldap_user():
start_time = time.time()
time_start_display = local_now_display()
logger.info("Start import ldap user task")
util_server = LDAPServerUtil()
util_import = LDAPImportUtil()
@ -35,11 +41,26 @@ def import_ldap_user():
org_ids = [Organization.DEFAULT_ID]
default_org = Organization.default()
orgs = list(set([Organization.get_instance(org_id, default=default_org) for org_id in org_ids]))
errors = util_import.perform_import(users, orgs)
new_users, errors = util_import.perform_import(users, orgs)
if errors:
logger.error("Imported LDAP users errors: {}".format(errors))
else:
logger.info('Imported {} users successfully'.format(len(users)))
if settings.AUTH_LDAP_SYNC_RECEIVERS:
user_ids = settings.AUTH_LDAP_SYNC_RECEIVERS
recipient_list = User.objects.filter(id__in=list(user_ids))
end_time = time.time()
extra_kwargs = {
'orgs': orgs,
'end_time': end_time,
'start_time': start_time,
'time_start_display': time_start_display,
'new_users': new_users,
'errors': errors,
'cost_time': end_time - start_time,
}
for user in recipient_list:
LDAPImportMessage(user, extra_kwargs).publish()
@shared_task(verbose_name=_('Registration periodic import ldap user task'))

View File

@ -0,0 +1,30 @@
{% load i18n %}
<p>{% trans "Sync task Finish" %}</p>
<b>{% trans 'Time' %}:</b>
<ul>
<li>{% trans 'Date start' %}: {{ start_time }}</li>
<li>{% trans 'Date end' %}: {{ end_time }}</li>
<li>{% trans 'Time cost' %}: {{ cost_time| floatformat:0 }}s</li>
</ul>
<b>{% trans "Synced Organization" %}:</b>
<ul>
{% for org in orgs %}
<li>{{ org }}</li>
{% endfor %}
</ul>
<b>{% trans "Synced User" %}:</b>
<ul>
{% for user in users %}
<li>{{ user }}</li>
{% endfor %}
</ul>
{% if errors %}
<b>{% trans 'Error' %}:</b>
<ul>
{% for error in errors %}
<li>{{ error }}</li>
{% endfor %}
</ul>
{% endif %}

View File

@ -400,11 +400,14 @@ class LDAPImportUtil(object):
logger.info('Start perform import ldap users, count: {}'.format(len(users)))
errors = []
objs = []
new_users = []
group_users_mapper = defaultdict(set)
for user in users:
groups = user.pop('groups', [])
try:
obj, created = self.update_or_create(user)
if created:
new_users.append(obj)
objs.append(obj)
except Exception as e:
errors.append({user['username']: str(e)})
@ -421,14 +424,13 @@ class LDAPImportUtil(object):
for org in orgs:
self.bind_org(org, objs, group_users_mapper)
logger.info('End perform import ldap users')
return errors
return new_users, errors
@staticmethod
def exit_user_group(user_groups_mapper):
def exit_user_group(self, user_groups_mapper):
# 通过对比查询本次导入用户需要移除的用户组
group_remove_users_mapper = defaultdict(set)
for user, current_groups in user_groups_mapper.items():
old_groups = set(user.groups.all())
old_groups = set(user.groups.filter(name__startswith=self.user_group_name_prefix))
exit_groups = old_groups - current_groups
logger.debug(f'Ldap user {user} exits user groups {exit_groups}')
for g in exit_groups:

View File

@ -4,6 +4,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import MethodNotAllowed
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from audits.handler import create_or_update_operate_log
@ -41,7 +42,6 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
ordering = ('-date_created',)
rbac_perms = {
'open': 'tickets.view_ticket',
'bulk': 'tickets.change_ticket',
}
def retrieve(self, request, *args, **kwargs):
@ -122,7 +122,7 @@ class TicketViewSet(CommonApiMixin, viewsets.ModelViewSet):
self._record_operate_log(instance, TicketAction.close)
return Response('ok')
@action(detail=False, methods=[PUT], permission_classes=[RBACPermission, ])
@action(detail=False, methods=[PUT], permission_classes=[IsAuthenticated, ])
def bulk(self, request, *args, **kwargs):
self.ticket_not_allowed()

View File

@ -6,7 +6,7 @@ from rest_framework.response import Response
from orgs.mixins.api import OrgBulkModelViewSet
from ..models import UserGroup, User
from ..serializers import UserGroupSerializer
from ..serializers import UserGroupSerializer, UserGroupListSerializer
__all__ = ['UserGroupViewSet']
@ -15,7 +15,10 @@ class UserGroupViewSet(OrgBulkModelViewSet):
model = UserGroup
filterset_fields = ("name",)
search_fields = filterset_fields
serializer_class = UserGroupSerializer
serializer_classes = {
'default': UserGroupSerializer,
'list': UserGroupListSerializer,
}
ordering = ('name',)
rbac_perms = (
("add_all_users", "users.add_usergroup"),

View File

@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
#
from django.db.models import Count
from django.db.models import Count, Q
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from common.serializers.fields import ObjectRelatedField
from common.serializers.mixin import ResourceLabelsMixin
@ -10,7 +11,7 @@ from .. import utils
from ..models import User, UserGroup
__all__ = [
'UserGroupSerializer',
'UserGroupSerializer', 'UserGroupListSerializer',
]
@ -29,7 +30,6 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
fields = fields_mini + fields_small + ['users', 'labels']
extra_kwargs = {
'created_by': {'label': _('Created by'), 'read_only': True},
'users_amount': {'label': _('Users amount')},
'id': {'label': _('ID')},
}
@ -45,6 +45,17 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer):
@classmethod
def setup_eager_loading(cls, queryset):
""" Perform necessary eager loading of data. """
queryset = queryset.prefetch_related('users', 'labels', 'labels__label') \
.annotate(users_amount=Count('users'))
queryset = queryset.prefetch_related('labels', 'labels__label') \
.annotate(users_amount=Count('users', filter=Q(users__is_service_account=False)))
return queryset
class UserGroupListSerializer(UserGroupSerializer):
users_amount = serializers.IntegerField(label=_('Users amount'), read_only=True)
class Meta(UserGroupSerializer.Meta):
fields = list(set(UserGroupSerializer.Meta.fields + ['users_amount']) - {'users'})
extra_kwargs = {
**UserGroupSerializer.Meta.extra_kwargs,
'users_amount': {'label': _('Users amount')},
}

View File

@ -163,9 +163,9 @@ def on_openid_create_or_update_user(sender, request, user, created, name, userna
user.save()
@shared_task(verbose_name=_('Clean audits session task log'))
@shared_task(verbose_name=_('Clean up expired user sessions'))
@register_as_period_task(crontab=CRONTAB_AT_PM_TWO)
def clean_audits_log_period():
def clean_expired_user_session_period():
UserSession.clear_expired_sessions()

View File

@ -12,6 +12,7 @@ from django.utils.translation import gettext as _
from django.views.generic import FormView, RedirectView
from authentication.errors import IntervalTooShort
from authentication.utils import check_user_property_is_correct
from common.utils import FlashMessageUtil, get_object_or_none, random_string
from common.utils.verify_code import SendAndVerifyCodeUtil
from users.notifications import ResetPasswordSuccessMsg
@ -148,7 +149,6 @@ class UserForgotPasswordView(FormView):
query_key = form_type
if form_type == 'sms':
query_key = 'phone'
target = target.lstrip('+')
try:
self.safe_verify_code(token, target, form_type, code)
@ -158,7 +158,7 @@ class UserForgotPasswordView(FormView):
form.add_error('code', str(e))
return super().form_invalid(form)
user = get_object_or_none(User, **{'username': username, query_key: target})
user = check_user_property_is_correct(username, **{query_key: target})
if not user:
form.add_error('code', _('No user matched'))
return super().form_invalid(form)

1501
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -78,7 +78,7 @@ geoip2 = "4.7.0"
ipip-ipdb = "1.6.1"
pywinrm = "0.4.3"
python-nmap = "0.7.1"
django = "4.1.10"
django = "4.1.13"
django-bootstrap3 = "23.4"
django-filter = "23.2"
django-formtools = "2.4.1"
@ -97,7 +97,7 @@ drf-yasg = "1.21.7"
coreapi = "2.3.3"
coreschema = "0.0.4"
openapi-codec = "1.3.2"
pillow = "10.0.0"
pillow = "10.0.1"
pytz = "2023.3"
django-proxy = "1.2.2"
python-daemon = "3.0.1"
@ -127,7 +127,7 @@ python-redis-lock = "4.0.0"
pyopenssl = "23.2.0"
redis = "4.6.0"
pymongo = "4.4.1"
pyfreerdp = "0.0.1"
pyfreerdp = "0.0.2"
ipython = "8.14.0"
forgerypy3 = "0.3.1"
django-debug-toolbar = "4.1.0"
@ -143,9 +143,10 @@ fido2 = "^1.1.2"
ua-parser = "^0.18.0"
user-agents = "^2.2.0"
django-cors-headers = "^4.3.0"
mistune = "0.8.4"
mistune = "2.0.3"
openai = "^1.3.7"
xlsxwriter = "^3.1.9"
exchangelib = "^5.1.0"
[tool.poetry.group.xpack.dependencies]
@ -154,8 +155,7 @@ azure-mgmt-subscription = "3.1.1"
azure-identity = "1.13.0"
azure-mgmt-compute = "30.0.0"
azure-mgmt-network = "23.1.0"
google-cloud-compute = "1.13.0"
grpcio = "1.56.2"
google-cloud-compute = "1.15.0"
alibabacloud-dysmsapi20170525 = "2.0.24"
python-novaclient = "18.3.0"
python-keystoneclient = "5.1.0"

View File

@ -17,6 +17,7 @@ from resources.assets import AssetsGenerator, NodesGenerator, PlatformGenerator
from resources.users import UserGroupGenerator, UserGenerator
from resources.perms import AssetPermissionGenerator
from resources.terminal import CommandGenerator, SessionGenerator
from resources.accounts import AccountGenerator
resource_generator_mapper = {
'asset': AssetsGenerator,
@ -27,6 +28,7 @@ resource_generator_mapper = {
'asset_permission': AssetPermissionGenerator,
'command': CommandGenerator,
'session': SessionGenerator,
'account': AccountGenerator,
'all': None
# 'stat': StatGenerator
}
@ -45,6 +47,7 @@ def main():
parser.add_argument('-o', '--org', type=str, default='')
args = parser.parse_args()
resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org
resource = resource.lower().rstrip('s')
generator_cls = []
if resource == 'all':

View File

@ -0,0 +1,32 @@
import random
import forgery_py
from accounts.models import Account
from assets.models import Asset
from .base import FakeDataGenerator
class AccountGenerator(FakeDataGenerator):
resource = 'account'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.assets = list(list(Asset.objects.all()[:5000]))
def do_generate(self, batch, batch_size):
accounts = []
for i in batch:
asset = random.choice(self.assets)
name = forgery_py.internet.user_name(True) + '-' + str(i)
d = {
'username': name,
'name': name,
'asset': asset,
'secret': name,
'secret_type': 'password',
'is_active': True,
'privileged': False,
}
accounts.append(Account(**d))
Account.objects.bulk_create(accounts, ignore_conflicts=True)

View File

@ -48,7 +48,7 @@ class AssetsGenerator(FakeDataGenerator):
def pre_generate(self):
self.node_ids = list(Node.objects.all().values_list('id', flat=True))
self.platform_ids = list(Platform.objects.all().values_list('id', flat=True))
self.platform_ids = list(Platform.objects.filter(category='host').values_list('id', flat=True))
def set_assets_nodes(self, assets):
for asset in assets:
@ -72,6 +72,17 @@ class AssetsGenerator(FakeDataGenerator):
assets.append(Asset(**data))
creates = Asset.objects.bulk_create(assets, ignore_conflicts=True)
self.set_assets_nodes(creates)
self.set_asset_platform(creates)
@staticmethod
def set_asset_platform(assets):
protocol = random.choice(['ssh', 'rdp', 'telnet', 'vnc'])
protocols = []
for asset in assets:
port = 22 if protocol == 'ssh' else 3389
protocols.append(Protocol(asset=asset, name=protocol, port=port))
Protocol.objects.bulk_create(protocols, ignore_conflicts=True)
def after_generate(self):
pass

View File

@ -41,7 +41,7 @@ class FakeDataGenerator:
start = time.time()
self.do_generate(batch, self.batch_size)
end = time.time()
using = end - start
using = round(end - start, 3)
from_size = created
created += len(batch)
print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using))

View File

@ -1,9 +1,11 @@
from random import choice, sample
from random import sample
import forgery_py
from .base import FakeDataGenerator
from orgs.utils import current_org
from rbac.models import RoleBinding, Role
from users.models import *
from .base import FakeDataGenerator
class UserGroupGenerator(FakeDataGenerator):
@ -47,3 +49,12 @@ class UserGenerator(FakeDataGenerator):
users.append(u)
users = User.objects.bulk_create(users, ignore_conflicts=True)
self.set_groups(users)
self.set_to_org(users)
def set_to_org(self, users):
bindings = []
role = Role.objects.get(name='OrgUser')
for u in users:
b = RoleBinding(user=u, role=role, org_id=current_org.id, scope='org')
bindings.append(b)
RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)