From 2fcbfe9f2105c22d00308170cb2e50e2d1af0ceb Mon Sep 17 00:00:00 2001 From: fit2bot <68588906+fit2bot@users.noreply.github.com> Date: Tue, 2 Jan 2024 16:11:56 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=20tree=20nodes=20?= =?UTF-8?q?=E9=81=BF=E5=85=8D=E5=A4=AA=E6=85=A2=20(#12472)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf: 优化 tree nodes 避免太慢 perf: 优化大量资产上的资产数生成比较慢 perf: 优化节点树 perf: 修改 tree nooooooooodes perf: 优化一些 api 比较大的问题 perf: 优化平台 api perf: 分页返回同步树 perf: 优化节点树 perf: 深度优化节点树 * perf: remove unused config --------- Co-authored-by: ibuler --- apps/assets/api/asset/asset.py | 7 +- apps/assets/api/domain.py | 9 +- apps/assets/api/mixin.py | 38 ++-- apps/assets/api/platform.py | 4 +- apps/assets/api/tree.py | 2 + apps/assets/const/types.py | 6 +- apps/assets/filters.py | 3 +- apps/assets/models/node.py | 34 ++-- apps/assets/serializers/domain.py | 26 ++- apps/assets/signal_handlers/asset.py | 7 +- .../signal_handlers/node_assets_amount.py | 4 +- .../signal_handlers/node_assets_mapping.py | 7 +- apps/common/api/mixin.py | 13 +- apps/common/decorators.py | 2 +- apps/common/signal_handlers.py | 6 +- apps/common/utils/common.py | 2 +- apps/common/utils/lock.py | 22 +-- apps/jumpserver/conf.py | 1 + apps/jumpserver/settings/custom.py | 1 + apps/jumpserver/settings/logging.py | 2 +- apps/orgs/signal_handlers/cache.py | 4 +- apps/perms/api/asset_permission.py | 6 +- apps/perms/api/user_permission/tree/asset.py | 9 +- .../user_permission/tree/node_with_asset.py | 164 +++++++++--------- apps/perms/models/asset_permission.py | 27 ++- apps/perms/serializers/permission.py | 35 +++- apps/perms/signal_handlers/refresh_perms.py | 11 +- apps/perms/utils/permission.py | 4 +- apps/perms/utils/user_perm.py | 94 ++++++++-- apps/perms/utils/user_perm_tree.py | 101 +++++++++-- apps/rbac/permissions.py | 2 +- apps/users/api/group.py | 7 +- apps/users/serializers/group.py | 17 +- utils/generate_fake_data/generate.py | 3 + .../generate_fake_data/resources/accounts.py | 32 ++++ utils/generate_fake_data/resources/assets.py | 13 +- utils/generate_fake_data/resources/base.py | 2 +- utils/generate_fake_data/resources/users.py | 17 +- 38 files changed, 508 insertions(+), 236 deletions(-) create mode 100644 utils/generate_fake_data/resources/accounts.py diff --git a/apps/assets/api/asset/asset.py b/apps/assets/api/asset/asset.py index eef18a773..b6e532b36 100644 --- a/apps/assets/api/asset/asset.py +++ b/apps/assets/api/asset/asset.py @@ -21,7 +21,6 @@ from common.drf.filters import BaseFilterSet, AttrRulesFilterBackend from common.utils import get_logger, is_uuid from orgs.mixins import generics from orgs.mixins.api import OrgBulkModelViewSet -from ..mixin import NodeFilterMixin from ...notifications import BulkUpdatePlatformSkipAssetUserMsg logger = get_logger(__file__) @@ -86,7 +85,7 @@ class AssetFilterSet(BaseFilterSet): return queryset.filter(protocols__name__in=value).distinct() -class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet): +class AssetViewSet(SuggestionMixin, OrgBulkModelViewSet): """ API endpoint that allows Asset to be viewed or edited. """ @@ -114,9 +113,7 @@ class AssetViewSet(SuggestionMixin, NodeFilterMixin, OrgBulkModelViewSet): ] def get_queryset(self): - queryset = super().get_queryset() \ - .prefetch_related('nodes', 'protocols') \ - .select_related('platform', 'domain') + queryset = super().get_queryset() if queryset.model is not Asset: queryset = queryset.select_related('asset_ptr') return queryset diff --git a/apps/assets/api/domain.py b/apps/assets/api/domain.py index 86097247c..46b586458 100644 --- a/apps/assets/api/domain.py +++ b/apps/assets/api/domain.py @@ -20,14 +20,15 @@ class DomainViewSet(OrgBulkModelViewSet): filterset_fields = ("name",) search_fields = filterset_fields ordering = ('name',) + serializer_classes = { + 'default': serializers.DomainSerializer, + 'list': serializers.DomainListSerializer, + } def get_serializer_class(self): if self.request.query_params.get('gateway'): return serializers.DomainWithGatewaySerializer - return serializers.DomainSerializer - - def get_queryset(self): - return super().get_queryset().prefetch_related('assets') + return super().get_serializer_class() class GatewayViewSet(HostViewSet): diff --git a/apps/assets/api/mixin.py b/apps/assets/api/mixin.py index d3ea4761d..25aefa62a 100644 --- a/apps/assets/api/mixin.py +++ b/apps/assets/api/mixin.py @@ -2,7 +2,7 @@ from typing import List from rest_framework.request import Request -from assets.models import Node, Protocol +from assets.models import Node, Platform, Protocol from assets.utils import get_node_from_request, is_query_node_all_assets from common.utils import lazyproperty, timeit @@ -71,37 +71,43 @@ class SerializeToTreeNodeMixin: return 'file' @timeit - def serialize_assets(self, assets, node_key=None, pid=None): - if node_key is None: - get_pid = lambda asset: getattr(asset, 'parent_key', '') - else: - get_pid = lambda asset: node_key + def serialize_assets(self, assets, node_key=None, get_pid=None): + if not get_pid and not node_key: + get_pid = lambda asset, platform: getattr(asset, 'parent_key', '') + sftp_asset_ids = Protocol.objects.filter(name='sftp') \ .values_list('asset_id', flat=True) - sftp_asset_ids = list(sftp_asset_ids) - data = [ - { + sftp_asset_ids = set(sftp_asset_ids) + platform_map = {p.id: p for p in Platform.objects.all()} + + data = [] + for asset in assets: + platform = platform_map.get(asset.platform_id) + if not platform: + continue + pid = node_key or get_pid(asset, platform) + if not pid or pid.isdigit(): + continue + data.append({ 'id': str(asset.id), 'name': asset.name, - 'title': f'{asset.address}\n{asset.comment}', - 'pId': pid or get_pid(asset), + 'title': f'{asset.address}\n{asset.comment}'.strip(), + 'pId': pid, 'isParent': False, 'open': False, - 'iconSkin': self.get_icon(asset), + 'iconSkin': self.get_icon(platform), 'chkDisabled': not asset.is_active, 'meta': { 'type': 'asset', 'data': { - 'platform_type': asset.platform.type, + 'platform_type': platform.type, 'org_name': asset.org_name, 'sftp': asset.id in sftp_asset_ids, 'name': asset.name, 'address': asset.address }, } - } - for asset in assets - ] + }) return data diff --git a/apps/assets/api/platform.py b/apps/assets/api/platform.py index 449b7b8a3..9f3a55bda 100644 --- a/apps/assets/api/platform.py +++ b/apps/assets/api/platform.py @@ -29,7 +29,9 @@ class AssetPlatformViewSet(JMSModelViewSet): } def get_queryset(self): - queryset = super().get_queryset() + queryset = super().get_queryset().prefetch_related( + 'protocols', 'automation' + ) queryset = queryset.filter(type__in=AllTypes.get_types_values()) return queryset diff --git a/apps/assets/api/tree.py b/apps/assets/api/tree.py index 762859ca1..277ab323e 100644 --- a/apps/assets/api/tree.py +++ b/apps/assets/api/tree.py @@ -126,6 +126,8 @@ class NodeChildrenAsTreeApi(SerializeToTreeNodeMixin, NodeChildrenApi): include_assets = self.request.query_params.get('assets', '0') == '1' if not self.instance or not include_assets: return Asset.objects.none() + if self.instance.is_org_root(): + return Asset.objects.none() if query_all: assets = self.instance.get_all_assets() else: diff --git a/apps/assets/const/types.py b/apps/assets/const/types.py index c33052c64..220d10731 100644 --- a/apps/assets/const/types.py +++ b/apps/assets/const/types.py @@ -268,7 +268,7 @@ class AllTypes(ChoicesMixin): meta = {'type': 'category', 'category': category.value, '_type': category.value} category_node = cls.choice_to_node(category, 'ROOT', meta=meta) category_count = category_type_mapper.get(category, 0) - category_node['name'] += f'({category_count})' + category_node['name'] += f' ({category_count})' nodes.append(category_node) # Type 格式化 @@ -277,7 +277,7 @@ class AllTypes(ChoicesMixin): meta = {'type': 'type', 'category': category.value, '_type': tp.value} tp_node = cls.choice_to_node(tp, category_node['id'], opened=False, meta=meta) tp_count = category_type_mapper.get(category + '_' + tp, 0) - tp_node['name'] += f'({tp_count})' + tp_node['name'] += f' ({tp_count})' platforms = tp_platforms.get(category + '_' + tp, []) if not platforms: tp_node['isParent'] = False @@ -286,7 +286,7 @@ class AllTypes(ChoicesMixin): # Platform 格式化 for p in platforms: platform_node = cls.platform_to_node(p, tp_node['id'], include_asset) - platform_node['name'] += f'({platform_count.get(p.id, 0)})' + platform_node['name'] += f' ({platform_count.get(p.id, 0)})' nodes.append(platform_node) return nodes diff --git a/apps/assets/filters.py b/apps/assets/filters.py index f1fe6d666..1475f9b2e 100644 --- a/apps/assets/filters.py +++ b/apps/assets/filters.py @@ -63,11 +63,10 @@ class NodeFilterBackend(filters.BaseFilterBackend): query_all = is_query_node_all_assets(request) if query_all: return queryset.filter( - Q(nodes__key__istartswith=f'{node.key}:') | + Q(nodes__key__startswith=f'{node.key}:') | Q(nodes__key=node.key) ).distinct() else: - print("Query query origin: ", queryset.count()) return queryset.filter(nodes__key=node.key).distinct() diff --git a/apps/assets/models/node.py b/apps/assets/models/node.py index d4fc8165d..7a15b9349 100644 --- a/apps/assets/models/node.py +++ b/apps/assets/models/node.py @@ -13,7 +13,7 @@ from django.db.transaction import atomic from django.utils.translation import gettext_lazy as _, gettext from common.db.models import output_as_string -from common.utils import get_logger +from common.utils import get_logger, timeit from common.utils.lock import DistributedLock from orgs.mixins.models import OrgManager, JMSOrgBaseModel from orgs.models import Organization @@ -195,11 +195,6 @@ class FamilyMixin: ancestor_keys = self.get_ancestor_keys(with_self=with_self) return self.__class__.objects.filter(key__in=ancestor_keys) - # @property - # def parent_key(self): - # parent_key = ":".join(self.key.split(":")[:-1]) - # return parent_key - def compute_parent_key(self): return compute_parent_key(self.key) @@ -349,29 +344,26 @@ class NodeAllAssetsMappingMixin: return 'ASSETS_ORG_NODE_ALL_ASSET_ids_MAPPING_{}'.format(org_id) @classmethod + @timeit def generate_node_all_asset_ids_mapping(cls, org_id): - from .asset import Asset - - logger.info(f'Generate node asset mapping: ' - f'thread={threading.get_ident()} ' - f'org_id={org_id}') + logger.info(f'Generate node asset mapping: org_id={org_id}') t1 = time.time() with tmp_to_org(org_id): node_ids_key = Node.objects.annotate( char_id=output_as_string('id') ).values_list('char_id', 'key') - # * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢) - nodes_asset_ids = Asset.nodes.through.objects.all() \ - .annotate(char_node_id=output_as_string('node_id')) \ - .annotate(char_asset_id=output_as_string('asset_id')) \ - .values_list('char_node_id', 'char_asset_id') - node_id_ancestor_keys_mapping = { node_id: cls.get_node_ancestor_keys(node_key, with_self=True) for node_id, node_key in node_ids_key } + # * 直接取出全部. filter(node__org_id=org_id)(大规模下会更慢) + nodes_asset_ids = cls.assets.through.objects.all() \ + .annotate(char_node_id=output_as_string('node_id')) \ + .annotate(char_asset_id=output_as_string('asset_id')) \ + .values_list('char_node_id', 'char_asset_id') + nodeid_assetsid_mapping = defaultdict(set) for node_id, asset_id in nodes_asset_ids: nodeid_assetsid_mapping[node_id].add(asset_id) @@ -386,7 +378,7 @@ class NodeAllAssetsMappingMixin: mapping[ancestor_key].update(asset_ids) t3 = time.time() - logger.info('t1-t2(DB Query): {} s, t3-t2(Generate mapping): {} s'.format(t2 - t1, t3 - t2)) + logger.info('Generate asset nodes mapping, DB query: {:.2f}s, mapping: {:.2f}s'.format(t2 - t1, t3 - t2)) return mapping @@ -436,6 +428,7 @@ class NodeAssetsMixin(NodeAllAssetsMappingMixin): return asset_ids @classmethod + @timeit def get_nodes_all_assets(cls, *nodes): from .asset import Asset node_ids = set() @@ -559,11 +552,6 @@ class Node(JMSOrgBaseModel, SomeNodesMixin, FamilyMixin, NodeAssetsMixin): def __str__(self): return self.full_value - # def __eq__(self, other): - # if not other: - # return False - # return self.id == other.id - # def __gt__(self, other): self_key = [int(k) for k in self.key.split(':')] other_key = [int(k) for k in other.key.split(':')] diff --git a/apps/assets/serializers/domain.py b/apps/assets/serializers/domain.py index d2b3e3550..9ea603ac3 100644 --- a/apps/assets/serializers/domain.py +++ b/apps/assets/serializers/domain.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # +from django.db.models import Count from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -7,18 +8,15 @@ from common.serializers import ResourceLabelsMixin from common.serializers.fields import ObjectRelatedField from orgs.mixins.serializers import BulkOrgResourceModelSerializer from .gateway import GatewayWithAccountSecretSerializer -from ..models import Domain, Asset +from ..models import Domain -__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer'] +__all__ = ['DomainSerializer', 'DomainWithGatewaySerializer', 'DomainListSerializer'] class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): gateways = ObjectRelatedField( many=True, required=False, label=_('Gateway'), read_only=True, ) - assets = ObjectRelatedField( - many=True, required=False, queryset=Asset.objects, label=_('Asset') - ) class Meta: model = Domain @@ -30,7 +28,9 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): def to_representation(self, instance): data = super().to_representation(instance) - assets = data['assets'] + assets = data.get('assets') + if assets is None: + return data gateway_ids = [str(i['id']) for i in data['gateways']] data['assets'] = [i for i in assets if str(i['id']) not in gateway_ids] return data @@ -49,6 +49,20 @@ class DomainSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): return queryset +class DomainListSerializer(DomainSerializer): + assets_amount = serializers.IntegerField(label=_('Assets amount'), read_only=True) + + class Meta(DomainSerializer.Meta): + fields = list(set(DomainSerializer.Meta.fields + ['assets_amount']) - {'assets'}) + + @classmethod + def setup_eager_loading(cls, queryset): + queryset = queryset.annotate( + assets_amount=Count('assets'), + ) + return queryset + + class DomainWithGatewaySerializer(serializers.ModelSerializer): gateways = GatewayWithAccountSecretSerializer(many=True, read_only=True) diff --git a/apps/assets/signal_handlers/asset.py b/apps/assets/signal_handlers/asset.py index 9598bec5d..5ab4a4117 100644 --- a/apps/assets/signal_handlers/asset.py +++ b/apps/assets/signal_handlers/asset.py @@ -80,10 +80,11 @@ RELATED_NODE_IDS = '_related_node_ids' @receiver(pre_delete, sender=Asset) def on_asset_delete(instance: Asset, using, **kwargs): - logger.debug("Asset pre delete signal recv: {}".format(instance)) node_ids = Node.objects.filter(assets=instance) \ .distinct().values_list('id', flat=True) - setattr(instance, RELATED_NODE_IDS, node_ids) + node_ids = list(node_ids) + logger.debug("Asset pre delete signal recv: {}, node_ids: {}".format(instance, node_ids)) + setattr(instance, RELATED_NODE_IDS, list(node_ids)) m2m_changed.send( sender=Asset.nodes.through, instance=instance, reverse=False, model=Node, pk_set=node_ids, @@ -93,8 +94,8 @@ def on_asset_delete(instance: Asset, using, **kwargs): @receiver(post_delete, sender=Asset) def on_asset_post_delete(instance: Asset, using, **kwargs): - logger.debug("Asset post delete signal recv: {}".format(instance)) node_ids = getattr(instance, RELATED_NODE_IDS, []) + logger.debug("Asset post delete signal recv: {}, node_ids: {}".format(instance, node_ids)) if node_ids: m2m_changed.send( sender=Asset.nodes.through, instance=instance, reverse=False, diff --git a/apps/assets/signal_handlers/node_assets_amount.py b/apps/assets/signal_handlers/node_assets_amount.py index 79e2ed568..ea2b3ba8a 100644 --- a/apps/assets/signal_handlers/node_assets_amount.py +++ b/apps/assets/signal_handlers/node_assets_amount.py @@ -15,8 +15,8 @@ from ..tasks import check_node_assets_amount_task logger = get_logger(__file__) -@on_transaction_commit @receiver(m2m_changed, sender=Asset.nodes.through) +@on_transaction_commit def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): # 不允许 `pre_clear` ,因为该信号没有 `pk_set` # [官网](https://docs.djangoproject.com/en/3.1/ref/signals/#m2m-changed) @@ -37,7 +37,7 @@ def on_node_asset_change(sender, action, instance, reverse, pk_set, **kwargs): update_nodes_assets_amount(node_ids=node_ids) -@merge_delay_run(ttl=5) +@merge_delay_run(ttl=30) def update_nodes_assets_amount(node_ids=()): nodes = Node.objects.filter(id__in=node_ids) nodes = Node.get_ancestor_queryset(nodes) diff --git a/apps/assets/signal_handlers/node_assets_mapping.py b/apps/assets/signal_handlers/node_assets_mapping.py index 34b707dab..a53e534b6 100644 --- a/apps/assets/signal_handlers/node_assets_mapping.py +++ b/apps/assets/signal_handlers/node_assets_mapping.py @@ -21,7 +21,7 @@ logger = get_logger(__name__) node_assets_mapping_pub_sub = lazy(lambda: RedisPubSub('fm.node_asset_mapping'), RedisPubSub)() -@merge_delay_run(ttl=5) +@merge_delay_run(ttl=30) def expire_node_assets_mapping(org_ids=()): logger.debug("Recv asset nodes changed signal, expire memery node asset mapping") # 所有进程清除(自己的 memory 数据) @@ -53,8 +53,9 @@ def on_node_post_delete(sender, instance, **kwargs): @receiver(m2m_changed, sender=Asset.nodes.through) -def on_node_asset_change(sender, instance, **kwargs): - expire_node_assets_mapping(org_ids=(instance.org_id,)) +def on_node_asset_change(sender, instance, action='pre_remove', **kwargs): + if action.startswith('post'): + expire_node_assets_mapping(org_ids=(instance.org_id,)) @receiver(django_ready) diff --git a/apps/common/api/mixin.py b/apps/common/api/mixin.py index c87d5f22a..50bb1efbe 100644 --- a/apps/common/api/mixin.py +++ b/apps/common/api/mixin.py @@ -98,12 +98,17 @@ class QuerySetMixin: return queryset if self.action == 'metadata': queryset = queryset.none() - if self.action in ['list', 'metadata']: - serializer_class = self.get_serializer_class() - if serializer_class and hasattr(serializer_class, 'setup_eager_loading'): - queryset = serializer_class.setup_eager_loading(queryset) return queryset + def paginate_queryset(self, queryset): + page = super().paginate_queryset(queryset) + serializer_class = self.get_serializer_class() + if page and serializer_class and hasattr(serializer_class, 'setup_eager_loading'): + ids = [i.id for i in page] + page = self.get_queryset().filter(id__in=ids) + page = serializer_class.setup_eager_loading(page) + return page + class ExtraFilterFieldsMixin: """ diff --git a/apps/common/decorators.py b/apps/common/decorators.py index 24b2695d3..7269c9e0d 100644 --- a/apps/common/decorators.py +++ b/apps/common/decorators.py @@ -65,7 +65,7 @@ class EventLoopThread(threading.Thread): _loop_thread = EventLoopThread() -_loop_thread.setDaemon(True) +_loop_thread.daemon = True _loop_thread.start() executor = ThreadPoolExecutor( max_workers=10, diff --git a/apps/common/signal_handlers.py b/apps/common/signal_handlers.py index 588a63632..ba5a4ad92 100644 --- a/apps/common/signal_handlers.py +++ b/apps/common/signal_handlers.py @@ -62,7 +62,7 @@ def digest_sql_query(): method = current_request.method path = current_request.get_full_path() - print(">>> [{}] {}".format(method, path)) + print(">>>. [{}] {}".format(method, path)) for table_name, queries in table_queries.items(): if table_name.startswith('rbac_') or table_name.startswith('auth_permission'): continue @@ -77,9 +77,9 @@ def digest_sql_query(): sql = query['sql'] if not sql or not sql.startswith('SELECT'): continue - print('\t{}. {}'.format(i, sql)) + print('\t{}.[{}s] {}'.format(i, round(float(query['time']), 2), sql[:1000])) - logger.debug(">>> [{}] {}".format(method, path)) + # logger.debug(">>> [{}] {}".format(method, path)) for name, counter in counters: logger.debug("Query {:3} times using {:.2f}s {}".format( counter.counter, counter.time, name) diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index 798b4a94c..86b68bef7 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -220,7 +220,7 @@ def timeit(func): now = time.time() result = func(*args, **kwargs) using = (time.time() - now) * 1000 - msg = "End call {}, using: {:.1f}ms".format(name, using) + msg = "Ends call: {}, using: {:.1f}ms".format(name, using) logger.debug(msg) return result diff --git a/apps/common/utils/lock.py b/apps/common/utils/lock.py index 773647725..949d13c6b 100644 --- a/apps/common/utils/lock.py +++ b/apps/common/utils/lock.py @@ -1,18 +1,16 @@ -from functools import wraps import threading +from functools import wraps +from django.db import transaction from redis_lock import ( Lock as RedisLock, NotAcquired, UNLOCK_SCRIPT, EXTEND_SCRIPT, RESET_SCRIPT, RESET_ALL_SCRIPT ) -from redis import Redis -from django.db import transaction -from common.utils import get_logger -from common.utils.inspect import copy_function_args -from common.utils.connection import get_redis_client -from jumpserver.const import CONFIG from common.local import thread_local +from common.utils import get_logger +from common.utils.connection import get_redis_client +from common.utils.inspect import copy_function_args logger = get_logger(__file__) @@ -76,6 +74,7 @@ class DistributedLock(RedisLock): # 要创建一个新的锁对象 with self.__class__(**self.kwargs_copy): return func(*args, **kwds) + return inner @classmethod @@ -95,7 +94,6 @@ class DistributedLock(RedisLock): if self.locked(): owner_id = self.get_owner_id() local_owner_id = getattr(thread_local, self.name, None) - if local_owner_id and owner_id == local_owner_id: return True return False @@ -140,14 +138,16 @@ class DistributedLock(RedisLock): logger.debug(f'Released reentrant-lock: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') return else: - self._raise_exc_with_log(f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') + self._raise_exc_with_log( + f'Reentrant-lock is not acquired: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') def _release_on_reentrant_locked_by_me(self): logger.debug(f'Release reentrant-lock locked by me: lock_id={self.id} lock={self.name}') id = getattr(thread_local, self.name, None) if id != self.id: - raise PermissionError(f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') + raise PermissionError( + f'Reentrant-lock is not locked by me: lock_id={self.id} owner_id={self.get_owner_id()} lock={self.name}') try: # 这里要保证先删除 thread_local 的标记, delattr(thread_local, self.name) @@ -191,7 +191,7 @@ class DistributedLock(RedisLock): # 处理是否在事务提交时才释放锁 if self._release_on_transaction_commit: logger.debug( - f'Release lock on transaction commit ... :lock_id={self.id} lock={self.name}') + f'Release lock on transaction commit:lock_id={self.id} lock={self.name}') transaction.on_commit(_release) else: _release() diff --git a/apps/jumpserver/conf.py b/apps/jumpserver/conf.py index 66aba779e..60e3dfd0c 100644 --- a/apps/jumpserver/conf.py +++ b/apps/jumpserver/conf.py @@ -531,6 +531,7 @@ class Config(dict): 'SYSLOG_SOCKTYPE': 2, 'PERM_EXPIRED_CHECK_PERIODIC': 60 * 60, + 'PERM_TREE_REGEN_INTERVAL': 1, 'FLOWER_URL': "127.0.0.1:5555", 'LANGUAGE_CODE': 'zh', 'TIME_ZONE': 'Asia/Shanghai', diff --git a/apps/jumpserver/settings/custom.py b/apps/jumpserver/settings/custom.py index fad5f0e9d..311e54aa8 100644 --- a/apps/jumpserver/settings/custom.py +++ b/apps/jumpserver/settings/custom.py @@ -208,6 +208,7 @@ OPERATE_LOG_ELASTICSEARCH_CONFIG = CONFIG.OPERATE_LOG_ELASTICSEARCH_CONFIG MAX_LIMIT_PER_PAGE = CONFIG.MAX_LIMIT_PER_PAGE DEFAULT_PAGE_SIZE = CONFIG.DEFAULT_PAGE_SIZE +PERM_TREE_REGEN_INTERVAL = CONFIG.PERM_TREE_REGEN_INTERVAL # Magnus DB Port MAGNUS_ORACLE_PORTS = CONFIG.MAGNUS_ORACLE_PORTS diff --git a/apps/jumpserver/settings/logging.py b/apps/jumpserver/settings/logging.py index c4cf90ebb..bd44e22dc 100644 --- a/apps/jumpserver/settings/logging.py +++ b/apps/jumpserver/settings/logging.py @@ -21,7 +21,7 @@ LOGGING = { }, 'main': { 'datefmt': '%Y-%m-%d %H:%M:%S', - 'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s', + 'format': '%(asctime)s [%(levelname).4s] %(message)s', }, 'exception': { 'datefmt': '%Y-%m-%d %H:%M:%S', diff --git a/apps/orgs/signal_handlers/cache.py b/apps/orgs/signal_handlers/cache.py index b4e6001e3..4dc3c796e 100644 --- a/apps/orgs/signal_handlers/cache.py +++ b/apps/orgs/signal_handlers/cache.py @@ -75,7 +75,7 @@ model_cache_field_mapper = { class OrgResourceStatisticsRefreshUtil: @staticmethod - @merge_delay_run(ttl=5) + @merge_delay_run(ttl=30) def refresh_org_fields(org_fields=()): for org, cache_field_name in org_fields: OrgResourceStatisticsCache(org).expire(*cache_field_name) @@ -104,7 +104,7 @@ def on_post_delete_refresh_org_resource_statistics_cache(sender, instance, **kwa def _refresh_session_org_resource_statistics_cache(instance: Session): cache_field_name = [ 'total_count_online_users', 'total_count_online_sessions', - 'total_count_today_active_assets','total_count_today_failed_sessions' + 'total_count_today_active_assets', 'total_count_today_failed_sessions' ] org_cache = OrgResourceStatisticsCache(instance.org) diff --git a/apps/perms/api/asset_permission.py b/apps/perms/api/asset_permission.py index 750f41fb4..c72254d0d 100644 --- a/apps/perms/api/asset_permission.py +++ b/apps/perms/api/asset_permission.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # + from orgs.mixins.api import OrgBulkModelViewSet from perms import serializers from perms.filters import AssetPermissionFilter @@ -13,7 +14,10 @@ class AssetPermissionViewSet(OrgBulkModelViewSet): 资产授权列表的增删改查api """ model = AssetPermission - serializer_class = serializers.AssetPermissionSerializer + serializer_classes = { + 'default': serializers.AssetPermissionSerializer, + 'list': serializers.AssetPermissionListSerializer, + } filterset_class = AssetPermissionFilter search_fields = ('name',) ordering = ('name',) diff --git a/apps/perms/api/user_permission/tree/asset.py b/apps/perms/api/user_permission/tree/asset.py index c620356ad..e5e2ab4d3 100644 --- a/apps/perms/api/user_permission/tree/asset.py +++ b/apps/perms/api/user_permission/tree/asset.py @@ -1,16 +1,14 @@ from django.conf import settings from rest_framework.response import Response -from assets.models import Asset from assets.api import SerializeToTreeNodeMixin +from assets.models import Asset from common.utils import get_logger - -from ..assets import UserAllPermedAssetsApi from .mixin import RebuildTreeMixin +from ..assets import UserAllPermedAssetsApi logger = get_logger(__name__) - __all__ = [ 'UserAllPermedAssetsAsTreeApi', 'UserUngroupAssetsAsTreeApi', @@ -31,7 +29,7 @@ class AssetTreeMixin(RebuildTreeMixin, SerializeToTreeNodeMixin): if request.query_params.get('search'): """ 限制返回数量, 搜索的条件不精准时,会返回大量的无意义数据 """ assets = assets[:999] - data = self.serialize_assets(assets, None) + data = self.serialize_assets(assets, 'root') return Response(data=data) @@ -42,6 +40,7 @@ class UserAllPermedAssetsAsTreeApi(AssetTreeMixin, UserAllPermedAssetsApi): class UserUngroupAssetsAsTreeApi(UserAllPermedAssetsAsTreeApi): """ 用户 '未分组节点的资产(直接授权的资产)' 作为树 """ + def get_assets(self): if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: return super().get_assets() diff --git a/apps/perms/api/user_permission/tree/node_with_asset.py b/apps/perms/api/user_permission/tree/node_with_asset.py index 0289b9d63..17cdd7573 100644 --- a/apps/perms/api/user_permission/tree/node_with_asset.py +++ b/apps/perms/api/user_permission/tree/node_with_asset.py @@ -1,6 +1,4 @@ import abc -import re -from collections import defaultdict from urllib.parse import parse_qsl from django.conf import settings @@ -13,7 +11,6 @@ from rest_framework.response import Response from accounts.const import AliasAccount from assets.api import SerializeToTreeNodeMixin -from assets.const import AllTypes from assets.models import Asset from assets.utils import KubernetesTree from authentication.models import ConnectionToken @@ -38,21 +35,36 @@ class BaseUserNodeWithAssetAsTreeApi( SelfOrPKUserMixin, RebuildTreeMixin, SerializeToTreeNodeMixin, ListAPIView ): + page_limit = 10000 def list(self, request, *args, **kwargs): - nodes, assets = self.get_nodes_assets() - tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) - tree_assets = self.serialize_assets(assets, node_key=self.node_key_for_serialize_assets) - data = list(tree_nodes) + list(tree_assets) - return Response(data=data) + offset = int(request.query_params.get('offset', 0)) + page_assets = self.get_page_assets() + + if not offset: + nodes, assets = self.get_nodes_assets() + page = page_assets[:self.page_limit] + assets = [*assets, *page] + tree_nodes = self.serialize_nodes(nodes, with_asset_amount=True) + tree_assets = self.serialize_assets(assets, **self.serialize_asset_kwargs) + data = list(tree_nodes) + list(tree_assets) + else: + page = page_assets[offset:(offset + self.page_limit)] + data = self.serialize_assets(page, **self.serialize_asset_kwargs) if page else [] + offset += len(page) + headers = {'X-JMS-TREE-OFFSET': offset} if offset else {} + return Response(data=data, headers=headers) @abc.abstractmethod def get_nodes_assets(self): return [], [] - @lazyproperty - def node_key_for_serialize_assets(self): - return None + def get_page_assets(self): + return [] + + @property + def serialize_asset_kwargs(self): + return {} class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): @@ -61,7 +73,6 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): def get_nodes_assets(self): self.query_node_util = UserPermNodeUtil(self.request.user) - self.query_asset_util = UserPermAssetUtil(self.request.user) ung_nodes, ung_assets = self._get_nodes_assets_for_ungrouped() fav_nodes, fav_assets = self._get_nodes_assets_for_favorite() all_nodes, all_assets = self._get_nodes_assets_for_all() @@ -69,31 +80,37 @@ class UserPermedNodesWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): assets = list(ung_assets) + list(fav_assets) + list(all_assets) return nodes, assets + def get_page_assets(self): + return self.query_asset_util.get_all_assets().annotate(parent_key=F('nodes__key')) + @timeit def _get_nodes_assets_for_ungrouped(self): if not settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: return [], [] node = self.query_node_util.get_ungrouped_node() assets = self.query_asset_util.get_ungroup_assets() - assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \ - .prefetch_related('platform') + assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) return [node], assets + @lazyproperty + def query_asset_util(self): + return UserPermAssetUtil(self.user) + @timeit def _get_nodes_assets_for_favorite(self): node = self.query_node_util.get_favorite_node() assets = self.query_asset_util.get_favorite_assets() - assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) \ - .prefetch_related('platform') + assets = assets.annotate(parent_key=Value(node.key, output_field=CharField())) return [node], assets + @timeit def _get_nodes_assets_for_all(self): nodes = self.query_node_util.get_whole_tree_nodes(with_special=False) if settings.PERM_SINGLE_ASSET_TO_UNGROUP_NODE: assets = self.query_asset_util.get_perm_nodes_assets() else: - assets = self.query_asset_util.get_all_assets() - assets = assets.annotate(parent_key=F('nodes__key')).prefetch_related('platform') + assets = Asset.objects.none() + assets = assets.annotate(parent_key=F('nodes__key')) return nodes, assets @@ -103,6 +120,7 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): # 默认展开的节点key default_unfolded_node_key = None + @timeit def get_nodes_assets(self): query_node_util = UserPermNodeUtil(self.user) query_asset_util = UserPermAssetUtil(self.user) @@ -136,14 +154,14 @@ class UserPermedNodeChildrenWithAssetsAsTreeApi(BaseUserNodeWithAssetAsTreeApi): node_key = getattr(node, 'key', None) return node_key - @lazyproperty - def node_key_for_serialize_assets(self): - return self.query_node_key or self.default_unfolded_node_key + @property + def serialize_asset_kwargs(self): + return { + 'node_key': self.query_node_key or self.default_unfolded_node_key + } -class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi( - SelfOrPKUserMixin, SerializeToTreeNodeMixin, ListAPIView -): +class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi(BaseUserNodeWithAssetAsTreeApi): @property def is_sync(self): sync = self.request.query_params.get('sync', 0) @@ -151,66 +169,52 @@ class UserPermedNodeChildrenWithAssetsAsCategoryTreeApi( @property def tp(self): - return self.request.query_params.get('type') - - def get_assets(self): - query_asset_util = UserPermAssetUtil(self.user) - node = PermNode.objects.filter( - granted_node_rels__user=self.user, parent_key='').first() - if node: - __, assets = query_asset_util.get_node_all_assets(node.id) - else: - assets = Asset.objects.none() - return assets - - def to_tree_nodes(self, assets): - if not assets: - return [] - assets = assets.annotate(tp=F('platform__type')) - asset_type_map = defaultdict(list) - for asset in assets: - asset_type_map[asset.tp].append(asset) - tp = self.tp - if tp: - assets = asset_type_map.get(tp, []) - if not assets: - return [] - pid = f'ROOT_{str(assets[0].category).upper()}_{tp}' - return self.serialize_assets(assets, pid=pid) params = self.request.query_params - get_root = not list(filter(lambda x: params.get(x), ('type', 'n'))) - resource_platforms = assets.order_by('id').values_list('platform_id', flat=True) - node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=get_root) - pattern = re.compile(r'\(0\)?') - nodes = [] - for node in node_all: - meta = node.get('meta', {}) - if pattern.search(node['name']) or meta.get('type') == 'platform': - continue - _type = meta.get('_type') - if _type: - node['type'] = _type - meta.setdefault('data', {}) - node['meta'] = meta - nodes.append(node) + return [params.get('category'), params.get('type')] - if not self.is_sync: - return nodes + @lazyproperty + def query_asset_util(self): + return UserPermAssetUtil(self.user) - asset_nodes = [] - for node in nodes: - node['open'] = True - tp = node.get('meta', {}).get('_type') - if not tp: - continue - assets = asset_type_map.get(tp, []) - asset_nodes += self.serialize_assets(assets, pid=node['id']) - return nodes + asset_nodes + @timeit + def get_assets(self): + return self.query_asset_util.get_all_assets() - def list(self, request, *args, **kwargs): - assets = self.get_assets() - nodes = self.to_tree_nodes(assets) - return Response(data=nodes) + def _get_tree_nodes_async(self): + if not self.tp or not all(self.tp): + nodes = UserPermAssetUtil.get_type_nodes_tree_or_cached(self.user) + return nodes, [] + + category, tp = self.tp + assets = self.get_assets().filter(platform__type=tp, platform__category=category) + return [], assets + + def _get_tree_nodes_sync(self): + if self.request.query_params.get('lv'): + return [] + nodes = self.query_asset_util.get_type_nodes_tree() + return nodes, [] + + @property + def serialize_asset_kwargs(self): + return { + 'get_pid': lambda asset, platform: 'ROOT_{}_{}'.format(platform.category.upper(), platform.type), + } + + def serialize_nodes(self, nodes, with_asset_amount=False): + return nodes + + def get_nodes_assets(self): + if self.is_sync: + return self._get_tree_nodes_sync() + else: + return self._get_tree_nodes_async() + + def get_page_assets(self): + if self.is_sync: + return self.get_assets() + else: + return [] class UserGrantedK8sAsTreeApi(SelfOrPKUserMixin, ListAPIView): diff --git a/apps/perms/models/asset_permission.py b/apps/perms/models/asset_permission.py index 5ecf824c0..519b0adb8 100644 --- a/apps/perms/models/asset_permission.py +++ b/apps/perms/models/asset_permission.py @@ -8,7 +8,7 @@ from django.utils.translation import gettext_lazy as _ from accounts.const import AliasAccount from accounts.models import Account from assets.models import Asset -from common.utils import date_expired_default +from common.utils import date_expired_default, lazyproperty from common.utils.timezone import local_now from labels.mixins import LabeledMixin from orgs.mixins.models import JMSOrgBaseModel @@ -105,6 +105,22 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel): return True return False + @lazyproperty + def users_amount(self): + return self.users.count() + + @lazyproperty + def user_groups_amount(self): + return self.user_groups.count() + + @lazyproperty + def assets_amount(self): + return self.assets.count() + + @lazyproperty + def nodes_amount(self): + return self.nodes.count() + def get_all_users(self): from users.models import User user_ids = self.users.all().values_list('id', flat=True) @@ -143,11 +159,14 @@ class AssetPermission(LabeledMixin, JMSOrgBaseModel): @classmethod def get_all_users_for_perms(cls, perm_ids, flat=False): - user_ids = cls.users.through.objects.filter(assetpermission_id__in=perm_ids) \ + user_ids = cls.users.through.objects \ + .filter(assetpermission_id__in=perm_ids) \ .values_list('user_id', flat=True).distinct() - group_ids = cls.user_groups.through.objects.filter(assetpermission_id__in=perm_ids) \ + group_ids = cls.user_groups.through.objects \ + .filter(assetpermission_id__in=perm_ids) \ .values_list('usergroup_id', flat=True).distinct() - group_user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \ + group_user_ids = User.groups.through.objects \ + .filter(usergroup_id__in=group_ids) \ .values_list('user_id', flat=True).distinct() user_ids = set(user_ids) | set(group_user_ids) if flat: diff --git a/apps/perms/serializers/permission.py b/apps/perms/serializers/permission.py index d09c33538..36e0a6f33 100644 --- a/apps/perms/serializers/permission.py +++ b/apps/perms/serializers/permission.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -from django.db.models import Q +from django.db.models import Q, Count from django.utils.translation import gettext_lazy as _ from rest_framework import serializers @@ -14,7 +14,7 @@ from orgs.mixins.serializers import BulkOrgResourceModelSerializer from perms.models import ActionChoices, AssetPermission from users.models import User, UserGroup -__all__ = ["AssetPermissionSerializer", "ActionChoicesField"] +__all__ = ["AssetPermissionSerializer", "ActionChoicesField", "AssetPermissionListSerializer"] class ActionChoicesField(BitChoicesField): @@ -142,8 +142,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali def perform_display_create(instance, **kwargs): # 用户 users_to_set = User.objects.filter( - Q(name__in=kwargs.get("users_display")) - | Q(username__in=kwargs.get("users_display")) + Q(name__in=kwargs.get("users_display")) | + Q(username__in=kwargs.get("users_display")) ).distinct() instance.users.add(*users_to_set) # 用户组 @@ -153,8 +153,8 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali instance.user_groups.add(*user_groups_to_set) # 资产 assets_to_set = Asset.objects.filter( - Q(address__in=kwargs.get("assets_display")) - | Q(name__in=kwargs.get("assets_display")) + Q(address__in=kwargs.get("assets_display")) | + Q(name__in=kwargs.get("assets_display")) ).distinct() instance.assets.add(*assets_to_set) # 节点 @@ -180,3 +180,26 @@ class AssetPermissionSerializer(ResourceLabelsMixin, BulkOrgResourceModelSeriali instance = super().create(validated_data) self.perform_display_create(instance, **display) return instance + + +class AssetPermissionListSerializer(AssetPermissionSerializer): + users_amount = serializers.IntegerField(read_only=True, label=_("Users amount")) + user_groups_amount = serializers.IntegerField(read_only=True, label=_("User groups amount")) + assets_amount = serializers.IntegerField(read_only=True, label=_("Assets amount")) + nodes_amount = serializers.IntegerField(read_only=True, label=_("Nodes amount")) + + class Meta(AssetPermissionSerializer.Meta): + amount_fields = ["users_amount", "user_groups_amount", "assets_amount", "nodes_amount"] + remove_fields = {"users", "assets", "nodes", "user_groups"} + fields = list(set(AssetPermissionSerializer.Meta.fields + amount_fields) - remove_fields) + + @classmethod + def setup_eager_loading(cls, queryset): + """Perform necessary eager loading of data.""" + queryset = queryset.annotate( + users_amount=Count("users"), + user_groups_amount=Count("user_groups"), + assets_amount=Count("assets"), + nodes_amount=Count("nodes"), + ) + return queryset diff --git a/apps/perms/signal_handlers/refresh_perms.py b/apps/perms/signal_handlers/refresh_perms.py index 88a5fe674..5387bc33f 100644 --- a/apps/perms/signal_handlers/refresh_perms.py +++ b/apps/perms/signal_handlers/refresh_perms.py @@ -3,15 +3,13 @@ from django.db.models.signals import m2m_changed, pre_delete, pre_save, post_save from django.dispatch import receiver -from users.models import User, UserGroup from assets.models import Asset -from common.utils import get_logger, get_object_or_none -from common.exceptions import M2MReverseNotAllowed from common.const.signals import POST_ADD, POST_REMOVE, POST_CLEAR - +from common.exceptions import M2MReverseNotAllowed +from common.utils import get_logger, get_object_or_none from perms.models import AssetPermission from perms.utils import UserPermTreeExpireUtil - +from users.models import User, UserGroup logger = get_logger(__file__) @@ -38,7 +36,7 @@ def on_user_groups_change(sender, instance, action, reverse, pk_set, **kwargs): group = UserGroup.objects.get(id=list(group_ids)[0]) org_id = group.org_id - has_group_perm = AssetPermission.user_groups.through.objects\ + has_group_perm = AssetPermission.user_groups.through.objects \ .filter(usergroup_id__in=group_ids).exists() if not has_group_perm: return @@ -115,6 +113,7 @@ def on_asset_permission_user_groups_changed(sender, instance, action, pk_set, re def on_node_asset_change(action, instance, reverse, pk_set, **kwargs): if not need_rebuild_mapping_node(action): return + print("Asset node changed: ", action) if reverse: asset_ids = pk_set node_ids = [instance.id] diff --git a/apps/perms/utils/permission.py b/apps/perms/utils/permission.py index b9b5b01be..036dabef2 100644 --- a/apps/perms/utils/permission.py +++ b/apps/perms/utils/permission.py @@ -1,8 +1,7 @@ from django.db.models import QuerySet from assets.models import Node, Asset -from common.utils import get_logger - +from common.utils import get_logger, timeit from perms.models import AssetPermission logger = get_logger(__file__) @@ -13,6 +12,7 @@ __all__ = ['AssetPermissionUtil'] class AssetPermissionUtil(object): """ 资产授权相关的方法工具 """ + @timeit def get_permissions_for_user(self, user, with_group=True, flat=False): """ 获取用户的授权规则 """ perm_ids = set() diff --git a/apps/perms/utils/user_perm.py b/apps/perms/utils/user_perm.py index 42d6cc462..0ad02fb25 100644 --- a/apps/perms/utils/user_perm.py +++ b/apps/perms/utils/user_perm.py @@ -1,13 +1,22 @@ -from django.conf import settings -from django.db.models import Q +import json +import re +from django.conf import settings +from django.core.cache import cache +from django.db.models import Q +from rest_framework.utils.encoders import JSONEncoder + +from assets.const import AllTypes from assets.models import FavoriteAsset, Asset -from common.utils.common import timeit +from common.utils.common import timeit, get_logger +from orgs.utils import current_org, tmp_to_root_org from perms.models import PermNode, UserAssetGrantedTreeNodeRelation from .permission import AssetPermissionUtil __all__ = ['AssetPermissionPermAssetUtil', 'UserPermAssetUtil', 'UserPermNodeUtil'] +logger = get_logger(__name__) + class AssetPermissionPermAssetUtil: @@ -16,29 +25,32 @@ class AssetPermissionPermAssetUtil: def get_all_assets(self): """ 获取所有授权的资产 """ - node_asset_ids = self.get_perm_nodes_assets(flat=True) - direct_asset_ids = self.get_direct_assets(flat=True) - asset_ids = list(node_asset_ids) + list(direct_asset_ids) - assets = Asset.objects.filter(id__in=asset_ids) - return assets + node_assets = self.get_perm_nodes_assets() + direct_assets = self.get_direct_assets() + # 比原来的查到所有 asset id 再搜索块很多,因为当资产量大的时候,搜索会很慢 + return (node_assets | direct_assets).distinct() + @timeit def get_perm_nodes_assets(self, flat=False): """ 获取所有授权节点下的资产 """ from assets.models import Node - nodes = Node.objects.prefetch_related('granted_by_permissions').filter( - granted_by_permissions__in=self.perm_ids).only('id', 'key') + nodes = Node.objects \ + .prefetch_related('granted_by_permissions') \ + .filter(granted_by_permissions__in=self.perm_ids) \ + .only('id', 'key') assets = PermNode.get_nodes_all_assets(*nodes) if flat: - return assets.values_list('id', flat=True) + return set(assets.values_list('id', flat=True)) return assets + @timeit def get_direct_assets(self, flat=False): """ 获取直接授权的资产 """ assets = Asset.objects.order_by() \ .filter(granted_by_permissions__id__in=self.perm_ids) \ .distinct() if flat: - return assets.values_list('id', flat=True) + return set(assets.values_list('id', flat=True)) return assets @@ -52,12 +64,62 @@ class UserPermAssetUtil(AssetPermissionPermAssetUtil): def get_ungroup_assets(self): return self.get_direct_assets() + @timeit def get_favorite_assets(self): - assets = self.get_all_assets() + assets = Asset.objects.all().valid() asset_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True) assets = assets.filter(id__in=list(asset_ids)) return assets + def get_type_nodes_tree(self): + assets = self.get_all_assets() + resource_platforms = assets.order_by('id').values_list('platform_id', flat=True) + node_all = AllTypes.get_tree_nodes(resource_platforms, get_root=True) + pattern = re.compile(r'\(0\)?') + nodes = [] + for node in node_all: + meta = node.get('meta', {}) + if pattern.search(node['name']) or meta.get('type') == 'platform': + continue + _type = meta.get('_type') + if _type: + node['type'] = _type + node['category'] = meta.get('category') + meta.setdefault('data', {}) + node['meta'] = meta + nodes.append(node) + return nodes + + @classmethod + def get_type_nodes_tree_or_cached(cls, user): + key = f'perms:type-nodes-tree:{user.id}:{current_org.id}' + nodes = cache.get(key) + if nodes is None: + nodes = cls(user).get_type_nodes_tree() + nodes_json = json.dumps(nodes, cls=JSONEncoder) + cache.set(key, nodes_json, 60 * 60 * 24) + else: + nodes = json.loads(nodes) + return nodes + + def refresh_type_nodes_tree_cache(self): + logger.debug("Refresh type nodes tree cache") + key = f'perms:type-nodes-tree:{self.user.id}:{current_org.id}' + cache.delete(key) + + def refresh_favorite_assets(self): + favor_ids = FavoriteAsset.objects.filter(user=self.user).values_list('asset_id', flat=True) + favor_ids = set(favor_ids) + + with tmp_to_root_org(): + valid_ids = self.get_all_assets() \ + .filter(id__in=favor_ids) \ + .values_list('id', flat=True) + valid_ids = set(valid_ids) + + invalid_ids = favor_ids - valid_ids + FavoriteAsset.objects.filter(user=self.user, asset_id__in=invalid_ids).delete() + def get_node_assets(self, key): node = PermNode.objects.get(key=key) node.compute_node_from_and_assets_amount(self.user) @@ -134,7 +196,11 @@ class UserPermNodeUtil: self.perm_ids = AssetPermissionUtil().get_permissions_for_user(self.user, flat=True) def get_favorite_node(self): - assets_amount = UserPermAssetUtil(self.user).get_favorite_assets().count() + favor_ids = FavoriteAsset.objects \ + .filter(user=self.user) \ + .values_list('asset_id') \ + .distinct() + assets_amount = Asset.objects.all().valid().filter(id__in=favor_ids).count() return PermNode.get_favorite_node(assets_amount) def get_ungrouped_node(self): diff --git a/apps/perms/utils/user_perm_tree.py b/apps/perms/utils/user_perm_tree.py index b88b6db26..13577a9b1 100644 --- a/apps/perms/utils/user_perm_tree.py +++ b/apps/perms/utils/user_perm_tree.py @@ -3,11 +3,12 @@ from collections import defaultdict from django.conf import settings from django.core.cache import cache +from django.db import transaction from assets.models import Asset from assets.utils import NodeAssetsUtil from common.db.models import output_as_string -from common.decorators import on_transaction_commit +from common.decorators import on_transaction_commit, merge_delay_run from common.utils import get_logger from common.utils.common import lazyproperty, timeit from orgs.models import Organization @@ -23,6 +24,7 @@ from perms.models import ( PermNode ) from users.models import User +from . import UserPermAssetUtil from .permission import AssetPermissionUtil logger = get_logger(__name__) @@ -50,24 +52,74 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin): def __init__(self, user): self.user = user - self.orgs = self.user.orgs.distinct() - self.org_ids = [str(o.id) for o in self.orgs] + + @lazyproperty + def orgs(self): + return self.user.orgs.distinct() + + @lazyproperty + def org_ids(self): + return [str(o.id) for o in self.orgs] @lazyproperty def cache_key_user(self): return self.get_cache_key(self.user.id) + @lazyproperty + def cache_key_time(self): + key = 'perms.user.node_tree.built_time.{}'.format(self.user.id) + return key + @timeit def refresh_if_need(self, force=False): - self._clean_user_perm_tree_for_legacy_org() + built_just_now = cache.get(self.cache_key_time) + if built_just_now: + logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now)) + return to_refresh_orgs = self.orgs if force else self._get_user_need_refresh_orgs() if not to_refresh_orgs: logger.info('Not have to refresh orgs') return - with UserGrantedTreeRebuildLock(self.user.id): + logger.info("Delay refresh user orgs: {} {}".format(self.user, [o.name for o in to_refresh_orgs])) + refresh_user_orgs_perm_tree(user_orgs=((self.user, tuple(to_refresh_orgs)),)) + refresh_user_favorite_assets(users=(self.user,)) + + @timeit + def refresh_tree_manual(self): + built_just_now = cache.get(self.cache_key_time) + if built_just_now: + logger.info('Refresh just now, pass: {}'.format(built_just_now)) + return + to_refresh_orgs = self._get_user_need_refresh_orgs() + if not to_refresh_orgs: + logger.info('Not have to refresh orgs for user: {}'.format(self.user)) + return + self.perform_refresh_user_tree(to_refresh_orgs) + + @timeit + def perform_refresh_user_tree(self, to_refresh_orgs): + # 再判断一次,毕竟构建树比较慢 + built_just_now = cache.get(self.cache_key_time) + if built_just_now: + logger.info('Refresh user perm tree just now, pass: {}'.format(built_just_now)) + return + + self._clean_user_perm_tree_for_legacy_org() + ttl = settings.PERM_TREE_REGEN_INTERVAL + cache.set(self.cache_key_time, int(time.time()), ttl) + + lock = UserGrantedTreeRebuildLock(self.user.id) + got = lock.acquire(blocking=False) + if not got: + logger.info('User perm tree rebuild lock not acquired, pass') + return + + try: for org in to_refresh_orgs: self._rebuild_user_perm_tree_for_org(org) - self._mark_user_orgs_refresh_finished(to_refresh_orgs) + self._mark_user_orgs_refresh_finished(to_refresh_orgs) + finally: + lock.release() def _rebuild_user_perm_tree_for_org(self, org): with tmp_to_org(org): @@ -75,7 +127,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin): UserPermTreeBuildUtil(self.user).rebuild_user_perm_tree() end = time.time() logger.info( - 'Refresh user [{user}] org [{org}] perm tree, user {use_time:.2f}s' + 'Refresh user perm tree: [{user}] org [{org}] {use_time:.2f}s' ''.format(user=self.user, org=org, use_time=end - start) ) @@ -90,7 +142,7 @@ class UserPermTreeRefreshUtil(_UserPermTreeCacheMixin): cached_org_ids = self.client.smembers(self.cache_key_user) cached_org_ids = {oid.decode() for oid in cached_org_ids} to_refresh_org_ids = set(self.org_ids) - cached_org_ids - to_refresh_orgs = Organization.objects.filter(id__in=to_refresh_org_ids) + to_refresh_orgs = list(Organization.objects.filter(id__in=to_refresh_org_ids)) logger.info(f'Need to refresh orgs: {to_refresh_orgs}') return to_refresh_orgs @@ -128,7 +180,8 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin): self.expire_perm_tree_for_user_groups_orgs(group_ids, org_ids) def expire_perm_tree_for_user_groups_orgs(self, group_ids, org_ids): - user_ids = User.groups.through.objects.filter(usergroup_id__in=group_ids) \ + user_ids = User.groups.through.objects \ + .filter(usergroup_id__in=group_ids) \ .values_list('user_id', flat=True).distinct() self.expire_perm_tree_for_users_orgs(user_ids, org_ids) @@ -151,6 +204,21 @@ class UserPermTreeExpireUtil(_UserPermTreeCacheMixin): logger.info('Expire all user perm tree') +@merge_delay_run(ttl=20) +def refresh_user_orgs_perm_tree(user_orgs=()): + for user, orgs in user_orgs: + util = UserPermTreeRefreshUtil(user) + util.perform_refresh_user_tree(orgs) + + +@merge_delay_run(ttl=20) +def refresh_user_favorite_assets(users=()): + for user in users: + util = UserPermAssetUtil(user) + util.refresh_favorite_assets() + util.refresh_type_nodes_tree_cache() + + class UserPermTreeBuildUtil(object): node_only_fields = ('id', 'key', 'parent_key', 'org_id') @@ -161,13 +229,14 @@ class UserPermTreeBuildUtil(object): self._perm_nodes_key_node_mapper = {} def rebuild_user_perm_tree(self): - self.clean_user_perm_tree() - if not self.user_perm_ids: - logger.info('User({}) not have permissions'.format(self.user)) - return - self.compute_perm_nodes() - self.compute_perm_nodes_asset_amount() - self.create_mapping_nodes() + with transaction.atomic(): + self.clean_user_perm_tree() + if not self.user_perm_ids: + logger.info('User({}) not have permissions'.format(self.user)) + return + self.compute_perm_nodes() + self.compute_perm_nodes_asset_amount() + self.create_mapping_nodes() def clean_user_perm_tree(self): UserAssetGrantedTreeNodeRelation.objects.filter(user=self.user).delete() diff --git a/apps/rbac/permissions.py b/apps/rbac/permissions.py index 7c3f21610..cffc22e71 100644 --- a/apps/rbac/permissions.py +++ b/apps/rbac/permissions.py @@ -139,7 +139,7 @@ class RBACPermission(permissions.DjangoModelPermissions): if isinstance(perms, str): perms = [perms] has = request.user.has_perms(perms) - logger.debug('View require perms: {}, result: {}'.format(perms, has)) + logger.debug('Api require perms: {}, result: {}'.format(perms, has)) return has def has_object_permission(self, request, view, obj): diff --git a/apps/users/api/group.py b/apps/users/api/group.py index 6cccfb4a2..1e34a7caa 100644 --- a/apps/users/api/group.py +++ b/apps/users/api/group.py @@ -6,7 +6,7 @@ from rest_framework.response import Response from orgs.mixins.api import OrgBulkModelViewSet from ..models import UserGroup, User -from ..serializers import UserGroupSerializer +from ..serializers import UserGroupSerializer, UserGroupListSerializer __all__ = ['UserGroupViewSet'] @@ -15,7 +15,10 @@ class UserGroupViewSet(OrgBulkModelViewSet): model = UserGroup filterset_fields = ("name",) search_fields = filterset_fields - serializer_class = UserGroupSerializer + serializer_classes = { + 'default': UserGroupSerializer, + 'list': UserGroupListSerializer, + } ordering = ('name',) rbac_perms = ( ("add_all_users", "users.add_usergroup"), diff --git a/apps/users/serializers/group.py b/apps/users/serializers/group.py index 2276b6e0e..b87ba4a0e 100644 --- a/apps/users/serializers/group.py +++ b/apps/users/serializers/group.py @@ -2,6 +2,7 @@ # from django.db.models import Count from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers from common.serializers.fields import ObjectRelatedField from common.serializers.mixin import ResourceLabelsMixin @@ -10,7 +11,7 @@ from .. import utils from ..models import User, UserGroup __all__ = [ - 'UserGroupSerializer', + 'UserGroupSerializer', 'UserGroupListSerializer', ] @@ -29,7 +30,6 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): fields = fields_mini + fields_small + ['users', 'labels'] extra_kwargs = { 'created_by': {'label': _('Created by'), 'read_only': True}, - 'users_amount': {'label': _('Users amount')}, 'id': {'label': _('ID')}, } @@ -45,6 +45,17 @@ class UserGroupSerializer(ResourceLabelsMixin, BulkOrgResourceModelSerializer): @classmethod def setup_eager_loading(cls, queryset): """ Perform necessary eager loading of data. """ - queryset = queryset.prefetch_related('users', 'labels', 'labels__label') \ + queryset = queryset.prefetch_related('labels', 'labels__label') \ .annotate(users_amount=Count('users')) return queryset + + +class UserGroupListSerializer(UserGroupSerializer): + users_amount = serializers.IntegerField(label=_('Users amount'), read_only=True) + + class Meta(UserGroupSerializer.Meta): + fields = list(set(UserGroupSerializer.Meta.fields + ['users_amount']) - {'users'}) + extra_kwargs = { + **UserGroupSerializer.Meta.extra_kwargs, + 'users_amount': {'label': _('Users amount')}, + } diff --git a/utils/generate_fake_data/generate.py b/utils/generate_fake_data/generate.py index 0931cca8f..cda1bbaed 100644 --- a/utils/generate_fake_data/generate.py +++ b/utils/generate_fake_data/generate.py @@ -17,6 +17,7 @@ from resources.assets import AssetsGenerator, NodesGenerator, PlatformGenerator from resources.users import UserGroupGenerator, UserGenerator from resources.perms import AssetPermissionGenerator from resources.terminal import CommandGenerator, SessionGenerator +from resources.accounts import AccountGenerator resource_generator_mapper = { 'asset': AssetsGenerator, @@ -27,6 +28,7 @@ resource_generator_mapper = { 'asset_permission': AssetPermissionGenerator, 'command': CommandGenerator, 'session': SessionGenerator, + 'account': AccountGenerator, 'all': None # 'stat': StatGenerator } @@ -45,6 +47,7 @@ def main(): parser.add_argument('-o', '--org', type=str, default='') args = parser.parse_args() resource, count, batch_size, org_id = args.resource, args.count, args.batch_size, args.org + resource = resource.lower().rstrip('s') generator_cls = [] if resource == 'all': diff --git a/utils/generate_fake_data/resources/accounts.py b/utils/generate_fake_data/resources/accounts.py new file mode 100644 index 000000000..9b0fc75d5 --- /dev/null +++ b/utils/generate_fake_data/resources/accounts.py @@ -0,0 +1,32 @@ +import random + +import forgery_py + +from accounts.models import Account +from assets.models import Asset +from .base import FakeDataGenerator + + +class AccountGenerator(FakeDataGenerator): + resource = 'account' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.assets = list(list(Asset.objects.all()[:5000])) + + def do_generate(self, batch, batch_size): + accounts = [] + for i in batch: + asset = random.choice(self.assets) + name = forgery_py.internet.user_name(True) + '-' + str(i) + d = { + 'username': name, + 'name': name, + 'asset': asset, + 'secret': name, + 'secret_type': 'password', + 'is_active': True, + 'privileged': False, + } + accounts.append(Account(**d)) + Account.objects.bulk_create(accounts, ignore_conflicts=True) diff --git a/utils/generate_fake_data/resources/assets.py b/utils/generate_fake_data/resources/assets.py index 517122c9b..a8e3ce239 100644 --- a/utils/generate_fake_data/resources/assets.py +++ b/utils/generate_fake_data/resources/assets.py @@ -48,7 +48,7 @@ class AssetsGenerator(FakeDataGenerator): def pre_generate(self): self.node_ids = list(Node.objects.all().values_list('id', flat=True)) - self.platform_ids = list(Platform.objects.all().values_list('id', flat=True)) + self.platform_ids = list(Platform.objects.filter(category='host').values_list('id', flat=True)) def set_assets_nodes(self, assets): for asset in assets: @@ -72,6 +72,17 @@ class AssetsGenerator(FakeDataGenerator): assets.append(Asset(**data)) creates = Asset.objects.bulk_create(assets, ignore_conflicts=True) self.set_assets_nodes(creates) + self.set_asset_platform(creates) + + @staticmethod + def set_asset_platform(assets): + protocol = random.choice(['ssh', 'rdp', 'telnet', 'vnc']) + protocols = [] + + for asset in assets: + port = 22 if protocol == 'ssh' else 3389 + protocols.append(Protocol(asset=asset, name=protocol, port=port)) + Protocol.objects.bulk_create(protocols, ignore_conflicts=True) def after_generate(self): pass diff --git a/utils/generate_fake_data/resources/base.py b/utils/generate_fake_data/resources/base.py index 39942d5f1..dd02c65f7 100644 --- a/utils/generate_fake_data/resources/base.py +++ b/utils/generate_fake_data/resources/base.py @@ -41,7 +41,7 @@ class FakeDataGenerator: start = time.time() self.do_generate(batch, self.batch_size) end = time.time() - using = end - start + using = round(end - start, 3) from_size = created created += len(batch) print('Generate %s: %s-%s [%s]' % (self.resource, from_size, created, using)) diff --git a/utils/generate_fake_data/resources/users.py b/utils/generate_fake_data/resources/users.py index c05e0793b..5ed159b53 100644 --- a/utils/generate_fake_data/resources/users.py +++ b/utils/generate_fake_data/resources/users.py @@ -1,9 +1,11 @@ -from random import choice, sample +from random import sample + import forgery_py -from .base import FakeDataGenerator - +from orgs.utils import current_org +from rbac.models import RoleBinding, Role from users.models import * +from .base import FakeDataGenerator class UserGroupGenerator(FakeDataGenerator): @@ -47,3 +49,12 @@ class UserGenerator(FakeDataGenerator): users.append(u) users = User.objects.bulk_create(users, ignore_conflicts=True) self.set_groups(users) + self.set_to_org(users) + + def set_to_org(self, users): + bindings = [] + role = Role.objects.get(name='OrgUser') + for u in users: + b = RoleBinding(user=u, role=role, org_id=current_org.id, scope='org') + bindings.append(b) + RoleBinding.objects.bulk_create(bindings, ignore_conflicts=True)